scipy 最小二乘法

scipy是和numpy搭配使用的一个科学计算包,这次我们试着用里面的最小二乘法做一下线性回归的操作,也跟sklearn里面的LinearRegression 比对一下

import numpy as np 
from scipy.optimize import leastsq
import matplotlib.pyplot as plt    # 顺便可视化看一下,后面也可以用pyecharts 进行可视化

Xi = np.array([8.19,2.72,6.39,8.71,4.7,2.66,3.78])
Yi = np.array([7,2,6,7,5,4,5])

def func(p,x):   #因为leastsq 的参数是需要把方程的系数变成一个列表,所以要设置 k,b = p
    k,b = p
    return k*x + b

def error(p,x,y,s):
    print(s)
    return func(p,x) - y

def er(p,x,y,s):     # 这里是把两个函数简化了,效果是一样的,s是一个参数,可以看调用了多少次函数
    print(s)
    k,b = p
    return k*x + b - y

p0 = [10,1]  #这里给p 设置一个随机的初始值
s="Test the number of iteration"
para = leastsq(er,p0,(Xi,Yi,s))  #这里会返回两个部分,一个是系数,另一个应该是成功与否的标志,1,2,3,4是成功的意思

k,b = para[0]
print("k=",k,"b=",b)

最后求得 k= 0.6422869122509914 , b= 1.734148745508376

%matplotlib inline

plt.rcParams['font.sans-serif']=['FangSong'] #这两个函数很重要,可以让matplotlib显示中文
plt.rcParams['axes.unicode_minus'] = False

plt.figure(figsize = (8,6),dpi = 80)
plt.scatter(Xi,Yi,c = 'r',label = 'Sample Point',linewidth = 3)

x = np.linspace(0,10.1000)
y = k*x + b
plt.plot(x,y,c = 'orange',label = 'Fitting Line',linewidth = 2)

plt.xlim(0,10) # 这里是设置x轴长度
plt.ylim(1,9) # 这里是设置y轴长度
plt.title('最小二乘法')
plt.legend()
plt.show()
image.png

然后我们用pyecharts来画一下,这里需要画散点图和折线图

import pyecharts.options as opts
from pyecharts.charts import Line,Scatter

x = np.linspace(0,10.1000)
y = k*x + b
def charts():
    scatter = (
        Scatter()
        .add_xaxis(Xi)
        .add_yaxis("Sample Point",Yi.tolist())
        .set_global_opts(title_opts=opts.TitleOpts(title="最小二乘法")))
    
    line=(
        Line()
        .add_xaxis(x)
        .add_yaxis("Fitting Line",y,is_symbol_show = False)
        .set_series_opts(label_opts=opts.LabelOpts(is_show=False)))

    scatter.overlap(line)
    return scatter

charts = charts()
charts.render_notebook()
image.png

其实sklearn 里面也有线性回归的函数,用的也是最小二乘法,我们也可以试试

from sklearn.linear_model import LinearRegression

regr = LinearRegression()
regr.fit(Xi.reshape(-1,1),Yi) # 这里要求传入的x 是 x * 1 需要进行转置

plt.figure(figsize = (8,6),dpi = 80)
plt.scatter(Xi,Yi,c = 'red',label = 'Sample Point',linewidth = 3)
plt.plot(Xi,regr.predict(Xi.reshape(-1,1)),c = 'orange',label = 'Fitting Line',linewidth = 2)

plt.xlim(0,10)
plt.ylim(1,9)
plt.title('最小二乘法')
plt.legend()
plt.show()
image.png

其实可以看出来,图形都是一样的

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容