cinbol 2023-03-19 18:36 采纳率: 50%
浏览 402
已结题

利用stable_baseline3算法库中的PPO算法训练自定义gym环境

用stable_baseline3的PPO训练自定义gym接口环境,目标如下:
输入(observation_space ):一个shape为2*8的矩阵,矩阵上各元素的值满足一定范围要求,由random随机生成得到
输出(action_space ):一个shape为1*3的矩阵,同样矩阵上各元素的值满足一定范围要求
理想的训练效果:无论输入是啥,输出稳定到[[5,20,200]],(或者比较接近,在这附近波动也行)


目前的问题:目前的情况就是训练不出任何效果;训练得到的权重文件,无论输入是啥,输出都会稳定到最大边界上(例如输出始终是[[-15,60,400]],或者始终是[[15,60,0]],或者始终是[[15,-60,0]]等等),请大家帮忙指点一下:
(解答要求:别复制粘贴GPT,能提供实际可行的建议达到上述理想训练效果)
下面是源码,有stable_baseline3、gym环境可直接运行,如果没有环境,可以在anaconda下pytorch+gym+stable_baseline3环境(分享码gb0v)下载

# -*- coding: utf-8 -*-
import gym
import numpy as np
import random
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
import torch


class gymEnv_CSE(gym.Env):
    """
    输入:一个shape为2*8的矩阵,矩阵上各元素的值满足一定范围要求,由random随机生成得到
    输出:一个shape为1*3的矩阵
    理想的训练效果:无论输入是啥,输出稳定为[[5,20,200]]
    """
    def __init__(self):
        self.observation_space = gym.spaces.Box(low=np.array([[200000, -13000000, -10000, 0, -90, -180, 0, 0] for _ in range(2)]),
                                                high=np.array([[600000, -12600000, 1000, 360, 90, 180, 400, 1] for _ in range(2)]),
                                                shape=(2, 8), dtype=np.float64)

        self.action_space = gym.spaces.Box(low=np.array([[-15, -60, 0]]),
                                           high=np.array([[15, 60, 400]]),
                                           shape=(1, 3), dtype=np.float64)

        self.state = None  # 强化学习输入所需状态
        self.done = None  # 本次实例是否完成
        self.actionRecord = None
        self.step_total = 0  # 累计步长
        self.reward_total = 0  # 累计回报

    def seed(self, seed=None):
        pass

    def reset(self):
        self.state = np.zeros((2, 8))
        self.done = False
        self.actionRecord = None
        self.step_total = 0
        self.reward_total = 0

        self.get_State_From_PlayerDataJson()
        return self.state

    def __del__(self):
        pass

    def get_State_From_PlayerDataJson(self):
        self.state[0, 0] = random.random() * 400000 + 200000
        self.state[0, 1] = random.random() * 400000 - 13000000
        self.state[0, 2] = random.random() * 11000 - 10000
        self.state[0, 3] = random.random() * 360
        self.state[0, 4] = random.random() * 180 - 90
        self.state[0, 5] = random.random() * 360 - 180
        self.state[0, 6] = random.random() * 400
        self.state[0, 7] = 1
        self.state[1, 0] = random.random() * 400000 + 200000
        self.state[1, 1] = random.random() * 400000 - 13000000
        self.state[1, 2] = random.random() * 11000 - 10000
        self.state[1, 3] = random.random() * 360
        self.state[1, 4] = random.random() * 180 - 90
        self.state[1, 5] = random.random() * 360 - 180
        self.state[1, 6] = random.random() * 400
        self.state[1, 7] = 0
        pass

    def step(self, action: np.ndarray):
        # # 从下面这几行可以看出,神经网络从训练刚开始产生的决策就没变过
        # if self.actionRecord is None:
        #     pass
        # else:
        #     if not (self.actionRecord == action).all():
        #         print(self.step_total)
        # self.actionRecord = action

        # 步数记录+1
        self.step_total += 1

        # 更新状态
        self.get_State_From_PlayerDataJson()

        # 计算reward,当输出稳定为[[5,20,200]]左右时的reward最大
        Fa = abs(action[0, 0] - 5)
        Fb = abs(action[0, 1] - 20)
        Fc = abs(action[0, 2] - 200)
        reward = (300 - Fa - Fb - Fc)/300

        # 累积reward
        self.reward_total += reward
        if self.step_total > 200:
            self.done = True
            print(self.reward_total)
        return self.state, reward, self.done, {}

    def close(self):
        pass

    def render(self, mode="human"):
        pass


