m0_61104322 2024-12-26 11:16 采纳率: 90.3%
浏览 7
已结题

提取药物embeddings的时候出了问题

大家好,我在用GIN模型将SMILE数据转为图数据后提取embeddings的时候出现了下列问题。主要是shape的问题,但是不知道如何修改?尝试修改了模型,但是没有解决。请大家给我具体修改意见和代码。谢谢。

```

RuntimeError Traceback (most recent call last)
Cell In [12], line 11
8 with torch.no_grad():
9 for idx, data in enumerate(graph_data_list):
10 # 获取模型的输出和嵌入
---> 11 output, embedding = model(data) # 返回分类输出和嵌入
13 # 打印每个batch的output和embedding的shape
14 print(f'Batch {idx} - Output shape: {output.shape}, Embedding shape: {embedding.shape}')

File /shared-libs/python3.9/py/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

Cell In [7], line 30, in GINModel.forward(self, data)
28 def forward(self, data):
29 x, edge_index = data.x, data.edge_index
---> 30 x = self.conv1(x, edge_index)
31 x = self.conv2(x, edge_index)
32 return self.fc(x)

File /shared-libs/python3.9/py/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

Cell In [7], line 19, in GINConv.forward(self, x, edge_index)
17 row, col = edge_index
18 x_out = self.mlp(x)
---> 19 return F.relu(x_out + x[row])

RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0

```**

  • 写回答

1条回答 默认 最新

  • 道友老李 JWE233286一种基于机器视觉的水表指针读数识别及修正的方法 专利发明者 2024-12-26 11:30
    关注
    让【道友老李】来帮你解答,本回答参考通义千问大模型编写,并整理提供,如果还有疑问可以点击头像关注私信或评论。
    如果答案让您满意,请采纳、关注,非常感谢!
    你好!根据你提供的错误信息,问题出在 `GINConv` 层的前向传播过程中。具体来说,`x_out` 和 `x[row]` 的形状不匹配。我们可以通过以下几个步骤来解决这个问题:
    1. 检查输入数据的形状:确保输入到模型的数据形状是正确的。
    2. 调整 GINConv 层的实现:确保 GINConv 层的实现正确处理了输入数据的形状。

    下面是一个详细的解决方案,包括代码示例和解释。

    1. 检查输入数据的形状

    首先,我们需要确保输入到模型的数据形状是正确的。你可以打印出 data.xdata.edge_index 的形状来检查它们是否符合预期。

    for idx, data in enumerate(graph_data_list):
        print(f"Batch {idx} - x shape: {data.x.shape}, edge_index shape: {data.edge_index.shape}")
    

    2. 调整 GINConv 层的实现

    假设你的 GINConv 层的实现如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class GINConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GINConv, self).__init__()
            self.mlp = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.ReLU(),
                nn.Linear(out_channels, out_channels)
            )
    
        def forward(self, x, edge_index):
            row, col = edge_index
            x_out = self.mlp(x)
            return F.relu(x_out + x[row])
    

    我们需要确保 x_outx[row] 的形状一致。一种可能的解决方案是将 x[row]x_out 进行适当的广播操作。我们可以使用 torch.index_select 来选择 x 中的特定行,以确保形状一致。

    修改后的 GINConv 层实现如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class GINConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GINConv, self).__init__()
            self.mlp = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.ReLU(),
                nn.Linear(out_channels, out_channels)
            )
    
        def forward(self, x, edge_index):
            row, col = edge_index
            x_out = self.mlp(x)
            x_row = torch.index_select(x, 0, row)  # 选择 x 中的特定行
            return F.relu(x_out + x_row)
    

    3. 完整的模型实现

    假设你的完整模型实现如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class GINConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GINConv, self).__init__()
            self.mlp = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.ReLU(),
                nn.Linear(out_channels, out_channels)
            )
    
        def forward(self, x, edge_index):
            row, col = edge_index
            x_out = self.mlp(x)
            x_row = torch.index_select(x, 0, row)  # 选择 x 中的特定行
            return F.relu(x_out + x_row)
    
    class GINModel(nn.Module):
        def __init__(self, in_channels, hidden_channels, num_classes):
            super(GINModel, self).__init__()
            self.conv1 = GINConv(in_channels, hidden_channels)
            self.conv2 = GINConv(hidden_channels, hidden_channels)
            self.fc = nn.Linear(hidden_channels, num_classes)
    
        def forward(self, data):
            x, edge_index = data.x, data.edge_index
            x = self.conv1(x, edge_index)
            x = self.conv2(x, edge_index)
            return self.fc(x), x  # 返回分类输出和嵌入
    
    # 示例数据
    graph_data_list = [your_graph_data_here]
    
    # 创建模型实例
    model = GINModel(in_channels=3, hidden_channels=64, num_classes=10)
    
    # 前向传播
    with torch.no_grad():
        for idx, data in enumerate(graph_data_list):
            output, embedding = model(data)  # 返回分类输出和嵌入
            print(f'Batch {idx} - Output shape: {output.shape}, Embedding shape: {embedding.shape}')
    

    4. 测试和调试

    运行上述代码,确保没有形状不匹配的错误。如果仍然有问题,请检查输入数据的形状,并确保 edge_indexx 的形状是正确的。

    希望这些修改能帮助你解决问题!如果有任何进一步的问题,请随时告诉我。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 1月3日
  • 已采纳回答 12月26日
  • 创建了问题 12月26日