brand new 223 2024-06-13 09:55 采纳率: 20%
浏览 30
已结题

为什么我版本升级之后运行速度变慢5倍??

我想利用新版本pytorchlightning的功能,所以进行了版本升级,但是升级之后程序运行速度显著变慢,显示是在gpu上运行的,但为什么速度比版本升级前慢5倍呢??下面是我的train代码和环境中的包配置。显示是在gpu上运行的。不知道为什么会出现这种情况。
我的显卡是4090,我只有一个gpu。之前版本的话运行一个epoch需要半小时,现在版本更新后需要两个半小时。
关于精度的问题,之前我的代码就是32位训练的,所以我这个也保持32位精度训练。

torch.set_float32_matmul_precision('high')

parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', default='exp/tmp', help='Full path to save best validation model')
parser.add_argument('--pretrain', type=ast.literal_eval, required=True, help='whether to pretrain the stage 1 model')

# profiler = PyTorchProfiler(
#     output_filename="profiler.txt",
#     dirpath='exp_dir',
#     activities=[
#         torch.profiler.ProfilerActivity.CPU,
#         torch.profiler.ProfilerActivity.CUDA
#     ],
#     schedule=torch.profiler.schedule(
#         wait=1,
#         warmup=1,
#         active=3,
#         repeat=2
#     ),
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('exp_dir'),
#     record_shapes=True,
#     profile_memory=True,
#     with_stack=True
# )

def main(conf):
    conf['masknet'].update({'n_src': conf['data']['n_src']})
    model, optimizer = make_model_and_optimizer(conf)
    exp_dir = conf['main_args']['exp_dir']

    if os.path.exists(os.path.join(exp_dir, 'precheckpoints/')):
        all_ckpt = os.listdir(os.path.join(exp_dir, 'precheckpoints/'))
        all_ckpt = [(ckpt, int("".join(filter(str.isdigit, ckpt)))) for ckpt in all_ckpt]
        all_ckpt.sort(key=lambda x: x[1])
        best_model_path = os.path.join(exp_dir, 'precheckpoints', all_ckpt[-1][0])
        checkpoint = torch.load(best_model_path, map_location='cpu')
        orig = checkpoint['state_dict']
        model_statedict = model.state_dict()
        for k in orig.keys():
            model_statedict[k[6:]] = orig[k]
        model.load_state_dict(model_statedict, strict=False)
        for param_group in optimizer.param_groups:
            param_group['lr'] = conf['optim']['lr']

    train_loader, val_loader = make_dataloaders(**conf['data'], **conf['training'], channels=slice(0, 4))

    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,  # 实现动态学习率
                                      patience=5)

    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    loss_func = BFLoss()
    system = BeamTasNetSystem(pretrain=conf['main_args']['pretrain'], new_lr=conf['optim']['lr'],
                              model=model, loss_func=loss_func, optimizer=optimizer,
                              train_loader=train_loader, val_loader=val_loader,
                              scheduler=scheduler,config=conf)

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(exp_dir, 'checkpoints/'),
        monitor='val_loss',
        mode='min',
        save_top_k=-1,
        verbose=1
    )

    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    logger = TensorBoardLogger(exp_dir, name="my_model")

    best_model_path = None
    if os.path.exists(os.path.join(exp_dir, 'checkpoints/')):
        all_ckpt = os.listdir(os.path.join(exp_dir, 'checkpoints/'))
        all_ckpt = [(ckpt, int("".join(filter(str.isdigit, ckpt)))) for ckpt in all_ckpt if
                    'ckpt' in ckpt and 'init' not in ckpt]
        if len(all_ckpt) > 0:
            all_ckpt.sort(key=lambda x: x[1])
            best_model_path = os.path.join(exp_dir, 'checkpoints', all_ckpt[-1][0])
    print("resume from {}".format(best_model_path))

    trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
                         enable_checkpointing=True,
                         callbacks=[LearningRateMonitor(), checkpoint_callback],
                         default_root_dir=exp_dir,
                         devices=1 if torch.cuda.is_available() else 0,
                         strategy='auto',
                         accelerator='gpu',
                         gradient_clip_val=5.,
                         logger=logger,
                         # accumulate_grad_batches=6,
                         precision=32,
                         # limit_train_batches=0.1,
                         # limit_val_batches=0.2,
                         # profiler=profiler,
                         )
    print(trainer.strategy.root_device)  # 输出 'gpu''cpu'

    trainer.fit(system, ckpt_path=best_model_path)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = model
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))


