class GraphNeuralNetwork(nn.Module):
"""GNN for predicting exploration value scores."""
def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 3):
super(GraphNeuralNetwork, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
# Input projection
self.input_proj = nn.Linear(input_dim, hidden_dim)
# GNN layers
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(num_layers):
self.convs.append(GCNConv(hidden_dim, hidden_dim))
self.norms.append(nn.LayerNorm(hidden_dim))
# Output layers (删除自注意力,直接接输出)
self.output_proj = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# 删除以下自注意力相关代码
# self.feature_attention = nn.MultiheadAttention(...)
def forward(self, data):
"""Forward pass through GNN."""
x, edge_index, batch = data.x, data.edge_index, data.batch
# Input projection
x = self.input_proj(x)
x = F.relu(x)
# GNN layers
for conv, norm in zip(self.convs, self.norms):
x_new = conv(x, edge_index)
x_new = F.relu(x_new)
x_new = norm(x_new)
x = x + x_new # Residual connection
# 删除自注意力部分
# x_reshaped = x.unsqueeze(0)
# x_attended, _ = self.feature_attention(x_reshaped, x_reshaped, x_reshaped)
# x = x_attended.squeeze(0)
# Node-level predictions
node_scores = self.output_proj(x).squeeze(-1)
# Graph-level predictions (mean of node scores)
graph_scores = global_mean_pool(node_scores, batch)
return node_scores, graph_scores
以上为我的模型的结构 在图卷积后加入多头注意力 是可行的吗