Cle?fai 2021-03-03 16:51 采纳率: 100%
浏览 257
已采纳

寻找会深度学习,图像生成方面的大神(DCGAN)

https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py

我想用github上的DCGAN的opensource来训练我自己的dataset,请问如何导入自己的dataset

  • 写回答

3条回答 默认 最新

  • ProfSnail 2021-03-04 19:44
    关注

    我把你的链接中的代码下载下来了,运行了一遍,是可以用的代码。

    dcgan.py在第109行用到了mnist.load_data()这个函数,读取的是自带的mnist.npz数据集。我看到mnist.load_data()函数的原文是这样的:

    # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    """MNIST handwritten digits dataset.
    """
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import numpy as np
    
    from tensorflow.python.keras.utils.data_utils import get_file
    from tensorflow.python.util.tf_export import keras_export
    
    
    @keras_export('keras.datasets.mnist.load_data')
    def load_data(path='mnist.npz'):
      """Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
    
      This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
      along with a test set of 10,000 images.
      More info can be found at the
      [MNIST homepage](http://yann.lecun.com/exdb/mnist/).
    
    
      Arguments:
          path: path where to cache the dataset locally
              (relative to `~/.keras/datasets`).
    
      Returns:
          Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    
          **x_train, x_test**: uint8 arrays of grayscale image data with shapes
            (num_samples, 28, 28).
    
          **y_train, y_test**: uint8 arrays of digit labels (integers in range 0-9)
            with shapes (num_samples,).
    
      License:
          Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
          which is a derivative work from original NIST datasets.
          MNIST dataset is made available under the terms of the
          [Creative Commons Attribution-Share Alike 3.0 license.](
          https://creativecommons.org/licenses/by-sa/3.0/)
      """
      origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
      path = get_file(
          path,
          origin=origin_folder + 'mnist.npz',
          file_hash=
          '731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
      with np.load(path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
    
        return (x_train, y_train), (x_test, y_test)

    可以根据该函数仿写一个读取数据的函数。经过查验,mnist.npz里面的样本是28*28的,需要缩放到28*28的样本。

    最后的函数是这样的:

    import numpy as np
    import cv2
    import cv
    import os
    import random
    
    def get_image(image_index, path=r'C:\Coding\Python\CSDN\Image\bibimbap', img_predix="hed"):
    	# 扩充前导0
    	'''
    	image_index 是数字,从0到999
    	path是数据集的绝对路径。也可以换成相对路径。
    	img_predix是数据集的前缀。
    	'''
    	image_index = "%04d" % image_index
    	image_path = os.path.join(path, img_predix+image_index+'.png')
    	# 转为灰度图
    	img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    	img = Reduce(img)
    	return img
    
    def Reduce(image):
        shrink = cv2.resize(image, (28,28), interpolation=cv2.INTER_AREA)  
        return shrink
    
    def load_data_mydatabase(path=r'C:\Coding\Python\CSDN\Image\bibimbap'):
    	# 由于mnist.npz中的数据集是单色值,28*28像素的数据。
    	# 因此需要选择预训练之后的bibimbap下head0000.png作为训练集和标签集合
    	# 由于只有1000张图,可以采用前900张作为训练集,最后100张作为数据集。
    	x_train = []
    	x_test = []
    	y_train = []
    	y_test = []
    	# 这个烤冷面我不太清楚你希望做成哪些类别,所以这里随机生成十个类别。
    	for train_index in range(900):
    		x_train.append(get_image(train_index))
    		y_train.append(int(random.uniform(0,10)))
    	x_train = np.array(x_train)
    	for test_index in range(900, 1000):
    		x_test.append(get_image(test_index))
    		y_test.append(int(random.uniform(0,10)))
    	x_test = np.array(x_test)
    	y_test = np.array(y_test)
    	return (x_train,y_train), (x_test,y_test)
    

    因为我记得你之前是要做冷面的数据集,我还下载下来了一份。但是不清楚你这是要几分类。所以我随机生成了一个十分类的标签值,题主根据自己需要生成新的标签值。

    使用的时候,将这片代码放到原代码中。并将dcgan.py中第109行换成

            (X_train, _), (_, _) = load_data_mydatabase()

    即可。
     

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

悬赏问题

  • ¥15 数据库数据成问号了,前台查询正常,数据库查询是?号
  • ¥15 算法使用了tf-idf,用手肘图确定k值确定不了,第四轮廓系数又太小才有0.006088746097507285,如何解决?(相关搜索:数据处理)
  • ¥15 彩灯控制电路,会的加我QQ1482956179
  • ¥200 相机拍直接转存到电脑上 立拍立穿无线局域网传
  • ¥15 (关键词-电路设计)
  • ¥15 如何解决MIPS计算是否溢出
  • ¥15 vue中我代理了iframe,iframe却走的是路由,没有显示该显示的网站,这个该如何处理
  • ¥15 操作系统相关算法中while();的含义
  • ¥15 CNVcaller安装后无法找到文件
  • ¥15 visual studio2022中文乱码无法解决