概念
Matplotlib将数据绘制为Figure,包含:
- Figure:图形。可含多个axes、特殊Artists(title, legend等)以及canvas(实际绘图的对象,对用户不可见)。
- Axes:坐标系。一个将点表示为坐标的区域。只能属于一个figure。含2~3个axis,每个axis对应的Label,以及一个title。Axes类及其方法是OO(面向对象)风格的主要接口。
- Axis:坐标轴。设定数据的范围,生成ticks(坐标轴上的marks)和ticklabels(ticks对应的标识)。
- Artist:所有可见的组件都属于Artist,包括Figure、Axes、Axis、title、legend等。当渲染figure时,所有artists会在canvas上绘制出来。
子图和子坐标系:
- SubFigure:Figure子类,一个figure可含多个subfigure,有figure相同的方法,但不能单独打印。
- add_subfigure:创建1个SubFigure实例。
Figure.add_subfigure
、SubFigure.add_subfigure
。 - subfigures:创建r行c列个subfigures。
Figure.subfigures
、SubFigure.subfigures
。 - SubplotBase:Axes子类,附加生成和操作Axes的方法。
- subplots(r,c):创建r行c列个SubplotBase实例。
pyplot.subplots
、Figure.subplots
。 - subplot(r,c,n):创建r行c列个SubplotBase实例中第n个SubplotBase实例(展开成1维后第n个)。
pyplot.subplot
、Figure.add_subplot
。
画图
显式隐式
方式1,显式的创建Figure和Axes,使用Axes的接口绘图。
fig, ax = plt.subplots() # Create a figure containing a single axes.
ax.plot([1, 2, 3, 4], [1, 4, 2, 3]) # Plot some data on the axes.
fig.show()
fig = plt.figure() # Create a figure without axes.
ax = Axes(fig, (0,0,0,0)) # Create an axes in the figure.
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
fig.show()
fig = Figure() # Create a figure without axes.
ax = fig.add_subplot() # Ass an axes into the figure.
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
fig.show()
fig = Figure() # Create a figure without axes.
# 等价于fig.add_axes(rect)
ax = Axes(fig, (0,0,0,0)) # Create an axes in the figure.
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
# 这种方式show()无效
ax = plt.axes()
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
plt.show()
方式2,隐式的创建figure和axes(若当前存在一个axes,则使用该axes,否则创建axes以及所需的figure),使用pyplot的接口绘图(Axes的每个绘图接口都有对应的pyplot接口)。
plt.plot([1, 2, 3, 4], [1, 4, 2, 3]) # Similar to MATLAB plot.
plt.show()
风格
OO风格,建议用于大型项目,逻辑清晰:
x = np.linspace(0, 2, 100) # Sample data.
# Note that even in the OO-style, we use `.pyplot.figure` to create the figure.
fig, ax = plt.subplots() # Create a figure and an axes.
ax.plot(x, x, label='linear') # Plot some data on the axes.
ax.plot(x, x**2, label='quadratic') # Plot more data on the axes...
ax.plot(x, x**3, label='cubic') # ... and some more.
ax.set_xlabel('x label') # Add an x-label to the axes.
ax.set_ylabel('y label') # Add a y-label to the axes.
ax.set_title("Simple Plot") # Add a title to the axes.
ax.legend() # Add a legend.
pyplot风格,建议用于交互式编程(如Notebook):
x = np.linspace(0, 2, 100) # Sample data.
plt.plot(x, x, label='linear') # Plot some data on the (implicit) axes.
plt.plot(x, x**2, label='quadratic') # etc.
plt.plot(x, x**3, label='cubic')
plt.xlabel('x label')
plt.ylabel('y label')
plt.title("Simple Plot")
plt.legend()
不推荐混用两种风格,如有需要:
- 获取当前figure:
plt.gcf()
- 获取当前axes:
plt.gca()
。
设置Figure
创建figure的几种方式:
plt.figure(figsize=(6.4,4.8), dpi=100.0)
:
- figsize,(width, height) inches(分辨率单位),一般默认为(6.4,4.8),notebook上默认为(6,4)。
- dpi,dots per inch(每个inch的点数),一般默认为100.0,notebook上默认为72。
示例
plt.figure(figsize=[4*2,4*2],dpi=72.0)#default: [6.4, 4.8] 100.0
plt.subplot(2,2,1),plt.imshow(img0),plt.title("original")
plt.subplot(2,2,2),plt.imshow(crop1),plt.title("train1")
plt.subplot(2,2,3),plt.imshow(crop2),plt.title("train2")
plt.subplot(2,2,4),plt.imshow(crop3),plt.title("test1")
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'random_crop.pdf'))
plt.show()
Scipy.stats绘制方差,当aa.shape[0]=1
时,方差不存在,也不会体现在图上:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
aa = np.array([[0,1,2,3,4],[0.1,1,2.1,3.2,4.1]])
e = stats.sem(aa, axis=0)
m = np.mean(aa, axis=0)
plt.plot(m)
plt.fill_between(range(5), m-e,m+e, alpha=0.3)
plt.show()
旋转坐标轴的标签:
plt.xticks(range(3), ["Zi 1", "Zi 2" ,"Zi 3"], rotation=-45)
# 或
plt.setp(plt.xticks(), rotation=-45, rotation_mode="anchor")
colorbar,支持多个子图共享colorbar:
import matplotlib
import matplotlib.pyplot as plt
cmap = matplotlib.colormaps['jet']
fig, axes = plt.subplots(6, 1)
im = None
for i in range(6):
im = axes[i].imshow(y[i].T, cmap=cmap, interpolation='nearest', aspect='auto')
axes[i].set_ylabel('t')
for i in range(5):
axes[i].set_xticks([])
axes[-1].set_xlabel('x')
fig.tight_layout() # tight_layout放在colorbar之前,否则colorbar会与其他axes重叠
cbar = fig.colorbar(im, ax=axes) # 通过axes列表指定覆盖的axes
cbar.set_label('u(t,x)')
fig.savefig(f'images/result.jpg')
fig.show()
绘制直方图
绘制grad、weight、activation的分布图,横坐标为2的幂次,纵坐标启用了log(10的幂次)
import numpy as np
import matplotlib.pyplot as plt
g = np.random.normal(loc=2**-8, scale=2**-8, size=(8192*8))
w = np.random.normal(loc=0, scale=2**-6, size=(8192*8))
a = np.random.normal(loc=0, scale=2**-0, size=(8192*8))
g = np.log2(g)
w = np.log2(w)
a = np.log2(a)
x = [g, w, a]
colors = ['r', 'y', 'b']
plt.hist(x, bins=50, histtype='bar', color=colors, label=colors, log=True)
# plt.xlim((-45, 5))
plt.xlabel('Exp')
plt.ylabel('Frequency')
plt.legend(['Gradient', 'Weight', 'Activation'])
plt.show()
绘制动画
ArtistAnimation
方式:
cmap = matplotlib.colormaps['jet']
fig, ax = plt.subplots(1,3)
ims = []
for i in range(10):
ax[0].set_title(f'Label')
# 若要固定每帧colorbar的坐标轴数值范围,可设定vmin、vmax:
# imshow(yy[i], cmap=cmap, vmin=0, vmax=1, animated=True)
im0 = ax[0].imshow(yy[i], cmap=cmap, animated=True)
ax[1].set_title(f'Prediction')
im1 = ax[1].imshow(yp[i], cmap=cmap, animated=True)
ax[2].set_title(f'Error')
im2 = ax[2].imshow(ye[i], cmap=cmap, animated=True)
fig.suptitle(f't={i}')
fig.tight_layout()
fig.colorbar(im1, ax=ax[2])
ims.append([im0, im1, im2])
# blit必须为True,否则画出的图为空
ani = animation.ArtistAnimation(fig, ims, interval=200, blit=True, repeat_delay=1000)
ani.save('animation.gif')
plt.show()
FuncAnimation
方式:
cmap = matplotlib.colormaps['jet']
fig, ax = plt.subplots(1,3, figsize=[7, 3])
ax[0].set_title(f'Label')
# 若要固定每帧colorbar的坐标轴数值范围,可设定vmin、vmax:
# imshow(yy[0], cmap=cmap, vmin=0, vmax=1)
im0 = ax[0].imshow(yy[0], cmap=cmap)
ax[1].set_title(f'Prediction')
im1 = ax[1].imshow(yp[0], cmap=cmap)
ax[2].set_title(f'Error')
im2 = ax[2].imshow(ye[0], cmap=cmap)
title = fig.suptitle(f't=0')
fig.tight_layout()
fig.colorbar(im1, ax=ax)
def animate(i):
y, p, e = yy[..., i, :], yp[..., i, :], ye[..., i, :]
im0.set_data(y)
im1.set_data(p)
im2.set_data(e)
# 若要固定每帧colorbar的坐标轴数值范围,结合上方,
# 设定set_clim(0, 1),或不设置。
im0.set_clim(np.min(y), np.max(y))
im1.set_clim(np.min(p), np.max(p))
im2.set_clim(np.min(e), np.max(e))
title.set_text(f't={i}')
ani = animation.FuncAnimation(fig, animate, interval=200, blit=False, frames=10,
repeat_delay=1000)
ani.save('animation.gif')
plt.show()
其他
若训练进程(包括nohup &启动的进程),总是因matplotlib异常中断,则:
import matplotlib.pyplot as plt
matplotlib.use('Agg')