使用tensorflow的API dataset遇到memoryerror 5C

使用Tensorflow的API dataset的时候遇到了memoryerror,可是我是使用官方推荐的占位符的方法啊,我的系统是ubuntu 18.0.4,tensorflow 的版本是1.13.1,Python3.6,先上代码:

def main(_):
  if FLAGS.self_test:

    train_data, train_labels = fake_data(256)
    validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
    test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
    num_epochs = 1
  else:

    stft_training, mfcc_training, labels_training = joblib.load(open(FLAGS.input, mode='rb'))
    stft_training = numpy.array(stft_training)
    mfcc_training = numpy.array(mfcc_training)
    labels_training = numpy.array(labels_training)
    stft_shape = stft_training.shape
    stft_shape = (None, stft_shape[1], stft_shape[2])

    mfcc_shape = mfcc_training.shape
    mfcc_shape = (None, mfcc_shape[1], mfcc_shape[2])

    labels_shape = labels_training.shape
    labels_shape = (None)

    stft_placeholder = tf.placeholder(stft_training.dtype, stft_shape)
    labels_placeholder = tf.placeholder(labels_training.dtype, labels_shape)
    mfcc_placeholder = tf.placeholder(mfcc_training.dtype, mfcc_shape)
    dataset_training = tf.data.Dataset.from_tensor_slices((stft_placeholder, mfcc_placeholder, labels_placeholder))
    dataset_training  = dataset_training .apply(
        tf.data.experimental.shuffle_and_repeat(len(stft_training), None))  
    dataset_training  = dataset_training .batch(BATCH_SIZE)
    dataset_training  = dataset_training .prefetch(1)
    iterator_training = dataset_training.make_initializable_iterator()
    next_element_training = iterator_training.get_next()
    num_epochs = NUM_EPOCHS

  train_size = labels_training.shape[0]
  stft = tf.placeholder(
      data_type(),
      shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WEITH, NUM_CHANNELS))
  mfcc = tf.placeholder(
      data_type(),
      shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WEITH, NUM_CHANNELS))
  labels = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))

  model = BRN(stft, mfcc)

  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True  

  with tf.Session(config=config) as sess:
    tf.global_variables_initializer().run()
    train_writer = tf.summary.FileWriter(log_dir + 'train', sess.graph)
    converter = tf.lite.TFLiteConverter.from_session(sess, [stft,mfcc], [logits])
    tflite_model = converter.convert()
    open("BRN.tflite", "wb").write(tflite_model)
    print('Initialized!')
    sess.run(iterator_training.initializer, feed_dict={stft_placeholder:stft_training,
                      mfcc_placeholder:stft_training,
                      labels_placeholder:stft_training})

报错信息:
图片说明

1个回答

可能是内存不足,一个是你的系统内存,一个是你GPU上的显存,建议你系统内存32GB或者更多,显存8GB或者更多。这样跑官方的例子才基本没问题。
你可以尝试把batch调小一点点,或者换一个好的计算机。

Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问