brand new 223 2024-06-13 09:55 采纳率: 20%
浏览 45
已结题

为什么我版本升级之后运行速度变慢5倍??

我想利用新版本pytorchlightning的功能,所以进行了版本升级,但是升级之后程序运行速度显著变慢,显示是在gpu上运行的,但为什么速度比版本升级前慢5倍呢??下面是我的train代码和环境中的包配置。显示是在gpu上运行的。不知道为什么会出现这种情况。
我的显卡是4090,我只有一个gpu。之前版本的话运行一个epoch需要半小时,现在版本更新后需要两个半小时。
关于精度的问题,之前我的代码就是32位训练的,所以我这个也保持32位精度训练。

torch.set_float32_matmul_precision('high')

parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', default='exp/tmp', help='Full path to save best validation model')
parser.add_argument('--pretrain', type=ast.literal_eval, required=True, help='whether to pretrain the stage 1 model')

# profiler = PyTorchProfiler(
#     output_filename="profiler.txt",
#     dirpath='exp_dir',
#     activities=[
#         torch.profiler.ProfilerActivity.CPU,
#         torch.profiler.ProfilerActivity.CUDA
#     ],
#     schedule=torch.profiler.schedule(
#         wait=1,
#         warmup=1,
#         active=3,
#         repeat=2
#     ),
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('exp_dir'),
#     record_shapes=True,
#     profile_memory=True,
#     with_stack=True
# )

def main(conf):
    conf['masknet'].update({'n_src': conf['data']['n_src']})
    model, optimizer = make_model_and_optimizer(conf)
    exp_dir = conf['main_args']['exp_dir']

    if os.path.exists(os.path.join(exp_dir, 'precheckpoints/')):
        all_ckpt = os.listdir(os.path.join(exp_dir, 'precheckpoints/'))
        all_ckpt = [(ckpt, int("".join(filter(str.isdigit, ckpt)))) for ckpt in all_ckpt]
        all_ckpt.sort(key=lambda x: x[1])
        best_model_path = os.path.join(exp_dir, 'precheckpoints', all_ckpt[-1][0])
        checkpoint = torch.load(best_model_path, map_location='cpu')
        orig = checkpoint['state_dict']
        model_statedict = model.state_dict()
        for k in orig.keys():
            model_statedict[k[6:]] = orig[k]
        model.load_state_dict(model_statedict, strict=False)
        for param_group in optimizer.param_groups:
            param_group['lr'] = conf['optim']['lr']

    train_loader, val_loader = make_dataloaders(**conf['data'], **conf['training'], channels=slice(0, 4))

    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,  # 实现动态学习率
                                      patience=5)

    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    loss_func = BFLoss()
    system = BeamTasNetSystem(pretrain=conf['main_args']['pretrain'], new_lr=conf['optim']['lr'],
                              model=model, loss_func=loss_func, optimizer=optimizer,
                              train_loader=train_loader, val_loader=val_loader,
                              scheduler=scheduler,config=conf)

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(exp_dir, 'checkpoints/'),
        monitor='val_loss',
        mode='min',
        save_top_k=-1,
        verbose=1
    )

    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    logger = TensorBoardLogger(exp_dir, name="my_model")

    best_model_path = None
    if os.path.exists(os.path.join(exp_dir, 'checkpoints/')):
        all_ckpt = os.listdir(os.path.join(exp_dir, 'checkpoints/'))
        all_ckpt = [(ckpt, int("".join(filter(str.isdigit, ckpt)))) for ckpt in all_ckpt if
                    'ckpt' in ckpt and 'init' not in ckpt]
        if len(all_ckpt) > 0:
            all_ckpt.sort(key=lambda x: x[1])
            best_model_path = os.path.join(exp_dir, 'checkpoints', all_ckpt[-1][0])
    print("resume from {}".format(best_model_path))

    trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
                         enable_checkpointing=True,
                         callbacks=[LearningRateMonitor(), checkpoint_callback],
                         default_root_dir=exp_dir,
                         devices=1 if torch.cuda.is_available() else 0,
                         strategy='auto',
                         accelerator='gpu',
                         gradient_clip_val=5.,
                         logger=logger,
                         # accumulate_grad_batches=6,
                         precision=32,
                         # limit_train_batches=0.1,
                         # limit_val_batches=0.2,
                         # profiler=profiler,
                         )
    print(trainer.strategy.root_device)  # 输出 'gpu''cpu'

    trainer.fit(system, ckpt_path=best_model_path)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = model
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))


