我使用DQN强化学习模型来做脓毒症患者的治疗推荐,我使用下面的代码进行训练并使用matplotlib绘制损失变化曲线,按照预期的情况下,损失值应当随着epoch次数的增加而逐渐减少并逐渐趋于平稳,可事实却相反,损失值随着epoch增加确实在逐渐升高,我也尝试过调整学习率和一些超参数,可是经过多次尝试结果大致相同,我不知道是哪方面出问题了?应该怎么解决?还有一点,我总感觉我的train()函数是不是有问题,我没什么训练经验,但是几万个epoch十几分钟就训练完了?我的显卡很拉,感觉不正常
import torch
import random
import pandas as pd
from tqdm import tqdm
import numpy as np
patients_features=["o:gender","o:Weight_kg","o:GCS","o:RR"
,"o:Glucose","o:Hb","o:WBC_count","o:Arterial_pH","o:paO2"
,"o:paCO2","o:Arterial_lactate","o:SOFA","o:SpO2"
,"o:age1","o:MeanBP","o:HR","o:SysBP","o:DiaBP"]
feature_num=len(patients_features)
action_num=25
class data_pool():
def __init__(self):
self.sasr=[]
def __len__(self):
return len(self.sasr)
def __getitem__(self,i):
return self.sasr[i]
def sample(self,number):
data=random.sample(self.sasr,number)
state = torch.FloatTensor([i[0] for i in data]).reshape(-1, feature_num).to(device)
action = torch.LongTensor([i[1] for i in data]).reshape(-1, 1).to(device)
reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1).to(device)
next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, feature_num).to(device)
over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1).to(device)
return state, action, reward, next_state,over
def load_sasr(self,file_name):
data=pd.read_csv(file_name)
for index,row in tqdm(data.iterrows()):
state=row[["s1_" + s for s in patients_features]].tolist()
action=row['action']
reward=row['reward']
next_state=row[["s2_" + s for s in patients_features]].tolist()
over=row['over']
self.sasr.append([state,action,reward,next_state,over])
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(feature_num, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
)
self.fc_action = torch.nn.Linear(256, action_num)
self.fc_state = torch.nn.Linear(256, 1)
def forward(self, state):
state = self.fc(state)
#评估state的价值
value_state = self.fc_state(state)
#每个state下每个action的价值
value_action = self.fc_action(state)
#综合以上两者计算最终的价值,action去均值是为了数值稳定
return value_state + value_action - value_action.mean(dim=-1,
keepdim=True)
model = Model().to(device)
model_delay = Model().to(device)
#复制参数
model_delay.load_state_dict(model.state_dict())
#训练
def train(pool):
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.1) # 设置学习率每10000步衰减0.1倍
loss_fn = torch.nn.MSELoss()
batch_size=32
epoch_number=80000
#共更新N轮数据
loss_series=[]
for epoch in tqdm(range(epoch_number)):
loss_epoch=0
#采样N条数据
state, action, reward, next_state, over = pool.sample(batch_size)
#计算value
value = model(state).gather(dim=1, index=action)
#对value进行限制
Q_thresh=20
lambda_value=0.01
loss_value=lambda_value*torch.clamp(torch.abs(value)-Q_thresh,min=0)
mean_value_loss=torch.mean(loss_value)
#计算target
with torch.no_grad():
target = model_delay(next_state)
target = target.max(dim=1)[0].reshape(-1, 1)
target = target * 0.99 * (1 - over) + reward
loss = loss_fn(value, target)+mean_value_loss
loss_epoch=loss.item()
loss.backward()
optimizer.step()
# scheduler.step()
optimizer.zero_grad()
if (epoch+1) % 1000==0:
loss_series.append(loss_epoch)
#复制参数
if (epoch + 1) % 10 == 0:
model_delay.load_state_dict(model.state_dict())
import matplotlib.pyplot as plt
plt.plot(range(1000,epoch_number+1,1000),loss_series)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.show()

