dssm using bert


import tensorflow as tf
from sklearn.model_selection import train_test_split
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.backend import keras

tf.config.experimental.set_memory_growth(device=tf.config.experimental.list_physical_devices(device_type='GPU')[0], enable=True)
 
embed_dim = 64
NEG, batch_size = 20, 128
 
    
config_path = 'chinese_roberta_L-6_H-384_A-12/bert_config.json'
checkpoint_path = 'chinese_roberta_L-6_H-384_A-12/bert_model.ckpt'

query_bert = build_transformer_model(config_path, checkpoint_path, return_keras_model=False).model 
query_layer = keras.layers.Dropout(0.1)(query_bert.output)
query_layer = keras.layers.Dense(128, activation='relu', kernel_regularizer='l2', name="query_tower")(query_layer)

doc_bert = build_transformer_model(config_path, checkpoint_path, return_keras_model=False).model  
for layer in doc_bert.layers:
    layer.name = layer.name + str("_doc")
doc_layer = keras.layers.Dropout(0.1)(doc_bert.output)
doc_layer = keras.layers.Dense(128, activation='relu', kernel_regularizer='l2', name="doc_tower")(doc_layer)
 
output = keras.layers.Dot(axes=1)([query_layer, doc_layer])

# output = tf.keras.layers.Dense(1, activation='sigmoid')(output)
output = keras.layers.Dense(2, activation='softmax')(output)
model = keras.models.Model(query_bert.input+doc_bert.input, output)

model.compile(loss="categorical_crossentropy", metrics=['acc' ], optimizer='RMSprop') 
 

# query tower
query_model = keras.Model(inputs=query_bert.input, outputs=query_layer)
# doc tower
doc_model = keras.Model(inputs=doc_bert.input, outputs=doc_layer)

 
print("[INFO] training model...")
model.fit(
    train_generator.forfit(),
    steps_per_epoch=len(train_generator),
    epochs=2, verbose=1)
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容