def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
随机获取训练图片
dataiter = iter(trainloader)
images, labels = dataiter.next()
显示图片
imshow(torchvision.utils.make_grid(images))
打印图片标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))