yunhen_pei 2022-06-07 10:28 采纳率: 50%
浏览 79
已结题

facenet中的train_tripletloss.py报错

问题遇到的现象和发生背景

运行facenet中的train_tripletloss.py报错,想知道是哪里出了问题?

问题相关代码
parser.add_argument('--logs_base_dir', type=str, 
        help='Directory where to write event logs.', default='~/logs/facenet')
    parser.add_argument('--models_base_dir', type=str,
        help='Directory where to write trained models and checkpoints.', default='~/models/facenet')
    parser.add_argument('--gpu_memory_fraction', type=float,
        help='Upper bound on the amount of GPU memory that will be used by the process.', default=0.7)
    parser.add_argument('--pretrained_model', type=str,
        help='Load a pretrained model before training starts.',
        default='C:/360Downloads/20200505-085843')
    parser.add_argument('--data_dir', type=str,
        help='Path to the data directory containing aligned face patches.',
        # default='~/datasets/casia/casia_maxpy_mtcnnalign_182_160')
        default = 'C:/360Downloads/DL/mask')
    parser.add_argument('--model_def', type=str,
        help='Model definition. Points to a module containing the definition of the inference graph.',
        default='models.inception_resnet_v1')
    parser.add_argument('--max_nrof_epochs', type=int,
        help='Number of epochs to run.', default=100)
    parser.add_argument('--batch_size', type=int,
        help='Number of images to process in a batch.', default=32)
    parser.add_argument('--image_size', type=int,
        help='Image size (height, width) in pixels.', default=160)
    parser.add_argument('--people_per_batch', type=int,
        help='Number of people per batch.', default=15)
    parser.add_argument('--images_per_person', type=int,
        help='Number of images per person.', default=20)
    parser.add_argument('--epoch_size', type=int,
        help='Number of batches per epoch.', default=100)
    parser.add_argument('--alpha', type=float,
        help='Positive to negative triplet distance margin.', default=0.2)
    parser.add_argument('--embedding_size', type=int,
        help='Dimensionality of the embedding.', default=128)
    parser.add_argument('--random_crop', 
        help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' +
         'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true')
    parser.add_argument('--random_flip', 
        help='Performs random horizontal flipping of training images.', action='store_true')
    parser.add_argument('--keep_probability', type=float,
        help='Keep probability of dropout for the fully connected layer(s).', default=0.5)
    parser.add_argument('--weight_decay', type=float,
        help='L2 weight regularization.', default=1e-2)
    parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'],
        help='The optimization algorithm to use', default='ADAGRAD')
    parser.add_argument('--learning_rate', type=float,
        help='Initial learning rate. If set to a negative value a learning rate ' +
        'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.1)
    parser.add_argument('--learning_rate_decay_epochs', type=int,
        help='Number of epochs between learning rate decay.', default=10)
    parser.add_argument('--learning_rate_decay_factor', type=float,
        help='Learning rate decay factor.', default=0.99)
    parser.add_argument('--moving_average_decay', type=float,
        help='Exponential decay for tracking of training parameters.', default=0.9999)
    parser.add_argument('--seed', type=int,
        help='Random seed.', default=666)
    parser.add_argument('--learning_rate_schedule_file', type=str,
        help='File containing the learning rate schedule that is used when learning_rate is set to to -1.',
        default='C:/360Downloads/facenet-master/data/learning_rate_schedule_classifier_casia.txt')

运行结果及报错内容
Traceback (most recent call last):
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 1327, in _do_call
    return fn(*args)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 1312, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 1420, in _call_tf_sessionrun
    status, run_metadata)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 516, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 4096 values, but the requested shape requires a multiple of 384
     [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embeddings, Reshape/shape)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/Users/云痕/Desktop/understand_facenet-master/understand_facenet/train_tripletloss.py", line 515, in <module>
    main(parse_arguments(sys.argv[1:]))
  File "C:/Users/云痕/Desktop/understand_facenet-master/understand_facenet/train_tripletloss.py", line 184, in main
    args.embedding_size, anchor, positive, negative, triplet_loss)
  File "C:/Users/云痕/Desktop/understand_facenet-master/understand_facenet/train_tripletloss.py", line 262, in train
    err, _, step, emb, lab = sess.run([loss, train_op, global_step, embeddings, labels_batch], feed_dict=feed_dict)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 905, in run
    run_metadata_ptr)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 1140, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 1321, in _do_run
    run_metadata)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\client\session.py", line 1340, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 4096 values, but the requested shape requires a multiple of 384
     [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embeddings, Reshape/shape)]]

Caused by op 'Reshape', defined at:
  File "C:/Users/云痕/Desktop/understand_facenet-master/understand_facenet/train_tripletloss.py", line 515, in <module>
    main(parse_arguments(sys.argv[1:]))
  File "C:/Users/云痕/Desktop/understand_facenet-master/understand_facenet/train_tripletloss.py", line 128, in main
    anchor, positive, negative = tf.unstack(tf.reshape(embeddings, [-1,3,args.embedding_size]), 3, 1)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 5781, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
    op_def=op_def)
  File "C:\Anaconda3\envs\tf1.7\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 4096 values, but the requested shape requires a multiple of 384
     [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embeddings, Reshape/shape)]]

  • 写回答

1条回答 默认 最新

  • 爱晚乏客游 2022-06-07 16:49
    关注

    embeddings为4096,后面的为3Xargs.embedding_size=384?4096无法整除以384,所以是无法reshape的,所以你应该是改了什么参数了,改回去看看

    img

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 6月15日
  • 已采纳回答 6月7日
  • 创建了问题 6月7日

悬赏问题

  • ¥20 机器学习能否像多层线性模型一样处理嵌套数据
  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效