第3章 多层感知机

本章将使用上一章简单介绍过的Dense全连接层来开始JAX模型的搭建,并实现一个多层感知机案例。

多层感知机(Multilayer Perceptron,MLP),是一种前馈人工神经网络模型,将输入的多个数据集映射到单一输出的数据集上。

图1 多层感知机

多层感知机一般分为三层,

  • 输入层
  • 隐藏层
  • 输出层。

当引入非线性的隐藏层后,理论上只要隐藏节点足够多,就可以拟合任意函数;同时,隐藏层越多,越容易拟合更复杂的函数。

隐藏层两个属性,

  • 节点数
  • 层数

层数越多,每一层需要的节点数越少。

全连接层——多层感知机的隐藏层

多层感知机的核心是隐藏层,隐藏层实际上就是一个全连接层。全连接层的每一个节点都与上一层的所有节点相连,用来把前面提取到的特征综合起来,所以全连接层的的参数也最多的。

图2 全连接层

注意,此处x₁,x₂,x₃看作参数(比如权重weight),w₁₁ ,w₁₂ ,w₁₃ … 则为输入参数,推导过程如下,

w₁₁ x x₁ + w₁₂ x x₂ + w₁₃ x x₃ = a₁
w₂₁ x x₁ + w₂₂ x x₂ + w₂₃ x x₃ = a₂
w₃₁ x x₁ + w₃₂ x x₂ + w₃₃ x x₃ = a₃

采用矩阵乘法公式,如下,

图3 矩阵相乘

其中,@是矩阵相乘,即f(x) = x@w。下面使用Python实现简单矩阵计算,

按公式推导,手动计算如下,

1.1 x 3 + 1.8 x 2 + 0.4 = 7.3
1.2 x 3 + 1.7 x 2 + 0.4 = 7.4

结果是一个新矩阵[7.3, 7.4]。

使用Python代码计算如下,

import jax

def multiplify(matrix, weights, bias):

    result = jax.numpy.matmul(matrix, weights) + bias

    return result

if __name__ == "__main__":

     matrix = jax.numpy.array([[1.1, 1.8], [1.2, 1.7]])

     weights = jax.numpy.array([[3], [2]])

     bias = 0.4

     result = multiplify(matrix, weights, bias)

     print(result)

打印输出如下所示,

[[7.3]
 [7.4]]

结果是一个维度为2维、形状为(2, 1)的矩阵。

JAX实现简单全连接层

从上一节的计算过程可知,全连接层的本质就是有一个特征空间线性变换到另外一个特征空间。目标空间的任一维都会受到源空间的每一维的影响。目标向量是源向量的加权和。

全连接层一般是接在特征提取网络之后,用于对特征的分类器。全连接层常出现在最后几层,用于对前面提取的特征做加权和计算。

图4 全连接层用于对特征的分类器

JAX实现全连接层代码如下,

import jax

def Dense(input_shape = (2, 1)):

    key = jax.random.PRNGKey(10)

    weights = jax.random.normal(key = key, shape = input_shape)
    biases = jax.random.normal(key = key, shape = (input_shape[-1],))

    params = [weights, biases]

    def apply_function(inputs):

        weights, biases = params

        dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return apply_function

def test():

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    dense = Dense()(inputs)

    print(dense)

if __name__ == "__main__":

    test()

全连接层Dense依次完成了函数、参数初始化,并使用默认的内置函数apply_function将对传入矩阵进行计算。打印输出如下所示,

[[-3.601719]
 [-4.189919]]

更多功能的全连接函数

使用外部参数的全连接函数。上一小节Dense函数内置的apply_function中,实际上调用了随机函数生成参数。如果要使用外部参数而非Dense函数内部生成的参数,则可以改进如下,

import jax

def Dense(inputs_shape = (2, 1)):

    key = jax.random.PRNGKey(10)

    weights = jax.random.normal(key = key, shape = inputs_shape)
    biases = jax.random.normal(key = key, shape = (inputs_shape[-1],))

    params = [weights, biases]

    def init_params_function():

        return params
 
    def apply_function(inputs, params = params):

         weights, biases = params

         dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return init_params_function, apply_function

def test():

    key = jax.random.PRNGKey(15)

    inputs_shape = (2, 1)

    weights = jax.random.normal(key = key, shape = inputs_shape)
    biases = jax.random.normal(key = key, shape = (inputs_shape[-1],))

    params = [weights, biases]

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    init_params_function, apply_function = Dense()
    dense = apply_function(inputs, params)

    print(dense)

if __name__ == "__main__":

    test()

这里使用了外部参数,而不是Dense内部生成的参数。打印输出如下所示,

[[1.5110686 ]
 [0.74590844]]

返回参数的全连接函数

前面学习了全连接函数和使用外部参数的全连接函数的方法,但有时候需要把生成的参数返回。代码如下,

import jax

def Dense(inputs_shape = (2, 1)):

    def init_function(shape = inputs_shape):

        key = jax.random.PRNGKey(10)

        weights, biases = jax.random.normal(key = key, shape = shape), jax.random.normal(key = key, shape = (shape[-1],))

        return (weights, biases)

    def apply_function(inputs, params):

        weights, biases = params

        dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return init_function, apply_function

def test():

    init_function, apply_function = Dense()

    init_params = init_function()

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    result = apply_function(inputs, init_params)

    print(f"init_params = {init_params}, result = {result}")

if __name__ == "__main__":

    test()

打印输出如下所示,

init_params = (Array([[-0.62187684],
 [-1.2754321 ]], dtype=float32), Array([-1.3445405], dtype=float32)),  result = [[-4.324383 ]
 [-4.2590275]]

不同种子生成不同随机参数的全连接函数

import jax

def Dense(input_shape = (2, 1), seed = 10):

    def init_function(shape = input_shape):

        key = jax.random.PRNGKey(seed)

        weights, biases = jax.random.normal(key = key, shape = shape), jax.random.normal(key = key, shape = (shape[-1],))

        return (weights, biases)

    def apply_function(inputs, params):

        weights, biases = params

        dotted = jax.numpy.dot(inputs, weights) + biases

        return dotted

    return init_function, apply_function

def test():

    array = [[1.1, 1.8], [1.2, 1.7]]
    inputs = jax.numpy.array(array)

    init_function, apply_function = Dense(seed = 10)
    init_params = init_function()
    
    dense = apply_function(inputs, init_params)

    print(f"dense1 = {dense}")

    print("----------------------------------------")

    init_function, apply_function = Dense(seed = 20)
    init_params = init_function()

    dense = apply_function(inputs, init_params)

    print(f"dense2 = {dense}")

if __name__ == "__main__":

    test()

打印输出如下所示,

dense1 = [[-4.324383 ]
 [-4.2590275]]
----------------------------------------
dense2 = [[2.524846]
 [2.44795 ]]
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,294评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,780评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,001评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,593评论 1 289
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,687评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,679评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,667评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,426评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,872评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,180评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,346评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,019评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,658评论 3 323
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,268评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,495评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,275评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,207评论 2 352

推荐阅读更多精彩内容