if __name__ == '__main__':
    seed_everything(seed=0)

    with open('local/conf.yml') as f:
        def_conf = yaml.safe_load(f)
    parser = prepare_parser_from_dict(def_conf, parser=parser)
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
    print(arg_dic)
    main(arg_dic)



下面我新增我所有包的版本,
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
asteroid==0.7.0
asteroid-filterbanks==0.4.0
async-timeout==4.0.3
attrs==23.2.0
audioread==3.0.1
beautifulsoup4==4.12.3
Brotli @ file:///C:/b/abs_3d36mno480/croot/brotli-split_1714483178642/work
cached-property==1.5.2
cachetools==5.3.3
certifi @ file:///C:/b/abs_55jxithrm3/croot/certifi_1717618077850/work/certifi
cffi @ file:///C:/b/abs_78eb1_vq6z/croot/cffi_1714483206096/work
charset-normalizer==2.0.12
colorama @ file:///C:/b/abs_a9ozq0l032/croot/colorama_1672387194846/work
contourpy==1.1.1
cycler==0.12.1
Cython==3.0.10
decorator==5.1.1
Deprecated==1.2.14
einops==0.8.0
filelock==3.14.0
fonttools==4.53.0
frozenlist==1.4.1
fsspec @ file:///C:/b/abs_d5jzmndkba/croot/fsspec_1714461594761/work
future==1.0.0
google==3.0.0
google-auth==2.29.0
google-auth-oauthlib==1.0.0
grpcio==1.64.1
huggingface-hub==0.23.2
idna @ file:///C:/b/abs_aad84bnnw5/croot/idna_1714398896795/work
importlib_metadata==7.1.0
importlib_resources==6.4.0
intel-openmp==2021.4.0
Jinja2==3.1.3
joblib==1.4.2
julius==0.2.7
kiwisolver==1.4.5
lazy_loader==0.4
librosa==0.10.2.post1
lightning-utilities==0.11.2
llvmlite==0.41.1
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.7.5
mir_eval==0.7
mkl==2021.4.0
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
networkx==3.0
neural_compressor==2.5.1
numba==0.58.1
numpy==1.24.4
oauthlib==3.2.2
opencv-python-headless==4.10.0.82
packaging==24.1
pandas==2.0.3
pb-bss-eval==0.0.2
pesq==0.0.4
pillow==10.2.0
platformdirs==4.2.2
pooch==1.8.1
prettytable==3.10.0
protobuf==5.27.1
psutil==5.9.8
py-cpuinfo==9.0.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocotools-windows==2.0.0.2
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pyDeprecate==0.3.1
pyparsing==3.1.2
PySocks @ file:///C:/ci/pysocks_1605287845585/work
pystoi==0.4.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.3.0
pytorch-ranger==0.1.1
pytz==2024.1
PyYAML @ file:///C:/b/abs_782o3mbw7z/croot/pyyaml_1698096085010/work
requests==2.26.0
requests-oauthlib==2.0.0
rsa==4.9
schema==0.7.7
scikit-learn==1.3.2
scipy==1.10.1
shapely==2.0.4
six==1.16.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.3.7
sympy==1.12.1
tbb==2021.12.0
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboard-plugin-wit==1.8.1
thop==0.1.1.post2209072238
threadpoolctl==3.5.0
torch==2.2.0+cu118
torch-complex==0.4.3
torch-optimizer==0.1.0
torch-stoi==0.2.1
torch-tb-profiler==0.4.3
torchaudio==2.2.0+cu118
torchmetrics==0.11.4
torchvision==0.17.0+cu118
tqdm @ file:///C:/b/abs_a7h7n45ccq/croot/tqdm_1716395987133/work
typing_extensions==4.12.2
tzdata==2024.1
urllib3==1.26.18
wcwidth==0.2.13
Werkzeug==3.0.3
win-inet-pton @ file:///C:/ci/win_inet_pton_1605306167264/work
wrapt==1.16.0
yarl==1.9.4
zipp==3.19.2

