2022-07-02 00:23

有一个需要用GAT实现节点分类的任务，代码已经写好了，就只需要优化一下，因为现在准确率不高，自己不太清楚需要调节哪些参数可以让准确率更高一点。 有数据有代码，只优化

`

import numpy as np
import pandas as pd
import torch
from torch import Tensor
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from sklearn import metrics
import glob
from scipy import signal
from torch_geometric.nn import GATConv
#--------------------------------------计算节点
node_path = r'/data1/ljh/eeg/fromgithub/save/mean_after_process_features.csv'
x_node = x_node.values[:, 1:]
#############################################################此时x读取的是对应的节点。
#---------------------------------------标签
ypath = r'/data1/ljh/eeg/fromgithub/save/mwan_label.csv'
#############################################################此时label读取的是对应的标签。
label = label.values[:, 1:]

#--------------------------------------计算边
# read csv file from the path
print(x)
print('type:\n', type(x))
print('x的大小：\n', x.shape)
y = x.values[:, 1:]

print(y)
#############################################################此时y读取的是相关系数矩阵。
#这里打乱一下顺序
np.random.seed(112)
np.random.shuffle(x_node)
np.random.seed(112)
np.random.shuffle(label)
np.random.seed(112)
np.random.shuffle(y)
print('x_node:\n', x_node)
print('label:\n', label)

print('type:\n', type(y))
print('y的大小：\n', y.shape)
y_tensor = torch.from_numpy(y)
x_node = torch.Tensor(x_node)
label = torch.Tensor(label)

#---------------------------------------------
threshold = 0.5
train_source_node, train_target_node = torch.where(y_tensor > threshold)
train_source_node = train_source_node.unsqueeze(0)
train_target_node = train_target_node.unsqueeze(0)
train_edge_index = torch.cat((train_source_node, train_target_node), dim=0)
# train_edge_index = train_edge_index.cuda()

# 构造网络
class Net(torch.nn.Module):
def __init__(self):
super(Net,