我想利用新版本pytorchlightning的功能,所以进行了版本升级,但是升级之后程序运行速度显著变慢,显示是在gpu上运行的,但为什么速度比版本升级前慢5倍呢??下面是我的train代码和环境中的包配置。显示是在gpu上运行的。不知道为什么会出现这种情况。
我的显卡是4090,我只有一个gpu。之前版本的话运行一个epoch需要半小时,现在版本更新后需要两个半小时。
关于精度的问题,之前我的代码就是32位训练的,所以我这个也保持32位精度训练。
torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', default='exp/tmp', help='Full path to save best validation model')
parser.add_argument('--pretrain', type=ast.literal_eval, required=True, help='whether to pretrain the stage 1 model')
# profiler = PyTorchProfiler(
# output_filename="profiler.txt",
# dirpath='exp_dir',
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA
# ],
# schedule=torch.profiler.schedule(
# wait=1,
# warmup=1,
# active=3,
# repeat=2
# ),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('exp_dir'),
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# )
def main(conf):
conf['masknet'].update({'n_src': conf['data']['n_src']})
model, optimizer = make_model_and_optimizer(conf)
exp_dir = conf['main_args']['exp_dir']
if os.path.exists(os.path.join(exp_dir, 'precheckpoints/')):
all_ckpt = os.listdir(os.path.join(exp_dir, 'precheckpoints/'))
all_ckpt = [(ckpt, int("".join(filter(str.isdigit, ckpt)))) for ckpt in all_ckpt]
all_ckpt.sort(key=lambda x: x[1])
best_model_path = os.path.join(exp_dir, 'precheckpoints', all_ckpt[-1][0])
checkpoint = torch.load(best_model_path, map_location='cpu')
orig = checkpoint['state_dict']
model_statedict = model.state_dict()
for k in orig.keys():
model_statedict[k[6:]] = orig[k]
model.load_state_dict(model_statedict, strict=False)
for param_group in optimizer.param_groups:
param_group['lr'] = conf['optim']['lr']
train_loader, val_loader = make_dataloaders(**conf['data'], **conf['training'], channels=slice(0, 4))
# Define scheduler
scheduler = None
if conf['training']['half_lr']:
scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, # 实现动态学习率
patience=5)
os.makedirs(exp_dir, exist_ok=True)
conf_path = os.path.join(exp_dir, 'conf.yml')
with open(conf_path, 'w') as outfile:
yaml.safe_dump(conf, outfile)
loss_func = BFLoss()
system = BeamTasNetSystem(pretrain=conf['main_args']['pretrain'], new_lr=conf['optim']['lr'],
model=model, loss_func=loss_func, optimizer=optimizer,
train_loader=train_loader, val_loader=val_loader,
scheduler=scheduler,config=conf)
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(exp_dir, 'checkpoints/'),
monitor='val_loss',
mode='min',
save_top_k=-1,
verbose=1
)
if conf['training']['early_stop']:
early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger = TensorBoardLogger(exp_dir, name="my_model")
best_model_path = None
if os.path.exists(os.path.join(exp_dir, 'checkpoints/')):
all_ckpt = os.listdir(os.path.join(exp_dir, 'checkpoints/'))
all_ckpt = [(ckpt, int("".join(filter(str.isdigit, ckpt)))) for ckpt in all_ckpt if
'ckpt' in ckpt and 'init' not in ckpt]
if len(all_ckpt) > 0:
all_ckpt.sort(key=lambda x: x[1])
best_model_path = os.path.join(exp_dir, 'checkpoints', all_ckpt[-1][0])
print("resume from {}".format(best_model_path))
trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
enable_checkpointing=True,
callbacks=[LearningRateMonitor(), checkpoint_callback],
default_root_dir=exp_dir,
devices=1 if torch.cuda.is_available() else 0,
strategy='auto',
accelerator='gpu',
gradient_clip_val=5.,
logger=logger,
# accumulate_grad_batches=6,
precision=32,
# limit_train_batches=0.1,
# limit_val_batches=0.2,
# profiler=profiler,
)
print(trainer.strategy.root_device) # 输出 'gpu' 或 'cpu'
trainer.fit(system, ckpt_path=best_model_path)
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
state_dict = torch.load(best_path)
system.load_state_dict(state_dict=state_dict['state_dict'])
system.cpu()
to_save = model
torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
if __name__ == '__main__':
seed_everything(seed=0)
with open('local/conf.yml') as f:
def_conf = yaml.safe_load(f)
parser = prepare_parser_from_dict(def_conf, parser=parser)
arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
print(arg_dic)
main(arg_dic)
下面我新增我所有包的版本,
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
asteroid==0.7.0
asteroid-filterbanks==0.4.0
async-timeout==4.0.3
attrs==23.2.0
audioread==3.0.1
beautifulsoup4==4.12.3
Brotli @ file:///C:/b/abs_3d36mno480/croot/brotli-split_1714483178642/work
cached-property==1.5.2
cachetools==5.3.3
certifi @ file:///C:/b/abs_55jxithrm3/croot/certifi_1717618077850/work/certifi
cffi @ file:///C:/b/abs_78eb1_vq6z/croot/cffi_1714483206096/work
charset-normalizer==2.0.12
colorama @ file:///C:/b/abs_a9ozq0l032/croot/colorama_1672387194846/work
contourpy==1.1.1
cycler==0.12.1
Cython==3.0.10
decorator==5.1.1
Deprecated==1.2.14
einops==0.8.0
filelock==3.14.0
fonttools==4.53.0
frozenlist==1.4.1
fsspec @ file:///C:/b/abs_d5jzmndkba/croot/fsspec_1714461594761/work
future==1.0.0
google==3.0.0
google-auth==2.29.0
google-auth-oauthlib==1.0.0
grpcio==1.64.1
huggingface-hub==0.23.2
idna @ file:///C:/b/abs_aad84bnnw5/croot/idna_1714398896795/work
importlib_metadata==7.1.0
importlib_resources==6.4.0
intel-openmp==2021.4.0
Jinja2==3.1.3
joblib==1.4.2
julius==0.2.7
kiwisolver==1.4.5
lazy_loader==0.4
librosa==0.10.2.post1
lightning-utilities==0.11.2
llvmlite==0.41.1
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.7.5
mir_eval==0.7
mkl==2021.4.0
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
networkx==3.0
neural_compressor==2.5.1
numba==0.58.1
numpy==1.24.4
oauthlib==3.2.2
opencv-python-headless==4.10.0.82
packaging==24.1
pandas==2.0.3
pb-bss-eval==0.0.2
pesq==0.0.4
pillow==10.2.0
platformdirs==4.2.2
pooch==1.8.1
prettytable==3.10.0
protobuf==5.27.1
psutil==5.9.8
py-cpuinfo==9.0.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocotools-windows==2.0.0.2
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pyDeprecate==0.3.1
pyparsing==3.1.2
PySocks @ file:///C:/ci/pysocks_1605287845585/work
pystoi==0.4.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.3.0
pytorch-ranger==0.1.1
pytz==2024.1
PyYAML @ file:///C:/b/abs_782o3mbw7z/croot/pyyaml_1698096085010/work
requests==2.26.0
requests-oauthlib==2.0.0
rsa==4.9
schema==0.7.7
scikit-learn==1.3.2
scipy==1.10.1
shapely==2.0.4
six==1.16.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.3.7
sympy==1.12.1
tbb==2021.12.0
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboard-plugin-wit==1.8.1
thop==0.1.1.post2209072238
threadpoolctl==3.5.0
torch==2.2.0+cu118
torch-complex==0.4.3
torch-optimizer==0.1.0
torch-stoi==0.2.1
torch-tb-profiler==0.4.3
torchaudio==2.2.0+cu118
torchmetrics==0.11.4
torchvision==0.17.0+cu118
tqdm @ file:///C:/b/abs_a7h7n45ccq/croot/tqdm_1716395987133/work
typing_extensions==4.12.2
tzdata==2024.1
urllib3==1.26.18
wcwidth==0.2.13
Werkzeug==3.0.3
win-inet-pton @ file:///C:/ci/win_inet_pton_1605306167264/work
wrapt==1.16.0
yarl==1.9.4
zipp==3.19.2
我知道我可能有包版本的冲突,但我怎么知道具体是哪几个包有冲突呢?安装包的话,我是该用conda还是pip安装呢??
我之前版本的包是pytoch-lightning==0.7.6,pytorch==1.10.0.
最后再附上我的profiler的分析图。但是不知道为什么,没有gpu的分析,只有cpu,这个问题我也不知道该如何解决。