qq_53216250 2024-03-28 18:11 采纳率: 0%
浏览 5

t-sne可视化csv数据集

在使用t-sne可视化csv数据集中遇到了以下的问题,报错信息如下:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_1220/4136105992.py in <module>
     56     idx = np.where(labels.values.flatten() == i)[0]
     57     plt.scatter(X_tsne[idx][:, 0], X_tsne[idx][:, 1], color=color_list[i],
---> 58             marker=shape_list[i], s=150, label=label_list[i], alpha=0.5)
     59 
     60 

IndexError: list index out of range

源代码如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pandas

# 加载自己的数据集和标签
X = pandas.read_csv(r"/root/autodl-tmp/376data.csv", header=None)  # 替换为你的数据集路径
labels = pandas.read_csv(r"/root/autodl-tmp/376label.csv", header=None) # 替换为你的标签路径

# 定义三个类别的均值和协方差矩阵
#mean1 = [0, 1]
#cov1 = [[1, 0.3], [0.3, 1]]
#mean2 = [3, 3]
#cov2 = [[1, -0.2], [-0.2, 3]]
#mean3 = [-10, 10]
#cov3 = [[1, 0], [0, 0.5]]
#mean4 = [-4, 2]
#cov4 = [[0.5, 0.2], [0.2, 2]]

# 生成三个类别的样本数据
#data1 = np.random.multivariate_normal(mean1, cov1, 100)
#data2 = np.random.multivariate_normal(mean2, cov2, 100)
#data3 = np.random.multivariate_normal(mean3, cov3, 100)
#data4 = np.random.multivariate_normal(mean4, cov4, 100)

#label1 = np.zeros(data1.shape[0]) + 0
#label2 = np.zeros(data1.shape[0]) + 1
#label3 = np.zeros(data1.shape[0]) + 2
#label4 = np.zeros(data1.shape[0]) + 3

# 将三个类别的数据合并
#data = np.concatenate((data1, data2, data3, data4))
#labels = np.concatenate((label1, label2, label3, label4))
#print(data.shape, labels.shape)

# 使用t-SNE进行降维
tsne = TSNE(n_components=2, random_state=42)
#X_tsne = tsne.fit_transform(data)
X_tsne = tsne.fit_transform(X)
# 归一化
x_min, x_max = X_tsne.min(0), X_tsne.max(0)
X_norm = (X_tsne - x_min) / (x_max - x_min)

# 绘制t-SNE可视化图
plt.figure(figsize=(10, 8))
plt.rcParams['font.sans-serif'] = ['Times New Roman']  # 图中文字体设置为Times New Roman

shape_list = ['o', 'D', '^', 'P', 's', 'x', '*', '+']  # 设置不同类别的形状
color_list = ['r', 'g', 'b', 'm', 'c', 'y', 'k', 'orange', 'purple', 'brown', 'pink'] 
  # 设置不同类别的颜色

label_list = ['SiO2', 'TiO2', 'Al2O3', 'FeOT', 'MgO', 'CaO', 'Na2O', 'K2O']
# 遍历所有标签种类
# 遍历所有标签种类
for i in range(len(np.unique(labels))):
    idx = np.where(labels.values.flatten() == i)[0]
    plt.scatter(X_tsne[idx][:, 0], X_tsne[idx][:, 1], color=color_list[i],
            marker=shape_list[i], s=150, label=label_list[i], alpha=0.5)


# # 遍历所有样本
#color_map = {0:'r', 1:'g', 2:'b', 3:'m'}   # 定义类别颜色映射关系
#shape_map = {0:'o', 1:'D', 2:'^', 3:'P'}
#default_color = 'k'  # 默认颜色
#default_shape = 'o'  # 默认形状
#color_map = {label: color_list[label % len(color_list)] if label in color_list else default_color for label in set(labels)}
#shape_map = {label: shape_list[label % len(shape_list)] if label in shape_list else default_shape for label in set(labels)}
#label_map = {0:'Class 1', 1:'Class 2', 2:'Class 3', 3:'Class 4'}
#for data, label in zip(X_norm, labels):
   # print(data, label)  # 打印每个样本数据和其对应标签
   # plt.scatter(data[0], data[1], color=color_map[label], marker=shape_map[label])
   # plt.text(data[0], data[1], label_map[label], ha='center', va='bottom')  # 所有样本都对应写上标签
# 添加图例,并设置字体大小
plt.legend(fontsize=20)

ax = plt.gca()  # gca:get current axis得到当前轴
# ax.spines['right'].set_visible(False)  # 取消右边界
# ax.spines['top'].set_visible(False)    # 取消上边界
ax.spines['right'].set_linewidth('2.0')  # 设置边框线宽为2.0
ax.spines['top'].set_linewidth('2.0')  # 设置边框线宽为2.0
ax.spines['bottom'].set_linewidth('2.0')  # 设置边框线宽为2.0
ax.spines['left'].set_linewidth('2.0')  # 设置边框线宽为2.0

