sns画法
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
income = pd.read_csv('Salary_Data.csv')
income.dtypes
sns.lmplot(x= 'YearsExperience',y='Salary', data = income, ci=None)
plt画法(+statsmodels)
#绘制散点图
plt.scatter(x=income.YearsExperience,y=income.Salary,color="blue")
#导入统计建模模块
import statsmodels.api as sm
#构建一元线性回归模型
fit=sm.formula.ols("Salary~YearsExperience",data=income).fit()
#预测
pred=fit.predict(exog=income.YearsExperience)
#绘制回归线
plt.plot(income.YearsExperience,pred,color="coral",linewidth=1)
plt.show()
plt画法(+sklearn)
from sklearn.linear_model import LinearRegression
import numpy as np
X = np.array(income.YearsExperience).reshape(-1,1)
y = income.Salary
# print(X.shape)
# print(y.shape)
lr = LinearRegression()
lr.fit(X,y)
predict = lr.predict(X)
plt.scatter(X, y, c='b', s=60)
plt.plot(X,predict,color="coral",linewidth=1)