Msy20070905 2024-07-16 23:31 采纳率: 21.2%
浏览 11

setting an array element with a sequence.


import tensorflow as tf
from matplotlib import pyplot as plt
import os
import numpy as np
import glob
from torch.utils import data
from PIL import Image
from torchvision import transforms


# 数据集内图片的验证
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
dataset_path = "rps_data_sample"

labels = []
for i in os.listdir(dataset_path):
    if os.path.isdir(os.path.join(dataset_path, i)):
        labels.append(i)

NUM_EXAMPLES = 5

for label in labels:
    label_dir = os.path.join(dataset_path, label)
    example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
    fig, axs = plt.subplots(1, 5, figsize=(10,2))
    for i in range(NUM_EXAMPLES):
        axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
        axs[i].get_xaxis().set_visible(False)
        axs[i].get_yaxis().set_visible(False)
    fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')


# 创建data.Dataset子类Mydataset
class Mydataset(data.Dataset):
    def __init__(self,root):
        self.imgs_path = root
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        return img_path
    def __len__(self):
        return len(self.imgs_path)

# glob遍历数据路径
all_imgs_path = glob.glob('rps_data_sample\*\*.jpg')
for var in all_imgs_path:
    print(var)

# 建立gesture_data
gesture_dataset = Mydataset(all_imgs_path)
print(len(gesture_dataset))

# path迭代
species = ['none','paper','rock','scissors']
species_to_id = dict((c,i) for i,c in enumerate(species))
id_to_species = dict((v,k) for k,v in species_to_id.items())
all_labels = []
for img in all_imgs_path:
    for i,c in enumerate(species):
        if c in img:
            all_labels.append(i)

transform = transforms.Compose([transforms.ToTensor()])
index = np.random.permutation(len(all_imgs_path))

all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

p = int(len(all_imgs_path)*0.8)
x_train_ = all_imgs_path[:p]
y_train_ = all_labels[:p]
x_test_ = all_imgs_path[p:]
y_test_ = all_imgs_path[p:]


class MyDatasetpro(data.Dataset):
    def __init__(self,img_paths ,labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        pil_img = Image.open(img)
        data = self.transforms(pil_img)
        return data, label
    def __len__(self):
        return len(self.imgs)

x_train = MyDatasetpro(x_train_, y_train_, transform)
x_test = MyDatasetpro(x_test_, y_test_, transform)
y_train = data.DataLoader(x_train, batch_size=5,shuffle=True)
y_test = data.DataLoader(x_test, batch_size=5,shuffle=True)

# tensorflow训练
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=32,kernel_size=(5,5),padding='same'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(filters=96,kernel_size=(5,5),padding='same'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.AveragePooling2D(pool_size=(3,3),padding='valid'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256,activation='relu',kernel_regularizer=tf.keras.regularizers.l1()),
    tf.keras.layers.Dense(64,activation='relu',kernel_regularizer=tf.keras.regularizers.l1()),
    tf.keras.layers.Dense(5,activation='softmax',kernel_regularizer=tf.keras.regularizers.l1())
])
model.compile(optimizer='nadam',loss="sparse_categoriical_crossentropy", metrics=["sparse_categorical_accuracy"])
# 由于神经层较多,cpu一次可能跑不了,断点续训
# checkpoint_save_path = './rps_data_sample/checkpoint.ckpt'
# if os.path.exists(checkpoint_save_path+'index'):
#   print('load model')

# model.load_weights(checkpoint_save_path)
# cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
#                                                 save_weights_only=True,
#                                                 save_best_only=True)

history = model.fit(np.array(x_train),y_train,batch_size=32, epochs=1000, validation_data=(np.array(x_test),y_test),validation_freq=1)
model.summary()

file = open('/weights.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name)+'\n')
    file.write(str(v.shape)+'\n')
    file.write(str(v.numpy()+'\n'))
file.close()

# 可视化
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss=history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label="Validation Accuracy")
plt.title('Training and Validation Accuracy')

结果报setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (400, 2) + inhomogeneous part.
怎么处理啊

  • 写回答

4条回答 默认 最新

  • 阿里嘎多学长 2024-07-16 23:31
    关注

    以下内容由CHATGPT及阿里嘎多学长共同生成、有用望采纳:


    您遇到的错误信息 "setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (400, 2) + inhomogeneous part." 指出您尝试用一个序列来设置数组元素,但是这个数组在两个维度之后的形状是不均匀的。这通常发生在尝试将一个形状不匹配的序列赋值给numpy数组或类似结构时。

    从您提供的代码片段来看,问题可能出在以下几行:

    x_train_ = all_imgs_path[:p]
    y_train_ = all_labels[:p]
    x_test_ = all_imgs_path[p:]
    y_test_ = all_labels[p:]
    

    这里,您将 all_imgs_pathall_labels 分别切片成了训练集和测试集。但是,当您创建 MyDatasetpro 类的实例时,您需要确保 img_pathslabels 的长度是一致的。如果 all_imgs_pathall_labels 的长度不一致,这将导致错误。

    此外,您在创建 DataLoader 对象时,似乎将 DataLoader 直接赋值给了 y_trainy_test,这是不正确的。DataLoader 是一个迭代器,用于批量加载数据,您应该将其与数据集一起使用。例如:

    y_train = data.DataLoader(x_train, batch_size=5, shuffle=True)
    y_test = data.DataLoader(x_test, batch_size=5, shuffle=False)
    

    请注意,y_trainy_test 应该是 DataLoader 对象,而不是标签数组本身。

    最后,您在保存模型权重时的代码存在一些问题。file.write(str(v.numpy()+'\n')) 这一行缺少了一个括号,应该是 file.write(str(v.numpy())+'\n')

    修复这些问题后,您的代码应该能够正常运行。如果问题仍然存在,请提供更多的上下文信息,以便我能够提供更具体的帮助。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 7月16日

悬赏问题

  • ¥15 35114 SVAC视频验签的问题
  • ¥15 impedancepy
  • ¥15 在虚拟机环境下完成以下,要求截图!
  • ¥15 求往届大挑得奖作品(ppt…)
  • ¥15 如何在vue.config.js中读取到public文件夹下window.APP_CONFIG.API_BASE_URL的值
  • ¥50 浦育平台scratch图形化编程
  • ¥20 求这个的原理图 只要原理图
  • ¥15 vue2项目中,如何配置环境,可以在打完包之后修改请求的服务器地址
  • ¥20 微信的店铺小程序如何修改背景图
  • ¥15 UE5.1局部变量对蓝图不可见