坚决不加班 2020-12-09 17:40 采纳率: 0%
浏览 205

Bert的输出,作为textCNN的输入报错

def call(self, inputs):
        input_ids, input_mask, token_type_ids = inputs
        
        embedding, _ = self.bert_model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
        
        textcnn_out = self.textcnn(embedding)
        
        linear_out = self.linear(textcnn_out)
        add_norm = self.layer_norm(add([textcnn_out, linear_out]))
        add_norm = self.dropout(add_norm)
        output = self.out(add_norm)

        return output

bert的输出,直接放入textCNN报错

if x.shape.ndims is None:

    AttributeError: 'str' object has no attribute 'shape'

textCNN代码如下:

class TextCNN(tf.keras.layers.Layer):
    def __init__(self, kernel_sizes=(8, 16, 32), filter_size=256, strides=1, activation='elu', dropout_rate=0.15,
                 **kwargs):
        super().__init__(**kwargs)
        self.kernel_size = kernel_sizes
        self.filter_size = filter_size
        self.strides = strides
        self.activation = activation
        self.dropout_rate = dropout_rate
        self.convs = [Conv1D(filters=self.filter_size, kernel_size=kernel_size, strides=self.strides,padding='valid', activation=self.activation)for kernel_size in self.kernel_size]
        self.dropout = Dropout(self.dropout_rate)
        self.globalavgpooling = GlobalAveragePooling1D()

    def call(self, inputs):
        grams = []

        for conv in self.convs:
            gram = conv(inputs)
            gram = self.globalavgpooling(gram)
            gram = self.dropout(gram)
            grams.append(gram)

        out = concatenate(grams, axis=-1)

        return out

    def compute_output_shape(self, input_shape):
        return (None, len(self.kernel_size) * self.filter_size)

求各位大佬帮帮忙!!!救救孩子吧!

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-09-09 18:37
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 以帮助更多的人 ^-^
    评论

报告相同问题?

悬赏问题

  • ¥15 在Starccm中相变材料的物理模型该如何选择?
  • ¥15 关于#android#的问题,请各位专家解答!
  • ¥200 PDF使用虚拟列表技术做渲染和加载带来的问题
  • ¥15 出现报错Debug Assertion Failed,如何解决?
  • ¥50 mcf中怎么实现导入的图片变成透明
  • ¥15 ruoyi-flowable流程设计配置的表单时,级联选择如何配置
  • ¥20 金属玻璃的剪切局部化程度怎么用ovito表示出来
  • ¥15 自定义控件在中文模式下不能输入数字
  • ¥15 关于#运维#的问题:用mail.abc.com 端口9696的方式同样能访问hr.abc.com 端口:6080 的页面两个网址都指向同一个外网ip(相关搜索:服务器)
  • ¥15 xgboost建模输出结果由三分类变成四分类