大家好,我在用GIN模型将SMILE数据转为图数据后提取embeddings的时候出现了下列问题。主要是shape的问题,但是不知道如何修改?尝试修改了模型,但是没有解决。请大家给我具体修改意见和代码。谢谢。
```
RuntimeError Traceback (most recent call last)
Cell In [12], line 11
8 with torch.no_grad():
9 for idx, data in enumerate(graph_data_list):
10 # 获取模型的输出和嵌入
---> 11 output, embedding = model(data) # 返回分类输出和嵌入
13 # 打印每个batch的output和embedding的shape
14 print(f'Batch {idx} - Output shape: {output.shape}, Embedding shape: {embedding.shape}')
File /shared-libs/python3.9/py/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
Cell In [7], line 30, in GINModel.forward(self, data)
28 def forward(self, data):
29 x, edge_index = data.x, data.edge_index
---> 30 x = self.conv1(x, edge_index)
31 x = self.conv2(x, edge_index)
32 return self.fc(x)
File /shared-libs/python3.9/py/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
Cell In [7], line 19, in GINConv.forward(self, x, edge_index)
17 row, col = edge_index
18 x_out = self.mlp(x)
---> 19 return F.relu(x_out + x[row])
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0
```**