sklearn.linear_model 之 LinearRegression

学习 handson_ml 时,学到 LinearRegression,记录一下。其中用到了python的格式化输出,对于小数的格式化 formatspec,详细介绍了各种简写对应的格式。

# 绘制散点图
country_stats.plot(kind='scatter', x='GDP per capita', y='Life satisfaction')
# 设置坐标轴范围,先x后y
plt.axis([0, 60000, 0, 10])

# 准备数据 
country_stats = prepare_country_stats(oecd_bli, gdp_per_captia)
# 按列堆叠1D数组,使其成为2D
X = np.c_[country_stats['GDP per capita']]
y = np.c_[country_stats['Life satisfaction']]

# 实例化线性模型,训练
model = LinearRegression()
model.fit(X, y)

# 预测新数据
X_new = [[22587]]
print(model.predict(X_new))

# 绘制线性模型
theta_0, theta_1= model.intercept_[0], model.coef_[0][0]
theta_0, theta_1
X = np.linspace(0, 60000, 1000)
plt.plot(X, theta_1 * X + theta_0, 'r-')
# g -> 通用格式,四舍五入,保留p个有效数字,会自动换科学计数法
plt.text(10000, 3.8, r'$y = \theta_0 + \theta_1 x$', fontsize=16, color='r')
plt.text(10000, 3, r'$\theta_0 = {:.4g}$'.format(theta_0), fontsize=14, color='r')
plt.text(10000, 2.3, r'$\theta_1 = {:.4g}$'.format(theta_1), fontsize=14, color='r')
plt.savefig('code_1-1.png')
plt.show()

output:

[[6.28653635]]
code_1-1.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容