qq_43008587
qq_43008587
采纳率33.3%
2020-04-30 19:50

不同数据集训练同一个CNN网络报错TypeError,如何解决?

20

最近在自己试着运行这个Flood-filling Networks算法,是全脑图像分割算法,基于CNN,使用Tensorflow,先附上链接:

https://github.com/google/ffn

作者给的用于训练推理的一个样本数据集是FIB-25数据集,具体来说就是果蝇的切片脑图像,用于训练的数据集的维度是520*520*520(灰度矩阵和标签都是这个维度),我使用作者给的这个数据集运行十分顺利,但是在换用另一个数据集:CREMI数据集(也是果蝇的脑切片,维度是125*1250*1250)进行网络的训练的时候,报错如下:

Traceback (most recent call last):
  File "train.py", line 739, in <module>
    app.run(main)
  File "/home/.local/lib/python3.6/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/.local/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "train.py", line 730, in main
    **json.loads(FLAGS.model_args))
  File "train.py", line 624, in train_ffn
    load_data_ops = define_data_input(model, queue_batch=1)
  File "train.py", line 412, in define_data_input
    0]))
  File "/home/.local/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2649, in equal
    "Equal", x=x, y=y, name=name)
  File "/home/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper
    param_name=input_name)
  File "/home/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'x' has DataType uint64 not in list of allowed values: bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, quint8, qint8, qint32, string, bool, complex128

看上去问题的核心在于train.py中的

    **json.loads(FLAGS.model_args))

这一行,但是这里的参数FLAGS.model_args是在训练模型的时候给定的参数,不管使用哪个数据集都是一样的,我把运行训练文件的代码贴过来:

python train.py \
--train_coords third_party/neuroproof_examples/validation_sample/tf_record_file/tf_record.tfrecords  \ #在之前的运行中生成的tfrecords文件
--data_volumes validation1:third_party/neuroproof_examples/validation_sample/grayscale_maps.h5:raw \ #灰度矩阵
--label_volumes validation1:third_party/neuroproof_examples/validation_sample/groundtruth.h5:stack \ #标签矩阵
--model_name convstack_3d.ConvStack3DFFNModel \
--model_args '{"depth": 12, "fov_size": [33, 33, 33], "deltas": [8, 8, 8]}'  \ #模型参数
--image_mean 128  \
--image_stddev 33

我怎么也想不通为什么对不同的数据集,json在解码相同的model_args时会报错……?是因为CREMI数据集的三个维数不一致,我依然沿用相同的参数导致的吗(FIB-25数据集维度是520*520*520,CREMI数据集维度是125*1250*1250)?

有些茫然,感谢大家的帮助。

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 复制链接分享
  • 邀请回答

1条回答

相关推荐