杰少123 2020-07-28 10:50 采纳率: 75%
浏览 2231
已采纳

modle.fit训练时出现InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [20,40] and labels shape [800]

```modle.fit训练时出现InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [20,40] and labels shape [800] [[node sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at :1) ]] [Op:__inference_train_function_2847]


import os
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# 生成验证码的字符集
# CHAR_SET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'I', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
#             'W', 'S', 'Y', 'Z', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0']
CHAR_SET=['0','1','2','3','4','5','6','7','8','9']
CHAR_SET_LEN = len(CHAR_SET)
# 验证码长度
CAPTCHA_LEN = 4

#lable one_hot
def text2label(text):
    label = np.zeros(CAPTCHA_LEN * CHAR_SET_LEN)
    for i in range(len(text)):
        idx = i * CHAR_SET_LEN + CHAR_SET.index(text[i])
        label[idx] = 1
    return label

def decode_image_and_resize(filename,labels):
    image_string=tf.io.read_file(filename)
    image_decoded=tf.image.decode_jpeg(image_string)
    image_resize=tf.image.resize(image_decoded,[60,160])/255.0
    return image_resize,labels

#获取图片地址和标签
def read_images_filename(filetype):
    path='./data/captcha/'
    image_path=path+filetype

    image_name=tf.constant([image_path+ fn for fn in os.listdir(image_path)])
    labels=tf.constant([float(fn.split('_')[0]) for fn in os.listdir(image_path)],tf.float32)
    labels_one_hot=tf.constant([text2label(fn.split('_')[0]) for fn in os.listdir(image_path)],tf.float32)
    return image_name,labels_one_hot,labels


# dataset处理
def prepare_dataset(data_dir):
    buffer_size = 2000
    batch_size = 20
    filename, y_train, yuan_train = read_images_filename(data_dir)

    dataset = tf.data.Dataset.from_tensor_slices((filename, y_train))
    dataset = dataset.map(map_func=decode_image_and_resize,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(batch_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset, yuan_train

model=keras.Sequential()

model.add(keras.layers.Conv2D(filters=32,
                              kernel_size=(3,3),
                              input_shape=(60,160,3),
                              activation='relu',
                              padding='same'))

model.add(keras.layers.Conv2D(filters=32,
                              kernel_size=(3,3),
                              activation='relu',
                              padding='same'))
model.add(keras.layers.MaxPool2D(pool_size=(2,2)))

model.add(keras.layers.Dropout(rate=0.3))
model.add(keras.layers.Conv2D(filters=64,
                              kernel_size=(3,3),
                              activation='relu',
                              padding='same'))

model.add(keras.layers.Conv2D(filters=64,
                              kernel_size=(3,3),
                              activation='relu',
                              padding='same'))
model.add(keras.layers.MaxPool2D(pool_size=(2,2)))
model.add(keras.layers.Dropout(rate=0.3))

model.add(keras.layers.Flatten())

model.add(keras.layers.Dense(128,activation='relu'))

model.add(tf.keras.layers.Dense(40,activation='softmax'))

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

path='train/'
dataset_train,yuan_train=prepare_dataset(path)
history=model.fit(dataset_train,epochs=10,verbose=1)

下面是数据的shape

![图片说明](https://img-ask.csdn.net/upload/202007/28/1595903969_107088.jpg)
labels.shape从【20,40】变成了【800】



想要向各位大神请教一下,刚刚接触这个不是太懂
  • 写回答

1条回答 默认 最新

  • threenewbee 2020-07-28 14:41
    关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥20 sub地址DHCP问题
  • ¥15 delta降尺度计算的一些细节,有偿
  • ¥15 Arduino红外遥控代码有问题
  • ¥15 数值计算离散正交多项式
  • ¥30 数值计算均差系数编程
  • ¥15 redis-full-check比较 两个集群的数据出错
  • ¥15 Matlab编程问题
  • ¥15 训练的多模态特征融合模型准确度很低怎么办
  • ¥15 kylin启动报错log4j类冲突
  • ¥15 超声波模块测距控制点灯,灯的闪烁很不稳定,经过调试发现测的距离偏大