PRO-014:使用svm与逻辑回归预测股票涨跌

  本主题使用逻辑回归与svm预测股票涨跌,效果不好,与抛硬币效果差不多。
  主要记录下其中的特征工程对特征的处理方式。
主要内容:
  1. 股票数据的特征处理;
  2. K-线的概念;
  3. 使用matplotlib绘制K-线;

一、数据获取与特征工程

  • 使用tushare提供的深沪股市交易数据;
  • 并使用连续一周的收盘价作为机器学习特征数据;
  1. 使用的package
import pandas as pd
import numpy as np
import tushare as ts
import sklearn
from sklearn.linear_model import LogisticRegression   # 训练模型
from sklearn.preprocessing import scale    # 数据预处理:标准化
  1. 获取k线数据
  • tushare提供如下几个函数获取交易k线数据。
 # 获取指定上市代码的公司的K线交易信息
# 数据返回格式:index(['date', 'open', 'close', 'high', 'low', 'volume', 'code'], dtype='object')
k_data = ts.get_k_data('600848', start='1988-01-01', end='', ktype='D') #训练集数据
# data.columns
k_data[0:4]
  1. 增加股价变动
k_data['result'] = k_data['close'].pct_change()
k_data[0:4]
  1. 删除NaN值
k_data.dropna(inplace=True)   # 在原来数据集上删除
k_data[0:4]
  1. 训练特征数据定义
# 定义训练的数据特征(核心是收盘价)
feature_data = pd.DataFrame()
feature_data['close'] = k_data['close']   # 取收盘价预测
feature_data['result'] = k_data['result']   # 股价涨跌,后面用来生成标签

feature_data[0:4]
  1. 使用连续一周的收盘价作为特征数据
# 特征工程:使用收盘价,连续一周作为一个特征训练输入
for i in range(1,8,1):    
    feature_data['close - ' + str(i) + 'd'] = k_data['close'].shift(i)

feature_data[0:10]
  1. 删除调整数据中NaN值
feature_data.dropna(inplace=True)   # 在原来数据集上删除
feature_data[0:4]
  1. 训练标签
    • 使用下一天的涨跌作为标签
train_label = np.sign(feature_data['result'].shift(-1))     
train_label[0:5]
8    -1.0
9    -1.0
10   -1.0
11   -1.0
12    1.0
Name: result, dtype: float64
  1. 删除股价变动列
feature_data.drop(['result'], axis=1, inplace=True)
feature_data[0:4]
  1. 数据标准化
train_data = sklearn.preprocessing.scale(feature_data) 
train_data[0:2]
array([[-0.26602226, -0.20240527, -0.14993246, -0.18011299, -0.11017521,
        -0.05842557,  0.04566276, -0.0245463 ],
       [-0.30894573, -0.26566062, -0.20198489, -0.14945344, -0.17967086,
        -0.10968345, -0.05790754,  0.04626196]])
train_label[-2:]
5680   -1.0
5681    NaN
Name: result, dtype: float64
  1. 缺失值处理
    • 最后一行数据应该是NaN,处理为0
train_label.replace(to_replace= np.NaN, value = 0, inplace = True)
train_label[-2:]
5680   -1.0
5681    0.0
Name: result, dtype: float64

二、数据训练-逻辑回归

  1. 使用sklearn的线性模型
classifier = LogisticRegression(C=1000, solver='lbfgs', multi_class='auto', penalty='l2', max_iter=100000)
  1. 训练
classifier.fit(train_data, train_label)
LogisticRegression(C=1000, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100000,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                   warm_start=False)
  1. score评分
classifier.score(train_data, train_label)
0.5082833979555869
  1. 预测
predict = classifier.predict(train_data) 
correct_num = (predict == train_label).sum()
correct_num
2884

三、数据训练-svm

  1. 使用SVM模型
from sklearn.svm import SVC

svc_classifier = SVC(kernel = 'rbf')
  1. 训练
svc_classifier.fit(train_data, train_label)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
    kernel='rbf', max_iter=-1, probability=False, random_state=None,
    shrinking=True, tol=0.001, verbose=False)
  1. score
svc_classifier.score(train_data, train_label)
0.5029961226647868
  1. 预测
