T2674519437 2024-04-14 22:42 采纳率: 33.3%
浏览 1

python死循环的退出条件


import os
import pickle

import numpy as np
import torch

from net.config import LinXiaoNetConfig
from net import LinXiaoNet
from net.loss.alpha_loss import AlphaLoss
from net.data.train_data_cache import TrainDataCache
from mcts.monte_tree_v2 import MonteTree, transfer_to_net_input, pos_idx2pos_pair
from utils.log import init_logger, logger


def save_chess_record(file_path, record):
    if not os.path.isdir(os.path.dirname(file_path)):
        os.makedirs(os.path.dirname(file_path))
    with open(file_path, 'wb+') as f:
        pickle.dump(record, f)


def save_checkpoint(save_dir, ep_num, chess_num, model_dict, optimizer_dict, lr_schedule_dict, data_cache):
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    torch.save(ddel_dict, os.path.join(save_dir, 'model.pth'))
    torch.save(optimizer_dict, os.path.join(save_dir, 'optimizer.pth'))
    torch.save(lr_schedule_dict, os.path.join(save_dir, 'lr_schedule.pth'))
    with open(os.path.join(save_dir, 'epoch.txt'), 'w+') as f:
        f.write('{}\n'.format(ep_num))
    with open(os.path.join(save_dir, 'chess_num.txt'), 'w+') as f:
        f.write('{}\n'.format(chess_num))
    with open(os.path.join(save_dir, 'data_cache.pkl'), 'wb+') as f:
        pickle.dump(data_cache, f)
    # with open(os.path.join(save_dir, 'tree.pkl'), 'wb+') as f:
    #     pickle.dump(tree, f)


def load_checkpoint(checkpoint_path):
    filename_list = os.listdir(checkpoint_path)
    model_filename = None
    optimizer_filename = None
    lr_schedule_filename = None
    epoch_filename = None
    chess_num_filename = None
    data_cache_filename = None
    # tree_filename = None

    model_data = None
    optimizer_data = None
    lr_schedule_data = None
    epoch_data = None
    chess_num = None
    data_cache_data = None
    # tree_data = None

    for filename in filename_list:
        if filename.find('model') > -1:
            model_filename = filename
        if filename.find('optimizer') > -1:
            optimizer_filename = filename
        if filename.find('lr_schedule') > -1:
            lr_schedule_filename = filename
        if filename.find('epoch') > -1:
            epoch_filename = filename
        if filename.find('chess_num') > -1:
            chess_num_filename = filename
        if filename.find('data_cache') > -1:
            data_cache_filename = filename
        # if filename.find('tree') > -1:
        #     tree_filename = filename
    if model_filename is not None:
        model_data = torch.load(os.path.join(checkpoint_path, model_filename))
    if optimizer_filename is not None:
        optimizer_data = torch.load(os.path.join(checkpoint_path, optimizer_filename))
    if lr_schedule_filename is not None:
        lr_schedule_data = torch.load(os.path.join(checkpoint_path, lr_schedule_filename))
    if epoch_filename is not None:
        with open(os.path.join(checkpoint_path, epoch_filename), 'r') as f:
            epoch_data = int(f.readlines()[0].strip())
    if chess_num_filename is not None:
        with open(os.path.join(checkpoint_path, chess_num_filename), 'r') as f:
            chess_num = int(f.readlines()[0].strip())
    if data_cache_filename is not None:
        with open(os.path.join(checkpoint_path, data_cache_filename), 'rb') as f:
            data_cache_data = pickle.load(f)
    # if tree_filename is not None:
    #     with open(os.path.join(checkpoint_path, tree_filename), 'rb') as f:
    #         tree_data = pickle.load(f)
    return model_data, optimizer_data, lr_schedule_data, data_cache_data, epoch_data, chess_num


# 生成训练数据
def generate_train_data(chess_size, chess_record):
    # 新建个空棋盘状态
    chess_state = np.zeros((chess_size, chess_size))
    data = []
    # 初始是黑子落子
    player = 1
    # 根据对弈记录步数计算是谁赢了
    winner = -1 if len(chess_record) % 2 == 0 else 1
    for i in range(len(chess_record)):
        # 获取落子位置索引
        pos_idx = chess_record[i][1]
        state = transfer_to_net_input(chess_state, player, chess_size)
        # 记录训练数据
        data.append({
            'state': state,
            'distribution': chess_record[i][0],
            'value': winner
        })
        # 根据棋盘状态和落子位置更新棋盘状态
        chess_state[pos_idx2pos_pair(pos_idx, chess_size)[0], pos_idx2pos_pair(pos_idx, chess_size)[1]] = player
        # 易手
        player = -player
        # TODO: 思考这里为什么要变号
        # 因为winner是训练数据中的奖励value,应该始终保持对于赢家为1、对于输家为-1;加入黑棋应,winner初始值为1
        # 由于黑子先手,因此对于所有黑子落子的状态都给正奖励,给所有白子落子的状态都给负奖励;反之亦然
        winner = -winner
    return data


