第3章 多层感知机

本章将使用上一章简单介绍过的Dense全连接层来开始JAX模型的搭建,并实现一个多层感知机案例。

多层感知机(Multilayer Perceptron,MLP),是一种前馈人工神经网络模型,将输入的多个数据集映射到单一输出的数据集上。

图1 多层感知机

多层感知机一般分为三层,

  • 输入层
  • 隐藏层
  • 输出层。

当引入非线性的隐藏层后,理论上只要隐藏节点足够多,就可以拟合任意函数;同时,隐藏层越多,越容易拟合更复杂的函数。

隐藏层两个属性,

  • 节点数
  • 层数

层数越多,每一层需要的节点数越少。

全连接层——多层感知机的隐藏层

多层感知机的核心是隐藏层,隐藏层实际上就是一个全连接层。全连接层的每一个节点都与上一层的所有节点相连,用来把前面提取到的特征综合起来,所以全连接层的的参数也最多的。

图2 全连接层

注意,此处x₁,x₂,x₃看作参数(比如权重weight),w₁₁ ,w₁₂ ,w₁₃ … 则为输入参数,推导过程如下,

w₁₁ x x₁ + w₁₂ x x₂ + w₁₃ x x₃ = a₁
w₂₁ x x₁ + w₂₂ x x₂ + w₂₃ x x₃ = a₂
w₃₁ x x₁ + w₃₂ x x₂ + w₃₃ x x₃ = a₃

采用矩阵乘法公式,如下,

图3 矩阵相乘

其中,@是矩阵相乘,即f(x) = x@w。下面使用Python实现简单矩阵计算,

按公式推导,手动计算如下,

1.1 x 3 + 1.8 x 2 + 0.4 = 7.3
1.2 x 3 + 1.7 x 2 + 0.4 = 7.4

结果是一个新矩阵[7.3, 7.4]。

使用Python代码计算如下,

import jax

def multiplify(matrix, weights, bias):

    result = jax.numpy.matmul(matrix, weights) + bias

    return result

if __name__ == "__main__":

     matrix = jax.numpy.array([[1.1, 1.8], [1.2, 1.7]])

     weights = jax.numpy.array([[3], [2]])

     bias = 0.4

     result = multiplify(matrix, weights, bias)

     print(result)

打印输出如下所示,

[[7.3]
 [7.4]]

结果是一个维度为2维、形状为(2, 1)的矩阵。

JAX实现简单全连接层

从上一节的计算过程可知,全连接层的本质就是有一个特征空间线性变换到另外一个特征空间。目标空间的任一维都会受到源空间的每一维的影响。目标向量是源向量的加权和。

全连接层一般是接在特征提取网络之后,用于对特征的分类器。全连接层常出现在最后几层,用于对前面提取的特征做加权和计算。

图4 全连接层用于对特征的分类器

JAX实现全连接层代码如下,

import jax

def Dense(input_shape = (2, 1)):

    key = jax.random.PRNGKey(10)

    weights = jax.random.normal(key = key, shape = input_shape)
    biases = jax.random.normal(key = key, shape = (input_shape[-1],))

    params = [weights, biases]

    def apply_function(inputs):

        weights, biases = params

        dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return apply_function

def test():

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    dense = Dense()(inputs)

    print(dense)

if __name__ == "__main__":

    test()

全连接层Dense依次完成了函数、参数初始化,并使用默认的内置函数apply_function将对传入矩阵进行计算。打印输出如下所示,

[[-3.601719]
 [-4.189919]]

更多功能的全连接函数

使用外部参数的全连接函数。上一小节Dense函数内置的apply_function中,实际上调用了随机函数生成参数。如果要使用外部参数而非Dense函数内部生成的参数,则可以改进如下,

import jax

def Dense(inputs_shape = (2, 1)):

    key = jax.random.PRNGKey(10)

    weights = jax.random.normal(key = key, shape = inputs_shape)
    biases = jax.random.normal(key = key, shape = (inputs_shape[-1],))

    params = [weights, biases]

    def init_params_function():

        return params
 
    def apply_function(inputs, params = params):

         weights, biases = params

         dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return init_params_function, apply_function

def test():

    key = jax.random.PRNGKey(15)

    inputs_shape = (2, 1)

    weights = jax.random.normal(key = key, shape = inputs_shape)
    biases = jax.random.normal(key = key, shape = (inputs_shape[-1],))

    params = [weights, biases]

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    init_params_function, apply_function = Dense()
    dense = apply_function(inputs, params)

    print(dense)

if __name__ == "__main__":

    test()

这里使用了外部参数,而不是Dense内部生成的参数。打印输出如下所示,

[[1.5110686 ]
 [0.74590844]]

返回参数的全连接函数

前面学习了全连接函数和使用外部参数的全连接函数的方法,但有时候需要把生成的参数返回。代码如下,

import jax

def Dense(inputs_shape = (2, 1)):

    def init_function(shape = inputs_shape):

        key = jax.random.PRNGKey(10)

        weights, biases = jax.random.normal(key = key, shape = shape), jax.random.normal(key = key, shape = (shape[-1],))

        return (weights, biases)

    def apply_function(inputs, params):

        weights, biases = params

        dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return init_function, apply_function

def test():

    init_function, apply_function = Dense()

    init_params = init_function()

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    result = apply_function(inputs, init_params)

    print(f"init_params = {init_params}, result = {result}")

if __name__ == "__main__":

    test()

打印输出如下所示,

init_params = (Array([[-0.62187684],
 [-1.2754321 ]], dtype=float32), Array([-1.3445405], dtype=float32)),  result = [[-4.324383 ]
 [-4.2590275]]

不同种子生成不同随机参数的全连接函数

import jax

def Dense(input_shape = (2, 1), seed = 10):

    def init_function(shape = input_shape):

        key = jax.random.PRNGKey(seed)

        weights, biases = jax.random.normal(key = key, shape = shape), jax.random.normal(key = key, shape = (shape[-1],))

        return (weights, biases)

    def apply_function(inputs, params):

        weights, biases = params

        dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return init_function, apply_function

def test():

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    init_function, apply_function = Dense(seed = 10)
    init_params = init_function()
    
    dense = apply_function(inputs, init_params)

    print(f"dense1 = {dense}")

    print("----------------------------------------")

    init_function, apply_function = Dense(seed = 20)
    init_params = init_function()

    dense = apply_function(inputs, init_params)

    print(f"dense2 = {dense}")

if __name__ == "__main__":

    test()

打印输出如下所示,

dense1 = [[-4.324383 ]
 [-4.2590275]]
----------------------------------------
dense2 = [[2.524846]
 [2.44795 ]]
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

友情链接更多精彩内容