本篇以代码的形式,展示了matpoltlib的几种常用图表,包括:折线图,柱状图,饼状图,散点图,三维散点图,实时动态图等。
我是在jupyter notebook上写的代码,所以引入模块有些随意。
import matplotlib.pyplot as plt
import numpy as np
# 方法一:
x1 = np.linspace(start = 0,stop = 2*np.pi,num=100)
print(x1.shape)
# 方法二:
x2 = np.arange(start = 0,stop = 2*np.pi,step = 0.1)
print(x2.shape) # (629,)
y1 = np.sin(x1)
y2 = np.cos(x2)
# 折线图
plt.plot(x1,y1,label="SIN") # 输入x和y,和线的名称
plt.plot(x2,y2,label="COS")
plt.title("Sin & Cos function")
plt.xlabel("X") # 横轴
plt.ylabel("Y") # 纵轴
plt.legend() # 显示图例
plt.show()
# 柱状图
x_lable = ["2016","2017","2018","2019","2020"]
number_list = [74.3,83.1,91.4,98.9,101.6]
# plt.bar(x_lable,number_list,width=0.2,color='green',label="China") 上下两种方法一致
plt.bar(range(len(x_lable)),number_list,width=0.3,color='green',label="China",tick_label=x_lable)
plt.xlabel("year")
plt.ylabel("GDP (trillion yuan)")
plt.legend()
plt.show()
# 双柱状图
x_lable = ["2016","2017","2018","2019","2020"]
China_GDP = [11.24,12.32,13.89,14.30,14.73]
US_GDP = [18.24,18.75,19.54,20.6,21.43]
plt.bar(np.array(range(len(x_lable)))-0.2,China_GDP,width=0.2,color='green',label="China")
plt.bar(range(len(x_lable)),np.zeros(len(x_lable)),width=0.2,tick_label=x_lable) # 这是个空白的柱子,只是让年份在两个柱子中间。
plt.bar(np.array(range(len(x_lable)))+0.2,US_GDP,width=0.2,color='orange',label="US")
plt.xlabel("Years")
plt.ylabel("GDP (trillion dollars)")
plt.legend()
plt.show()
# 饼状图
label = ["apple","banana","orange","peach","grapes"]
percent = [10,20,20,40,10]
ex = [0.1,0.1,0,0,0.1] # 这个是用来突出某些块
# 还可以设置startangle= 30来突出立体感
# 最后那个参数用于设置显示百分比
plt.pie(labels=label,x=percent,explode=ex,colors=["blue","orange","green","yellow","red"],shadow=True,autopct='%.1f%%')
plt.title("fruit import pie chart")
#plt.legend()
plt.show()
# 实时画图
x = []
y = []
plt.ion() # 打开实时画图
for i in range(10):
x.append(i)
y.append(i**3)
plt.clf() # 把上一帧画面清除
plt.plot(x,y)
plt.pause(0.01) #不暂停的话,不会出现图像
plt.ioff()# 关闭实时画图
plt.show()
# 画3D图
from mpl_toolkits.mplot3d import Axes3D
x = np.random.normal(0,1,100)
y = np.random.normal(0,1,100)
z = np.random.normal(0,1,100)
location = (x,y,z)
print(x.shape,y.shape,z.shape)
fig = plt.figure() # 创建一个窗口
ax = Axes3D(fig) # 将窗口放到3d空间中
ax.scatter(x,y,z)
plt.show()
# 二维散点图
x1 = np.random.randn(1000)
y1 = np.random.randn(1000)
x2 = np.random.normal(0,0.5,1000)
y2 = np.random.normal(0,0.5,1000)
# s表示点的大小
plt.scatter(x1,y1,s=1,c='blue',label="N(0~1)")
plt.scatter(x2,y2,s=1,c='orange',label="N(0~0.5)")
plt.legend()
plt.xlabel("x")
plt.ylabel("y")
plt.show()