我的代码在最下方, 如何只运行这个
if __name__ == '__main__':
train_loader = make_loader(
yes_defect=train_yes_defect, no_defect=train_no_defect,
batch_size=args.batch_size,
transform=train_transform)
val_loader = make_loader(
yes_defect=val_yes_defect, no_defect=val_no_defect,
batch_size=args.batch_size,
transform=val_transform, train=False)
meta_model.load_state_dict(torch.load("cur.pt"))
train_model = meta_model.clone()
criterion = nn.CrossEntropyLoss()
train_optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)
# Start training
post_train(train_model, train_loader, val_loader, criterion, train_optimizer, args)
torch.save(train_model.state_dict(), "last.pt")
而不运行这个
meta_model = TrainModel().to(args.device)
criterion = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(
meta_model.parameters(), lr=args.meta_lr)
meta_train = MetaPrinterFolder(
train_no_defect, train_yes_defect, train_transform, val_transform)
meta_test = MetaPrinterFolder(
val_no_defect, val_yes_defect, train_transform, val_transform)
# Start training
meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion)
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision
from torchvision.utils import make_grid
from torchvision import transforms
from torchvision.datasets import ImageFolder
import timm
from tqdm import tqdm
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn.metrics import f1_score, accuracy_score
random.seed(0)
# Set up training dataset
train_no_defect = [
file for file in glob.glob(r"D:\\OneDrive - The University of Nottingham\\ESR1\\work\\3d printing\\dataset\\3d printing\\no_defected\\*.jpg") if "scratch_2" not in file]
train_yes_defect = [
file for file in glob.glob(r"D:\\OneDrive - The University of Nottingham\\ESR1\\work\\3d printing\\dataset\\3d printing\\defected\\*.jpg") if "no_bottom" not in file
]
train_yes_defect = random.choices(train_yes_defect, k=len(train_no_defect))
# Set up validation dataset
val_no_defect = [
file for file in glob.glob("no_defected/*.jpg") if "scratch_2" in file]
val_yes_defect = [
file for file in glob.glob("defected/*.jpg") if "no_bottom" in file]
# Count the number of the the class
# Training
count_train_no_defect = len(train_no_defect)
count_train_defect = len(train_yes_defect)
# Validation
count_val_no_defect = len(val_no_defect)
count_val_defect = len(val_yes_defect)
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
# Set up title and and x label
x_title = ["training", "validation"]
no_defect_score = [count_train_no_defect, count_val_no_defect]
defect_score = [count_train_defect, count_val_defect]
x = np.arange(len(x_title))
width = 0.3
# Plot the data
bar1 = ax.bar(x, no_defect_score, width, color="#D67D3E", label="no defect")
bar2 = ax.bar(x + width, defect_score, width, color="#F9E4D4", label="defect")
# Add heights above the bar plot
for rect, height in zip(bar1 + bar2, no_defect_score + defect_score):
height = rect.get_height()
plt.text(
rect.get_x() + rect.get_width() / 2.0, height + 2,
f"{height:.0f}", ha="center", va="bottom")
# Beautify the plot (optional)
ax.set_xticks(x + width / 2)
ax.set_xticklabels(x_title)
ax.set_yticks([])
ax.set_title("Distribution of dataset")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)
# Show the annotation to each bar
plt.legend()
plt.savefig("data.pdf", transparent=True)
class FewShotDataset(Dataset):
def __init__(self, img_list, transform):
self.img_list = img_list
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
label, path = self.img_list[idx]
img = Image.open(path)
img = self.transform(img)
return img, label
class MetaPrinterFolder:
def __init__(self, no_defect, yes_defect,
train_transform=None, val_transform=None):
self.no_defect = no_defect
self.yes_defect = yes_defect
self.train_transform = train_transform
self.val_transform = val_transform
def get_random_task(self, K=1):
train_task, _ = self.get_random_task_split(train_K=K, test_K=0)
return train_task
def get_random_task_split(self, train_K=1, test_K=1):
train_samples = []
test_samples = []
sample_num = train_K + test_K
# ====== Good list =======
for idx, path in enumerate(np.random.choice(self.no_defect, sample_num,
replace=False)):
if idx < train_K:
train_samples.append((0, path))
else:
test_samples.append((0, path))
# ====== Bad list =======
for i, path in enumerate(np.random.choice(self.yes_defect, sample_num,
replace=False)):
if i < train_K:
train_samples.append((1, path))
else:
test_samples.append((1, path))
train_task = FewShotDataset(train_samples, self.train_transform)
test_task = FewShotDataset(test_samples, self.val_transform)
return train_task, test_task
demo_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((400, 400)),
transforms.CenterCrop((352, 352)),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
meta_3d_printer = MetaPrinterFolder(
train_no_defect, train_yes_defect, demo_transform, demo_transform)
train_task = meta_3d_printer.get_random_task()
class ReptileModel(nn.Module):
def __init__(self):
super().__init__()
def point_grad_to(self, target):
self.zero_grad()
for p, target_p in zip(self.parameters(), target.parameters()):
if p.grad is None:
if self.is_cuda():
p.grad = torch.zeros(p.size()).cuda()
else:
p.grad = torch.zeros(p.size())
p.grad.data.add_(p.data - target_p.data)
def is_cuda(self):
return next(self.parameters()).is_cuda
class TrainModel(ReptileModel):
def __init__(self, model_name="resnet34", pretrained=True, num_classes=2):
super().__init__()
# Model settings
self.model_name = model_name
self.pretrained=pretrained
self.num_classes = num_classes
# Check out the doc: https://rwightman.github.io/pytorch-image-models/
# for different models
self.model = timm.create_model(model_name, pretrained=pretrained)
# Change the output linear layers to fit the output classes
self.model.fc = nn.Linear(
self.model.fc.weight.shape[1],
num_classes
)
def forward(self, x):
return self.model(x)
def clone(self):
clone = TrainModel(self.model_name, self.pretrained, self.num_classes)
clone.load_state_dict(self.state_dict())
if self.is_cuda():
clone.cuda()
return clone
@torch.no_grad()
def evaluate(model, val_loader, args):
model.eval()
total_predict = []
total_ground_truth = []
for iteration in range(args.iterations):
data, label = val_loader.__next__()
data = data.to(args.device)
label = label.to(args.device)
output = model(data)
prediction = output.argmax(dim=-1)
total_predict.extend(prediction.cpu().tolist())
total_ground_truth.extend(label.cpu().tolist())
return accuracy_score(total_ground_truth, total_predict), \
f1_score(total_ground_truth, total_predict, average="macro")
def train_iter(model, train_loader, criterion, optimizer, args):
model.train()
for iteration in range(args.iterations):
data, label = train_loader.__next__()
data = data.to(args.device)
label = label.to(args.device)
# Send data into the model and compute the loss
output = model(data)
loss = criterion(output, label)
# Update the model with back propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def get_optimizer(net, args, state=None):
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
if state is not None:
optimizer.load_state_dict(state)
return optimizer
def set_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def make_infinite(dataloader):
while True:
for x in dataloader:
yield x
def meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion):
for meta_iteration in tqdm(range(args.start_meta_iteration, args.meta_iterations)):
# Update learning rate
print('Meta Iteration: {}'.format(meta_iteration))
meta_lr = args.meta_lr * (1. - meta_iteration / float(args.meta_iterations))
set_learning_rate(meta_optimizer, meta_lr)
# Clone model
net = meta_model.clone()
optimizer = get_optimizer(net, args)
# Sample base task from Meta-Train
train_dataset = meta_train.get_random_task(args.train_shots)
infinite_train_loader = make_infinite(
DataLoader(
train_dataset, args.batch_size, shuffle=True,
num_workers=2, pin_memory=True))
# Update fast net
train_iter(net, infinite_train_loader, criterion, optimizer, args)
# Update slow net
meta_model.point_grad_to(net)
meta_optimizer.step()
# Meta-Evaluation
if meta_iteration % args.validate_every == 0:
for (meta_dataset, mode) in [(meta_test, "val")]:
train, test = meta_dataset.get_random_task_split(
train_K=args.shots, test_K=5)
infinite_train_loader = make_infinite(
DataLoader(
train, args.batch_size, shuffle=True,
num_workers=2, pin_memory=True))
infinite_test_loader = make_infinite(
DataLoader(
test, args.batch_size, shuffle=True,
num_workers=2, pin_memory=True))
# Base-train
net = meta_model.clone()
optimizer = get_optimizer(net, args)
train_iter(
net, infinite_train_loader, criterion, optimizer, args)
# Base-test: compute meta-loss, which is base-validation error
meta_acc, meta_f1 = evaluate(net, infinite_test_loader, args)
print(f"\n{mode}: f1-accuracy: {meta_f1:.3f}, acc: {meta_acc:.3f}")
if meta_iteration % args.check_every == 0:
torch.save(meta_model.state_dict(), "cur.pt")
class args:
# Training
epochs = 30
batch_size = 32
train_shots = 10
shots = 5
meta_iterations = 1000
start_meta_iteration = 0
iterations = 5
test_iterations = 50
meta_lr = 0.1
validate_every = 50
check_every = 100
lr = 3e-4
weight_decay=1e-5
device = "cuda" if torch.cuda.is_available() else "cpu"
# Transform
size = 400
crop_size = 352
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# Set up train loader and test loader
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((args.size, args.size)),
transforms.CenterCrop((args.crop_size, args.crop_size)),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=args.mean, std=args.std)
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((args.size, args.size)),
transforms.CenterCrop((args.crop_size, args.crop_size)),
transforms.Normalize(mean=args.mean, std=args.std)
])
# Set up model
meta_model = TrainModel().to(args.device)
criterion = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(
meta_model.parameters(), lr=args.meta_lr)
meta_train = MetaPrinterFolder(
train_no_defect, train_yes_defect, train_transform, val_transform)
meta_test = MetaPrinterFolder(
val_no_defect, val_yes_defect, train_transform, val_transform)
# Start training
meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion)
class ListDataset(Dataset):
def __init__(self, yes_defect, no_defect, transform=None):
self.img_list = yes_defect + no_defect
self.label = [1] * len(yes_defect) + [0] * len(no_defect)
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img = Image.open(self.img_list[idx])
label = self.label[idx]
img = self.transform(img)
return img, label
def make_loader(yes_defect, no_defect, transform, batch_size,
shuffle=True, num_workers=2, pin_memory=True,
train=True):
dataset = ListDataset(
yes_defect=yes_defect, no_defect=no_defect, transform=transform)
loader = DataLoader(
dataset, batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=pin_memory)
return loader
@torch.no_grad()
def post_train_evaluate(model, val_loader, args):
model.eval()
total_predict = []
total_ground_truth = []
for data, label in val_loader:
data = data.to(args.device)
label = label.to(args.device)
output = model(data)
prediction = output.argmax(dim=-1)
total_predict.extend(prediction.cpu().tolist())
total_ground_truth.extend(label.cpu().tolist())
return accuracy_score(total_ground_truth, total_predict), \
f1_score(total_ground_truth, total_predict, average="macro")
def post_train(model, train_loader, val_loader, criterion, optimizer, args):
best_f1 = 0
for epoch in range(args.epochs):
train_progress_bar = tqdm(
train_loader, desc=f"Epochs: {epoch + 1}/{args.epochs}")
model.train()
for data, label in train_progress_bar:
data = data.to(args.device)
label = label.to(args.device)
# Send data into the model and compute the loss
output = model(data)
loss = criterion(output, label)
# Update the model with back propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Compute the accuracy ans save the best model
eval_acc, eval_f1 = post_train_evaluate(model, val_loader, args)
print(f"Validation accuracy: {eval_acc:.8f} f1-score: {eval_f1:.8f}")
if eval_f1 > best_f1:
best_f1 = eval_f1
torch.save(model.state_dict(), "best.pt")
if __name__ == '__main__':
train_loader = make_loader(
yes_defect=train_yes_defect, no_defect=train_no_defect,
batch_size=args.batch_size,
transform=train_transform)
val_loader = make_loader(
yes_defect=val_yes_defect, no_defect=val_no_defect,
batch_size=args.batch_size,
transform=val_transform, train=False)
meta_model.load_state_dict(torch.load("cur.pt"))
train_model = meta_model.clone()
criterion = nn.CrossEntropyLoss()
train_optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)
# Start training
post_train(train_model, train_loader, val_loader, criterion, train_optimizer, args)
torch.save(train_model.state_dict(), "last.pt")