这几天在利用TensorFlow2.4构建bert-Bilstm-crf做NER时,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4. The bug has since been fixed in tensorflow/tensorflow##45534”。
虽然模型能够建起来,summary也能看到,但这会对训练有影响吗?怎么解决?
我的TensorFlow版本为2.4,keras2bert版本为0.89.0
CRF用的是tensorflow_addons(版本为0.13.0)的layers里面的,即“from tensorflow_addons.layers import CRF”
代码如下:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Bidirectional, LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
# from tensorflow_addons.layers import CRF
import tensorflow_addons
from tensorflow_addons.layers import CRF
from tf2crf import ModelWithCRFLoss
import os
import keras_bert
class MyModel(object):
def __init__(self, vocab_size: int, num_class: int, max_len: int = 100,
embedding_dim: int = 128, rnn_units: int = 128, drop_rate: float = 0.5):
self.vocab_size = vocab_size
self.num_class = num_class
self.max_len = max_len
self.embedding_dim = embedding_dim
self.rnn_units = rnn_units
self.drop_rate = drop_rate
self.bert_config_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_config.json"
self.bert_check_point_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_model.ckpt"
self.vocab_path = "./chinese_bert_wwm_L-12_H-768_A-12/vocab.txt"
self.crf = CRF(num_class)
def build_model(self):
model = keras_bert.load_trained_model_from_checkpoint(
config_file=self.bert_config_path,
checkpoint_file=self.bert_check_point_path,
seq_len=self.max_len,
trainable=True
)
inputs = model.inputs
embedding = model.output
x = Bidirectional(LSTM(units=self.rnn_units, return_sequences=True))(embedding)
x = Dropout(self.drop_rate)(x)
x = Dense(self.num_class)(x)
x = CRF(self.num_class)(x)
model = Model(inputs=inputs, outputs=x)
return model
if __name__ == '__main__':
mymodel = MyModel(
vocab_size=300,
num_class=5,
max_len=100
)
model = mymodel.build_model()
model.summary()
warning如下: