self attention核心代码的tensorflow实现

各种BERT的核心是self attention, 这是一种基于transformer的结构。在语言模型中,它会尽量避免使用循环,而是使用attention分数总结句子中不同的部分之间的关系。

import numpy as np
import tensorflow as tf

# 1. prepare input
input = [
    [1., 0., 1., 0.],
    [0., 2., 0., 2.],
    [1., 1., 1., 1.]
]

# 2. prepare weights for key, query, value
w_key = [
    [0, 0, 1],
    [1, 1, 0],
    [0, 1, 0],
    [1, 1, 0]
]

w_query = [
    [1, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 1]
]

w_value = [
    [0, 2, 0],
    [0, 3, 0],
    [1, 0, 3],
    [1, 1, 0]
]

#3. get keys, querys and values
keys = np.dot(input, w_key)
querys = np.dot(input, w_query)
values = np.dot(input, w_value)

print("keys:", keys)
print("querys:", querys)
print("values:", values)

#4. calculate attention score
attn_scores = np.dot(querys, keys.T)
print("attn_scores:", attn_scores)

#5. calculate softmax
attn_scores_softmax = tf.nn.softmax(attn_scores)
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)
print("attn_scores_softmax", attn_scores_softmax)
attn_scores_softmax = attn_scores_softmax.numpy()

# 6. multiply scores with values
weighted_values = values[:, None] * attn_scores_softmax.T[:,:,None]
print("weighted_values", weighted_values)

# 7. sum weighted values to get output
outputs = weighted_values.sum(axis=0)
print("outputs:", outputs)

输出
keys: [[0. 1. 1.]
[4. 4. 0.]
[2. 3. 1.]]
querys: [[1. 0. 2.]
[2. 2. 2.]
[2. 1. 3.]]
values: [[1. 2. 3.]
[2. 8. 0.]
[2. 6. 3.]]
attn_scores: [[ 2. 4. 4.]
[ 4. 16. 12.]
[ 4. 12. 10.]]

attn_scores_softmax tf.Tensor(
[[0.06 0.47 0.47]
[0. 0.98 0.02]
[0. 0.88 0.12]], shape=(3, 3), dtype=float64)
weighted_values [[[0.06 0.13 0.19]
[0. 0. 0. ]
[0. 0. 0. ]]

[[0.94 3.75 0. ]
[1.96 7.86 0. ]
[1.76 7.04 0. ]]

[[0.94 2.81 1.4 ]
[0.04 0.11 0.05]
[0.24 0.72 0.36]]]
outputs: [[1.94 6.68 1.6 ]
[2. 7.96 0.05]
[2. 7.76 0.36]]

reference:
https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容