def linear_schedule(progress_remaining: float):
    return progress_remaining * 0.0005


def stepCallBack(a, b):
    """
    用于保存过程权重
    """
    weightSaveInterval = 200000
    if a["self"].num_timesteps % weightSaveInterval == 0:
        a["self"].save("./PPO_processWeight_IO1/W_" + str(a["self"].num_timesteps))


def train_PPO():
    # Parallel environments
    num_process = 8
    # 这里是多进程并行训练环境
    envList = [gymEnv_CSE for _ in range(num_process)]
    env = SubprocVecEnv(envList)
    # # 单进程环境
    # env = gymEnv_CSE()

    policy_kwargs = dict(activation_fn=torch.nn.ReLU,
                         net_arch=[128, 128, 256, dict(pi=[128, 64], vf=[128, 32])])

    model = PPO(policy="MlpPolicy",  # 选择网络类型,可选MlpPolicy,CnnPolicy,MultiInputPolicy
                env=env,  # Gym中的环境
                learning_rate=linear_schedule,  # 学习率,默认为0.0003
                batch_size=128,  # batch的大小,默认为64
                tensorboard_log="./CSE-TSNR_PPO_tensorboard/",  # tensorboard 的日志文件夹(如果没有,则不记录),默认为None
                policy_kwargs=policy_kwargs,  # 在创建时传递给策略的附加参数,默认为None
                verbose=0,  # 详细级别:0 无输出,1 信息,2 调试,默认为0
                )

    model.learn(total_timesteps=1000000,  # 要训练的环境步数
                callback=stepCallBack,  # 在每一步调用的回调,可以用CheckpointCallback来创建一个存档点和规定存档间隔
                )

    model.save("CSE-TSNR_PPO_IO1")


def run_PPO():
    # Parallel environments
    env = gymEnv_CSE()

    model = PPO.load("CSE-TSNR_PPO_IO1.zip")

    obs = env.reset()
    dones = False
    while not dones:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        print("--------------")
        print(action, obs, rewards, dones, info)
        env.render()

    print("run-PPO:success fin ^ V ^!")


if __name__ == '__main__':
    train_PPO()
    print("========================================================")
    run_PPO()


  • 写回答

10条回答 默认 最新

  • 追cium 2023-03-19 18:42
    关注

    参考GPT和自己的思路:

    根据您提供的代码和问题描述,有几个可能导致训练不成功的原因:

    1. action_space 的范围设置过大,可能导致训练不稳定。建议尝试缩小 action_space 的范围,看看训练效果是否有所改善。

    2. 神经网络结构、learning_rate、batch_size 等参数可能也会对训练效果有很大的影响。建议尝试调整这些参数,看看训练效果是否有所改善。

    3. 还有可能是您的训练数据不够充分,或者训练时间不够长,导致模型没有达到稳定状态。建议增加训练时间,并尝试增加训练数据的数量。

    另外,建议在训练过程中使用一些常用的训练技巧,比如经验回放、dropout、批归一化等,这些技巧可以有效提高训练效果和稳定性。

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 3月21日
  • 修改了问题 3月19日
  • 修改了问题 3月19日
  • 创建了问题 3月19日

悬赏问题

  • ¥100 需要跳转番茄畅听app的adb命令
  • ¥50 寻找一位有逆向游戏盾sdk 应用程序经验的技术
  • ¥15 请问有用MZmine处理 “Waters SYNAPT G2-Si QTOF质谱仪在MSE模式下采集的非靶向数据” 的分析教程吗
  • ¥50 opencv4nodejs 如何安装
  • ¥15 adb push异常 adb: error: 1409-byte write failed: Invalid argument
  • ¥15 nginx反向代理获取ip,java获取真实ip
  • ¥15 eda:门禁系统设计
  • ¥50 如何使用js去调用vscode-js-debugger的方法去调试网页
  • ¥15 376.1电表主站通信协议下发指令全被否认问题
  • ¥15 物体双站RCS和其组成阵列后的双站RCS关系验证