subscribers 2022-01-06 15:42 采纳率: 57.1%
浏览 48
已结题

pytorch与PyG


self.gat1 = GATConv(self.input_dim, self.hidden_dim, dropout=0.5, heads=self.heads)
self.gat2 = GATConv(self.hidden_dim*self.heads, self.out_dim, dropout=0.5, heads=1, concat=False)

self.inner_attention = nn.Sequential(
            self.gat1,
            nn.ReLU(),
            nn.Dropout(p=0.6),
            self.gat2,
            nn.ReLU()
        )

# 请问我这样调用的时候为什么会报错啊?谢谢
x, edge_index = graph_data.x, graph_data.edge_index
out = self.inner_attention()(x, edge_index)


TypeError: forward() missing 1 required positional argument: 'input'
  • 写回答

2条回答 默认 最新

  • 爱晚乏客游 2022-01-07 10:23
    关注

    我还以为是你自己实现的网络结构呢。。如果是自带的话,那就是输入格式的问题了。具体你可以看看这个,Sequential多输入需要封装一下。
    https://blog.csdn.net/qq_23968185/article/details/108277724
    或者你可以直接写在forward()里面,不用Sequential。

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.gat1 = GATConv(self.input_dim, self.hidden_dim, dropout=0.5, heads=self.heads)
            self.gat2 = GATConv(self.hidden_dim * self.heads, self.out_dim, dropout=0.5, heads=1, concat=False)
        
        def forward(self, data):
            x, edge_index = data.x, data.edge_index
            x = self.gat1(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
            x = self.gat2(x, edge_index)
            x = F.relu(x)
            return x #返回什么自己确定下
    edge_index = torch.tensor([[0, 1, 1, 2],[1, 0, 2, 1]], dtype=torch.long)
    x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
    data = Data(x=x, edge_index=edge_index)
    model=Net()
    out=model(data)
    print(out)
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 1月15日
  • 已采纳回答 1月7日
  • 创建了问题 1月6日

悬赏问题

  • ¥15 2020长安杯与连接网探
  • ¥15 关于#matlab#的问题:在模糊控制器中选出线路信息,在simulink中根据线路信息生成速度时间目标曲线(初速度为20m/s,15秒后减为0的速度时间图像)我想问线路信息是什么
  • ¥15 banner广告展示设置多少时间不怎么会消耗用户价值
  • ¥16 mybatis的代理对象无法通过@Autowired装填
  • ¥15 可见光定位matlab仿真
  • ¥15 arduino 四自由度机械臂
  • ¥15 wordpress 产品图片 GIF 没法显示
  • ¥15 求三国群英传pl国战时间的修改方法
  • ¥15 matlab代码代写,需写出详细代码,代价私
  • ¥15 ROS系统搭建请教(跨境电商用途)