行秩列秩矩阵的秩
2022-05-19 21:17
采纳率: 100%
浏览 40

使用TF2.4 构建bert-Bilstm-crf,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4”,怎么解决?

这几天在利用TensorFlow2.4构建bert-Bilstm-crf做NER时,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4. The bug has since been fixed in tensorflow/tensorflow##45534”。
虽然模型能够建起来,summary也能看到,但这会对训练有影响吗?怎么解决?
我的TensorFlow版本为2.4,keras2bert版本为0.89.0
CRF用的是tensorflow_addons(版本为0.13.0)的layers里面的,即“from tensorflow_addons.layers import CRF”
代码如下:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Bidirectional, LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
# from tensorflow_addons.layers import CRF
import tensorflow_addons
from tensorflow_addons.layers import CRF

from tf2crf import ModelWithCRFLoss
import os
import keras_bert


class MyModel(object):
    def __init__(self, vocab_size: int, num_class: int, max_len: int = 100,
                 embedding_dim: int = 128, rnn_units: int = 128, drop_rate: float = 0.5):
        self.vocab_size = vocab_size
        self.num_class = num_class
        self.max_len = max_len
        self.embedding_dim = embedding_dim
        self.rnn_units = rnn_units
        self.drop_rate = drop_rate
        self.bert_config_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_config.json"
        self.bert_check_point_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_model.ckpt"
        self.vocab_path = "./chinese_bert_wwm_L-12_H-768_A-12/vocab.txt"
        self.crf = CRF(num_class)

    def build_model(self):
        model = keras_bert.load_trained_model_from_checkpoint(
            config_file=self.bert_config_path,
            checkpoint_file=self.bert_check_point_path,
            seq_len=self.max_len,
            trainable=True
        )
        inputs = model.inputs
        embedding = model.output
        x = Bidirectional(LSTM(units=self.rnn_units, return_sequences=True))(embedding)
        x = Dropout(self.drop_rate)(x)
        x = Dense(self.num_class)(x)
        x = CRF(self.num_class)(x)
        model = Model(inputs=inputs, outputs=x)
        return model


if __name__ == '__main__':
    mymodel = MyModel(
        vocab_size=300,
        num_class=5,
        max_len=100
    )
    model = mymodel.build_model()
    model.summary()

warning如下:

img

3条回答 默认 最新

相关推荐 更多相似问题