Lukas00990 2022-10-04 17:03 采纳率: 40.8%
浏览 386
已结题

python 如何只运行部分代码

我的代码在最下方, 如何只运行这个

if __name__ == '__main__':

    train_loader = make_loader(
        yes_defect=train_yes_defect, no_defect=train_no_defect,
        batch_size=args.batch_size,
        transform=train_transform)
    val_loader = make_loader(
        yes_defect=val_yes_defect, no_defect=val_no_defect,
        batch_size=args.batch_size,
        transform=val_transform, train=False)

    meta_model.load_state_dict(torch.load("cur.pt"))
    train_model = meta_model.clone()
    criterion = nn.CrossEntropyLoss()
    train_optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)

    # Start training
    post_train(train_model, train_loader, val_loader, criterion, train_optimizer, args)
    torch.save(train_model.state_dict(), "last.pt")

而不运行这个

meta_model = TrainModel().to(args.device)
criterion = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(
    meta_model.parameters(), lr=args.meta_lr)

meta_train = MetaPrinterFolder(
    train_no_defect, train_yes_defect, train_transform, val_transform)
meta_test = MetaPrinterFolder(
    val_no_defect, val_yes_defect, train_transform, val_transform)

# Start training
meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion)
import os
import glob
import random

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import torchvision
from torchvision.utils import make_grid
from torchvision import transforms
from torchvision.datasets import ImageFolder

import timm
from tqdm import tqdm

from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm

from sklearn.metrics import f1_score, accuracy_score

random.seed(0)

# Set up training dataset
train_no_defect = [
    file for file in glob.glob(r"D:\\OneDrive - The University of Nottingham\\ESR1\\work\\3d printing\\dataset\\3d printing\\no_defected\\*.jpg") if "scratch_2" not in file]
train_yes_defect = [
    file for file in glob.glob(r"D:\\OneDrive - The University of Nottingham\\ESR1\\work\\3d printing\\dataset\\3d printing\\defected\\*.jpg") if "no_bottom" not in file
]
train_yes_defect = random.choices(train_yes_defect, k=len(train_no_defect))

# Set up validation dataset
val_no_defect = [
    file for file in glob.glob("no_defected/*.jpg") if "scratch_2" in file]
val_yes_defect = [
    file for file in glob.glob("defected/*.jpg") if "no_bottom" in file]

# Count the number of the the class
# Training
count_train_no_defect = len(train_no_defect)
count_train_defect = len(train_yes_defect)

# Validation
count_val_no_defect = len(val_no_defect)
count_val_defect = len(val_yes_defect)

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)

# Set up title and and x label
x_title = ["training", "validation"]
no_defect_score = [count_train_no_defect, count_val_no_defect]
defect_score = [count_train_defect, count_val_defect]
x = np.arange(len(x_title))
width = 0.3

# Plot the data
bar1 = ax.bar(x, no_defect_score, width, color="#D67D3E", label="no defect")
bar2 = ax.bar(x + width, defect_score, width, color="#F9E4D4", label="defect")

# Add heights above the bar plot
for rect, height in zip(bar1 + bar2, no_defect_score + defect_score):
    height = rect.get_height()
    plt.text(
        rect.get_x() + rect.get_width() / 2.0, height + 2,
        f"{height:.0f}", ha="center", va="bottom")

# Beautify the plot (optional)
ax.set_xticks(x + width / 2)
ax.set_xticklabels(x_title)
ax.set_yticks([])
ax.set_title("Distribution of dataset")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)

# Show the annotation to each bar
plt.legend()
plt.savefig("data.pdf", transparent=True)

class FewShotDataset(Dataset):
    def __init__(self, img_list, transform):
        self.img_list = img_list
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        label, path = self.img_list[idx]
        img = Image.open(path)
        img = self.transform(img)
        return img, label

class MetaPrinterFolder:
    def __init__(self, no_defect, yes_defect,
                 train_transform=None, val_transform=None):
        self.no_defect = no_defect
        self.yes_defect = yes_defect
        self.train_transform = train_transform
        self.val_transform = val_transform

    def get_random_task(self, K=1):
        train_task, _ = self.get_random_task_split(train_K=K, test_K=0)
        return train_task

    def get_random_task_split(self, train_K=1, test_K=1):
        train_samples = []
        test_samples = []

        sample_num = train_K + test_K
        # ====== Good list =======
        for idx, path in enumerate(np.random.choice(self.no_defect, sample_num,
                                                    replace=False)):
            if idx < train_K:
                train_samples.append((0, path))
            else:
                test_samples.append((0, path))

        # ====== Bad list =======
        for i, path in enumerate(np.random.choice(self.yes_defect, sample_num,
                                                  replace=False)):
            if i < train_K:
                train_samples.append((1, path))
            else:
                test_samples.append((1, path))

        train_task = FewShotDataset(train_samples, self.train_transform)
        test_task = FewShotDataset(test_samples, self.val_transform)

        return train_task, test_task