predict = svc_classifier.predict(train_data) 
correct_num = (predict == train_label).sum()
correct_num
2854

四、附录:相关tushare函数说明

4.1. ts包帮助说明

  1. bond (package)
    • 投资参考数据接口:bounds.py
  2. coins (package)
    • 数字货币行情数据:market.py
  3. data (package)
    • (无)
  4. fund (package)
    • 获取基金净值数据接口:nav.py
  5. futures (package)
    • 国内期货:domestic.py
    • 国际期货:intlfutures.py
  6. internet (package)
    • 电影票房:boxoffice.py
    • 财新网新闻数据检索下载:caixinnews.py
  7. pro (package)
    • 新的更好的接口,需要指定token,有的需要积分的:data_pro.py
  8. stock (package)
    • 龙虎榜数据:billboard.py
    • 获取股票分类数据接口 :classifying.py
    • 基本面数据接口:fundamental.py
    • 全球市场:globals.py
    • 股票技术指标接口:indictor.py
    • 宏观经济数据接口:macro.py
    • 新闻事件数据接口:newsevent.py
    • 投资参考数据接口:reference.py
    • 上海银行间同业拆放利率(Shibor)数据接口:shibor.py
    • 交易数据接口:trading.py
  9. trader (package)
    • 股票实盘交易接口:trader.py
  10. util (package)
    • 工具,比如:日期时间工具。
  • 说明:
    • 所有包下的接口都使用别名,在tushare包下直接使用。
"""
for trading data
"""
from tushare.stock.trading import (get_hist_data, get_tick_data,
                                   get_today_all, get_realtime_quotes,
                                   get_h_data, get_today_ticks,
                                   get_index, get_hists,
                                   get_k_data, get_day_all,
                                   get_sina_dd, bar, tick,
                                   get_markets, quotes,
                                   get_instrument, reset_instrument)
help(ts)
Help on package tushare:

NAME
    tushare - # -*- coding:utf-8 -*-

PACKAGE CONTENTS
    bond (package)
    coins (package)
    data (package)
    fund (package)
    futures (package)
    internet (package)
    pro (package)
    stock (package)
    trader (package)
    util (package)

