行秩列秩矩阵的秩 2022-05-19 21:17 采纳率: 50%
浏览 143
已结题

使用TF2.4 构建bert-Bilstm-crf,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4”,怎么解决?

这几天在利用TensorFlow2.4构建bert-Bilstm-crf做NER时,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4. The bug has since been fixed in tensorflow/tensorflow##45534”。
虽然模型能够建起来,summary也能看到,但这会对训练有影响吗?怎么解决?
我的TensorFlow版本为2.4,keras2bert版本为0.89.0
CRF用的是tensorflow_addons(版本为0.13.0)的layers里面的,即“from tensorflow_addons.layers import CRF”
代码如下:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Bidirectional, LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
# from tensorflow_addons.layers import CRF
import tensorflow_addons
from tensorflow_addons.layers import CRF

from tf2crf import ModelWithCRFLoss
import os
import keras_bert


class MyModel(object):
    def __init__(self, vocab_size: int, num_class: int, max_len: int = 100,
                 embedding_dim: int = 128, rnn_units: int = 128, drop_rate: float = 0.5):
        self.vocab_size = vocab_size
        self.num_class = num_class
        self.max_len = max_len
        self.embedding_dim = embedding_dim
        self.rnn_units = rnn_units
        self.drop_rate = drop_rate
        self.bert_config_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_config.json"
        self.bert_check_point_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_model.ckpt"
        self.vocab_path = "./chinese_bert_wwm_L-12_H-768_A-12/vocab.txt"
        self.crf = CRF(num_class)

    def build_model(self):
        model = keras_bert.load_trained_model_from_checkpoint(
            config_file=self.bert_config_path,
            checkpoint_file=self.bert_check_point_path,
            seq_len=self.max_len,
            trainable=True
        )
        inputs = model.inputs
        embedding = model.output
        x = Bidirectional(LSTM(units=self.rnn_units, return_sequences=True))(embedding)
        x = Dropout(self.drop_rate)(x)
        x = Dense(self.num_class)(x)
        x = CRF(self.num_class)(x)
        model = Model(inputs=inputs, outputs=x)
        return model


if __name__ == '__main__':
    mymodel = MyModel(
        vocab_size=300,
        num_class=5,
        max_len=100
    )
    model = mymodel.build_model()
    model.summary()

warning如下:

img

  • 写回答

3条回答 默认 最新

查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 5月31日
  • 已采纳回答 5月23日
  • 创建了问题 5月19日

悬赏问题

  • ¥15 Stata链式中介效应代码修改
  • ¥15 latex投稿显示click download
  • ¥15 请问读取环境变量文件失败是什么原因?
  • ¥15 在若依框架下实现人脸识别
  • ¥15 网络科学导论,网络控制
  • ¥100 安卓tv程序连接SQLSERVER2008问题
  • ¥15 利用Sentinel-2和Landsat8做一个水库的长时序NDVI的对比,为什么Snetinel-2计算的结果最小值特别小,而Lansat8就很平均
  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错