jianjiu7
jianjiu17
2019-05-18 15:22
采纳率: 0%
浏览 2.7k

pycharm运行mnist_show.py出现如下问题,

这是深度学习入门这本书里的一段代码,请问这个问题是什么意思以及怎样解决?

报错如下:(下面有源代码)Python 3.7.3 (default, Mar 27 2019, 17:13:21) [MSC v.1915 64 bit (AMD64)] on win32
runfile('E:/PycharmProjects/deep-learning-from-scratch-master/ch03/mnist_show.py', wdir='E:/PycharmProjects/deep-learning-from-scratch-master/ch03')
Converting train-images-idx3-ubyte.gz to NumPy Array ...
Traceback (most recent call last):
File "D:\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3296, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 1, in
runfile('E:/PycharmProjects/deep-learning-from-scratch-master/ch03/mnist_show.py', wdir='E:/PycharmProjects/deep-learning-from-scratch-master/ch03')
File "D:\Program Files\JetBrains\PyCharm 2019.1.1\helpers\pydev_pydev_bundle\pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "D:\Program Files\JetBrains\PyCharm 2019.1.1\helpers\pydev_pydev_imps_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:/PycharmProjects/deep-learning-from-scratch-master/ch03/mnist_show.py", line 13, in
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
File "E:\PycharmProjects\deep-learning-from-scratch-master\dataset\mnist.py", line 106, in load_mnist
init_mnist()
File "E:\PycharmProjects\deep-learning-from-scratch-master\dataset\mnist.py", line 76, in init_mnist
dataset = _convert_numpy()

源代码为:# coding: utf-8

mnist_show.py::::

import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image

def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

img = x_train[0]
label = t_train[0]
print(label) # 5

print(img.shape) # (784,)
img = img.reshape(28, 28) # 把图像的形状变为原来的尺寸
print(img.shape) # (28, 28)

img_show(img)

mnist.py:::

coding: utf-8

try:
import urllib.request
except ImportError:
raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}

dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784

def _download(file_name):
file_path = dataset_dir + "/" + file_name

if os.path.exists(file_path):
    return

print("Downloading " + file_name + " ... ")
urllib.request.urlretrieve(url_base + file_name, file_path)
print("Done")

def download_mnist():
for v in key_file.values():
_download(v)

def _load_label(file_name):
file_path = dataset_dir + "/" + file_name

print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)
print("Done")

return labels

def _load_img(file_name):
file_path = dataset_dir + "/" + file_name

print("Converting " + file_name + " to NumPy Array ...")    
with gzip.open(file_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, img_size)
print("Done")

return data

def _convert_numpy():
dataset = {}
dataset['train_img'] = _load_img(key_file['train_img'])
dataset['train_label'] = _load_label(key_file['train_label'])

dataset['test_img'] = _load_img(key_file['test_img'])
dataset['test_label'] = _load_label(key_file['test_label'])

return dataset

def init_mnist():
download_mnist()
dataset = _convert_numpy()
print("Creating pickle file ...")
with open(save_file, 'wb') as f:
pickle.dump(dataset, f, -1)
print("Done!")

def _change_one_hot_label(X):
T = np.zeros((X.size, 10))
for idx, row in enumerate(T):
row[X[idx]] = 1

return T

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""读入MNIST数据集

Parameters
----------
normalize : 将图像的像素值正规化为0.0~1.0
one_hot_label : 
    one_hot_label为True的情况下,标签作为one-hot数组返回
    one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
flatten : 是否将图像展开为一维数组

Returns
-------
(训练图像, 训练标签), (测试图像, 测试标签)
"""
if not os.path.exists(save_file):
    init_mnist()

with open(save_file, 'rb') as f:
    dataset = pickle.load(f)

if normalize:
    for key in ('train_img', 'test_img'):
        dataset[key] = dataset[key].astype(np.float32)
        dataset[key] /= 255.0

if one_hot_label:
    dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
    dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

if not flatten:
     for key in ('train_img', 'test_img'):
        dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 

if name == '__main__':
init_mnist()

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 邀请回答

相关推荐