种种粒粒在目 2021-05-16 12:19 采纳率: 0%
浏览 123

pytorch loss一直居高不下 GAT+MLP

class LinearLayer(torch.nn.Module):
    def __init__(self,in_feature,hid_feature,out_feature):
        super(LinearLayer,self).__init__()
        self.dense1=torch.nn.Linear(in_feature,hid_feature)
        self.dense2=torch.nn.Linear(hid_feature,out_feature)
    def forward(self,x):
        x=self.dense1(x)
        x=torch.nn.functional.leaky_relu(x)
        x=self.dense2(x)
        x=torch.nn.functional.leaky_relu(x)
        return x
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self,in_feature,out_feature,concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.concat = concat
        self.in_feature=in_feature
        self.out_feature=out_feature
        self.adj = torch.tensor([[1, 1, 1, 1, 1],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0]
                                 ])
        # 定义可训练参数,即论文中的W和a
        self.W = torch.nn.Parameter(torch.zeros(size=(1,in_feature, out_feature)))
        torch.nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier初始化
        self.a = torch.nn.Parameter(torch.zeros(size=(1,5, 2 * out_feature, 1)))
        torch.nn.init.xavier_uniform_(self.a.data, gain=1.414)  # xavier初始化
    def forward(self,x):#x:[batch,5,in_feature]
        batch = x.size()[0]
        N = 5
        W_batch=self.W.repeat(batch,1,1)## [batch,in_feature,out_feature]???????????/repeat存疑
        MLP_result = torch.matmul(x, W_batch)  # [batch,5,in_feature]x[batch,in_feature,out_feature]->[batch,5,out_feature]
        adj = self.adj  # [5,5]
        ## 两种不同的repeat方式
        h_i=MLP_result.repeat_interleave(N, dim=1)#[batch,5,out_feature]->[batch,25,out_feature]
        h_j = MLP_result.repeat(1, N, 1)  # [batch,25, out_feature]
        a_input = torch.cat([h_i, h_j], dim=2).view(batch, N, N, 2 * self.out_feature)  # [batch,5*5,out_feature*2]->[batch,5,5,out_feature*2]
        a_batch=self.a.repeat(batch,1,1,1)#[1,5,out_feature*2,1]->[batch,5,out_feature*2,1]
        e = torch.matmul(a_input,a_batch).squeeze(3)  # [batch,5,5,out_feature*2]x[batch,5,out_feature*2,1]->[batch,5,5,1]->[batch,5,5]
        e = torch.nn.functional.leaky_relu(e)
        zero_vec = -1e12 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)  # [batch,5,5]
        attention = torch.nn.functional.softmax(attention, dim=2)  # 归一化
        attention = torch.nn.functional.dropout(attention)
        h_prime = torch.matmul(attention, MLP_result)  # [batch,N,N]X[batch,N,out_feature]->[batch,N,out_feature]
        print(h_prime.size())
        if self.concat:
            h_prime = torch.nn.functional.elu(h_prime)
        out = h_prime[:, 0, :]  # [batch,out_feature]
        return out


class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.mlp_hid_feature = 64
        self.mlp_out_feature = 64
        self.att_out_feature=64
        self.out_hid_feature=64
        self.mlp=LinearLayer(N_STATE,self.mlp_hid_feature,self.mlp_out_feature)#[batch,5,state]->[batch,5,mlp_out_feature]
        self.att=GraphAttentionLayer(self.mlp_out_feature,self.att_out_feature)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        self.out=LinearLayer(self.att_out_feature,self.out_hid_feature,N_ACTION)#[batch,out_feature]->[batch,action]
        # 定义可训练参数,即论文中的W和a
    def forward(self, x):
        MLP_result = self.mlp(x)  #[batch,5,state]->[batch,5,mlp_out_feature]
        attention_result = self.att(MLP_result)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        out = self.out(attention_result)#[batch,out_feature]->[batch,action]
        return out

代码如上所示 训练是强化学习DQN 可以断定是网络结构的问题 因为之前把网络结构改成简单的全连接loss是可以收敛的 现在中间加了层GAT,loss就一直居高不下。开始的loss是0.几几 然后慢慢飙升到一千多,在这个基础上开始上下震荡……

  • 写回答

0条回答 默认 最新

      报告相同问题?

      相关推荐 更多相似问题

      悬赏问题

      • ¥20 新闻小程序6万人在线
      • ¥15 Fluent轴流风扇模拟
      • ¥15 基于GPS的自行车定位系统设计
      • ¥15 idea中安装matplotlib模块完成,运行还是显示无安装
      • ¥15 robotframework 运行报错
      • ¥60 C# (VS2015) 用HttpWebRequest get 方式 与 post 方式
      • ¥30 yolo侦测mammogram总是没有好结果,求经验
      • ¥380 網頁顯示MT4後台數據
      • ¥20 Pyqt5如何实现对指定窗口调用显示视频信号
      • ¥15 ResNET50修改参数