问题遇到的现象和发生背景
import torch
from torch.utils.data import Dataset, DataLoader
#自定义数据集类,torch.utils.data.random_split() 划分训练集、验证集、测试集。
class MyDataSet(Dataset):
def __init__(self, loaded_data):
self.data = loaded_data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
Data_path = "/content/drive/MyDrive/train_data.csv"
Totle_data = pd.read_csv(Data_path)
custom_dataset = MyDataSet(Totle_data)
#按照比例划分
train_size = int(len(custom_dataset) * 0.8)
validate_size = int(len(custom_dataset) * 0.1)
test_size = len(custom_dataset) - validate_size - train_size
train_dataset, validate_dataset, test_dataset = torch.utils.data.random_split(custom_dataset, [train_size, validate_size, test_size])
#设置保存路径
train_data_path="/content/drive/MyDrive/Data_Mining_Training.csv"
dev_data_path = "/content/drive/MyDrive/Data_Mining_Dev.csv"
test_data_path="/content/drive/MyDrive/Data_Mining_Test.csv"
#index参数设置为False表示不保存行索引,header设置为False表示不保存列索引
train_dataset.to_csv(train_data_path,index=False,header=True)
validate_dataset.to_csv(dev_data_path ,index=False,header=True)
validate_dataset.to_csv(test_data_path,index=False,header=True)
倒数第三行报错