我知道我可能有包版本的冲突,但我怎么知道具体是哪几个包有冲突呢?安装包的话,我是该用conda还是pip安装呢??
我之前版本的包是pytoch-lightning==0.7.6,pytorch==1.10.0.

最后再附上我的profiler的分析图。但是不知道为什么,没有gpu的分析,只有cpu,这个问题我也不知道该如何解决。

img

  • 写回答

24条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-06-13 09:57
    关注
    获得1.00元问题酬金

    引自免费微信小程序:皆我百晓生

    面对PyTorch Lightning升级后运行速度明显下降的问题,有几种可能的原因及对应的排查与解决策略。根据提供的信息,我们可以从以下几个方面入手:

    1. 版本兼容性

    确保所有依赖库(如PyTorch、CUDA等)与新版本的PyTorch Lightning兼容。不兼容的版本组合可能会导致性能下降。

    2. 性能分析工具

    你已经尝试在代码中添加了Profiler注释,但似乎没有实际启用。启用性能分析可以帮助识别瓶颈所在。将以下代码段加入到你的训练循环中,确保正确配置Profiler并运行它:

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA
        ],
        schedule=torch.profiler.schedule(
            wait=1,
            warmup=1,
            active=3,
            repeat=2
        ),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('exp_dir'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for epoch in range(num_epochs):
            for batch in train_loader:
                # Your training step here
                # ...
                prof.step()
    

    3. 检查并优化模型

    • 模型复杂度:检查新版本是否引入了更复杂的模型结构或计算密集型操作。
    • 数据加载:数据预处理和加载时间可能随着版本更新而变化。确保数据加载效率。
    • GPU利用率:确认GPU是否被充分利用。有时,模型的某些部分可能没有在GPU上运行,导致效率低下。

    4. 梯度累积

    你提到的accumulate_grad_batches参数被注释掉了。这个参数可以显著影响内存使用和训练速度。如果模型太大,尝试重新启用它,并调整其值以找到最佳平衡点。

    5. 精度设置

    你设置了precision=32。尝试将精度设置为16,这通常会提高GPU上的训练速度,但要确保模型的数值稳定性不受影响。

    6. 并行策略

    检查strategy='auto'是否选择了最优的并行策略。可能需要手动指定一个更合适的策略,比如strategy='ddp',特别是如果你有多个GPU可用时。

    7. 硬件和驱动

    确保你的GPU驱动是最新的,有时候旧的驱动与新版本的库不兼容,导致性能下降。

    8. 代码审查

    仔细审查你的代码,尤其是与新功能相关的部分,确保没有引入低效的操作或不必要的计算。

    结论

    通过上述步骤,你应该能够定位性能下降的原因,并采取相应的措施进行优化。如果问题仍然存在,考虑回退到之前的版本,直到找到一个稳定的解决方案。同时,也可以在PyTorch Lightning的官方论坛或GitHub上寻求社区的帮助,提供详细的错误报告和性能分析结果。

    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 6月21日
  • 修改了问题 6月17日
  • 修改了问题 6月17日
  • 修改了问题 6月13日
  • 展开全部

悬赏问题

  • ¥300 寻抓云闪付tn组成网页付款链接
  • ¥15 请问Ubuntu要怎么安装chrome呀?
  • ¥15 视频编码 十六进制问题
  • ¥15 Xsheii7我安装这个文件的时候跳出来另一个文件已锁定文件的无一部分进程无法访问。这个该怎么解决
  • ¥15 unity terrain打包后地形错位,跟建筑不在同一个位置,怎么办
  • ¥15 FileNotFoundError 解决方案
  • ¥15 uniapp实现如下图的图表功能
  • ¥15 u-subsection如何修改相邻两个节点样式
  • ¥30 vs2010开发 WFP(windows filtering platform)
  • ¥15 服务端控制goose报文控制块的发布问题