Tensorflow 2.0 seq2seq 预测sin曲线

在本实验中构造了一系列正弦曲线的变形:
X=np.sin(T, dtype=np.float32)0.01np.random.rand()(T+np.random.rand()10)+np.random.rand()*10
样子的话,大概是:

image.png

预测的任务是,已知某条线前N个数(N为400到1600之间的随机数),预测接下来的200个数。也就是我们要处理一个变长的输入,这就需要用到ragedtensor。
这里面的主要知识点可以被概括为:

  1. Raged tensor的使用。这里采用了from_tensor方法构造出了一个Raged tensor来存储一系列长度不一致的序列。
  2. 利用karas处理变长序列。由于karas中的LSTM是更高级的API,因此可以直接对Raged tensor进行处理。
  3. Seq2seq框架。这里采用的是tensorflow_addons中的seq2seq框架,由于此框架本身对文本处理的支持比较多,在应对我们的(相对简单的)实数序列的案例中,反而需要进行一些定制。主要是对sampler的定制。
  4. Earlystop机制。直接采用karas的earlystop即可。

预测效果:


image.png

完整代码

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

T=np.arange(2000)*0.05
X=np.sin(T, dtype=np.float32)*0.01*np.random.rand()*(T+np.random.rand()*10)+np.random.rand()*10
plt.plot(X)

batch_size = 32
max_time = 200
hidden_size = 128

YLEN=max_time

def generateTrain(records=32):
    trainXs=[]
    trainYs=[]
    XLens=[]
    YLens=[]
    for i in range(records):
        X=(np.sin(T)*0.01*np.random.rand()*(T+np.random.rand()*10)+np.random.rand()*10).astype(np.float32)
        xi=X[:1800].copy()
        sepre=np.random.randint(400,1600)
        xi[sepre::]=0
        yi=X[sepre:sepre+YLEN].copy()
        trainXs.append(np.expand_dims(xi,axis=0))
        trainYs.append(np.expand_dims(yi,axis=0))
        XLens.append(sepre)
        YLens.append(YLEN)
    trainX=np.concatenate(trainXs,axis=0)
    trainY=np.concatenate(trainYs,axis=0)
    seqLen=np.array(XLens)
    outLen=np.array(YLens)
    return (trainX,trainY,seqLen,outLen)

def generateTest():
    X=(np.sin(T)*0.01*np.random.rand()*(T+np.random.rand()*10)+np.random.rand()*10).astype(np.float32)
    xi=X[:1800].copy()
    yi=X[1800:1800+YLEN].copy()
    testX=np.expand_dims(xi, axis=0)
    testY=np.expand_dims(yi, axis=0)
    seqLen_test=np.array([1800])
    outLen_test=np.array([YLEN])
    return (testX,testY,seqLen_test,outLen_test)

def gen_raged_train():
    while True:
        trainX,trainY,seqLen,_ = generateTrain()
        raged_train_x = tf.RaggedTensor.from_tensor(trainX, lengths=seqLen)
        raged_train_x = tf.expand_dims(raged_train_x, -1)
        tensor_train_y = tf.convert_to_tensor(trainY)
        yield raged_train_x,tensor_train_y
    
dataset = tf.data.Dataset.from_generator(
     gen_raged_train,
     output_signature=(
         tf.RaggedTensorSpec(shape=(batch_size, None, 1), dtype=tf.float32, ragged_rank=1),
         tf.TensorSpec(shape=(batch_size, 200,), dtype=tf.float32))
)

import tensorflow_addons as tfa
import tensorflow as tf

inputs = tf.keras.layers.Input(shape=[None, 1], ragged=True)

encoding, state_h, state_c = tf.keras.layers.LSTM(hidden_size, return_state=True)(inputs)

encoder_state = [state_h, state_c]

decoder_cell = tf.keras.layers.LSTMCell(hidden_size)

sample_fn = lambda x: x

end_fn = lambda x:False

sampler = tfa.seq2seq.InferenceSampler(sample_fn = sample_fn, sample_shape=[hidden_size],sample_dtype=tf.float32,end_fn=end_fn)

decoder = tfa.seq2seq.BasicDecoder(decoder_cell, sampler, maximum_iterations=200)

input_lengths = tf.fill([batch_size], max_time)

initial_state = decoder_cell.get_initial_state(encoding)

output, state, lengths = decoder(
    tf.convert_to_tensor(encoding), initial_state=initial_state)

logits = output.rnn_output

output_layer = tf.keras.layers.Dense(1)

out_seq = tf.squeeze(output_layer(logits))

print(out_seq.shape)

model = tf.keras.Model(inputs=inputs, outputs=out_seq)

model.compile(optimizer="Adam", loss="mse", metrics=["mse", "mae", "mape"])

early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)

model.fit(dataset.take(128),validation_data=val_data, epochs=1024,callbacks=[early_stop])

model.save('saved_model/my_model')

plt.figure()
plt.plot(X)
plt.plot([np.nan]*2000+model.predict(np.expand_dims(X,0)).tolist())
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,125评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,293评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,054评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,077评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,096评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,062评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,988评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,817评论 0 273
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,266评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,486评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,646评论 1 347
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,375评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,974评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,621评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,796评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,642评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,538评论 2 352

推荐阅读更多精彩内容