种种粒粒在目 2021-05-16 12:19 采纳率: 0%
浏览 141

pytorch loss一直居高不下 GAT+MLP

class LinearLayer(torch.nn.Module):
    def __init__(self,in_feature,hid_feature,out_feature):
        super(LinearLayer,self).__init__()
        self.dense1=torch.nn.Linear(in_feature,hid_feature)
        self.dense2=torch.nn.Linear(hid_feature,out_feature)
    def forward(self,x):
        x=self.dense1(x)
        x=torch.nn.functional.leaky_relu(x)
        x=self.dense2(x)
        x=torch.nn.functional.leaky_relu(x)
        return x
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self,in_feature,out_feature,concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.concat = concat
        self.in_feature=in_feature
        self.out_feature=out_feature
        self.adj = torch.tensor([[1, 1, 1, 1, 1],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0]
                                 ])
        # 定义可训练参数,即论文中的W和a
        self.W = torch.nn.Parameter(torch.zeros(size=(1,in_feature, out_feature)))
        torch.nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier初始化
        self.a = torch.nn.Parameter(torch.zeros(size=(1,5, 2 * out_feature, 1)))
        torch.nn.init.xavier_uniform_(self.a.data, gain=1.414)  # xavier初始化
    def forward(self,x):#x:[batch,5,in_feature]
        batch = x.size()[0]
        N = 5
        W_batch=self.W.repeat(batch,1,1)## [batch,in_feature,out_feature]???????????/repeat存疑
        MLP_result = torch.matmul(x, W_batch)  # [batch,5,in_feature]x[batch,in_feature,out_feature]->[batch,5,out_feature]
        adj = self.adj  # [5,5]
        ## 两种不同的repeat方式
        h_i=MLP_result.repeat_interleave(N, dim=1)#[batch,5,out_feature]->[batch,25,out_feature]
        h_j = MLP_result.repeat(1, N, 1)  # [batch,25, out_feature]
        a_input = torch.cat([h_i, h_j], dim=2).view(batch, N, N, 2 * self.out_feature)  # [batch,5*5,out_feature*2]->[batch,5,5,out_feature*2]
        a_batch=self.a.repeat(batch,1,1,1)#[1,5,out_feature*2,1]->[batch,5,out_feature*2,1]
        e = torch.matmul(a_input,a_batch).squeeze(3)  # [batch,5,5,out_feature*2]x[batch,5,out_feature*2,1]->[batch,5,5,1]->[batch,5,5]
        e = torch.nn.functional.leaky_relu(e)
        zero_vec = -1e12 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)  # [batch,5,5]
        attention = torch.nn.functional.softmax(attention, dim=2)  # 归一化
        attention = torch.nn.functional.dropout(attention)
        h_prime = torch.matmul(attention, MLP_result)  # [batch,N,N]X[batch,N,out_feature]->[batch,N,out_feature]
        print(h_prime.size())
        if self.concat:
            h_prime = torch.nn.functional.elu(h_prime)
        out = h_prime[:, 0, :]  # [batch,out_feature]
        return out


class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.mlp_hid_feature = 64
        self.mlp_out_feature = 64
        self.att_out_feature=64
        self.out_hid_feature=64
        self.mlp=LinearLayer(N_STATE,self.mlp_hid_feature,self.mlp_out_feature)#[batch,5,state]->[batch,5,mlp_out_feature]
        self.att=GraphAttentionLayer(self.mlp_out_feature,self.att_out_feature)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        self.out=LinearLayer(self.att_out_feature,self.out_hid_feature,N_ACTION)#[batch,out_feature]->[batch,action]
        # 定义可训练参数,即论文中的W和a
    def forward(self, x):
        MLP_result = self.mlp(x)  #[batch,5,state]->[batch,5,mlp_out_feature]
        attention_result = self.att(MLP_result)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        out = self.out(attention_result)#[batch,out_feature]->[batch,action]
        return out

代码如上所示 训练是强化学习DQN 可以断定是网络结构的问题 因为之前把网络结构改成简单的全连接loss是可以收敛的 现在中间加了层GAT,loss就一直居高不下。开始的loss是0.几几 然后慢慢飙升到一千多,在这个基础上开始上下震荡……

  • 写回答

1条回答 默认 最新

  • 不是复数 2024-04-09 14:52
    关注

    我也遇到这个问题,想问一下怎么办

    评论

报告相同问题?

悬赏问题

  • ¥15 关于cpci总线的几个问题,有点迷糊
  • ¥15 乳腺癌数据集 相关矩阵 特征选择
  • ¥15 我的游戏账号被盗取,请问我该怎么做
  • ¥15 通关usb3.0.push文件,导致usb频繁断连
  • ¥15 有没有能解决微信公众号,只能实时拍照,没有选择相册上传功能,我不懂任何技术,,有没有给我发个软件就能搞定的方法
  • ¥15 Pythontxt文本可视化
  • ¥15 如何基于Ryu环境下使用scapy包进行数据包构造
  • ¥15 springboot国际化
  • ¥15 搭建QEMU环境运行OP-TEE出现错误
  • ¥15 Minifilter文件保护