引自免费微信小程序:皆我百晓生
要解决在加载CIFAR-10数据集时遇到的字典键名不匹配的问题,首先需要理解Python处理字节串(bytes)和字符串(str)的区别。在Python 3中,从文件读取的数据默认是字节串类型,而我们通常在编程中使用的变量名和键名是字符串类型。
在你提供的链接中,处理CIFAR-10数据集的代码可能使用了pickle模块来读取数据文件。pickle模块在读取数据时会返回字节类型的键名,例如b'data'
、b'labels'
等。如果你期望的是字符串类型的键名(如'data'
、'labels'
),则需要进行相应的转换。
下面是一个修改后的代码示例,用于从CIFAR-10数据文件中读取数据并转换字节类型的键名为字符串类型:
import pickle
import numpy as np
def load_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = pickle.load(f, encoding='latin1')
# Convert byte keys to string keys
datadict = {key.decode('utf-8'): value for key, value in datadict.items()}
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
Y = np.array(Y)
return X, Y
# Example usage
train_data = []
train_labels = []
for i in range(1,6):
f = './cifar-10-batches-py/data_batch_' + str(i)
data, labels = load_batch(f)
train_data.append(data)
train_labels.append(labels)
# Combine all training batches
train_data = np.concatenate(train_data)
train_labels = np.concatenate(train_labels)
这段代码的关键在于使用字典推导式将所有字节类型的键名转换为字符串类型:
datadict = {key.decode('utf-8'): value for key, value in datadict.items()}
这样,无论原始数据中的键名是字节类型还是字符串类型,你都能统一地使用字符串类型的键名来访问数据。
确保你已经正确安装了所需的库,并且CIFAR-10数据集已经被下载并放置在正确的目录下。上述代码假设数据文件位于当前目录下的./cifar-10-batches-py/
目录中。根据实际情况调整路径即可。