demo_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((400, 400)),
    transforms.CenterCrop((352, 352)),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

meta_3d_printer = MetaPrinterFolder(
    train_no_defect, train_yes_defect, demo_transform, demo_transform)
train_task = meta_3d_printer.get_random_task()

class ReptileModel(nn.Module):
    def __init__(self):
        super().__init__()

    def point_grad_to(self, target):
        self.zero_grad()
        for p, target_p in zip(self.parameters(), target.parameters()):
            if p.grad is None:
                if self.is_cuda():
                    p.grad = torch.zeros(p.size()).cuda()
                else:
                    p.grad = torch.zeros(p.size())
            p.grad.data.add_(p.data - target_p.data)

    def is_cuda(self):
        return next(self.parameters()).is_cuda

class TrainModel(ReptileModel):
    def __init__(self, model_name="resnet34", pretrained=True, num_classes=2):
        super().__init__()

        # Model settings
        self.model_name = model_name
        self.pretrained=pretrained
        self.num_classes = num_classes

        # Check out the doc: https://rwightman.github.io/pytorch-image-models/
        #  for different models
        self.model = timm.create_model(model_name, pretrained=pretrained)

        # Change the output linear layers to fit the output classes
        self.model.fc = nn.Linear(
            self.model.fc.weight.shape[1],
            num_classes
        )

    def forward(self, x):
        return self.model(x)

    def clone(self):
        clone = TrainModel(self.model_name, self.pretrained, self.num_classes)
        clone.load_state_dict(self.state_dict())
        if self.is_cuda():
            clone.cuda()
        return clone

@torch.no_grad()
def evaluate(model, val_loader, args):
    model.eval()

    total_predict = []
    total_ground_truth = []
    for iteration in range(args.iterations):
        data, label = val_loader.__next__()
        data = data.to(args.device)
        label = label.to(args.device)

        output = model(data)
        prediction = output.argmax(dim=-1)

        total_predict.extend(prediction.cpu().tolist())
        total_ground_truth.extend(label.cpu().tolist())

    return accuracy_score(total_ground_truth, total_predict), \
           f1_score(total_ground_truth, total_predict, average="macro")


def train_iter(model, train_loader, criterion, optimizer, args):
    model.train()
    for iteration in range(args.iterations):
        data, label = train_loader.__next__()
        data = data.to(args.device)
        label = label.to(args.device)

        # Send data into the model and compute the loss
        output = model(data)
        loss = criterion(output, label)

        # Update the model with back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss.item()

def get_optimizer(net, args, state=None):
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    if state is not None:
        optimizer.load_state_dict(state)
    return optimizer

def set_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def make_infinite(dataloader):
    while True:
        for x in dataloader:
            yield x

def meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion):
    for meta_iteration in tqdm(range(args.start_meta_iteration, args.meta_iterations)):
        # Update learning rate
        print('Meta Iteration: {}'.format(meta_iteration))
        meta_lr = args.meta_lr * (1. - meta_iteration / float(args.meta_iterations))
        set_learning_rate(meta_optimizer, meta_lr)

        # Clone model
        net = meta_model.clone()
        optimizer = get_optimizer(net, args)

        # Sample base task from Meta-Train
        train_dataset = meta_train.get_random_task(args.train_shots)
        infinite_train_loader = make_infinite(
            DataLoader(
                train_dataset, args.batch_size, shuffle=True,
                num_workers=2, pin_memory=True))

        # Update fast net
        train_iter(net, infinite_train_loader, criterion, optimizer, args)

        # Update slow net
        meta_model.point_grad_to(net)
        meta_optimizer.step()

        # Meta-Evaluation
        if meta_iteration % args.validate_every == 0:
            for (meta_dataset, mode) in [(meta_test, "val")]:
                train, test = meta_dataset.get_random_task_split(
                    train_K=args.shots, test_K=5)
                infinite_train_loader = make_infinite(
                    DataLoader(
                        train, args.batch_size, shuffle=True,
                        num_workers=2, pin_memory=True))
                infinite_test_loader = make_infinite(
                    DataLoader(
                        test, args.batch_size, shuffle=True,
                        num_workers=2, pin_memory=True))

                # Base-train
                net = meta_model.clone()
                optimizer = get_optimizer(net, args)
                train_iter(
                    net, infinite_train_loader, criterion, optimizer, args)

                # Base-test: compute meta-loss, which is base-validation error
                meta_acc, meta_f1 = evaluate(net, infinite_test_loader, args)
                print(f"\n{mode}: f1-accuracy: {meta_f1:.3f}, acc: {meta_acc:.3f}")

        if meta_iteration % args.check_every == 0:
            torch.save(meta_model.state_dict(), "cur.pt")

