椒盐玛奇朵 2021-03-02 21:27 采纳率: 0%
浏览 333
已结题

flow_from_directory生成图像的通道数如何放到第一维度?

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense,Input
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
from keras.applications import VGG16, ResNet50
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.layers import Activation, Dropout, Flatten, Dense
import numpy as np
from keras import backend as K
from keras.models import Model
from keras.callbacks import EarlyStopping
import keras
from keras.models import load_model


# K.set_image_dim_format('th')
keras.backend.set_image_data_format('channels_first')

WEIGHTS_PATH = 'vgg16_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
img_width, img_height = 224,224

#model = VGG16(include_top=False, weights='imagenet')

input_tensor = Input(shape=(3, img_width, img_height)) # 当使用不包括top的VGG16时,要指定输入的shape,否则会报错
model = ResNet50(include_top=False, weights=None, input_tensor=input_tensor)
print('Model loaded.')
model.load_weights('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')


x = model.output
x = Flatten()(x)
x = Dense(256,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(10, activation = 'softmax')(x)

model2 = Model(inputs=model.input, outputs=x)
#
#
# #model2 = load_model('mstar.h5')
#
#
for layer in model2.layers[:45]: # set the first 11 layers(fine tune conv4 and conv5 block can also further improve accuracy
    layer.trainable = False
model2.compile(loss='categorical_crossentropy',
              optimizer = SGD(lr=1e-3, momentum=0.9),#SGD(lr=1e-3,momentum=0.9)
              metrics=['accuracy'])
#
#
#
#
train_data_dir = 'D:\\workspace\\fzp\\SENSORS\\SARData\\MSTAR\MSTAR_Code\\ML-CV-master\\ML-CV-master\\MSTAR_ATR\\data\\train'
validation_data_dir =  'D:\\workspace\\fzp\\SENSORS\\SARData\\MSTAR\MSTAR_Code\\ML-CV-master\\ML-CV-master\\MSTAR_ATR\\data\\test'
#img_width, img_height = 128, 128
nb_train_samples = 2536
nb_validation_samples = 2636
epochs = 200
batch_size = 16
#
#
train_datagen = ImageDataGenerator(  # 图片数据生成器 用来扩充数据集大小,增强模型的泛化能力
        rescale=1./255,       # 将0-255值的图像映射到0-1
        shear_range=0.2,      # 剪切强度(逆时针方向的剪切变换角度
        rotation_range=10.,   # 图片随机转动的角度
        zoom_range=0.2,       # 随机缩放的幅度
        horizontal_flip=True)   # 进行随机水平翻转
#
test_datagen = ImageDataGenerator(rescale=1./255)
#
# # 图片generator
train_generator = train_datagen.flow_from_directory( # 以文件夹路径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生batch数据
        train_data_dir,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        class_mode='categorical')   # "categorical"会返回2D的one-hot编码标签
#
#
#
validation_generator = test_datagen.flow_from_directory(
        validation_data_dir,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        class_mode='categorical')

early_stopping = EarlyStopping(monitor='val_loss', patience=3)
#
# #model2.load_weights('mstar.h5')
#
model2.fit_generator(
        train_generator,
        steps_per_epoch=nb_train_samples // batch_size,
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=nb_validation_samples // batch_size)
#
#

以上程序为用Res50Net训练识别MSTAR数据。这里使用了预训练+fine tune 只调整最后5层的参数,前45层参数不变。程序报错:ValueError: Error when checking input: expected input_1 to have shape (3, 224, 224) but got array with shape (224, 224, 3)

经查,flow_from_directory中的color参数默认时会转为rgb,但是通道数在最后,但是卷积层要求通道数在第一维,请问如何更改。这里全程看不到train和test的数据,无法使用reshape

  • 写回答

1条回答 默认 最新

  • 歇歇 2021-03-03 10:35
    关注

    img_width, img_height = 224,224 

    input_tensor = Input(shape=(3, img_width, img_height))

    已经指定了形状

     不能变成这样(224, 224, 3)

    评论

报告相同问题?

悬赏问题

  • ¥30 这是哪个作者做的宝宝起名网站
  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!