DATA
    __warningregistry__ = {'version': 2757, ("unclosed file <_io.TextIOWra...

VERSION
    1.2.17

AUTHOR
    Jimmy Liu

FILE
    /Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tushare/__init__.py

4.2. 函数帮助说明

  1. 常用的函数可以直接从官网获取。
  2. 部分函数在官网没有说明的,直接使用help获取帮助。
help(ts.get_k_data)
Help on function get_k_data in module tushare.stock.trading:

get_k_data(code=None, start='', end='', ktype='D', autype='qfq', index=False, retry_count=3, pause=0.001)
    获取k线数据
    ---------
    Parameters:
      code:string
                  股票代码 e.g. 600848
      start:string
                  开始日期 format:YYYY-MM-DD 为空时取上市首日
      end:string
                  结束日期 format:YYYY-MM-DD 为空时取最近一个交易日
      autype:string
                  复权类型,qfq-前复权 hfq-后复权 None-不复权,默认为qfq
      ktype:string
                  数据类型,D=日k线 W=周 M=月 5=5分钟 15=15分钟 30=30分钟 60=60分钟,默认为D
      retry_count : int, 默认 3
                 如遇网络等问题重复执行的次数 
      pause : int, 默认 0
                重复请求数据过程中暂停的秒数,防止请求间隔时间太短出现的问题
    return
    -------
      DataFrame
          date 交易日期 (index)
          open 开盘价
          high  最高价
          close 收盘价
          low 最低价
          volume 成交量
          amount 成交额
          turnoverratio 换手率
          code 股票代码

4.3. k-线

  • 来自百度百科
4.3.1. 来源

  K线图这种图表源处于日本德川幕府时代,被当时日本米市的商人用来记录米市的行情与价格波动,后因其细腻独到的标画方式而被引入到股市及期货市场。目前,这种图表分析法在我国以至整个东南亚地区均尤为流行。由于用这种方法绘制出来的图表形状颇似一根根蜡烛,加上这些蜡烛有黑白之分,因而也叫阴阳线图表。通过K线图,我们能够把每日或某一周期的市况表现完全记录下来,股价经过一段时间的盘档后,在图上即形成一种特殊区域或形态,不同的形态显示出不同意义。我们可以从这些形态的变化中摸索出一些有规律的东西出来。K线图形态可分为反转形态、整理形态及缺口和趋向线等。
  那么,为什么叫“K线”呢?实际上,在日本的“K”并不是写成“K”字,而是写做“罫”(日本音读kei),K线是“罫线”的读音,K线图称为“罫线”,西方以其英文首字母“K”直译为“K”线,由此发展而来。

4.3.2. k-线说明

  首先我们找到该日或某一周期的最高和最低价,垂直地连成一条直线
  然后再找出当日或某一周期的开市和收市价,把这二个价位连接成一条狭长的长方柱体
    |- 1. 假如当日或某一周期的收市价较开市价为高(即低开高收),我们便以红色来表示,或是在柱体上留白,这种柱体就称之为“阳线”。
    |- 2. 如果当日或某一周期的收市价较开市价为低(即高开低收),我们则以绿色表示,又或是在柱上涂黑色,这柱体就是“阴线”了。

  根据K线的计算周期可将其分为日K线,周K线,月K线,年K线。

k-线图示意图
  • 注意:
    • 很多软件都可以用彩色实体来表示阴线和阳线,在国内股票和期货市场 ,通常用红色表示阳线,绿色表示阴线。
    • 但涉及到欧美股票及外汇市场的投资者应该注意:在这些市场上通常用绿色代表阳线,红色代表阴线,和国内习惯刚好相反。
4.3.3. k-绘制
  • 绘制k-线当然使用第三方模块。
1. 安装matplotlib.finance
  • matplotlib2以上版本已经把mpl_finance模块移除了,所以需要先安装才能使用.
  • 安装指令:

    • pip install mpl-finance
  • 下载链接:

    • github下载模块文件: https://github.com/matplotlib/mpl_finance
%matplotlib inline
import mpl_finance as mpf
2. k线绘制函数说明
candlestick_ohlc(ax, quotes, width=0.2, colorup='k', colordown='r', alpha=1.0)
#    |- ax : `Axes`
#       Axes对象
#    |- quotes : sequence of (time, open, high, low, close, ...) sequences
#       K-线数据,其中time必须是float格式。使用date2num函数转换。    
#    |- width : float
#        fraction of a day for the rectangle width
#    |- colorup : color
#        the color of the rectangle where close >= open
#    |- colordown : color
#         the color of the rectangle where close <  open
#    |- alpha : float
#        the rectangle alpha level
    
#返回
#        返回 (lines, patches) 
  • 可以使用 help(mpf.candlestick_ohlc)获取官方帮助
3. 绘制实现
%matplotlib inline
import mpl_finance as mpf
import tushare as ts
import matplotlib.pyplot as plt
from matplotlib.pylab import date2num
import pandas as pd



# 加载数据:index(['date', 'open', 'close', 'high', 'low', 'volume', 'code'], dtype='object')
k_data = ts.get_k_data('000010', start='2019-04-26', end='', ktype='D') 

k_data['date'] = pd.to_datetime(k_data['date'], format="%Y-%m-%d")    # 1994-03-24
# 把k-data转换为candlestick_ohlc函数需要的类型。
mpf_data = []
for _, row in k_data.iterrows():
    date_ = row[0: 1]
    open_, close_, high_, low_ = row[1: 5]
    mpf_data.append((date2num(date_), open_, high_, low_, close_))


# 创建坐标系
figure = plt.figure(figsize=(8, 6))
ax = figure.add_axes([0.1, 0.1, 0.8, 0.8])
ax.xaxis_date()    # x轴自动转换为日期时间

plt.xticks(rotation=45)   # 坐标刻度标签旋转45度
plt.yticks()

plt.title("600848:k-线图(最近两个月)")   # 标题
plt.xlabel("时间")                 # x-轴 
plt.ylabel("股价(元)")        # y-轴
mpf.candlestick_ohlc(ax, mpf_data, width=0.5, colorup=(1, 0, 0, 1), colordown=(0, 1, 0, 1))

# ax.plot(k_data['date'], k_data['close'])
plt.show()

来自tushare数据的k-线图

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容