第一次搭建pytorch的cnn网络,在手写数据集mnist上CNN性能良好,但在lfw_people数据集上就出现了预测的结果全部为一类,查阅的资料发现可能是网络优化错误,但在更改网络结构或者降低学习率后情况没有解决,请各位帮帮忙,这种情况是什么问题呢?
lfw_people数据集介绍:
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_lfw_people.html
数据集共7个类别,每个类别采样后有80个图片。
预测结果截图:
** 全部代码:**
from sklearn.datasets import fetch_lfw_people
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch
import torch.utils.data as Data
from torch.autograd import Variable
import numpy as np
BATCH_SIZE = 10
EPOCH = 3
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=1, data_home='/mnt', download_if_missing=False)
# lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=1)
images = lfw_people.images / 255.
target_name = lfw_people.target
# 采样,每类80张
a = zip(images, target_name)
image = []
target = []
a0, a1, a2, a3, a4, a5, a6 = 0, 0, 0, 0, 0, 0, 0
for x, y in a:
if y == 0:
if a0 <= 80:
image.append(x)
target.append(y)
a0 += 1
elif y == 1:
if a1 <= 80:
image.append(x)
target.append(y)
a1 += 1
elif y == 2:
if a2 <= 80:
image.append(x)
target.append(y)
a2 += 1
elif y == 3:
if a3 <= 80:
image.append(x)
target.append(y)
a3 += 1
elif y == 4:
if a4 <= 80:
image.append(x)
target.append(y)
a4 += 1
elif y == 5:
if a5 <= 80:
image.append(x)
target.append(y)
a5 += 1
else:
if a6 <= 80:
image.append(x)
target.append(y)
a6 += 1
images = np.array(image)
target_name = np.array(target)
# 分测试集、训练集
X_train, X_test, y_train, y_test = train_test_split(images, target_name, train_size=0.8, random_state=42)
X_train, X_test, y_train, y_test = torch.from_numpy(X_train), torch.from_numpy(X_test), \
torch.from_numpy(y_train), torch.from_numpy(y_test)
X_train = torch.unsqueeze(X_train, dim=1).type(torch.FloatTensor)
X_test = Variable(torch.unsqueeze(X_test, dim=1)).type(torch.FloatTensor)
dataset = Data.TensorDataset(X_train, y_train)
loader = Data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
# CNN网络结构
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.Conv1 = nn.Sequential(
nn.Conv2d( # ->(1,125,95)
in_channels=1,
out_channels=16,
kernel_size=(5, 5),
stride=1,
padding=2, # if stride = 1,padding_size = (kernel_size-1)/2
),
nn.BatchNorm2d(num_features=16),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.Conv2 = nn.Sequential(
nn.Conv2d(16, 32, (5, 5), 1, 2),
nn.BatchNorm2d(num_features=32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.Conv3 = nn.Sequential(
nn.Conv2d(32, 64, (5, 5), 1, 2), # ->(16,14,14)
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.Conv4 = nn.Sequential(
nn.Conv2d(64, 128, (5, 5), 1, 2), # ->(16,14,14)
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Linear(128*7*5, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 7),
)
def forward(self, x):
x = self.Conv1(x)
x = self.Conv2(x)
x = self.Conv3(x)
x = self.Conv4(x)
x = x.view(x.size(0), -1)
output = self.fc(x)
return output
cnn = CNN()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss() # 回归问题用MSEloss,分类用CrossEntropyLoss
for epoch in range(EPOCH):
for step, (x, y) in enumerate(loader):
b_x = Variable(x)
b_y = Variable(y)
if torch.cuda.is_available():
b_x = b_x.cuda()
b_y = b_y.cuda()
cnn = cnn.cuda()
X_test = X_test.cuda()
y_test = y_test.cuda()
output = cnn(b_x)
loss = loss_function(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 5 == 0:
test_output = cnn(X_test)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
print(pred_y)
accuracy = torch.true_divide((pred_y == y_test).sum(), y_test.size(0))
print('Epoch:', epoch, '|train_loss:%.4f' % loss.item(), '|accuracy:', accuracy)