Msy20070905 2024-07-17 08:12 采纳率: 21.2%
浏览 2

generator` yielded an element of shape (3, 224, 224)


import tensorflow as tf
from matplotlib import pyplot as plt
import os
import numpy as np
import glob
from torch.utils import data
from PIL import Image
from torchvision import transforms

# 数据集内图片的验证
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
dataset_path = "rps_data_sample"

labels = []
for i in os.listdir(dataset_path):
    if os.path.isdir(os.path.join(dataset_path, i)):
        labels.append(i)

NUM_EXAMPLES = 5

for label in labels:
    label_dir = os.path.join(dataset_path, label)
    example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
    fig, axs = plt.subplots(1, 5, figsize=(10, 2))
    for i in range(NUM_EXAMPLES):
        axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
        axs[i].get_xaxis().set_visible(False)
        axs[i].get_yaxis().set_visible(False)
    fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')


# 创建data.Dataset子类Mydataset
class Mydataset(data.Dataset):
    def __init__(self, root):
        self.imgs_path = root

    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        return img_path

    def __len__(self):
        return len(self.imgs_path)


# glob遍历数据路径
all_imgs_path = glob.glob('rps_data_sample\\*\\*.jpg')
for var in all_imgs_path:
    print(var)

# 建立gesture_data
gesture_dataset = Mydataset(all_imgs_path)
print(len(gesture_dataset))

# path迭代
species = ['none', 'paper', 'rock', 'scissors']
species_to_id = dict((c, i) for i, c in enumerate(species))
id_to_species = dict((v, k) for k, v in species_to_id.items())
all_labels = []
for img in all_imgs_path:
    for i, c in enumerate(species):
        if c in img:
            all_labels.append(i)

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 统一图像尺寸
    transforms.ToTensor()
])
index = np.random.permutation(len(all_imgs_path))

all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

p = int(len(all_imgs_path) * 0.8)
x_train_ = all_imgs_path[:p]
y_train_ = all_labels[:p]
x_test_ = all_imgs_path[p:]
y_test_ = all_imgs_path[p:]


class MyDatasetpro(data.Dataset):
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform

    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        pil_img = Image.open(img).convert('RGB')  # 确保图像为RGB格式
        data = self.transforms(pil_img)
        return data, label

    def __len__(self):
        return len(self.imgs)


x_train = MyDatasetpro(x_train_, y_train_, transform)
x_test = MyDatasetpro(x_test_, y_test_, transform)


# 使用TensorFlow的数据处理功能
def tf_dataset(dataset, batch_size):
    def generator():
        for img, label in dataset:
            yield img.numpy(), label

    ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.int32),
                                        output_shapes=((224, 224,3), ()))
    ds = ds.batch(batch_size)
    return ds


batch_size = 32
train_ds = tf_dataset(x_train, batch_size)
test_ds = tf_dataset(x_test, batch_size)

# tensorflow训练
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=32, kernel_size=(5, 5), padding='same', input_shape=(224, 224, 3)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(filters=96, kernel_size=(5, 5), padding='same'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.AveragePooling2D(pool_size=(3, 3), padding='valid'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l1()),
    tf.keras.layers.Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l1()),
    tf.keras.layers.Dense(5, activation='softmax', kernel_regularizer=tf.keras.regularizers.l1())
])
model.compile(optimizer='nadam', loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"])

history = model.fit(train_ds, epochs=10000, validation_data=test_ds, validation_freq=1)
model.summary()

file = open('/weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

# 可视化
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label="Validation Accuracy")
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label="Validation Loss")
plt.title('Training and Validation Loss')
plt.legend()

plt.show()

这段代码中报generator` yielded an element of shape (3, 224, 224) where an element of shape (224, 224, 3) was expected.咋处理啊

  • 写回答

5条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-07-17 08:12
    关注

    引自免费微信小程序:皆我百晓生

    这个错误是因为你的数据生成器返回的图像形状不符合模型的输入要求。模型期望的图像形状是 (224, 224, 3)(高、宽和颜色通道),但生成器返回的图像形状是 (3, 224, 224)。这可能是因为图像数据的颜色通道顺序不正确。在大多数情况下,图像数据应该是 (高度, 宽度, 颜色通道) 的顺序,即 RGB 格式的图像。

    为了解决这个问题,你需要确保图像数据是 RGB 格式。在你的代码中,你已经使用 Image.open(img).convert('RGB') 来确保图像为 RGB 格式,这一步是正确的。但可能之前的数据加载方式导致图像数据的通道顺序不正确。你可以通过检查 all_imgs_path 中图像的实际情况来确认这一点。

    此外,在定义 tf_dataset 函数时,你需要确保 output_shapes 参数正确地反映了图像的形状。由于模型期望的图像形状是 (224, 224, 3),所以你应该将 output_shapes 设置为这个形状。以下是修正后的 tf_dataset 函数:

    def tf_dataset(dataset, batch_size):
        def generator():
            for img, label in dataset:
                yield tf.convert_to_tensor(img, dtype=tf.float32), label
    
        ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.int32),
                                           output_shapes=((224, 224, 3)))  # 修改这里以匹配模型期望的形状
        ds = ds.batch(batch_size)
        return ds
    

    请检查你的数据加载和处理过程,确保图像数据在进入模型之前已经是正确的 RGB 格式,并且形状与模型期望的形状相匹配。如果问题仍然存在,你可能需要检查数据加载和处理的每一步,以确保图像数据的处理过程没有出错。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 7月17日

悬赏问题

  • ¥15 35114 SVAC视频验签的问题
  • ¥15 impedancepy
  • ¥15 在虚拟机环境下完成以下,要求截图!
  • ¥15 求往届大挑得奖作品(ppt…)
  • ¥15 如何在vue.config.js中读取到public文件夹下window.APP_CONFIG.API_BASE_URL的值
  • ¥50 浦育平台scratch图形化编程
  • ¥20 求这个的原理图 只要原理图
  • ¥15 vue2项目中,如何配置环境,可以在打完包之后修改请求的服务器地址
  • ¥20 微信的店铺小程序如何修改背景图
  • ¥15 UE5.1局部变量对蓝图不可见