https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py
我想用github上的DCGAN的opensource来训练我自己的dataset,请问如何导入自己的dataset
https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py
我想用github上的DCGAN的opensource来训练我自己的dataset,请问如何导入自己的dataset
我把你的链接中的代码下载下来了,运行了一遍,是可以用的代码。
dcgan.py在第109行用到了mnist.load_data()这个函数,读取的是自带的mnist.npz数据集。我看到mnist.load_data()函数的原文是这样的:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST handwritten digits dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):
"""Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
along with a test set of 10,000 images.
More info can be found at the
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
Arguments:
path: path where to cache the dataset locally
(relative to `~/.keras/datasets`).
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: uint8 arrays of grayscale image data with shapes
(num_samples, 28, 28).
**y_train, y_test**: uint8 arrays of digit labels (integers in range 0-9)
with shapes (num_samples,).
License:
Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
which is a derivative work from original NIST datasets.
MNIST dataset is made available under the terms of the
[Creative Commons Attribution-Share Alike 3.0 license.](
https://creativecommons.org/licenses/by-sa/3.0/)
"""
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
origin=origin_folder + 'mnist.npz',
file_hash=
'731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
with np.load(path, allow_pickle=True) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
return (x_train, y_train), (x_test, y_test)
可以根据该函数仿写一个读取数据的函数。经过查验,mnist.npz里面的样本是28*28的,需要缩放到28*28的样本。
最后的函数是这样的:
import numpy as np
import cv2
import cv
import os
import random
def get_image(image_index, path=r'C:\Coding\Python\CSDN\Image\bibimbap', img_predix="hed"):
# 扩充前导0
'''
image_index 是数字,从0到999
path是数据集的绝对路径。也可以换成相对路径。
img_predix是数据集的前缀。
'''
image_index = "%04d" % image_index
image_path = os.path.join(path, img_predix+image_index+'.png')
# 转为灰度图
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = Reduce(img)
return img
def Reduce(image):
shrink = cv2.resize(image, (28,28), interpolation=cv2.INTER_AREA)
return shrink
def load_data_mydatabase(path=r'C:\Coding\Python\CSDN\Image\bibimbap'):
# 由于mnist.npz中的数据集是单色值,28*28像素的数据。
# 因此需要选择预训练之后的bibimbap下head0000.png作为训练集和标签集合
# 由于只有1000张图,可以采用前900张作为训练集,最后100张作为数据集。
x_train = []
x_test = []
y_train = []
y_test = []
# 这个烤冷面我不太清楚你希望做成哪些类别,所以这里随机生成十个类别。
for train_index in range(900):
x_train.append(get_image(train_index))
y_train.append(int(random.uniform(0,10)))
x_train = np.array(x_train)
for test_index in range(900, 1000):
x_test.append(get_image(test_index))
y_test.append(int(random.uniform(0,10)))
x_test = np.array(x_test)
y_test = np.array(y_test)
return (x_train,y_train), (x_test,y_test)
因为我记得你之前是要做冷面的数据集,我还下载下来了一份。但是不清楚你这是要几分类。所以我随机生成了一个十分类的标签值,题主根据自己需要生成新的标签值。
使用的时候,将这片代码放到原代码中。并将dcgan.py中第109行换成
(X_train, _), (_, _) = load_data_mydatabase()
即可。