if __name__ == '__main__':
    seed_everything(seed=0)

    with open('local/conf.yml') as f:
        def_conf = yaml.safe_load(f)
    parser = prepare_parser_from_dict(def_conf, parser=parser)
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
    print(arg_dic)
    main(arg_dic)



下面我新增我所有包的版本,
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
asteroid==0.7.0
asteroid-filterbanks==0.4.0
async-timeout==4.0.3
attrs==23.2.0
audioread==3.0.1
beautifulsoup4==4.12.3
Brotli @ file:///C:/b/abs_3d36mno480/croot/brotli-split_1714483178642/work
cached-property==1.5.2
cachetools==5.3.3
certifi @ file:///C:/b/abs_55jxithrm3/croot/certifi_1717618077850/work/certifi
cffi @ file:///C:/b/abs_78eb1_vq6z/croot/cffi_1714483206096/work
charset-normalizer==2.0.12
colorama @ file:///C:/b/abs_a9ozq0l032/croot/colorama_1672387194846/work
contourpy==1.1.1
cycler==0.12.1
Cython==3.0.10
decorator==5.1.1
Deprecated==1.2.14
einops==0.8.0
filelock==3.14.0
fonttools==4.53.0
frozenlist==1.4.1
fsspec @ file:///C:/b/abs_d5jzmndkba/croot/fsspec_1714461594761/work
future==1.0.0
google==3.0.0
google-auth==2.29.0
google-auth-oauthlib==1.0.0
grpcio==1.64.1
huggingface-hub==0.23.2
idna @ file:///C:/b/abs_aad84bnnw5/croot/idna_1714398896795/work
importlib_metadata==7.1.0
importlib_resources==6.4.0
intel-openmp==2021.4.0
Jinja2==3.1.3
joblib==1.4.2
julius==0.2.7
kiwisolver==1.4.5
lazy_loader==0.4
librosa==0.10.2.post1
lightning-utilities==0.11.2
llvmlite==0.41.1
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.7.5
mir_eval==0.7
mkl==2021.4.0
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
networkx==3.0
neural_compressor==2.5.1
numba==0.58.1
numpy==1.24.4
oauthlib==3.2.2
opencv-python-headless==4.10.0.82
packaging==24.1
pandas==2.0.3
pb-bss-eval==0.0.2
pesq==0.0.4
pillow==10.2.0
platformdirs==4.2.2
pooch==1.8.1
prettytable==3.10.0
protobuf==5.27.1
psutil==5.9.8
py-cpuinfo==9.0.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocotools-windows==2.0.0.2
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pyDeprecate==0.3.1
pyparsing==3.1.2
PySocks @ file:///C:/ci/pysocks_1605287845585/work
pystoi==0.4.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.3.0
pytorch-ranger==0.1.1
pytz==2024.1
PyYAML @ file:///C:/b/abs_782o3mbw7z/croot/pyyaml_1698096085010/work
requests==2.26.0
requests-oauthlib==2.0.0
rsa==4.9
schema==0.7.7
scikit-learn==1.3.2
scipy==1.10.1
shapely==2.0.4
six==1.16.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.3.7
sympy==1.12.1
tbb==2021.12.0
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboard-plugin-wit==1.8.1
thop==0.1.1.post2209072238
threadpoolctl==3.5.0
torch==2.2.0+cu118
torch-complex==0.4.3
torch-optimizer==0.1.0
torch-stoi==0.2.1
torch-tb-profiler==0.4.3
torchaudio==2.2.0+cu118
torchmetrics==0.11.4
torchvision==0.17.0+cu118
tqdm @ file:///C:/b/abs_a7h7n45ccq/croot/tqdm_1716395987133/work
typing_extensions==4.12.2
tzdata==2024.1
urllib3==1.26.18
wcwidth==0.2.13
Werkzeug==3.0.3
win-inet-pton @ file:///C:/ci/win_inet_pton_1605306167264/work
wrapt==1.16.0
yarl==1.9.4
zipp==3.19.2

