问题:用了网上大佬的方法训练自己的模型识别辛普森人物,但是每当进行到如下训练函数:
training = model.fit(train_gen, steps_per_epoch=len(x_train)//BATCH_SIZE, epochs=EPOCHS, validation_data=(x_val,y_val), validation_steps=len(y_val)//BATCH_SIZE, callbacks=callbacks_list)
就会报错,并提示我X和Y的形状不同。
我观察了一下,是一轮训练完成后报的错,我在此函数之前设置如下监测:print(x_train.shape) print(y_train.shape)
输出为:(11047, 80, 80, 1)和(11047, 10)。
此为本段上下文代码:
datagen = canaro.generators.imageDataGenerator()
x_train = np.array(x_train)
y_train = np.array(y_train)
train_gen = datagen.flow(x_train, y_train, batch_size=BATCH_SIZE)
#创建模型
model = canaro.models.createSimpsonsModel(IMG_SIZE=IMG_SIZE, channels=channels, output_dim=len(characters), loss='binary_crossentropy', decay=1e-6, learning_rate=0.001, momentum=0.9, nesterov=True)
# model.summary()输出模型摘要
#回调列表
from tensorflow.keras.callbacks import LearningRateScheduler
callbacks_list = [LearningRateScheduler(canaro.lr_schedule)]
print(x_train.shape)
print(y_train.shape)
#开始训练!
training = model.fit(train_gen, steps_per_epoch=len(x_train)//BATCH_SIZE, epochs=EPOCHS, validation_data=(x_val,y_val), validation_steps=len(y_val)//BATCH_SIZE, callbacks=callbacks_list)
以下为错误信息:
File "C:\Users\abner\Desktop\文件夹\编程\编程项目\python\opencv\高级课程3:opencv深度学习网络识别辛普森人物.py", line 92, in <module>
training = model.fit(train_gen, steps_per_epoch=len(x_train)//BATCH_SIZE, epochs=EPOCHS, validation_data=(x_val,y_val), validation_steps=len(y_val)//BATCH_SIZE, callbacks=callbacks_list)
File "C:\Users\abner\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1118, in fit
self._eval_data_handler = data_adapter.DataHandler(
File "C:\Users\abner\anaconda3\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 1100, in __init__
self._adapter = adapter_cls(
File "C:\Users\abner\anaconda3\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 274, in __init__
_check_data_cardinality(inputs)
File "C:\Users\abner\anaconda3\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 1529, in _check_data_cardinality
raise ValueError(msg)
初学者看不出来怎么改,不敢乱改,求大佬们帮帮忙!