mmmmayi 2019-07-21 16:52 采纳率: 0%
浏览 585
已结题

使用tensorflow的API dataset遇到memoryerror

使用Tensorflow的API dataset的时候遇到了memoryerror,可是我是使用官方推荐的占位符的方法啊,我的系统是ubuntu 18.0.4,tensorflow 的版本是1.13.1,Python3.6,先上代码:

def main(_):
  if FLAGS.self_test:

    train_data, train_labels = fake_data(256)
    validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
    test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
    num_epochs = 1
  else:

    stft_training, mfcc_training, labels_training = joblib.load(open(FLAGS.input, mode='rb'))
    stft_training = numpy.array(stft_training)
    mfcc_training = numpy.array(mfcc_training)
    labels_training = numpy.array(labels_training)
    stft_shape = stft_training.shape
    stft_shape = (None, stft_shape[1], stft_shape[2])

    mfcc_shape = mfcc_training.shape
    mfcc_shape = (None, mfcc_shape[1], mfcc_shape[2])

    labels_shape = labels_training.shape
    labels_shape = (None)

    stft_placeholder = tf.placeholder(stft_training.dtype, stft_shape)
    labels_placeholder = tf.placeholder(labels_training.dtype, labels_shape)
    mfcc_placeholder = tf.placeholder(mfcc_training.dtype, mfcc_shape)
    dataset_training = tf.data.Dataset.from_tensor_slices((stft_placeholder, mfcc_placeholder, labels_placeholder))
    dataset_training  = dataset_training .apply(
        tf.data.experimental.shuffle_and_repeat(len(stft_training), None))  
    dataset_training  = dataset_training .batch(BATCH_SIZE)
    dataset_training  = dataset_training .prefetch(1)
    iterator_training = dataset_training.make_initializable_iterator()
    next_element_training = iterator_training.get_next()
    num_epochs = NUM_EPOCHS

  train_size = labels_training.shape[0]
  stft = tf.placeholder(
      data_type(),
      shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WEITH, NUM_CHANNELS))
  mfcc = tf.placeholder(
      data_type(),
      shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WEITH, NUM_CHANNELS))
  labels = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))

  model = BRN(stft, mfcc)

  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True  

  with tf.Session(config=config) as sess:
    tf.global_variables_initializer().run()
    train_writer = tf.summary.FileWriter(log_dir + 'train', sess.graph)
    converter = tf.lite.TFLiteConverter.from_session(sess, [stft,mfcc], [logits])
    tflite_model = converter.convert()
    open("BRN.tflite", "wb").write(tflite_model)
    print('Initialized!')
    sess.run(iterator_training.initializer, feed_dict={stft_placeholder:stft_training,
                      mfcc_placeholder:stft_training,
                      labels_placeholder:stft_training})

报错信息:
图片说明

  • 写回答

1条回答 默认 最新

  • threenewbee 2019-07-21 17:01
    关注

    可能是内存不足,一个是你的系统内存,一个是你GPU上的显存,建议你系统内存32GB或者更多,显存8GB或者更多。这样跑官方的例子才基本没问题。
    你可以尝试把batch调小一点点,或者换一个好的计算机。

    评论

报告相同问题?

悬赏问题

  • ¥15 LiBeAs的带隙等于0.997eV,计算阴离子的N和P
  • ¥15 关于#windows#的问题:怎么用WIN 11系统的电脑 克隆WIN NT3.51-4.0系统的硬盘
  • ¥15 来真人,不要ai!matlab有关常微分方程的问题求解决,
  • ¥15 perl MISA分析p3_in脚本出错
  • ¥15 k8s部署jupyterlab,jupyterlab保存不了文件
  • ¥15 ubuntu虚拟机打包apk错误
  • ¥199 rust编程架构设计的方案 有偿
  • ¥15 回答4f系统的像差计算
  • ¥15 java如何提取出pdf里的文字?
  • ¥100 求三轴之间相互配合画圆以及直线的算法