「已注销」 2019-05-04 14:25 采纳率: 0%
浏览 3364

TensorFlow的Keras如何使用Dataset作为数据输入?

当我把dataset作为输入数据是总会报出如下错误,尽管我已经在数据解析那里reshape了图片大小为(512,512,1),请问该如何修改?

ValueError: Error when checking input: expected conv2d_input to have 4 dimensions, but got array with shape (None, 1)

图片大小定义

import tensorflow as tf
from  tensorflow import keras

IMG_HEIGHT = 512
IMG_WIDTH = 512
IMG_CHANNELS = 1
IMG_PIXELS = IMG_CHANNELS * IMG_HEIGHT * IMG_WIDTH

解析函数

def parser(record):
    features = tf.parse_single_example(record, features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([23], tf.int64)
    })
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = tf.cast(features['label'], tf.int32)

    image.set_shape([IMG_PIXELS])
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
    image = tf.cast(image, tf.float32)

    return image, label

模型构建

dataset = tf.data.TFRecordDataset([TFRECORD_PATH])
dataset.map(parser)
dataset = dataset.repeat(10*10).batch(10)

model = keras.Sequential([
        keras.layers.Conv2D(filters=32, kernel_size=(5, 5), padding='same', activation='relu', input_shape=(512, 512, 1)),
        keras.layers.MaxPool2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),
        keras.layers.MaxPool2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dropout(0.25),
        keras.layers.Dense(23, activation='softmax')
    ])

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.sparse_categorical_crossentropy,
              metrics=[tf.keras.metrics.categorical_accuracy])

model.fit(dataset.make_one_shot_iterator(), epochs=10, steps_per_epoch=10)
  • 写回答

1条回答

  • hffgggggg 2019-08-03 00:46
    关注

    可能需要将最后一行代码修改成如下:

    model.fit(dataset.make_one_shot_iterator().get_next(), epochs=10, steps_per_epoch=10)
    

    或者这样:

    model.fit(dataset, epochs=10, steps_per_epoch=10)
    
    评论

报告相同问题?

悬赏问题

  • ¥15 素材场景中光线烘焙后灯光失效
  • ¥15 请教一下各位,为什么我这个没有实现模拟点击
  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 ubuntu子系统密码忘记
  • ¥15 保护模式-系统加载-段寄存器