machine-think 2021-04-24 22:16 采纳率: 25%
浏览 48

tf2.0自定义网络训练CIFAR10

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics

def preprocess(x,y):
    y = tf.squeeze(y)
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    x = tf.cast(x,dtype=tf.float32) / 255.

    return x,y

batchsize = 128
(x,y),(x_test,y_test) = datasets.cifar10.load_data()

print('datasets:',x.shape,y.shape,x.min(),x.max())

train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(batchsize)


class MyDense(layers.Layer):
    # 自定义层
    def __init__(self,inp_dim,outp_dim):
        super(MyDense,self).__init__()

        self.kernel = self.add_variable('w',[inp_dim,outp_dim])
        # self.bias = self.add_variable('b',[outp_dim])

    def call(self,inputs,training=None):
        x = inputs @ self.kernel
        return x

class MyNetwork(keras.Model):
    def __init__(self):
        super(MyNetwork,self).__init__()

        self.fc1 = MyDense(32*32*3,256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self,inputs,training=None):
        x = tf.reshape(inputs,[-1,32*32*3])
        x = self.fc1(x)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        return x

network = MyNetwork()
network.compile(
    optimizer=optimizers.Adam(lr=0.001),
    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

network.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)


network.evaluate(test_db)

报错信息:

 File "E:/python/TF2.0/study_bili/CIFAR10.py", line 66, in <module>
    network.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1133, in fit
    return_dict=True)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1379, in evaluate
    tmp_logs = test_function(iterator)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\def_function.py", line 846, in _call
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\function.py", line 550, in call
    ctx=ctx)
  File "D:\Anaconda3\envs\TF\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  logits and labels must be broadcastable: logits_size=[16384,10] labels_size=[163840,10]
	 [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at E:/python/TF2.0/study_bili/CIFAR10.py:66) ]] [Op:__inference_test_function_1742]

Function call stack:
test_function
  • 写回答

1条回答 默认 最新

  • 半调子全栈 2023-04-22 16:48
    关注

    根据错误信息可以看出,logits_size=[16384,10]labels_size=[163840,10]不一致,导致不能广播,因此出现了InvalidArgumentError错误。

    具体来说,logits_size表示的是模型输出的形状,这里是[batch_size, 10],而labels_size表示的是标签的形状,这里是[batch_size*10, 10]。这是由于在preprocess()函数中使用了tf.one_hot()函数将标签进行了独热编码,将形状从[batch_size, 1]变成了[batch_size, 10]

    为了解决这个问题,可以将train_dbtest_db中的标签不进行独热编码,即将preprocess()函数中的以下代码:

    y = tf.one_hot(y, depth=10)
    

    改为:

    y = tf.cast(y, dtype=tf.int32)
    

    这样可以保证标签的形状与模型输出的形状一致,从而解决这个错误。

    评论

报告相同问题?

悬赏问题

  • ¥15 写一个方法checkPerson,入参实体类Person,出参布尔值
  • ¥15 我想咨询一下路面纹理三维点云数据处理的一些问题,上传的坐标文件里是怎么对无序点进行编号的,以及xy坐标在处理的时候是进行整体模型分片处理的吗
  • ¥15 CSAPPattacklab
  • ¥15 一直显示正在等待HID—ISP
  • ¥15 Python turtle 画图
  • ¥15 关于大棚监测的pcb板设计
  • ¥15 stm32开发clion时遇到的编译问题
  • ¥15 lna设计 源简并电感型共源放大器
  • ¥15 如何用Labview在myRIO上做LCD显示?(语言-开发语言)
  • ¥15 Vue3地图和异步函数使用