如标题,tf2相较tf1.x在api上有比较大的变动,1.x的很多api都在2.0中移除。
本文使用tf2.0的api实现一个简单的线性回归算法。
import tensorflow as tf
print(tf.__version__)
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(199346745)
# 产生测试数据
w_known = [1.4, 0.4, -0.4, .3, -1.9]
DIM = len(w_known)
N = 1000
BATCH = 300
x = np.random.random((N, DIM))
# default DIM=5
y_ = sum(w_known[i]*x[:,i] for i in range(len(w_known)))
err = 0.01*np.random.normal(size=N)
y = (y_ + err).reshape((N, 1))
# 查看测试数据分布
plt.hist(err, 30)
plt.show()
plt.hist(y.reshape(1000), 30)
plt.show()
# 得益于tf2.0的动态图特性,可以在函数中直接循环训练
def lr(x, y, BATCH=None, niter=1000):
if BATCH is None:
BATCH = N
losses = []
w = tf.Variable(tf.random.normal(shape=(DIM, 1), mean=0))
for i in range(niter):
randidx = np.random.choice(N, size=BATCH)
x2, y2 = (tf.constant(x[randidx], dtype='float32'),
tf.constant(y[randidx], dtype='float32'))
loss = lambda: tf.losses.MeanSquaredError()(tf.matmul(x2, w), y2)
opt = tf.keras.optimizers.SGD(1e-1)
opt.minimize(loss, var_list=[w])
losses.append(loss().numpy())
return w, losses
w, losses = lr(x, y, 300)
# 查看loss的变化
plt.plot(losses)
plt.show()
# 查看w和设定的w的距离
tf.losses.MeanSquaredError()(w, tf.reshape(tf.constant(w_known),(5,1)))