weixin_58969230 2024-02-12 23:41 采纳率: 0%
浏览 12

基于强化学习的脓毒血症治疗推荐

我使用DQN强化学习模型来做脓毒症患者的治疗推荐,我使用下面的代码进行训练并使用matplotlib绘制损失变化曲线,按照预期的情况下,损失值应当随着epoch次数的增加而逐渐减少并逐渐趋于平稳,可事实却相反,损失值随着epoch增加确实在逐渐升高,我也尝试过调整学习率和一些超参数,可是经过多次尝试结果大致相同,我不知道是哪方面出问题了?应该怎么解决?还有一点,我总感觉我的train()函数是不是有问题,我没什么训练经验,但是几万个epoch十几分钟就训练完了?我的显卡很拉,感觉不正常

import torch
import random
import pandas as pd
from tqdm import  tqdm
import numpy as np

patients_features=["o:gender","o:Weight_kg","o:GCS","o:RR"
    ,"o:Glucose","o:Hb","o:WBC_count","o:Arterial_pH","o:paO2"
    ,"o:paCO2","o:Arterial_lactate","o:SOFA","o:SpO2"
    ,"o:age1","o:MeanBP","o:HR","o:SysBP","o:DiaBP"]

feature_num=len(patients_features)
action_num=25

class data_pool():
    def __init__(self):
        self.sasr=[]
    def __len__(self):
        return len(self.sasr)
    def __getitem__(self,i):
        return self.sasr[i]
    def sample(self,number):
        data=random.sample(self.sasr,number)
        state = torch.FloatTensor([i[0] for i in data]).reshape(-1, feature_num).to(device)
        action = torch.LongTensor([i[1] for i in data]).reshape(-1, 1).to(device)
        reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1).to(device)
        next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, feature_num).to(device)
        over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1).to(device)
        return state, action, reward, next_state,over

    def load_sasr(self,file_name):
        data=pd.read_csv(file_name)
        for index,row in tqdm(data.iterrows()):
            state=row[["s1_" + s for s in patients_features]].tolist()
            action=row['action']
            reward=row['reward']
            next_state=row[["s2_" + s for s in patients_features]].tolist()
            over=row['over']
            self.sasr.append([state,action,reward,next_state,over])


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(feature_num, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
        )

        self.fc_action = torch.nn.Linear(256, action_num)
        self.fc_state = torch.nn.Linear(256, 1)

    def forward(self, state):
        state = self.fc(state)

        #评估state的价值
        value_state = self.fc_state(state)

        #每个state下每个action的价值
        value_action = self.fc_action(state)

        #综合以上两者计算最终的价值,action去均值是为了数值稳定
        return value_state + value_action - value_action.mean(dim=-1,
                                                              keepdim=True)

model = Model().to(device)
model_delay = Model().to(device)

#复制参数
model_delay.load_state_dict(model.state_dict())

#训练
def train(pool):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.1)  # 设置学习率每10000步衰减0.1倍
    loss_fn = torch.nn.MSELoss()

    batch_size=32
    epoch_number=80000
    #共更新N轮数据

    loss_series=[]

    for epoch in tqdm(range(epoch_number)):
        loss_epoch=0

        #采样N条数据
        state, action, reward, next_state, over = pool.sample(batch_size)

        #计算value
        value = model(state).gather(dim=1, index=action)

        #对value进行限制
        Q_thresh=20
        lambda_value=0.01
        loss_value=lambda_value*torch.clamp(torch.abs(value)-Q_thresh,min=0)
        mean_value_loss=torch.mean(loss_value)

        #计算target
        with torch.no_grad():
            target = model_delay(next_state)
        target = target.max(dim=1)[0].reshape(-1, 1)
        target = target * 0.99 * (1 - over) + reward

        loss = loss_fn(value, target)+mean_value_loss
        loss_epoch=loss.item()
        loss.backward()
        optimizer.step()
        # scheduler.step()
        optimizer.zero_grad()

        if (epoch+1) % 1000==0:
            loss_series.append(loss_epoch)

        #复制参数
        if (epoch + 1) % 10 == 0:
            model_delay.load_state_dict(model.state_dict())

    import matplotlib.pyplot as plt
    plt.plot(range(1000,epoch_number+1,1000),loss_series)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.show()


img

img

  • 写回答

6条回答

  • 叫兽-郭老师 新星创作者: Java技术领域 2024-02-12 23:43
    关注

    参考通义千问和郭老师的小迷弟雅思莫了-编写提供,主要是为了帮助您快速解决问题,如下回答,您可以参考下,如有问题,评论区留言。

    从您提供的代码片段来看,您正在使用DQN(Deep Q Network)强化学习模型对脓毒症治疗进行推荐。您遇到的问题是训练过程中损失函数值随着epoch增加而升高,而不是减少并趋于平稳。

    有几个可能的原因:

    1. 目标网络更新频率:在DQN中,通常需要保持目标网络的参数相对稳定,以确保目标值的稳定性。您的代码中每10个epoch才更新一次目标网络参数,这可能导致目标值剧烈变化,进而影响训练过程。尝试更频繁地同步主网络和目标网络的参数,例如每固定步数或每隔一定时间间隔同步。

    2. 奖励设计:如果奖励函数设计不合理,可能会导致损失函数无法有效收敛。请检查您的奖励函数是否能够正确反映治疗效果的好坏,并且对于改善状态的动作应给予正向奖励,对于恶化状态的动作给予负向奖励。

    3. 探索与利用的平衡:虽然代码片段没有体现epsilon-greedy策略或者其它探索策略,但在实际DQN训练中,一个合适的探索策略也是至关重要的。过高的探索率可能会导致agent随机选择动作,从而影响loss的有效下降。

    4. 学习率调整:尽管您已经尝试了调整学习率,但可能还需要进一步优化。过低的学习率会导致训练速度慢、难以收敛;过高的学习率则可能导致损失震荡。您可以尝试使用学习率衰减策略,比如在代码中的scheduler.step()并未在每个epoch后执行,可以将其移入for循环内部。

    5. Q值剪裁:您在计算loss时加入了对Q值的绝对值剪裁部分的损失,这可能是导致问题的一个因素。这种做法并不常见于标准的DQN算法实现中,它可能会干扰Q值的学习过程。建议先去掉这部分看看训练结果如何。

    6. 模型架构和超参数:除了上述几点,还应该检查模型架构是否合理,以及其它超参数设置,如批次大小、gamma值等是否合适。

    解决这类问题通常需要耐心地调试和实验,逐步排查各个可能的影响因素。

    评论

报告相同问题?

问题事件

  • 修改了问题 2月13日
  • 创建了问题 2月12日