plt.xticks(fontsize=20)  # 定义坐标轴刻度
plt.yticks(fontsize=20)

plt.xlabel('t-SNE Dimension 1', fontsize=20)  # 定义坐标轴标题
plt.ylabel('t-SNE Dimension 2', fontsize=20)
plt.title('t-SNE Visualization', fontsize=24)  # 定义图题

plt.show()  # 显示图形
plt.savefig('./visualization.png', dpi=600)  # 保存图为png格式
# plt.savefig('./visualization.svg', format='svg')  # 保存图为svg格式


  • 写回答

1条回答 默认 最新

  • Kwan的解忧杂货铺 Java领域优质创作者 2024-03-28 18:11
    关注

    qq_53216250 晚上好🌙🌙🌙
    本答案参考ChatGPT-3.5

    问题分析: 在使用t-sne可视化csv数据集时,报错信息为IndexError: list index out of range,原因是在for循环中访问了超出列表范围的索引。

    解决方案:

    1. 检查label_list列表的长度是否与np.unique(labels)的长度相等,如果不相等,需要重新定义label_list,确保与标签种类数量一致。
    2. 确保shape_list和color_list的长度大于或等于标签种类数量,在此处可以使用循环列表的方式避免越界错误。
    3. 检查labels的维度是否为二维,如果是一维,请转换成二维。
    4. 检查labels的数据类型是否为整数类型,如果不是,请转换为整数类型。

    修改后的代码如下:

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.manifold import TSNE
    import pandas as pd
    
    # 加载自己的数据集和标签
    X = pd.read_csv("/root/autodl-tmp/376data.csv", header=None)  # 替换为你的数据集路径
    labels = pd.read_csv("/root/autodl-tmp/376label.csv", header=None) # 替换为你的标签路径
    
    # 使用t-SNE进行降维
    tsne = TSNE(n_components=2, random_state=42)
    X_tsne = tsne.fit_transform(X)
    
    # 归一化
    x_min, x_max = X_tsne.min(0), X_tsne.max(0)
    X_norm = (X_tsne - x_min) / (x_max - x_min)
    
    # 绘制t-SNE可视化图
    plt.figure(figsize=(10, 8))
    plt.rcParams['font.sans-serif'] = ['Times New Roman']  # 图中文字体设置为Times New Roman
    
    # 设置不同类别的形状和颜色
    shape_list = ['o', 'D', '^', 'P', 's', 'x', '*', '+']
    color_list = ['r', 'g', 'b', 'm', 'c', 'y', 'k', 'orange', 'purple', 'brown', 'pink']
    
    # 遍历所有标签种类
    label_list = ['SiO2', 'TiO2', 'Al2O3', 'FeOT', 'MgO', 'CaO', 'Na2O', 'K2O']
    for i in range(len(np.unique(labels.values))):
        idx = np.where(labels.values.flatten() == i)[0]
        plt.scatter(X_norm[idx][:, 0], X_norm[idx][:, 1], color=color_list[i % len(color_list)],
                    marker=shape_list[i % len(shape_list)], s=150, label=label_list[i], alpha=0.5)
    
    # 添加图例,并设置字体大小
    plt.legend(fontsize=20)
    
    ax = plt.gca()
    ax.spines['right'].set_linewidth('2.0') 
    ax.spines['top'].set_linewidth('2.0')  
    ax.spines['bottom'].set_linewidth('2.0')  
    ax.spines['left'].set_linewidth('2.0')  
    
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    
    plt.xlabel('t-SNE Dimension 1', fontsize=20)
    plt.ylabel('t-SNE Dimension 2', fontsize=20)
    plt.title('t-SNE Visualization', fontsize=24)
    
    plt.show()
    plt.savefig('./visualization.png', dpi=600)
    

    注意事项:

    1. 在绘制图像之前,一定要确保label_list、shape_list和color_list的长度应适配和符合数据集的实际情况。
    2. 如果仍有错误,请检查数据集路径和标签路径是否正确,并确认数据集和标签的内容是否符合要求。
    评论

报告相同问题?

问题事件

  • 创建了问题 3月28日

悬赏问题

  • ¥15 装 pytorch 的时候出了好多问题,遇到这种情况怎么处理?
  • ¥20 IOS游览器某宝手机网页版自动立即购买JavaScript脚本
  • ¥15 手机接入宽带网线,如何释放宽带全部速度
  • ¥30 关于#r语言#的问题:如何对R语言中mfgarch包中构建的garch-midas模型进行样本内长期波动率预测和样本外长期波动率预测
  • ¥15 ETLCloud 处理json多层级问题
  • ¥15 matlab中使用gurobi时报错
  • ¥15 这个主板怎么能扩出一两个sata口
  • ¥15 不是,这到底错哪儿了😭
  • ¥15 2020长安杯与连接网探
  • ¥15 关于#matlab#的问题:在模糊控制器中选出线路信息,在simulink中根据线路信息生成速度时间目标曲线(初速度为20m/s,15秒后减为0的速度时间图像)我想问线路信息是什么