Python使用matplotlib.pyplot画热图和损失图的代码详解
作者:童话ing
众所周知,在完成论文相关工作时画图必不可少,如损失函数图、热力图等是非常常见的图,在本文中,总结了这两个图的画法,下面给出了完整的代码,开箱即用,感兴趣的同学可以自己动手尝试一下
一、损失函数图
import matplotlib.pyplot as plt file = open('E:\\5120154230PythonCode\\PBAN-PyTorch-master\\state_dict\\loss\\PBAN_New_restaurant15_0.001_80_0.2_16.csv') # 打开文档 data = file.readlines() # 读取文档数据 para_1 = [] # 新建列表,用于保存第一列数据 para_2 = [] # 新建列表,用于保存第二列数据 cnt = 0 for num in data: try: temp = num.split(",") cnt += 1 if cnt==700: break except: continue para_1.append(float(num.split(',')[0])) para_2.append(float(num.split(',')[1])) plt.figure() # plt.title('loss') plt.xlabel("iterations") plt.ylabel("loss") #color in cnblogs.com/qccc/p/12795541.html #orange、teal、red、chocolate plt.plot(para_1, para_2) plt.show()
CSV数据格式:第一列为Epoch或者迭代次数等,第二列为损失值。
效果图:
二、热图
import matplotlib.pyplot as plt import pandas as pd import matplotlib.ticker as ticker d = [ [0.43757705, 0.30564879, 0.08757705, 0.013755781, 0.13755781, 0.04080211, 0.03615228], [0.31525328, 0.42328909, 0.04004493, 0.01735733, 0.01755249, 0.02630009, 0.09020273], [0.01546572, 0.09022246, 0.4166335, 0.09773314, 0.10259592, 0.0447391, 0.03261019], [0.01536734, 0.010553601, 0.045800883, 0.39755909, 0.1465714, 0.0408309, 0.03612638], [0.11513351, 0.01193435, 0.051866556, 0.046714543, 0.42510962, 0.03154159, 0.4848393], [0.11544053, 0.0941444, 0.050161916, 0.09768857, 0.11385846, 0.43073818, 0.13351071], [0.01529034, 0.07752335, 0.04121181, 0.01742287, 0.35099512, 0.03777161, 0.38087882] ] variables = ['Great', 'food', 'but', 'the', 'service', 'was', 'dreadful'] labels = ['Great', 'food', 'but', 'the', 'service', 'was', 'dreadful'] df = pd.DataFrame(d, columns=variables, index=labels) fig = plt.figure(figsize=(7, 6)) #宽、高 ax = fig.add_subplot(1, 1, 1) #画布设置为1行1列显示在第一块中 # cmap参考:https://matplotlib.org/2.0.2/users/colormaps.html # hot_r、afmhot_r、plasma_r、ocean_r # interpolation:nearest,None、none cax = ax.matshow(df, interpolation='nearest', cmap='hot_r') fig.colorbar(cax) tick_spacing = 1 ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) ax.set_xticklabels([''] + list(df.columns)) ax.set_yticklabels([''] + list(df.index)) plt.show()
效果:
另一份代码:
import matplotlib.pylab as plt import numpy as np def samplemat(dims): aa = np.zeros(dims) for i in range(dims[1]): aa[0,i] = i return aa dimlist = [(1, 12)] for d in dimlist: arr = samplemat(d) plt.matshow(arr) plt.show()
以上就是Python使用matplotlib.pyplot画热图和损失图的代码详解的详细内容,更多关于Python matplotlib.pyplot画图的资料请关注脚本之家其它相关文章!