这个代码基本是按照tensorflow官方教程里面的代码写的,应该是一模一样了,但是却报错了
def _bytes_feature(value):
if isinstance(value,type(tf.constant(0))):
value=value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def feature_to_string(feature):
strings=feature.SerializerToString()
return strings
n_boservations=int(1e4)
feature0=np.random.choice([False,True],n_boservations)
feature1=np.random.randint(0,5,n_boservations)
strings=np.array([b'cat',b'dog',b'chicken',b'horse',b'goat'])
feature2=strings[feature1]
feature3=np.random.randn(n_boservations)
#构建Example
def serialize_example(feature0,feature1,feature2,feature3):
feature={
'feature0':_int64_feature(feature0),
'feature1':_int64_feature(feature1),
'feature2':_bytes_feature(feature2),
'feature3':_float_feature(feature3)
}
example_proto=tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
def generator():
for features in features_dataset:
yield serialize_example(*features)
#对数据集进行处理
serialized_features_dataset = tf.data.Dataset.from_generator(
generator, output_types=tf.string, output_shapes=())
#写入文件
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)