种种粒粒在目 2021-05-16 12:19 采纳率: 0%
浏览 145

pytorch loss一直居高不下 GAT+MLP

class LinearLayer(torch.nn.Module):
    def __init__(self,in_feature,hid_feature,out_feature):
        super(LinearLayer,self).__init__()
        self.dense1=torch.nn.Linear(in_feature,hid_feature)
        self.dense2=torch.nn.Linear(hid_feature,out_feature)
    def forward(self,x):
        x=self.dense1(x)
        x=torch.nn.functional.leaky_relu(x)
        x=self.dense2(x)
        x=torch.nn.functional.leaky_relu(x)
        return x
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self,in_feature,out_feature,concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.concat = concat
        self.in_feature=in_feature
        self.out_feature=out_feature
        self.adj = torch.tensor([[1, 1, 1, 1, 1],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0]
                                 ])
        # 定义可训练参数,即论文中的W和a
        self.W = torch.nn.Parameter(torch.zeros(size=(1,in_feature, out_feature)))
        torch.nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier初始化
        self.a = torch.nn.Parameter(torch.zeros(size=(1,5, 2 * out_feature, 1)))
        torch.nn.init.xavier_uniform_(self.a.data, gain=1.414)  # xavier初始化
    def forward(self,x):#x:[batch,5,in_feature]
        batch = x.size()[0]
        N = 5
        W_batch=self.W.repeat(batch,1,1)## [batch,in_feature,out_feature]???????????/repeat存疑
        MLP_result = torch.matmul(x, W_batch)  # [batch,5,in_feature]x[batch,in_feature,out_feature]->[batch,5,out_feature]
        adj = self.adj  # [5,5]
        ## 两种不同的repeat方式
        h_i=MLP_result.repeat_interleave(N, dim=1)#[batch,5,out_feature]->[batch,25,out_feature]
        h_j = MLP_result.repeat(1, N, 1)  # [batch,25, out_feature]
        a_input = torch.cat([h_i, h_j], dim=2).view(batch, N, N, 2 * self.out_feature)  # [batch,5*5,out_feature*2]->[batch,5,5,out_feature*2]
        a_batch=self.a.repeat(batch,1,1,1)#[1,5,out_feature*2,1]->[batch,5,out_feature*2,1]
        e = torch.matmul(a_input,a_batch).squeeze(3)  # [batch,5,5,out_feature*2]x[batch,5,out_feature*2,1]->[batch,5,5,1]->[batch,5,5]
        e = torch.nn.functional.leaky_relu(e)
        zero_vec = -1e12 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)  # [batch,5,5]
        attention = torch.nn.functional.softmax(attention, dim=2)  # 归一化
        attention = torch.nn.functional.dropout(attention)
        h_prime = torch.matmul(attention, MLP_result)  # [batch,N,N]X[batch,N,out_feature]->[batch,N,out_feature]
        print(h_prime.size())
        if self.concat:
            h_prime = torch.nn.functional.elu(h_prime)
        out = h_prime[:, 0, :]  # [batch,out_feature]
        return out


class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.mlp_hid_feature = 64
        self.mlp_out_feature = 64
        self.att_out_feature=64
        self.out_hid_feature=64
        self.mlp=LinearLayer(N_STATE,self.mlp_hid_feature,self.mlp_out_feature)#[batch,5,state]->[batch,5,mlp_out_feature]
        self.att=GraphAttentionLayer(self.mlp_out_feature,self.att_out_feature)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        self.out=LinearLayer(self.att_out_feature,self.out_hid_feature,N_ACTION)#[batch,out_feature]->[batch,action]
        # 定义可训练参数,即论文中的W和a
    def forward(self, x):
        MLP_result = self.mlp(x)  #[batch,5,state]->[batch,5,mlp_out_feature]
        attention_result = self.att(MLP_result)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        out = self.out(attention_result)#[batch,out_feature]->[batch,action]
        return out

代码如上所示 训练是强化学习DQN 可以断定是网络结构的问题 因为之前把网络结构改成简单的全连接loss是可以收敛的 现在中间加了层GAT,loss就一直居高不下。开始的loss是0.几几 然后慢慢飙升到一千多,在这个基础上开始上下震荡……

  • 写回答

1条回答 默认 最新

  • 不是复数 2024-04-09 14:52
    关注

    我也遇到这个问题,想问一下怎么办

    评论

报告相同问题?

悬赏问题

  • ¥15 c程序不知道为什么得不到结果
  • ¥40 复杂的限制性的商函数处理
  • ¥15 程序不包含适用于入口点的静态Main方法
  • ¥15 素材场景中光线烘焙后灯光失效
  • ¥15 请教一下各位,为什么我这个没有实现模拟点击
  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置