weixin_50763826 2022-10-03 09:44 采纳率: 0%
浏览 215
已结题

pytorch部署到多GPU的问题

import os
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, HeteroConv, GlobalAttention
from htg_data import HTG_data
import torch.nn.functional as F

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('token', 'next_token', 'token'): GCNConv(-1, hidden_channels),
                ('token', 'token_sink', 'sink'): SAGEConv((-1, -1), hidden_channels),
                ('token', 'belongs_to', 'property'): GATConv((-1, -1), hidden_channels),
                ('property', 'property_sink', 'sink'): SAGEConv((-1, -1), hidden_channels),
                ('property', 'next_property', 'property'): GATConv((-1, -1), hidden_channels),
            }, aggr='sum')
            self.convs.append(conv)
        self.pooling_gate_nn = Linear(hidden_channels, 1)
        self.pooling = GlobalAttention(self.pooling_gate_nn)
        self.lin = Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.pooling.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x_dict, edge_index_dict, batch):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        x = torch.cat((x_dict['sink'],x_dict['property'],x_dict['token']), 0)
        x = self.pooling(x, batch)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x

def run(rank, world_size: int, root: str):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    dataset = HTG_data(root = root)

    print(dataset[0])

    train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)


    torch.manual_seed(12345)
    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=2).to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])

    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    criterion = torch.nn.MultiLabelSoftMarginLoss()

    for epoch in range(1, 51):
        model.train()

        total_loss = 0
        for data in train_loader:
            data = data.to(rank)
            optimizer.zero_grad()
            logits = model(data.x_dict, data.edge_index_dict, data.batch)
            loss = criterion(logits, data.y.to(torch.float))
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * logits.size(0)
        loss = total_loss / len(train_loader.dataset)

        dist.barrier()

        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ')

    dist.destroy_process_group()

if __name__ == '__main__':

    root = '/home/ylzqn/HTG_Data/HTG_CAG'
    world_size = torch.cuda.device_count()
    print('Let\'s use', world_size, 'GPUs!')
    args = (world_size, root)
    mp.spawn(run, args=args, nprocs=world_size, join=True)

**Traceback (most recent call last):
File "/home/ylzqn/Jupyter Notebook/pkgcode2vec/htg_model.py", line 127, in
mp.spawn(run, args=args, nprocs=world_size, join=True)
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, args)
File "/home/ylzqn/Jupyter Notebook/pkgcode2vec/htg_model.py", line 71, in run
model = DistributedDataParallel(model, device_ids=[rank])
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 435, in init
'Modules with uninitialized parameters can't be used with DistributedDataParallel. '
RuntimeError: Modules with uninitialized parameters can't be used with DistributedDataParallel. Run a dummy forward pass to correctly initialize the modules
*


我想用DistributedDataParallel将模型部署到4个GPU,出现了以上的问题. 希望哪位能指导一下.

  • 写回答

2条回答 默认 最新

  • 生产队的小刘 Python领域新星创作者 2022-10-03 11:16
    关注

    推荐参考文章:https://blog.csdn.net/qq_40564301/article/details/123694176
    (望采纳哦~)

    评论

报告相同问题?

问题事件

  • 系统已结题 10月11日
  • 创建了问题 10月3日

悬赏问题

  • ¥15 #MATLAB仿真#车辆换道路径规划
  • ¥15 java 操作 elasticsearch 8.1 实现 索引的重建
  • ¥15 数据可视化Python
  • ¥15 要给毕业设计添加扫码登录的功能!!有偿
  • ¥15 kafka 分区副本增加会导致消息丢失或者不可用吗?
  • ¥15 微信公众号自制会员卡没有收款渠道啊
  • ¥15 stable diffusion
  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条