import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
def preprocess(x,y):
y = tf.squeeze(y)
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
x = tf.cast(x,dtype=tf.float32) / 255.
return x,y
batchsize = 128
(x,y),(x_test,y_test) = datasets.cifar10.load_data()
print('datasets:',x.shape,y.shape,x.min(),x.max())
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(batchsize)
class MyDense(layers.Layer):
# 自定义层
def __init__(self,inp_dim,outp_dim):
super(MyDense,self).__init__()
self.kernel = self.add_variable('w',[inp_dim,outp_dim])
# self.bias = self.add_variable('b',[outp_dim])
def call(self,inputs,training=None):
x = inputs @ self.kernel
return x
class MyNetwork(keras.Model):
def __init__(self):
super(MyNetwork,self).__init__()
self.fc1 = MyDense(32*32*3,256)
self.fc2 = MyDense(256, 128)
self.fc3 = MyDense(128, 64)
self.fc4 = MyDense(64, 32)
self.fc5 = MyDense(32, 10)
def call(self,inputs,training=None):
x = tf.reshape(inputs,[-1,32*32*3])
x = self.fc1(x)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
x = self.fc5(x)
return x
network = MyNetwork()
network.compile(
optimizer=optimizers.Adam(lr=0.001),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)
network.evaluate(test_db)
报错信息:
File "E:/python/TF2.0/study_bili/CIFAR10.py", line 66, in <module>
network.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1133, in fit
return_dict=True)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1379, in evaluate
tmp_logs = test_function(iterator)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\def_function.py", line 846, in _call
return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\function.py", line 1848, in _filtered_call
cancellation_manager=cancellation_manager)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\function.py", line 1924, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\function.py", line 550, in call
ctx=ctx)
File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[16384,10] labels_size=[163840,10]
[[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at E:/python/TF2.0/study_bili/CIFAR10.py:66) ]] [Op:__inference_test_function_1742]
Function call stack:
test_function