if __name__ == '__main__':

    conf = LinXiaoNetConfig()
    conf.set_cuda(True)
    conf.set_input_shape(8, 8)
    conf.set_train_info(5, 16, 1e-2)
    conf.set_checkpoint_config(5, 'checkpoints/v2train')
    conf.set_num_worker(0)
    conf.set_log('log/v2train.log')
    # conf.set_pretrained_path('checkpoints/v2m4000/epoch_15')

    init_logger(conf.log_file)
    logger()(conf)

    device = 'cuda' if conf.use_cuda else 'cpu'

    # 创建策略网络
    model = LinXiaoNet(3)
    model.to(device)

    loss_func = AlphaLoss()
    loss_func.to(device)

    optimizer = torch.optim.SGD(model.parameters(), conf.init_lr, 0.9, weight_decay=5e-4)
    lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.95)

    # initial config tree
    tree = MonteTree(model, device, chess_size=conf.input_shape[0], simulate_count=500)
    data_cache = TrainDataCache(num_worker=conf.num_worker)

    ep_num = 0
    chess_num = 0
    # config train interval
    train_every_chess = 18

    # 加载检查点
    if conf.pretrain_path is not None:
        model_data, optimizer_data, lr_schedule_data, data_cache, ep_num, chess_num = load_checkpoint(conf.pretrain_path)
        model.load_state_dict(model_data)
        optimizer.load_state_dict(optimizer_data)
        lr_schedule.load_state_dict(lr_schedule_data)
        logger()('successfully load pretrained : {}'.format(conf.pretrain_path))

    while True:
        logger()(f'self chess game no.{chess_num+1} start.')
        # 进行一次自我对弈,获取对弈记录
        chess_record = tree.self_game()
        logger()(f'self chess game no.{chess_num+1} end.')
        # 根据对弈记录生成训练数据
        train_data = generate_train_data(tree.chess_size, chess_record)
        # 将训练数据存入缓存
        for i in range(len(train_data)):
            data_cache.push(train_data[i])
        if chess_num % train_every_chess == 0:
            logger()(f'train start.')
            loader = data_cache.get_loader(conf.batch_size)
            model.train()

            for _ in range(conf.epoch_num):
                loss_record = []
                for bat_state, bat_dist, bat_winner in loader:
                    bat_state, bat_dist, bat_winner = bat_state.to(device), bat_dist.to(device), bat_winner.to(device)
                    optimizer.zero_grad()
                    prob, value = model(bat_state)
                    loss = loss_func(prob, value, bat_dist, bat_winner)
                    loss.backward()
                    optimizer.step()
                    loss_record.append(loss.item())
                logger()(f'train epoch {ep_num} loss: {sum(loss_record) / float(len(loss_record))}')
                ep_num += 1
                if ep_num % conf.checkpoint_save_every_num == 0:
                    save_checkpoint(
                        os.path.join(conf.checkpoint_save_dir, f'epoch_{ep_num}'),
                        ep_num, chess_num, model.state_dict(), optimizer.state_dict(), lr_schedule.state_dict(), data_cache
                    )
            lr_schedule.step()
            logger()(f'train end.')
        chess_num += 1
        save_chess_record(
            os.path.join(conf.checkpoint_save_dir, f'chess_record_{chess_num}.pkl'),
            chess_record
        )
        # break

    pass

这个python运行,最后这个循环有退出条件吗

  • 写回答

5条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-14 22:42
    关注

    引自免费微信小程序:皆我百晓生

    这个Python程序中的主循环没有显式的退出条件。它会持续不断地进行自我对弈(self-game),生成训练数据,将数据添加到缓存,并在达到一定数量的棋局后进行训练。循环的执行依赖于程序未被人为中断。

    要给这个无限循环添加一个退出条件,你可以考虑以下方法之一:

    1. 设置最大棋局数:定义一个最大棋局数变量,当chess_num达到这个值时,跳出循环。
    max_chess_num = 10000  # 设置最大的棋局数
    while chess_num < max_chess_num:
        ...
    
    1. 键盘中断(KeyboardInterrupt):允许用户通过按下Ctrl+C来中断程序执行。在这种情况下,你需要捕获KeyboardInterrupt异常并适当地清理资源。
    try:
        while True:
            ...
    except KeyboardInterrupt:
        print("Training interrupted by user.")
        # 在这里添加必要的清理操作,例如保存当前状态等
        exit(0)
    
    1. 根据特定条件退出:你还可以基于其他逻辑条件,比如训练效果达到某个阈值或者达到预设的训练轮数等,来决定何时退出循环。

    选择上述任一方式,并相应地修改代码即可为循环添加退出条件。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月14日

悬赏问题

  • ¥15 ad5933的I2C
  • ¥15 请问RTX4060的笔记本电脑可以训练yolov5模型吗?
  • ¥15 数学建模求思路及代码
  • ¥50 silvaco GaN HEMT有栅极场板的击穿电压仿真问题
  • ¥15 谁会P4语言啊,我想请教一下
  • ¥15 哪个tomcat中startup一直一闪而过 找不出问题
  • ¥15 这个怎么改成直流激励源给加热电阻提供5a电流呀
  • ¥50 求解vmware的网络模式问题 别拿AI回答
  • ¥24 EFS加密后,在同一台电脑解密出错,证书界面找不到对应指纹的证书,未备份证书,求在原电脑解密的方法,可行即采纳
  • ¥15 springboot 3.0 实现Security 6.x版本集成