Keras练习:线性回归

Keras是高层神经网络API,后端可基于Tensorflow运行。

这里创建一个简单数据集,做线性回归,感受一下Keras的便捷。

Regression
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense


np.random.seed(1327)


def create_data():
    x = np.linspace(-1, 1, 200)
    np.random.shuffle(x)
    y = 0.5 * x + 2 + np.random.normal(0, 0.05, (200,))
    return x, y


def build_model():
    model = Sequential()
    model.add(Dense(input_dim=1, units=1))
    model.compile(loss='mse', optimizer='sgd')
    return model


def train(model, x, y):
    print('Training........')
    for step in range(1001):
        cost = model.train_on_batch(x, y)
        if step % 100 == 0:
            print('COST:', cost)


def test(model, x, y):
    print('\nTesting.......')
    cost = model.evaluate(x_test, y_test, batch_size=40)
    print('TEST COST: ', cost)
    w, b = model.layers[0].get_weights()
    print('Weights=', w, 'biases=', b)


if __name__ == '__main__':
    x, y = create_data()
    # plt.scatter(x, y)
    # plt.show()

    x_train = x[:160]
    y_train = y[:160]
    x_test = x[160:]
    y_test = y[160:]

    model = build_model()
    train(model, x_train, y_train)
    test(model, x_test, y_test)

    y_predict = model.predict(x_test)
    plt.scatter(x_test, y_test)
    plt.plot(x_test, y_predict)
    plt.show()

输出:

Training........
COST: 4.0496254
COST: 0.08321373
COST: 0.0063530593
COST: 0.0031990125
COST: 0.0026882463
COST: 0.002564282
COST: 0.0025330638
COST: 0.0025251824
COST: 0.002523192
COST: 0.0025226888
COST: 0.0025225622

Testing.......
40/40 [==============================] - 0s 1ms/step
TEST COST:  0.0018893185770139098
Weights= [[0.5145617]] biases= [1.9962281]
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • Keras 是提供一些高可用的 Python API ,能帮助你快速的构建和训练自己的深度学习模型,它的后端是 T...
    hugoren阅读 9,190评论 1 22
  • 本章涵盖了神经网络的核心组件Keras概论设置深度学习工作环境使用神经网络来解决基本分类和回归问题 本章旨在让你开...
    凤凰花开那一天阅读 16,387评论 0 57
  • 机器学习术语表 本术语表中列出了一般的机器学习术语和 TensorFlow 专用术语的定义。 A A/B 测试 (...
    yalesaleng阅读 6,044评论 0 11
  • 终究会坠落的流星, 哭泣的向日葵, 伤心是为谁。 你背起手轻跳的离开, 微语的街角, 你是否噙着笑, 不考虑回头。...
    不思中州晚阅读 1,937评论 2 11
  • 2017年12月17日,日拱一卒,积硅步,精进未来的自己。 我相信在这个宇宙当中存在着一股令万事万物向善的“宇宙意...
    凡尘花仙子阅读 3,851评论 1 4