qq_39089104 2017-11-04 13:09 采纳率: 0%
浏览 2927

运用python读取cifar时出现无法解决的问题

我在运用以下代码在python中读取cifar10时,出现了无法解决的问题

代码如下:
import cPickle

import numpy as np

import os

class Cifar10DataReader():

def init(self,cifar_folder,onehot=True):

self.cifar_folder=cifar_folder

self.onehot=onehot

self.data_index=1

self.read_next=True

self.data_label_train=None

self.data_label_test=None

self.batch_index=0

def unpickle(self,f):  
    fo = open(f, 'rb')  
    d = cPickle.load(fo)  
    fo.close()  
    return d  

def next_train_data(self,batch_size=100):  
    assert 10000%batch_size==0,"10000%batch_size!=0"  
    rdata=None  
    rlabel=None  
    if self.read_next:  
        f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))  
        print 'read: %s'%f  
        dic_train=self.unpickle(f)  
        self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9  
        np.random.shuffle(self.data_label_train)  

        self.read_next=False  
        if self.data_index==5:  
            self.data_index=1  
        else:   
            self.data_index+=1  

    if self.batch_index<len(self.data_label_train)//batch_size:  
        #print self.batch_index  
        datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
        self.batch_index+=1  
        rdata,rlabel=self._decode(datum,self.onehot)  
    else:  
        self.batch_index=0  
        self.read_next=True  
        return self.next_train_data(batch_size=batch_size)  

    return rdata,rlabel  

def _decode(self,datum,onehot):  
    rdata=list();rlabel=list()  
    if onehot:  
        for d,l in datum:  
            rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
            hot=np.zeros(10)  
            hot[int(l)]=1  
            rlabel.append(hot)  
    else:  
        for d,l in datum:  
            rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
            rlabel.append(int(l))  
    return rdata,rlabel  

def next_test_data(self,batch_size=100):  
    if self.data_label_test is None:  
        f=os.path.join(self.cifar_folder,"test_batch")  
        print 'read: %s'%f  
        dic_test=self.unpickle(f)  
        data=dic_test['data']  
        labels=dic_test['labels']#0~9  
        self.data_label_test=zip(data,labels)  

    np.random.shuffle(self.data_label_test)  
    datum=self.data_label_test[0:batch_size]  

    return self._decode(datum,self.onehot)  

if name=="__main__":

dr=Cifar10DataReader(cifar_folder="/xxx/cifar-10-batches-py/")

import matplotlib.pyplot as plt

d,l=dr.next_test_data()

print np.shape(d),np.shape(l)

plt.imshow(d[0])

plt.show()

for i in xrange(600):

d,l=dr.next_train_data(batch_size=100)

print np.shape(d),np.shape(l)

报错如下
Traceback (most recent call last):
File "/home/fujiarun/PycharmProjects/untitled2/duqu1.py", line 86, in
d,l = dr.next_test_data()
File "/home/fujiarun/PycharmProjects/untitled2/duqu1.py", line 71, in next_test_data
dic_test = self.unpickle(f)
File "/home/fujiarun/PycharmProjects/untitled2/duqu1.py", line 19, in unpickle
d = cPickle.load(fo)
cPickle.UnpicklingError: invalid load key, ''.
read: /home/fujiarun/桌面/cifar-10-batches-bin/test_batch.bin

跪求各位大神帮助解决

  • 写回答

2条回答 默认 最新

  • devmiao 2017-11-05 02:07
    关注
    评论

报告相同问题?

悬赏问题

  • ¥30 这是哪个作者做的宝宝起名网站
  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!