class args:
    # Training
    epochs = 30
    batch_size = 32
    train_shots = 10
    shots = 5
    meta_iterations = 1000
    start_meta_iteration = 0
    iterations = 5
    test_iterations = 50
    meta_lr = 0.1
    validate_every = 50
    check_every = 100
    lr = 3e-4
    weight_decay=1e-5
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Transform
    size = 400
    crop_size = 352
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]


# Set up train loader and test loader
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((args.size, args.size)),
    transforms.CenterCrop((args.crop_size, args.crop_size)),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=args.mean, std=args.std)
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((args.size, args.size)),
    transforms.CenterCrop((args.crop_size, args.crop_size)),
    transforms.Normalize(mean=args.mean, std=args.std)
])

# Set up model
meta_model = TrainModel().to(args.device)
criterion = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(
    meta_model.parameters(), lr=args.meta_lr)

meta_train = MetaPrinterFolder(
    train_no_defect, train_yes_defect, train_transform, val_transform)
meta_test = MetaPrinterFolder(
    val_no_defect, val_yes_defect, train_transform, val_transform)

# Start training
meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion)


class ListDataset(Dataset):
    def __init__(self, yes_defect, no_defect, transform=None):
        self.img_list = yes_defect + no_defect
        self.label = [1] * len(yes_defect) + [0] * len(no_defect)
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx])
        label = self.label[idx]
        img = self.transform(img)
        return img, label


def make_loader(yes_defect, no_defect, transform, batch_size,
                shuffle=True, num_workers=2, pin_memory=True,
                train=True):
    dataset = ListDataset(
        yes_defect=yes_defect, no_defect=no_defect, transform=transform)
    loader = DataLoader(
        dataset, batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=pin_memory)

    return loader

@torch.no_grad()
def post_train_evaluate(model, val_loader, args):
    model.eval()

    total_predict = []
    total_ground_truth = []
    for data, label in val_loader:
        data = data.to(args.device)
        label = label.to(args.device)

        output = model(data)
        prediction = output.argmax(dim=-1)

        total_predict.extend(prediction.cpu().tolist())
        total_ground_truth.extend(label.cpu().tolist())

    return accuracy_score(total_ground_truth, total_predict), \
           f1_score(total_ground_truth, total_predict, average="macro")


def post_train(model, train_loader, val_loader, criterion, optimizer, args):
    best_f1 = 0
    for epoch in range(args.epochs):
        train_progress_bar = tqdm(
            train_loader, desc=f"Epochs: {epoch + 1}/{args.epochs}")

        model.train()
        for data, label in train_progress_bar:
            data = data.to(args.device)
            label = label.to(args.device)

            # Send data into the model and compute the loss
            output = model(data)
            loss = criterion(output, label)

            # Update the model with back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Compute the accuracy ans save the best model
        eval_acc, eval_f1 = post_train_evaluate(model, val_loader, args)
        print(f"Validation accuracy: {eval_acc:.8f} f1-score: {eval_f1:.8f}")
        if eval_f1 > best_f1:
            best_f1 = eval_f1
            torch.save(model.state_dict(), "best.pt")

if __name__ == '__main__':

    train_loader = make_loader(
        yes_defect=train_yes_defect, no_defect=train_no_defect,
        batch_size=args.batch_size,
        transform=train_transform)
    val_loader = make_loader(
        yes_defect=val_yes_defect, no_defect=val_no_defect,
        batch_size=args.batch_size,
        transform=val_transform, train=False)

    meta_model.load_state_dict(torch.load("cur.pt"))
    train_model = meta_model.clone()
    criterion = nn.CrossEntropyLoss()
    train_optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)

    # Start training
    post_train(train_model, train_loader, val_loader, criterion, train_optimizer, args)
    torch.save(train_model.state_dict(), "last.pt")

  • 写回答

8条回答 默认 最新

  • Python-ZZY 2022-10-04 18:23
    关注

    把代码删掉。
    或者在这段代码的开头加上三个引号,末尾也加上三个引号,就可以把代码注释掉

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(7条)

报告相同问题?

问题事件

  • 系统已结题 10月15日
  • 已采纳回答 10月7日
  • 修改了问题 10月4日
  • 创建了问题 10月4日

悬赏问题

  • ¥50 potsgresql15备份问题
  • ¥15 Mac系统vs code使用phpstudy如何配置debug来调试php
  • ¥15 目前主流的音乐软件,像网易云音乐,QQ音乐他们的前端和后台部分是用的什么技术实现的?求解!
  • ¥60 pb数据库修改与连接
  • ¥15 spss统计中二分类变量和有序变量的相关性分析可以用kendall相关分析吗?
  • ¥15 拟通过pc下指令到安卓系统,如果追求响应速度,尽可能无延迟,是不是用安卓模拟器会优于实体的安卓手机?如果是,可以快多少毫秒?
  • ¥20 神经网络Sequential name=sequential, built=False
  • ¥16 Qphython 用xlrd读取excel报错
  • ¥15 单片机学习顺序问题!!
  • ¥15 ikuai客户端多拨vpn,重启总是有个别重拨不上