m0_72183621
2022-07-02 00:23
采纳率: 33.3%
浏览 74
已结题

有一个需要用GAT实现节点分类的任务,代码已经写好了,就只需要优化一下,因为现在准确率不高,自己不太清楚需要调节哪些参数可以让准确率更高一点。 有数据有代码,只优化

我是在服务器上跑的,所以这几行,地址需要他改一下。看能不能优化到0.7,csv文件发邮箱

img


`

import numpy as np
import pandas as pd
import torch
from torch import Tensor
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from sklearn import metrics
import glob
from scipy import signal
from torch_geometric.nn import GATConv
#--------------------------------------计算节点
node_path = r'/data1/ljh/eeg/fromgithub/save/mean_after_process_features.csv'
x_node = pd.read_csv(node_path)
x_node = x_node.values[:, 1:]
#############################################################此时x读取的是对应的节点。
#---------------------------------------标签
ypath = r'/data1/ljh/eeg/fromgithub/save/mwan_label.csv'
label = pd.read_csv(ypath)
#############################################################此时label读取的是对应的标签。
label = label.values[:, 1:]

#--------------------------------------计算边
# read csv file from the path
xpath = r'/data1/ljh/eeg/fromgithub/save/mean_adjacent_matrix.csv'
x = pd.read_csv(xpath)
print(x)
print('type:\n', type(x))
print('x的大小:\n', x.shape)
y = x.values[:, 1:]

print(y)
#############################################################此时y读取的是相关系数矩阵。
#这里打乱一下顺序
np.random.seed(112)
np.random.shuffle(x_node)
np.random.seed(112)
np.random.shuffle(label)
np.random.seed(112)
np.random.shuffle(y)
print('x_node:\n', x_node)
print('label:\n', label)



print('type:\n', type(y))
print('y的大小:\n', y.shape)
y_tensor = torch.from_numpy(y)
x_node = torch.Tensor(x_node)
label = torch.Tensor(label)

#---------------------------------------------
threshold = 0.5
train_source_node, train_target_node = torch.where(y_tensor > threshold)
train_source_node = train_source_node.unsqueeze(0)
train_target_node = train_target_node.unsqueeze(0)
train_edge_index = torch.cat((train_source_node, train_target_node), dim=0)
# train_edge_index = train_edge_index.cuda()



# 构造网络
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,

1条回答 默认 最新

相关推荐 更多相似问题