# -*- coding: utf-8 -*-
import tushare as ts
import pandas as pd
from sklearn import linear_model
import matplotlib.pyplot as plt
import numpy as np
#下载数据
data1=ts.get_hist_data('601998')
#data1 = pd.DataFrame(index='date',columns=['date', 'close','ma10']) #
#data1=data1[:300] #numpy筛选
data1=data1.iloc[0:300] #pandas筛选
data1= pd.DataFrame(data1)
data1=data1.sort_index(ascending=True, axis=0) #排序
#data1.reset_index(name='date')
#data1=data1.as_matrix()
#------------------------
#编辑数据
data1.to_csv('/home/abc/program/sklearn/a601998.csv')
#x1=[[0, 0], [1, 1], [2, 2]]
#x11=[[1.2, 1.3], [1.2, 2.2], [2.5, 3]]
y1=data1['close']
#y1=[0, 1, 2]
x1=data1['ma10']
x0=range(len(y1))
#xk=('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g')
xk=data1.index #获得dataframe的索引列
x1=np.array(x1).reshape(-1,1)
y1=np.array(y1).reshape(-1,1)
#----------------------
#回归分析
reg=linear_model.LinearRegression()
reg.fit (x1,y1)
quan=reg.coef_
pian=reg.intercept_
yuce=x1*quan+pian
#print(type(yuce))
#--------------------
#转置
#x1=x1.T
#y1=y1.T
#numpy数组转为dataframe
x1=pd.DataFrame(x1.reshape(-1,1))
y1=pd.DataFrame(y1.reshape(-1,1))
yuce=pd.DataFrame(yuce.reshape(-1,1))
x1.columns=['a']
y1.columns=['b']
yuce.columns=['yuce']
#print(x1['a'].head())
#print(y1['b'].head())
#将datafram转为series类型
x1=pd.Series(x1['a'])
y1=pd.Series(y1['b'])
yuce=pd.Series(yuce['yuce'])
x10=data1['close']
y10=data1['ma10']
print(type(x1))
print(type(y1))
print(type(yuce))
print(type(x10))
print(type(y10))
#绘图
plt.figure() #figsize=(6, 18.5)
plt.grid() # 网格
plt.plot(x0,x1, c='g', linewidth=2, label='ma10') # 折线图,绿色
plt.plot(x0,y1, c='b', linewidth=1.5, label='close') # 折线图,蓝色
plt.plot(x0,yuce, c='tan', linewidth=1.5, label='yuce') # 折线图,灰色
#plt.xticks(x0, xk,color='blue',rotation=60) # 设置x轴刻度
plt.ylabel('price') #设置坐标轴名
# plt.axis([0,5,0,5]) #设置横纵坐标的范围
plt.xlim(0, len(y1)) # 设置x轴的范围
plt.ylim(5, 7) # 设置y轴的范围
print('-----------------')
print(y1)
print('-----------------')
print(x1)
#
#plt.scatter(x0,x1, marker='o', c='r') # 散点
#plt.scatter(x0,y1, marker='s', c='r') # 散点
#
panduan1= yuce >=x1
panduan2= yuce<x1
panduan3= x1>=y1
panduan4= x1<y1
plt.fill_between(x0, x1, y1,where=panduan3,facecolor='yellow', alpha=0.9) #'olive','black'
plt.fill_between(x0, x1, y1,where=panduan4, facecolor='blue', alpha=0.9)
plt.fill_between(x0, 5, 7,where=panduan1,facecolor='green', alpha=0.3)
plt.fill_between(x0, 5, 7,where=panduan2, facecolor='red', alpha=0.3)
#plt.set_title('fill_betweenx where')
plt.legend() # 图例
plt.title('matplotlab') # 设置标题
plt.savefig('main',dpi=600) #保存图片到当前目录
plt.show() #显示图形