import os
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, HeteroConv, GlobalAttention
from htg_data import HTG_data
import torch.nn.functional as F
class HeteroGNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers=2, dropout=0.5):
super().__init__()
self.dropout = dropout
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv({
('token', 'next_token', 'token'): GCNConv(-1, hidden_channels),
('token', 'token_sink', 'sink'): SAGEConv((-1, -1), hidden_channels),
('token', 'belongs_to', 'property'): GATConv((-1, -1), hidden_channels),
('property', 'property_sink', 'sink'): SAGEConv((-1, -1), hidden_channels),
('property', 'next_property', 'property'): GATConv((-1, -1), hidden_channels),
}, aggr='sum')
self.convs.append(conv)
self.pooling_gate_nn = Linear(hidden_channels, 1)
self.pooling = GlobalAttention(self.pooling_gate_nn)
self.lin = Linear(hidden_channels, out_channels)
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
self.pooling.reset_parameters()
self.lin.reset_parameters()
def forward(self, x_dict, edge_index_dict, batch):
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
x = torch.cat((x_dict['sink'],x_dict['property'],x_dict['token']), 0)
x = self.pooling(x, batch)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin(x)
return x
def run(rank, world_size: int, root: str):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
dataset = HTG_data(root = root)
print(dataset[0])
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)
torch.manual_seed(12345)
model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=2).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = torch.nn.MultiLabelSoftMarginLoss()
for epoch in range(1, 51):
model.train()
total_loss = 0
for data in train_loader:
data = data.to(rank)
optimizer.zero_grad()
logits = model(data.x_dict, data.edge_index_dict, data.batch)
loss = criterion(logits, data.y.to(torch.float))
loss.backward()
optimizer.step()
total_loss += float(loss) * logits.size(0)
loss = total_loss / len(train_loader.dataset)
dist.barrier()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ')
dist.destroy_process_group()
if __name__ == '__main__':
root = '/home/ylzqn/HTG_Data/HTG_CAG'
world_size = torch.cuda.device_count()
print('Let\'s use', world_size, 'GPUs!')
args = (world_size, root)
mp.spawn(run, args=args, nprocs=world_size, join=True)
**Traceback (most recent call last):
File "/home/ylzqn/Jupyter Notebook/pkgcode2vec/htg_model.py", line 127, in
mp.spawn(run, args=args, nprocs=world_size, join=True)
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, args)
File "/home/ylzqn/Jupyter Notebook/pkgcode2vec/htg_model.py", line 71, in run
model = DistributedDataParallel(model, device_ids=[rank])
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 435, in init
'Modules with uninitialized parameters can't be used with DistributedDataParallel. '
RuntimeError: Modules with uninitialized parameters can't be used with DistributedDataParallel. Run a dummy forward pass to correctly initialize the modules*
我想用DistributedDataParallel将模型部署到4个GPU,出现了以上的问题. 希望哪位能指导一下.