错误说明:
在使用tensorflow.keras加载包含多个自定义loss预训练神经网络时,Visual Studio Code蹿红:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "D:\ProgramFiles\Anaconda3\envs\MatlabEngine\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "D:\ProgramFiles\Anaconda3\envs\MatlabEngine\lib\site-packages\keras\dtensor\utils.py", line 144, in _wrap_function
init_method(instance, *args, **kwargs)
TypeError: __init__() got an unexpected keyword argument 'reduction'
加载模型代码为:
M = tf.keras.models.load_model('D:/MODEL/my_model.h5', {'user_loss1': user_loss1, 'user_loss2': user_loss2})
自定义损失函数如下:
# Define the loss functions
def user_loss1(y_true, y_pred):
loss = tf.keras.backend.mean(tf.keras.backend.std(y_true - y_pred)) / tf.keras.backend.std(tf.keras.backend.square(y_true))
return loss
def user_loss2(y_true, y_pred):
loss = tf.keras.backend.std(y_true - y_pred)
return loss
环境配置如下:
# 输入:
Python --version
conda list tensorflow
# 输出:
Python 3.9.16
# packages in environment at D:\ProgramFiles\Anaconda3\envs\MatlabEngine:#
# Name Version Build Channel
tensorflow 2.11.1 pypi_0 pypi
tensorflow-estimator 2.11.0 pypi_0 pypi
tensorflow-intel 2.11.1 pypi_0 pypi
tensorflow-io-gcs-filesystem 0.31.0 pypi_0 pypi
测试描述:
单一只使用user_loss1或者user_loss2时,该报错消失,而给出如下报错
M = tf.keras.models.load_model('D:/MODEL/my_model.h5', {'user_loss1': user_loss1})
# 输出
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "D:\ProgramFiles\Anaconda3\envs\MatlabEngine\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "D:\ProgramFiles\Anaconda3\envs\MatlabEngine\lib\site-packages\keras\saving\legacy\serialization.py", line 557, in deserialize_keras_object
raise ValueError(
ValueError: Unknown loss function: 'user_loss2'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
有遇到过类似错误的小伙伴不?后面是如何解决的呀?