使用python做道路裂纹缺陷检测(语义分割),训练模型时一直没有得到正确的结果
训练了8个周期,从第二个周期开始准确率一直都没有变化
第8个周期时,模型输出的结果:
不知道是我处理输出模型输出结果的问题,还是模型本身的问题
这是我的代码
from torch.utils import data
from torch import nn, optim
from torchvision import transforms
import torch
from PIL import Image
from matplotlib import pyplot as plt
import os
device = torch.device("cuda:0")
transformer = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
def booltf(tensor):
tensor[tensor >= 0.5] = 1
tensor[tensor < 0.5] = 0
return tensor
# 定义dataset
class SegmentationDataset(data.Dataset):
def __init__(self, img_file, label_file):
self.imgs = []
self.labels = []
for img, label in zip(os.listdir(img_file), os.listdir(label_file)):
self.imgs.append(os.path.join(img_file, img))
self.labels.append(os.path.join(label_file, label))
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
img = Image.open(img)
label = Image.open(label)
img_tensor = transformer(img)
label_tensor = transformer(label)
label_tensor = booltf(label_tensor)
label_tensor = label_tensor.squeeze().long()
return img_tensor, label_tensor
def __len__(self):
return len(self.imgs)
# 使用matplotlib显示图像
def toview(tensor, is_gray=True):
"""
:param tensor: 传入一个tensor
:param is_gray: 如果为True,则显示灰度图片,否则显示彩色图片
"""
img = transforms.ToPILImage()(tensor.float())
if is_gray:
plt.imshow(img, cmap='gray')
else:
plt.imshow(img)
# 定义Unet模型
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class Upsample(nn.Module):
def __init__(self, channels):
super(Upsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(2 * channels, channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
self.upconv = nn.Sequential(
nn.ConvTranspose2d(channels, channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
)
def forward(self, x):
x = self.conv_relu(x)
x = self.upconv(x)
return x
class Unet_model(nn.Module):
def __init__(self):
super(Unet_model, self).__init__()
self.down1 = Downsample(3, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.down5 = Downsample(512, 1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU()
)
self.up1 = Upsample(512)
self.up2 = Upsample(256)
self.up3 = Upsample(128)
self.conv_2 = Downsample(128, 64)
self.last = nn.Conv2d(64, 2, kernel_size=1)
def forward(self, input):
x1 = self.down1(input, is_pool=False)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x5 = self.up(x5)
x5 = torch.cat([x4, x5], dim=1)
x5 = self.up1(x5)
x5 = torch.cat([x3, x5], dim=1)
x5 = self.up2(x5)
x5 = torch.cat([x2, x5], dim=1)
x5 = self.up3(x5)
x5 = torch.cat([x1, x5], dim=1)
x5 = self.conv_2(x5, is_pool=False)
x5 = self.last(x5)
return x5
img_file = r"./data/CrackForest-dataset-master/image"
label_file = r"./data/CrackForest-dataset-master/groundTruthPngImg"
batch_size = 4
dataset = SegmentationDataset(img_file, label_file)
train_data = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
net = Unet_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())
accuracy = []
for epoch in range(1, 21):
correct = 0
total = 0
for image, label in train_data:
image, label = image.to(device), label.to(device)
out = net(image)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
out = torch.argmax(out, dim=1)
correct += (out == label).sum().item()
total += batch_size * 256 * 256
accuracy.append((correct / total) * 100)
plt.subplot(1, 3, 1)
toview(image[0], is_gray=False)
plt.xlabel('input')
plt.subplot(1, 3, 2)
toview(label[0])
plt.xlabel('label')
plt.subplot(1, 3, 3)
toview(out[0])
plt.xlabel('out')
plt.show()
print("epoch:{}, accuracy:{:.4f}%".format(epoch, accuracy[-1]))
思路:先读入image和label,并转化成tensor,因为是二分类的问题并且使用ToTensor后label_tensor的值会转化为[0,1],所以我设置label_tensor中大于0.5的值统一为1,小于0.5的值统一为1,在代码中我定义了一个函数
def booltf(tensor):
tensor[tensor >= 0.5] = 1
tensor[tensor < 0.5] = 0
return tensor
然后使用nn.CrossEntropyLoss()计算损失,然后再优化函数,更新参数。
再使用torch.argmax(out, dim=1),将输出的out转换成我想要的二分类的图像,最后通过ToPILImage和plt将图片显示出来