使用PPO算法不收敛
import os
import glob
import time
from datetime import datetime
import torch
import numpy as np
import gym
from PPODemo import PPO
from gym import spaces
from OrbitalTransfer import *
#修改为使用归一化输入
import random
import time
from OrbitCore import *
class Environment(gym.Env):
def __init__(self):
self.min_action = -1
self.max_action = 1
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(12,),
dtype=float) #追击rv,逃逸rv
self.action_space = spaces.Box(low = self.min_action, high = self.max_action, shape=(3,),
dtype=float) # action给出三个xyx两个速度增量方向
self.chase = None #归一化之后的
self.escape = None
self.time = 1
# torch.set_default_dtype(torch.float64)
def reset(self):
self._initialize_positions()#0-1之间
observation = self._get_observation()#0-1之间
return observation
def _initialize_positions(self): # 初始化航天器位置信息
current_time = int(time.time())
# 使用当前时间戳作为随机数生成器的种子
random.seed(current_time)
'''
初始化:设置a,e,i,omega,w,M0
设置a,i,omega差不多
'''
a = random.random()
i = random.random()
omega = random.random()
self.chase = torch.tensor([a, random.random(), i, omega, random.random(), random.random()])#设置到0-1之间
self.escape = torch.tensor([a, random.random(), i, omega, random.random(), random.random()])
def step(self,action):#输入action,输出状态,奖励,结束标志,action应该也是tensor
# action = torch.clamp(torch.tensor(action)*0.0001, min=-0.0001, max=0.0001)#需要修改为直接输出为tensor,限制action大小
action = torch.clamp(action*0.0001, min=-0.0001, max=0.0001)
# action = torch.clamp(action, min=-1, max=1)#限制大小
self.chase = self.update_state(self.chase,action)
action = torch.tensor([0, 0, 0])
self.escape = self.update_state(self.escape, action)
observation = self._get_observation()
reward = self.reward()
done = self.check_termination()
return observation, reward, done, None
def update_state(self,state,action):#输入0-1的state,更新self.state
S = self.inverseTrans(state)
Core = OrbitCore()
Transfer = OrbitalTransfer()
r, v = Core.Orbit_Element_2_State_rv(S)
v = v + action
Predict = OrbitPredict()
rv = torch.cat((r.unsqueeze(0),v.unsqueeze(0)),dim = 1).view(-1)
# print("rv",rv)
# r, v = Predict.J2OrbitRV(rv,50)#时间先设置为50s,输出rv
# coe = Core.State_rv_2_Orbit_Element(r, v)#这个函数有问题,需要修改
# coe2 = Core.State_rv_2_Orbit_Element(r.numpy(),v.numpy())
# state = self.trans(coe)#转化到0-1之间
coe = Core.State_rv_2_Orbit_Element(r, v)
if torch.isnan(coe[5]):
print("error in main update_state State_rv_2_Orbit_Element")
# print("coe",coe)#出现e>1,设置推力小于0.0001
coe = Predict.J2Orbit(coe, 50)
if torch.isnan(coe[5]):
print("error in main update_state J2Orbit")
state = self.trans(coe)
return state
def _get_observation(self):
observation = torch.cat((self.chase.unsqueeze(0),self.escape.unsqueeze(0)), dim=1)
return observation
def check_termination(self):
terminate = False
chase = self.inverseTrans(self.chase)#0-1转化为六根
escape = self.inverseTrans(self.escape)
distance = self.distance(chase,escape)
if distance <= 0.5:
terminate = True
return terminate
def trans(self,state):#六根转化到0-1之间
a = (state[0]-30000)/15000
e = state[1]
i = state[2]/180
omega = state[3]/180
w = state[4]/360
M0 = state[5]/360
return torch.tensor([a, e, i, omega, w, M0])
def inverseTrans(self,state):#0-1转化为六根数
a = state[0]*15000 + 30000
e = state[1]
i = state[2]*180
omega = state[3]*180
w = state[4]*360
M0 = state[5]*360
return torch.tensor([a,e,i,omega,w,M0], dtype=torch.float64)
def distance(self,state1,state2):
Core = OrbitCore()
r1, v1 = Core.Orbit_Element_2_State_rv(state1)#修改为支持tensor的,state是六根
r2, v2 = Core.Orbit_Element_2_State_rv(state2)
distance = torch.norm(r1 - r2)
return distance
def reward(self):
coe_chase = self.inverseTrans(self.chase)
coe_escape = self.inverseTrans(self.escape)
reward = -1 * self.distance(coe_chase,coe_escape)
return reward
################################### Training ###################################
def train():
# Predict = OrbitPredict()
print("============================================================================================")
####### initialize environment hyperparameters ######
env_name = "train-orbit-cw"
has_continuous_action_space = True # continuous action space; else discrete
max_ep_len = 1000 # max timesteps in one episode
max_training_timesteps = int(3e6) # break training loop if timeteps > max_training_timesteps
print_freq = max_ep_len * 10 # print avg reward in the interval (in num timesteps)
log_freq = max_ep_len * 2 # log avg reward in the interval (in num timesteps)
save_model_freq = int(1e5) # save model frequency (in num timesteps)
action_std = 0.6 # starting std for action distribution (Multivariate Normal)
action_std_decay_rate = 0.05 # linearly decay action_std (action_std = action_std - action_std_decay_rate)
min_action_std = 0.1 # minimum action_std (stop decay after action_std <= min_action_std)
action_std_decay_freq = int(2.5e5) # action_std decay frequency (in num timesteps)
#####################################################
## Note : print/log frequencies should be > than max_ep_len
################ PPO hyperparameters ################
update_timestep = max_ep_len * 4 # update policy every n timesteps
K_epochs = 80 # update policy for K epochs in one PPO update
eps_clip = 0.2 # clip parameter for PPO
gamma = 0.99 # discount factor
lr_actor = 0.0003 # learning rate for actor network
lr_critic = 0.001 # learning rate for critic network
random_seed = 0 # set random seed if required (0 = no random seed)
#####################################################
print("training environment name : " + env_name)
# env = gym.make(env_name)
env = Environment()
# state space dimension
state_dim = env.observation_space.shape[0]
# action space dimension
if has_continuous_action_space:
action_dim = env.action_space.shape[0]
else:
action_dim = env.action_space.n
###################### logging ######################
#### log files for multiple runs are NOT overwritten
log_dir = "PPO_logs"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_dir = log_dir + '/' + env_name + '/'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
#### get number of log files in log directory
run_num = 0
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)
#### create new log file for each run
log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv"
print("current logging run number for " + env_name + " : ", run_num)
print("logging at : " + log_f_name)
#####################################################
################### checkpointing ###################
log_dir = "PPO_preTrained"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_dir = log_dir + '/' + env_name + '/'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
run_num_pretrained = 0 #### change this to prevent overwriting weights in same env_name folder
current_num_files = next(os.walk(log_dir))[2]
run_num_pretrained = len(current_num_files)
directory = "PPO_preTrained"
if not os.path.exists(directory):
os.makedirs(directory)
directory = directory + '/' + env_name + '/'
if not os.path.exists(directory):
os.makedirs(directory)
checkpoint_path = directory + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)
#####################################################
############# print all hyperparameters #############
print("--------------------------------------------------------------------------------------------")
print("max training timesteps : ", max_training_timesteps)
print("max timesteps per episode : ", max_ep_len)
print("model saving frequency : " + str(save_model_freq) + " timesteps")
print("log frequency : " + str(log_freq) + " timesteps")
print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")
print("--------------------------------------------------------------------------------------------")
print("state space dimension : ", state_dim)
print("action space dimension : ", action_dim)
print("--------------------------------------------------------------------------------------------")
if has_continuous_action_space:
print("Initializing a continuous action space policy")
print("--------------------------------------------------------------------------------------------")
print("starting std of action distribution : ", action_std)
print("decay rate of std of action distribution : ", action_std_decay_rate)
print("minimum std of action distribution : ", min_action_std)
print("decay frequency of std of action distribution : " + str(action_std_decay_freq) + " timesteps")
else:
print("Initializing a discrete action space policy")
print("--------------------------------------------------------------------------------------------")
print("PPO update frequency : " + str(update_timestep) + " timesteps")
print("PPO K epochs : ", K_epochs)
print("PPO epsilon clip : ", eps_clip)
print("discount factor (gamma) : ", gamma)
print("--------------------------------------------------------------------------------------------")
print("optimizer learning rate actor : ", lr_actor)
print("optimizer learning rate critic : ", lr_critic)
if random_seed:
print("--------------------------------------------------------------------------------------------")
print("setting random seed to ", random_seed)
torch.manual_seed(random_seed)
env.seed(random_seed)
np.random.seed(random_seed)
#####################################################
print("============================================================================================")
################# training procedure ################
# initialize a PPO agent
ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space,
action_std)
# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("============================================================================================")
# logging file
log_f = open(log_f_name, "w+")
log_f.write('episode,timestep,reward\n')
# printing and logging variables
print_running_reward = 0
print_running_episodes = 0
log_running_reward = 0
log_running_episodes = 0
time_step = 0
i_episode = 0
# training loop
while time_step <= max_training_timesteps:
state = env.reset()
current_ep_reward = 0
for t in range(1, max_ep_len + 1):
# select action with policy
action = ppo_agent.select_action(state.float())
# print(action)
state, reward, done, _ = env.step(action)
# saving reward and is_terminals
ppo_agent.buffer.rewards.append(reward)
ppo_agent.buffer.is_terminals.append(done)
time_step += 1
current_ep_reward += reward
# update PPO agent
if time_step % update_timestep == 0:
ppo_agent.update()
# if continuous action space; then decay action std of ouput action distribution
if has_continuous_action_space and time_step % action_std_decay_freq == 0:
ppo_agent.decay_action_std(action_std_decay_rate, min_action_std)
# log in logging file
if time_step % log_freq == 0:
# log average reward till last episode
log_avg_reward = log_running_reward / log_running_episodes
# log_avg_reward = round(log_avg_reward, 4)
log_avg_reward = torch.round(log_avg_reward)
log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
log_f.flush()
log_running_reward = 0
log_running_episodes = 0
# printing average reward
if time_step % print_freq == 0:
# print average reward till last episode
print_avg_reward = print_running_reward / print_running_episodes
# print_avg_reward = round(print_avg_reward, 2)
print_avg_reward = torch.round(print_avg_reward)
print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step,
print_avg_reward))
print_running_reward = 0
print_running_episodes = 0
# save model weights
if time_step % save_model_freq == 0:
print("--------------------------------------------------------------------------------------------")
print("saving model at : " + checkpoint_path)
ppo_agent.save(checkpoint_path)
print("model saved")
print("Elapsed Time : ", datetime.now().replace(microsecond=0) - start_time)
print("--------------------------------------------------------------------------------------------")
# break; if the episode is over
if done:
break
print_running_reward += current_ep_reward
print_running_episodes += 1
log_running_reward += current_ep_reward
log_running_episodes += 1
i_episode += 1
log_f.close()
env.close()
# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time : ", end_time - start_time)
print("============================================================================================")
if __name__ == '__main__':
train()
上面是训练的代码(还涉及调用其他的库)
具体的项目地址是:https://github.com/YRQhit/PPOTrain/tree/main
代码的目的是训练一个训练一个网络,给三个方向的速度,靠近一个目标,但是训练中不收敛
训练的曲线图如下