「已注销」 2019-05-04 14:25 采纳率: 0%
浏览 3364

TensorFlow的Keras如何使用Dataset作为数据输入?

当我把dataset作为输入数据是总会报出如下错误,尽管我已经在数据解析那里reshape了图片大小为(512,512,1),请问该如何修改?

ValueError: Error when checking input: expected conv2d_input to have 4 dimensions, but got array with shape (None, 1)

图片大小定义

import tensorflow as tf
from  tensorflow import keras

IMG_HEIGHT = 512
IMG_WIDTH = 512
IMG_CHANNELS = 1
IMG_PIXELS = IMG_CHANNELS * IMG_HEIGHT * IMG_WIDTH

解析函数

def parser(record):
    features = tf.parse_single_example(record, features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([23], tf.int64)
    })
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = tf.cast(features['label'], tf.int32)

    image.set_shape([IMG_PIXELS])
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
    image = tf.cast(image, tf.float32)

    return image, label

模型构建

dataset = tf.data.TFRecordDataset([TFRECORD_PATH])
dataset.map(parser)
dataset = dataset.repeat(10*10).batch(10)

model = keras.Sequential([
        keras.layers.Conv2D(filters=32, kernel_size=(5, 5), padding='same', activation='relu', input_shape=(512, 512, 1)),
        keras.layers.MaxPool2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),
        keras.layers.MaxPool2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dropout(0.25),
        keras.layers.Dense(23, activation='softmax')
    ])

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.sparse_categorical_crossentropy,
              metrics=[tf.keras.metrics.categorical_accuracy])

model.fit(dataset.make_one_shot_iterator(), epochs=10, steps_per_epoch=10)
  • 写回答

1条回答 默认 最新

  • hffgggggg 2019-08-03 00:46
    关注

    可能需要将最后一行代码修改成如下:

    model.fit(dataset.make_one_shot_iterator().get_next(), epochs=10, steps_per_epoch=10)
    

    或者这样:

    model.fit(dataset, epochs=10, steps_per_epoch=10)
    
    评论

报告相同问题?

悬赏问题

  • ¥88 找成都本地经验丰富懂小程序开发的技术大咖
  • ¥15 如何处理复杂数据表格的除法运算
  • ¥15 如何用stc8h1k08的片子做485数据透传的功能?(关键词-串口)
  • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
  • ¥200 uniapp长期运行卡死问题解决
  • ¥15 请教:如何用postman调用本地虚拟机区块链接上的合约?
  • ¥15 为什么使用javacv转封装rtsp为rtmp时出现如下问题:[h264 @ 000000004faf7500]no frame?
  • ¥15 乘性高斯噪声在深度学习网络中的应用
  • ¥15 关于docker部署flink集成hadoop的yarn,请教个问题 flink启动yarn-session.sh连不上hadoop,这个整了好几天一直不行,求帮忙看一下怎么解决
  • ¥15 深度学习根据CNN网络模型,搭建BP模型并训练MNIST数据集