我知道我可能有包版本的冲突,但我怎么知道具体是哪几个包有冲突呢?安装包的话,我是该用conda还是pip安装呢??
我之前版本的包是pytoch-lightning==0.7.6,pytorch==1.10.0.

最后再附上我的profiler的分析图。但是不知道为什么,没有gpu的分析,只有cpu,这个问题我也不知道该如何解决。

img

  • 写回答

24条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-06-13 09:57
    关注
    获得1.00元问题酬金

    引自免费微信小程序:皆我百晓生

    面对PyTorch Lightning升级后运行速度明显下降的问题,有几种可能的原因及对应的排查与解决策略。根据提供的信息,我们可以从以下几个方面入手:

    1. 版本兼容性

    确保所有依赖库(如PyTorch、CUDA等)与新版本的PyTorch Lightning兼容。不兼容的版本组合可能会导致性能下降。

    2. 性能分析工具

    你已经尝试在代码中添加了Profiler注释,但似乎没有实际启用。启用性能分析可以帮助识别瓶颈所在。将以下代码段加入到你的训练循环中,确保正确配置Profiler并运行它:

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA
        ],
        schedule=torch.profiler.schedule(
            wait=1,
            warmup=1,
            active=3,
            repeat=2
        ),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('exp_dir'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for epoch in range(num_epochs):
            for batch in train_loader:
                # Your training step here
                # ...
                prof.step()
    

    3. 检查并优化模型

    • 模型复杂度:检查新版本是否引入了更复杂的模型结构或计算密集型操作。
    • 数据加载:数据预处理和加载时间可能随着版本更新而变化。确保数据加载效率。
    • GPU利用率:确认GPU是否被充分利用。有时,模型的某些部分可能没有在GPU上运行,导致效率低下。

    4. 梯度累积

    你提到的accumulate_grad_batches参数被注释掉了。这个参数可以显著影响内存使用和训练速度。如果模型太大,尝试重新启用它,并调整其值以找到最佳平衡点。

    5. 精度设置

    你设置了precision=32。尝试将精度设置为16,这通常会提高GPU上的训练速度,但要确保模型的数值稳定性不受影响。

    6. 并行策略

    检查strategy='auto'是否选择了最优的并行策略。可能需要手动指定一个更合适的策略,比如strategy='ddp',特别是如果你有多个GPU可用时。

    7. 硬件和驱动

    确保你的GPU驱动是最新的,有时候旧的驱动与新版本的库不兼容,导致性能下降。

    8. 代码审查

    仔细审查你的代码,尤其是与新功能相关的部分,确保没有引入低效的操作或不必要的计算。

    结论

    通过上述步骤,你应该能够定位性能下降的原因,并采取相应的措施进行优化。如果问题仍然存在,考虑回退到之前的版本,直到找到一个稳定的解决方案。同时,也可以在PyTorch Lightning的官方论坛或GitHub上寻求社区的帮助,提供详细的错误报告和性能分析结果。

    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 6月21日
  • 修改了问题 6月17日
  • 修改了问题 6月17日
  • 修改了问题 6月13日
  • 展开全部

悬赏问题

  • ¥15 关于#python#的问题
  • ¥15 主机可以ping通路由器但是连不上网怎么办
  • ¥15 数据库一张以时间排好序的表中,找出多次相邻的那些行
  • ¥50 关于DynamoRIO处理多线程程序时候的问题
  • ¥15 kubeadm部署k8s出错
  • ¥15 Abaqus打不开cae文件怎么办?
  • ¥20 双系统开机引导中windows系统消失问题?
  • ¥15 小程序准备上线,软件开发公司需要提供哪些资料给甲方
  • ¥15 关于生产日期批次退货退款,库存回退的问题
  • ¥15 手机应用的时间可以修改吗