我在运用以下代码在python中读取cifar10时,出现了无法解决的问题
代码如下:
import cPickle
import numpy as np
import os
class Cifar10DataReader():
def init(self,cifar_folder,onehot=True):
self.cifar_folder=cifar_folder
self.onehot=onehot
self.data_index=1
self.read_next=True
self.data_label_train=None
self.data_label_test=None
self.batch_index=0
def unpickle(self,f):
fo = open(f, 'rb')
d = cPickle.load(fo)
fo.close()
return d
def next_train_data(self,batch_size=100):
assert 10000%batch_size==0,"10000%batch_size!=0"
rdata=None
rlabel=None
if self.read_next:
f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
print 'read: %s'%f
dic_train=self.unpickle(f)
self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9
np.random.shuffle(self.data_label_train)
self.read_next=False
if self.data_index==5:
self.data_index=1
else:
self.data_index+=1
if self.batch_index<len(self.data_label_train)//batch_size:
#print self.batch_index
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
rdata,rlabel=self._decode(datum,self.onehot)
else:
self.batch_index=0
self.read_next=True
return self.next_train_data(batch_size=batch_size)
return rdata,rlabel
def _decode(self,datum,onehot):
rdata=list();rlabel=list()
if onehot:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
hot=np.zeros(10)
hot[int(l)]=1
rlabel.append(hot)
else:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
rlabel.append(int(l))
return rdata,rlabel
def next_test_data(self,batch_size=100):
if self.data_label_test is None:
f=os.path.join(self.cifar_folder,"test_batch")
print 'read: %s'%f
dic_test=self.unpickle(f)
data=dic_test['data']
labels=dic_test['labels']#0~9
self.data_label_test=zip(data,labels)
np.random.shuffle(self.data_label_test)
datum=self.data_label_test[0:batch_size]
return self._decode(datum,self.onehot)
if name=="__main__":
dr=Cifar10DataReader(cifar_folder="/xxx/cifar-10-batches-py/")
import matplotlib.pyplot as plt
d,l=dr.next_test_data()
print np.shape(d),np.shape(l)
plt.imshow(d[0])
plt.show()
for i in xrange(600):
d,l=dr.next_train_data(batch_size=100)
print np.shape(d),np.shape(l)
报错如下
Traceback (most recent call last):
File "/home/fujiarun/PycharmProjects/untitled2/duqu1.py", line 86, in
d,l = dr.next_test_data()
File "/home/fujiarun/PycharmProjects/untitled2/duqu1.py", line 71, in next_test_data
dic_test = self.unpickle(f)
File "/home/fujiarun/PycharmProjects/untitled2/duqu1.py", line 19, in unpickle
d = cPickle.load(fo)
cPickle.UnpicklingError: invalid load key, ''.
read: /home/fujiarun/桌面/cifar-10-batches-bin/test_batch.bin
跪求各位大神帮助解决