种种粒粒在目 2021-05-16 12:19 采纳率: 0%
浏览 150

pytorch loss一直居高不下 GAT+MLP

class LinearLayer(torch.nn.Module):
    def __init__(self,in_feature,hid_feature,out_feature):
        super(LinearLayer,self).__init__()
        self.dense1=torch.nn.Linear(in_feature,hid_feature)
        self.dense2=torch.nn.Linear(hid_feature,out_feature)
    def forward(self,x):
        x=self.dense1(x)
        x=torch.nn.functional.leaky_relu(x)
        x=self.dense2(x)
        x=torch.nn.functional.leaky_relu(x)
        return x
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self,in_feature,out_feature,concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.concat = concat
        self.in_feature=in_feature
        self.out_feature=out_feature
        self.adj = torch.tensor([[1, 1, 1, 1, 1],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0],
                                 [1, 0, 0, 0, 0]
                                 ])
        # 定义可训练参数,即论文中的W和a
        self.W = torch.nn.Parameter(torch.zeros(size=(1,in_feature, out_feature)))
        torch.nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier初始化
        self.a = torch.nn.Parameter(torch.zeros(size=(1,5, 2 * out_feature, 1)))
        torch.nn.init.xavier_uniform_(self.a.data, gain=1.414)  # xavier初始化
    def forward(self,x):#x:[batch,5,in_feature]
        batch = x.size()[0]
        N = 5
        W_batch=self.W.repeat(batch,1,1)## [batch,in_feature,out_feature]???????????/repeat存疑
        MLP_result = torch.matmul(x, W_batch)  # [batch,5,in_feature]x[batch,in_feature,out_feature]->[batch,5,out_feature]
        adj = self.adj  # [5,5]
        ## 两种不同的repeat方式
        h_i=MLP_result.repeat_interleave(N, dim=1)#[batch,5,out_feature]->[batch,25,out_feature]
        h_j = MLP_result.repeat(1, N, 1)  # [batch,25, out_feature]
        a_input = torch.cat([h_i, h_j], dim=2).view(batch, N, N, 2 * self.out_feature)  # [batch,5*5,out_feature*2]->[batch,5,5,out_feature*2]
        a_batch=self.a.repeat(batch,1,1,1)#[1,5,out_feature*2,1]->[batch,5,out_feature*2,1]
        e = torch.matmul(a_input,a_batch).squeeze(3)  # [batch,5,5,out_feature*2]x[batch,5,out_feature*2,1]->[batch,5,5,1]->[batch,5,5]
        e = torch.nn.functional.leaky_relu(e)
        zero_vec = -1e12 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)  # [batch,5,5]
        attention = torch.nn.functional.softmax(attention, dim=2)  # 归一化
        attention = torch.nn.functional.dropout(attention)
        h_prime = torch.matmul(attention, MLP_result)  # [batch,N,N]X[batch,N,out_feature]->[batch,N,out_feature]
        print(h_prime.size())
        if self.concat:
            h_prime = torch.nn.functional.elu(h_prime)
        out = h_prime[:, 0, :]  # [batch,out_feature]
        return out


class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.mlp_hid_feature = 64
        self.mlp_out_feature = 64
        self.att_out_feature=64
        self.out_hid_feature=64
        self.mlp=LinearLayer(N_STATE,self.mlp_hid_feature,self.mlp_out_feature)#[batch,5,state]->[batch,5,mlp_out_feature]
        self.att=GraphAttentionLayer(self.mlp_out_feature,self.att_out_feature)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        self.out=LinearLayer(self.att_out_feature,self.out_hid_feature,N_ACTION)#[batch,out_feature]->[batch,action]
        # 定义可训练参数,即论文中的W和a
    def forward(self, x):
        MLP_result = self.mlp(x)  #[batch,5,state]->[batch,5,mlp_out_feature]
        attention_result = self.att(MLP_result)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
        out = self.out(attention_result)#[batch,out_feature]->[batch,action]
        return out

代码如上所示 训练是强化学习DQN 可以断定是网络结构的问题 因为之前把网络结构改成简单的全连接loss是可以收敛的 现在中间加了层GAT,loss就一直居高不下。开始的loss是0.几几 然后慢慢飙升到一千多,在这个基础上开始上下震荡……

  • 写回答

1条回答 默认 最新

  • 无限虚空 2024-06-20 22:23
    关注

    PyTorch中的Loss值持续居高不下可能是因为模型在训练过程中遇到了一些挑战,这通常涉及到网络结构(如GAT(Graph Attention Networks)和MLP(Multi-Layer Perceptron)),优化策略以及数据本身的特性。以下是几个可能的原因及对应的解决方案:

    1. 模型选择与架构:

      • GAT: 如果GAT的注意力机制设置不合理,可能会导致信息抓取不足或过拟合。检查注意力头的数量、权重分配和注意力函数是否合适。
      • MLP: MLP在处理非线性复杂问题时可能表现不佳。尝试增加隐藏层数量、调整激活函数或使用更复杂的网络结构。
    2. 优化器与学习率:

      • 使用Adam或SGD等优化器时,确认学习率是否适当。如果太高可能导致梯度爆炸,如果太低可能导致收敛慢。
      • 考虑使用学习率衰减策略或学习率调度器来动态调整学习率。
    3. 损失函数:

      • 确保选择的损失函数与任务类型匹配。例如,对于分类任务可能使用交叉熵,回归任务用MSE。如果任务涉及多标签或不平衡数据,可能需要调整权重或使用其他特殊损失。
    4. 数据预处理:

      • 数据清洗、归一化或标准化可能影响模型性能。检查输入数据的分布和质量,可能需要进行数据增强或降噪。
    5. 过拟合与正则化:

      • 添加Dropout、L1或L2正则化,或者使用早停策略来防止过拟合。
    6. 批量大小:

      • 批量大小也可能影响训练过程。适当调整批量大小可能有助于模型更好地收敛。
    7. 验证集监控:

      • 定期检查验证集性能,防止在训练集上过度拟合。如果验证集损失持续上升,可能需要调整模型。
    评论

报告相同问题?

悬赏问题

  • ¥15 在若依框架下实现人脸识别
  • ¥15 网络科学导论,网络控制
  • ¥100 安卓tv程序连接SQLSERVER2008问题
  • ¥15 利用Sentinel-2和Landsat8做一个水库的长时序NDVI的对比,为什么Snetinel-2计算的结果最小值特别小,而Lansat8就很平均
  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?
  • ¥15 vite+vue3+plyr播放本地public文件夹下视频无法加载
  • ¥15 c#逐行读取txt文本,但是每一行里面数据之间空格数量不同