RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x188 and 94x600)
代码如下
import numpy as np
import random
import math
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from transformers import BertModel, BertConfig
from utils import to_gpu
from utils import ReverseLayerF
def masked_mean(tensor, mask, dim):
"""Finding the mean along dim"""
masked = torch.mul(tensor, mask)
return masked.sum(dim=dim) / mask.sum(dim=dim)
def masked_max(tensor, mask, dim):
"""Finding the max along dim"""
masked = torch.mul(tensor, mask)
neg_inf = torch.zeros_like(tensor)
neg_inf[~mask] = -math.inf
return (masked + neg_inf).max(dim=dim)
# let's define a simple model that can deal with multimodal variable length sequence
class MISA(nn.Module):
def __init__(self, config):
super(MISA, self).__init__()
self.config = config
self.text_size = config.embedding_size
self.visual_size = config.visual_size
self.acoustic_size = config.acoustic_size
self.input_sizes = [self.text_size, self.visual_size, self.acoustic_size]
self.hidden_sizes = [int(self.text_size), int(self.visual_size), int(self.acoustic_size)]
self.output_size = config.num_classes
self.dropout_rate = config.dropout
self.activation = self.config.activation()
self.tanh = nn.Tanh()
rnn = nn.LSTM if self.config.rnncell == "lstm" else nn.GRU
# 处理文本模态
if self.config.use_bert:
bertconfig = BertConfig.from_pretrained('/root/autodl-tmp/MISA1/bert-base-uncased',
output_hidden_states=True)
self.bertmodel = BertModel.from_pretrained('/root/autodl-tmp/MISA1/bert-base-uncased', config=bertconfig)
self.project_t = nn.Linear(768, 600) # BERT输出维度是768,需要调整为600
else:
self.embed = nn.Embedding(len(config.word2id), self.text_size)
self.trnn1 = rnn(self.text_size, self.hidden_sizes[0], bidirectional=True)
self.trnn2 = rnn(2 * self.hidden_sizes[0], self.hidden_sizes[0], bidirectional=True)
# 处理视觉模态
self.vrnn1 = rnn(self.visual_size, self.hidden_sizes[1], bidirectional=True)
self.vrnn2 = rnn(2 * self.hidden_sizes[1], self.hidden_sizes[1], bidirectional=True)
self.project_v = nn.Linear(self.hidden_sizes[1] * 2, 600) # 视觉模态的输出大小
# 处理音频模态
self.arnn1 = rnn(self.acoustic_size, self.hidden_sizes[2], bidirectional=True)
self.arnn2 = rnn(2 * self.hidden_sizes[2], self.hidden_sizes[2], bidirectional=True)
self.project_a = nn.Linear(self.hidden_sizes[2] * 2, 600) # 音频模态的输出大小
# LayerNorm处理
self.tlayer_norm = nn.LayerNorm(self.hidden_sizes[0] * 2)
self.vlayer_norm = nn.LayerNorm(self.hidden_sizes[1] * 2)
self.alayer_norm = nn.LayerNorm(self.hidden_sizes[2] * 2)
# 最后的分类层
self.fc = nn.Linear(600, self.output_size)
def forward(self, text_input, visual_input, acoustic_input):
print("Text input shape:", text_input.shape) # 查看 text_input 的形状
print("Visual input shape:", visual_input.shape) # 查看 visual_input 的形状
print("Acoustic input shape:", acoustic_input.shape) # 查看 acoustic_input 的形状
# BERT 模态处理
text_output = self.bert(text_input)[1] # [CLS] token 的输出
text_output = self.project_t(text_output)
print("Text projected shape:", text_output.shape) # 检查形状
# 视觉模态处理
visual_output = self.visual_encoder(visual_input)
print("Visual output before projection:", visual_output.shape)
visual_output = self.project_v(visual_output)
print("Visual output after projection:", visual_output.shape)
# 音频模态处理
acoustic_output = self.acoustic_encoder(acoustic_input)
print("Acoustic output before projection:", acoustic_output.shape)
acoustic_output = self.project_a(acoustic_output)
print("Acoustic output after projection:", acoustic_output.shape)
print("Acoustic projected shape:", acoustic_output.shape) # 检查形状
# 拼接
combined_output = torch.cat((text_output, visual_output, acoustic_output), dim=-1)
print("Combined output shape:", combined_output.shape) # 检查拼接后的形状
# 最后的分类
logits = self.fc(combined_output)
return logits
##########################################
# mapping modalities to same sized space
##########################################
if self.config.use_bert:
self.project_t = nn.Sequential()
self.project_t.add_module('project_t', nn.Linear(in_features=768, out_features=config.hidden_size))
self.project_t.add_module('project_t_activation', self.activation)
self.project_t.add_module('project_t_layer_norm', nn.LayerNorm(config.hidden_size))
else:
self.project_t = nn.Sequential()
self.project_t.add_module('project_t',
nn.Linear(in_features=hidden_sizes[0] * 4, out_features=config.hidden_size))
self.project_t.add_module('project_t_activation', self.activation)
self.project_t.add_module('project_t_layer_norm', nn.LayerNorm(config.hidden_size))
self.project_v = nn.Sequential()
self.project_v.add_module('project_v',
nn.Linear(in_features=hidden_sizes[1] * 4, out_features=config.hidden_size))
self.project_v.add_module('project_v_activation', self.activation)
self.project_v.add_module('project_v_layer_norm', nn.LayerNorm(config.hidden_size))
self.project_a = nn.Sequential()
self.project_a.add_module('project_a',
nn.Linear(in_features=hidden_sizes[2] * 4, out_features=config.hidden_size))
self.project_a.add_module('project_a_activation', self.activation)
self.project_a.add_module('project_a_layer_norm', nn.LayerNorm(config.hidden_size))
##########################################
# private encoders
##########################################
self.private_t = nn.Sequential()
self.private_t.add_module('private_t_1',
nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.private_t.add_module('private_t_activation_1', nn.Sigmoid())
self.private_v = nn.Sequential()
self.private_v.add_module('private_v_1',
nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.private_v.add_module('private_v_activation_1', nn.Sigmoid())
self.private_a = nn.Sequential()
self.private_a.add_module('private_a_3',
nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.private_a.add_module('private_a_activation_3', nn.Sigmoid())
##########################################
# shared encoder
##########################################
self.shared = nn.Sequential()
self.shared.add_module('shared_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.shared.add_module('shared_activation_1', nn.Sigmoid())
##########################################
# reconstruct
##########################################
self.recon_t = nn.Sequential()
self.recon_t.add_module('recon_t_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.recon_v = nn.Sequential()
self.recon_v.add_module('recon_v_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.recon_a = nn.Sequential()
self.recon_a.add_module('recon_a_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
##########################################
# shared space adversarial discriminator
##########################################
if not self.config.use_cmd_sim:
self.discriminator = nn.Sequential()
self.discriminator.add_module('discriminator_layer_1',
nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
self.discriminator.add_module('discriminator_layer_1_activation', self.activation)
self.discriminator.add_module('discriminator_layer_1_dropout', nn.Dropout(dropout_rate))
self.discriminator.add_module('discriminator_layer_2',
nn.Linear(in_features=config.hidden_size, out_features=len(hidden_sizes)))
##########################################
# shared-private collaborative discriminator
##########################################
self.sp_discriminator = nn.Sequential()
self.sp_discriminator.add_module('sp_discriminator_layer_1',
nn.Linear(in_features=config.hidden_size, out_features=4))
# 修正fusion层的输入维度
self.fusion = nn.Sequential(
nn.Linear(600 * 3, self.config.hidden_size * 3), # 输入维度 600 * 3
nn.Dropout(self.dropout_rate),
self.activation,
nn.Linear(self.config.hidden_size * 3, self.output_size)
)
self.fusion.add_module('fusion_layer_1',
nn.Linear(in_features=600 * 3, out_features=self.config.hidden_size * 3))
self.fusion.add_module('fusion_layer_1_dropout', nn.Dropout(self.dropout_rate))
self.fusion.add_module('fusion_layer_1_activation', self.activation)
self.fusion.add_module('fusion_layer_3',
nn.Linear(in_features=self.config.hidden_size * 3, out_features=self.output_size))
self.tlayer_norm = nn.LayerNorm((hidden_sizes[0] * 2,))
self.vlayer_norm = nn.LayerNorm((hidden_sizes[1] * 2,))
self.alayer_norm = nn.LayerNorm((hidden_sizes[2] * 2,))
encoder_layer = nn.TransformerEncoderLayer(d_model=self.config.hidden_size, nhead=2)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
def extract_features(self, sequence, lengths, rnn1, rnn2, layer_norm):
packed_sequence = pack_padded_sequence(sequence, lengths.cpu())
if self.config.rnncell == "lstm":
packed_h1, (final_h1, _) = rnn1(packed_sequence)
_, (final_h2, _) = rnn2(packed_h1)
else:
packed_h1, final_h1 = rnn1(packed_sequence)
_, final_h2 = rnn2(packed_h1)
return final_h1, final_h2
def alignment(self, sentences, visual, acoustic, lengths, bert_sent, bert_sent_type, bert_sent_mask):
batch_size = lengths.size(0)
# 如果使用BERT
if self.config.use_bert:
bert_output = self.bertmodel(input_ids=bert_sent,
attention_mask=bert_sent_mask,
token_type_ids=bert_sent_type)
bert_output = bert_output[0]
# mask mean处理
masked_output = torch.mul(bert_sent_mask.unsqueeze(2), bert_output)
mask_len = torch.sum(bert_sent_mask, dim=1, keepdim=True)
bert_output = torch.sum(masked_output, dim=1, keepdim=False) / mask_len
utterance_text = bert_output
else:
sentences = self.embed(sentences)
final_h1t, final_h2t = self.extract_features(sentences, lengths, self.trnn1, self.trnn2, self.tlayer_norm)
utterance_text = torch.cat((final_h1t, final_h2t), dim=2).permute(1, 0, 2).contiguous().view(batch_size, -1)
# 提取视觉模态特征
final_h1v, final_h2v = self.extract_features(visual, lengths, self.vrnn1, self.vrnn2, self.vlayer_norm)
utterance_video = torch.cat((final_h1v, final_h2v), dim=2).permute(1, 0, 2).contiguous().view(batch_size, -1)
# 提取音频模态特征
final_h1a, final_h2a = self.extract_features(acoustic, lengths, self.arnn1, self.arnn2, self.alayer_norm)
utterance_audio = torch.cat((final_h1a, final_h2a), dim=2).permute(1, 0, 2).contiguous().view(batch_size, -1)
# 对每个模态进行投影,确保输出的维度一致
utterance_text = self.project_t(utterance_text)
utterance_video = self.project_v(utterance_video)
utterance_audio = self.project_a(utterance_audio)
# 进行共享/私有编码
self.shared_private(utterance_text, utterance_video, utterance_audio)
# 如果不使用命令相似度
if not self.config.use_cmd_sim:
# 判别器
reversed_shared_code_t = ReverseLayerF.apply(self.utt_shared_t, self.config.reverse_grad_weight)
reversed_shared_code_v = ReverseLayerF.apply(self.utt_shared_v, self.config.reverse_grad_weight)
reversed_shared_code_a = ReverseLayerF.apply(self.utt_shared_a, self.config.reverse_grad_weight)
self.domain_label_t = self.discriminator(reversed_shared_code_t)
self.domain_label_v = self.discriminator(reversed_shared_code_v)
self.domain_label_a = self.discriminator(reversed_shared_code_a)
else:
self.domain_label_t = None
self.domain_label_v = None
self.domain_label_a = None
# 共享或私有编码
self.shared_or_private_p_t = self.sp_discriminator(self.utt_private_t)
self.shared_or_private_p_v = self.sp_discriminator(self.utt_private_v)
self.shared_or_private_p_a = self.sp_discriminator(self.utt_private_a)
self.shared_or_private_s = self.sp_discriminator(
(self.utt_shared_t + self.utt_shared_v + self.utt_shared_a) / 3.0)
# 进行重构
self.reconstruct()
# 进行1层Transformer融合
h = torch.stack((self.utt_private_t, self.utt_private_v, self.utt_private_a,
self.utt_shared_t, self.utt_shared_v, self.utt_shared_a), dim=0)
h = self.transformer_encoder(h)
h = torch.cat((h[0], h[1], h[2], h[3], h[4], h[5]), dim=1)
o = self.fusion(h)
return o
def reconstruct(self, ):
self.utt_t = (self.utt_private_t + self.utt_shared_t)
self.utt_v = (self.utt_private_v + self.utt_shared_v)
self.utt_a = (self.utt_private_a + self.utt_shared_a)
self.utt_t_recon = self.recon_t(self.utt_t)
self.utt_v_recon = self.recon_v(self.utt_v)
self.utt_a_recon = self.recon_a(self.utt_a)
def shared_private(self, utterance_t, utterance_v, utterance_a):
# Projecting to same sized space
self.utt_t_orig = utterance_t = self.project_t(utterance_t)
self.utt_v_orig = utterance_v = self.project_v(utterance_v)
self.utt_a_orig = utterance_a = self.project_a(utterance_a)
# Private-shared components
self.utt_private_t = self.private_t(utterance_t)
self.utt_private_v = self.private_v(utterance_v)
self.utt_private_a = self.private_a(utterance_a)
self.utt_shared_t = self.shared(utterance_t)
self.utt_shared_v = self.shared(utterance_v)
self.utt_shared_a = self.shared(utterance_a)
def forward(self, sentences, video, acoustic, lengths, bert_sent, bert_sent_type, bert_sent_mask):
batch_size = lengths.size(0)
o = self.alignment(sentences, video, acoustic, lengths, bert_sent, bert_sent_type, bert_sent_mask)
return o