DeepAR 是 Amazon 于 2017 年提出的基于深度学习的时间序列预测方法,目前已集成到 Amazon SageMaker 和 GluonTS 中。前者是 AWS 的机器学习云平台,后者是 Amazon 开源的时序预测工具库。
传统的时间序列预测方法(ARIMA、Holt-Winters’ 等)往往针对一维时间序列本身建模,难以利用额外特征。此外,传统方法的预测目标通常是序列在每个时间步上的取值。与之相比,基于神经网络的 DeepAR 方法可以很方便地将额外的特征纳入考虑,且其预测目标是序列在每个时间步上取值的概率分布。在特定场景下,概率预测比单点预测更有意义。以零售业为例,若已知商品未来销量的概率分布,则可以利用运筹优化方法推算在不同业务目标下的最优采购量,从而辅助决策。
获取股票信息
liststock =['sz.300807','sz.300789','sz.300771','sz.300546','sz.300479','sz.300462','sz.300455','sz.300449','sz.300386','sz.300368']
#liststock = ['sz.300462']
listdic = []
lg = bs.login()
for ite in liststock:
dd = mygetstockdata(ite)
dic = {"start":dd.date[0],"target":list(dd.close),"cat":int(liststock[0].split('.')[1]),"dynamic_feat":[list(dd.volume),list(dd.turn)]}
#strjon = json.dumps(dic)
listdic.append(dic)
bs.logout()
traindata = ListDataset(
listdic,
freq = "1d"
)
设置estimator,训练保存到tmp文件夹
prediction_length = 30
estimator = DeepAREstimator(
prediction_length=prediction_length,
context_length=30,
freq="1d",
trainer=Trainer(ctx="cpu",
epochs=5,#30
learning_rate=1e-3,
num_batches_per_epoch=50 #100
)
)
predictor = estimator.train(traindata)
predictor.serialize(Path("./tmp/"))
可以直接使用也可以保存下来后续直接用
predictor = Predictor.deserialize(Path("./tmp/"))
import matplotlib.pyplot as plt
from gluonts.dataset.util import to_pandas
for test_entry, forecast in zip(testdata, predictor.predict(testdata)):
to_pandas(test_entry)[-60:].plot(linewidth=2)
forecast.plot(color='g', prediction_intervals=[50.0, 90.0])
plt.grid(which='both')
最后得到得图像,随机选择了10个计算机行业得股票做的训练从19年1月到现在得数据,最后预测得是其中得一支股票
从图片可以看出来日期越靠后偏差越大,用这10只股票来看基本都是先大涨一波然后大跌。太靠后得不太靠谱,这10支股票得数据我试着跑了两天,日期近得没什么变化,但是对长期得预测变化较大,如果用来参考最好还是比较频繁得预测。
这是用10支股票训练得结果,后续打算多选几个股票来训练。
最近正在兼职帮朋友做销售数字的分析,才学的deepar,目前来看涉及到多种类得预测还是要比其它得好一些。有感兴趣得可以交流一下。