问题与背景
在调用pytorch库的bert encoder报错,报错内容:
TypeError: forward() missing 1 required positional argument: 'attention_mask'
相关代码
from pytorch_transformers import BertModel,BertConfig
class Summarizer(nn.Module):
def __init__(self,args, word_padding_idx, vocab_size, device, checkpoint=None):
self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
config = BertConfig(self.vocab_size)
self.encoder = self.bert.model.encoder(config)
我的解答思路和尝试过的方法
# 替换self.encoder = self.bert.model.encoder(config)
# 方案1
self.encoder = self.bert.model.encoder()
# 方案2
self.encoder = self.bert.model.encoder
# 方案3
self.encoder = self.bert.model.encoder(config)
# 方案3指定固定参数
self.encoder = self.bert.model.encoder(768, [0,1])
思考
bert在实例化self.bert=Bert()时,bert对象已经包含了bert_embedding和bert_encoder等结构。
- 请问此时能否说明模型的编码器已经使用了bert编码器?
- 请问是否需要单独指定模型的编码器self.encoder = self.bert.model.encoder()?
模型结构更直观一些
模型:bert_embedding+transformer_encoder+transformer_decoder
结构:
self{
bert{
embedding()
encoder()
decoder()
}
(encoder)transformer_encoder{}
(decoder)transformer_decoder{}
}
模型:bert_embedding+bert_encoder+transformer_decoder
结构:
self{
bert{
embedding()
encoder() # flag1
decoder()
}
(encoder)bert_encoder{} # flag2 # 这里是否需要单独指定bert_encoder尚不确定,可能会和flag1位置指向同一个地址产生递归
(decoder)transformer_decoder{}
}