如下图所示,其功能是将代码封装进dataset并转化我可迭代格式,但是在执行预处理map()函数的时候报错:
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype uint8: 'Tensor("arg0:0", shape=(28, 28), dtype=uint8)'
然而在删除map后正常运行,说明不是转换格式的问题,求问各位大神这是为什么呢?
报错代码:
(x,y),(x_val,y_val)=datasets.mnist.load_data()
def trans(x,y):
x=tf.convert_to_tensor(x,dtype=tf.float32)
y=tf.convert_to_tensor(y,dtype=tf.int32)
y=tf.one_hot(y,depth=10)
return x,y
train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db.map(trans)
train_db.shuffle(10000).batch(32)
正常运行:
(x,y),(x_val,y_val)=datasets.mnist.load_data()
x=tf.convert_to_tensor(x,dtype=tf.float32)
y=tf.convert_to_tensor(y,dtype=tf.int32)
y=tf.one_hot(y,depth=10)
train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db.shuffle(10000).batch(32)