Lukas00990 2022-10-04 17:03 采纳率: 40.8%
浏览 391
已结题

python 如何只运行部分代码

我的代码在最下方, 如何只运行这个

if __name__ == '__main__':

    train_loader = make_loader(
        yes_defect=train_yes_defect, no_defect=train_no_defect,
        batch_size=args.batch_size,
        transform=train_transform)
    val_loader = make_loader(
        yes_defect=val_yes_defect, no_defect=val_no_defect,
        batch_size=args.batch_size,
        transform=val_transform, train=False)

    meta_model.load_state_dict(torch.load("cur.pt"))
    train_model = meta_model.clone()
    criterion = nn.CrossEntropyLoss()
    train_optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)

    # Start training
    post_train(train_model, train_loader, val_loader, criterion, train_optimizer, args)
    torch.save(train_model.state_dict(), "last.pt")

而不运行这个

meta_model = TrainModel().to(args.device)
criterion = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(
    meta_model.parameters(), lr=args.meta_lr)

meta_train = MetaPrinterFolder(
    train_no_defect, train_yes_defect, train_transform, val_transform)
meta_test = MetaPrinterFolder(
    val_no_defect, val_yes_defect, train_transform, val_transform)

# Start training
meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion)
import os
import glob
import random

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import torchvision
from torchvision.utils import make_grid
from torchvision import transforms
from torchvision.datasets import ImageFolder

import timm
from tqdm import tqdm

from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm

from sklearn.metrics import f1_score, accuracy_score

random.seed(0)

# Set up training dataset
train_no_defect = [
    file for file in glob.glob(r"D:\\OneDrive - The University of Nottingham\\ESR1\\work\\3d printing\\dataset\\3d printing\\no_defected\\*.jpg") if "scratch_2" not in file]
train_yes_defect = [
    file for file in glob.glob(r"D:\\OneDrive - The University of Nottingham\\ESR1\\work\\3d printing\\dataset\\3d printing\\defected\\*.jpg") if "no_bottom" not in file
]
train_yes_defect = random.choices(train_yes_defect, k=len(train_no_defect))

# Set up validation dataset
val_no_defect = [
    file for file in glob.glob("no_defected/*.jpg") if "scratch_2" in file]
val_yes_defect = [
    file for file in glob.glob("defected/*.jpg") if "no_bottom" in file]

# Count the number of the the class
# Training
count_train_no_defect = len(train_no_defect)
count_train_defect = len(train_yes_defect)

# Validation
count_val_no_defect = len(val_no_defect)
count_val_defect = len(val_yes_defect)

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)

# Set up title and and x label
x_title = ["training", "validation"]
no_defect_score = [count_train_no_defect, count_val_no_defect]
defect_score = [count_train_defect, count_val_defect]
x = np.arange(len(x_title))
width = 0.3

# Plot the data
bar1 = ax.bar(x, no_defect_score, width, color="#D67D3E", label="no defect")
bar2 = ax.bar(x + width, defect_score, width, color="#F9E4D4", label="defect")

# Add heights above the bar plot
for rect, height in zip(bar1 + bar2, no_defect_score + defect_score):
    height = rect.get_height()
    plt.text(
        rect.get_x() + rect.get_width() / 2.0, height + 2,
        f"{height:.0f}", ha="center", va="bottom")

# Beautify the plot (optional)
ax.set_xticks(x + width / 2)
ax.set_xticklabels(x_title)
ax.set_yticks([])
ax.set_title("Distribution of dataset")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)

# Show the annotation to each bar
plt.legend()
plt.savefig("data.pdf", transparent=True)

class FewShotDataset(Dataset):
    def __init__(self, img_list, transform):
        self.img_list = img_list
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        label, path = self.img_list[idx]
        img = Image.open(path)
        img = self.transform(img)
        return img, label

class MetaPrinterFolder:
    def __init__(self, no_defect, yes_defect,
                 train_transform=None, val_transform=None):
        self.no_defect = no_defect
        self.yes_defect = yes_defect
        self.train_transform = train_transform
        self.val_transform = val_transform

    def get_random_task(self, K=1):
        train_task, _ = self.get_random_task_split(train_K=K, test_K=0)
        return train_task

    def get_random_task_split(self, train_K=1, test_K=1):
        train_samples = []
        test_samples = []

        sample_num = train_K + test_K
        # ====== Good list =======
        for idx, path in enumerate(np.random.choice(self.no_defect, sample_num,
                                                    replace=False)):
            if idx < train_K:
                train_samples.append((0, path))
            else:
                test_samples.append((0, path))

        # ====== Bad list =======
        for i, path in enumerate(np.random.choice(self.yes_defect, sample_num,
                                                  replace=False)):
            if i < train_K:
                train_samples.append((1, path))
            else:
                test_samples.append((1, path))

        train_task = FewShotDataset(train_samples, self.train_transform)
        test_task = FewShotDataset(test_samples, self.val_transform)

        return train_task, test_task

demo_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((400, 400)),
    transforms.CenterCrop((352, 352)),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

meta_3d_printer = MetaPrinterFolder(
    train_no_defect, train_yes_defect, demo_transform, demo_transform)
train_task = meta_3d_printer.get_random_task()

class ReptileModel(nn.Module):
    def __init__(self):
        super().__init__()

    def point_grad_to(self, target):
        self.zero_grad()
        for p, target_p in zip(self.parameters(), target.parameters()):
            if p.grad is None:
                if self.is_cuda():
                    p.grad = torch.zeros(p.size()).cuda()
                else:
                    p.grad = torch.zeros(p.size())
            p.grad.data.add_(p.data - target_p.data)

    def is_cuda(self):
        return next(self.parameters()).is_cuda

class TrainModel(ReptileModel):
    def __init__(self, model_name="resnet34", pretrained=True, num_classes=2):
        super().__init__()

        # Model settings
        self.model_name = model_name
        self.pretrained=pretrained
        self.num_classes = num_classes

        # Check out the doc: https://rwightman.github.io/pytorch-image-models/
        #  for different models
        self.model = timm.create_model(model_name, pretrained=pretrained)

        # Change the output linear layers to fit the output classes
        self.model.fc = nn.Linear(
            self.model.fc.weight.shape[1],
            num_classes
        )

    def forward(self, x):
        return self.model(x)

    def clone(self):
        clone = TrainModel(self.model_name, self.pretrained, self.num_classes)
        clone.load_state_dict(self.state_dict())
        if self.is_cuda():
            clone.cuda()
        return clone

@torch.no_grad()
def evaluate(model, val_loader, args):
    model.eval()

    total_predict = []
    total_ground_truth = []
    for iteration in range(args.iterations):
        data, label = val_loader.__next__()
        data = data.to(args.device)
        label = label.to(args.device)

        output = model(data)
        prediction = output.argmax(dim=-1)

        total_predict.extend(prediction.cpu().tolist())
        total_ground_truth.extend(label.cpu().tolist())

    return accuracy_score(total_ground_truth, total_predict), \
           f1_score(total_ground_truth, total_predict, average="macro")


def train_iter(model, train_loader, criterion, optimizer, args):
    model.train()
    for iteration in range(args.iterations):
        data, label = train_loader.__next__()
        data = data.to(args.device)
        label = label.to(args.device)

        # Send data into the model and compute the loss
        output = model(data)
        loss = criterion(output, label)

        # Update the model with back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss.item()

def get_optimizer(net, args, state=None):
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    if state is not None:
        optimizer.load_state_dict(state)
    return optimizer

def set_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def make_infinite(dataloader):
    while True:
        for x in dataloader:
            yield x

def meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion):
    for meta_iteration in tqdm(range(args.start_meta_iteration, args.meta_iterations)):
        # Update learning rate
        print('Meta Iteration: {}'.format(meta_iteration))
        meta_lr = args.meta_lr * (1. - meta_iteration / float(args.meta_iterations))
        set_learning_rate(meta_optimizer, meta_lr)

        # Clone model
        net = meta_model.clone()
        optimizer = get_optimizer(net, args)

        # Sample base task from Meta-Train
        train_dataset = meta_train.get_random_task(args.train_shots)
        infinite_train_loader = make_infinite(
            DataLoader(
                train_dataset, args.batch_size, shuffle=True,
                num_workers=2, pin_memory=True))

        # Update fast net
        train_iter(net, infinite_train_loader, criterion, optimizer, args)

        # Update slow net
        meta_model.point_grad_to(net)
        meta_optimizer.step()

        # Meta-Evaluation
        if meta_iteration % args.validate_every == 0:
            for (meta_dataset, mode) in [(meta_test, "val")]:
                train, test = meta_dataset.get_random_task_split(
                    train_K=args.shots, test_K=5)
                infinite_train_loader = make_infinite(
                    DataLoader(
                        train, args.batch_size, shuffle=True,
                        num_workers=2, pin_memory=True))
                infinite_test_loader = make_infinite(
                    DataLoader(
                        test, args.batch_size, shuffle=True,
                        num_workers=2, pin_memory=True))

                # Base-train
                net = meta_model.clone()
                optimizer = get_optimizer(net, args)
                train_iter(
                    net, infinite_train_loader, criterion, optimizer, args)

                # Base-test: compute meta-loss, which is base-validation error
                meta_acc, meta_f1 = evaluate(net, infinite_test_loader, args)
                print(f"\n{mode}: f1-accuracy: {meta_f1:.3f}, acc: {meta_acc:.3f}")

        if meta_iteration % args.check_every == 0:
            torch.save(meta_model.state_dict(), "cur.pt")

class args:
    # Training
    epochs = 30
    batch_size = 32
    train_shots = 10
    shots = 5
    meta_iterations = 1000
    start_meta_iteration = 0
    iterations = 5
    test_iterations = 50
    meta_lr = 0.1
    validate_every = 50
    check_every = 100
    lr = 3e-4
    weight_decay=1e-5
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Transform
    size = 400
    crop_size = 352
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]


# Set up train loader and test loader
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((args.size, args.size)),
    transforms.CenterCrop((args.crop_size, args.crop_size)),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=args.mean, std=args.std)
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((args.size, args.size)),
    transforms.CenterCrop((args.crop_size, args.crop_size)),
    transforms.Normalize(mean=args.mean, std=args.std)
])

# Set up model
meta_model = TrainModel().to(args.device)
criterion = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(
    meta_model.parameters(), lr=args.meta_lr)

meta_train = MetaPrinterFolder(
    train_no_defect, train_yes_defect, train_transform, val_transform)
meta_test = MetaPrinterFolder(
    val_no_defect, val_yes_defect, train_transform, val_transform)

# Start training
meta_train_reptile(args, meta_model, meta_train, meta_test, meta_optimizer, criterion)


class ListDataset(Dataset):
    def __init__(self, yes_defect, no_defect, transform=None):
        self.img_list = yes_defect + no_defect
        self.label = [1] * len(yes_defect) + [0] * len(no_defect)
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx])
        label = self.label[idx]
        img = self.transform(img)
        return img, label


def make_loader(yes_defect, no_defect, transform, batch_size,
                shuffle=True, num_workers=2, pin_memory=True,
                train=True):
    dataset = ListDataset(
        yes_defect=yes_defect, no_defect=no_defect, transform=transform)
    loader = DataLoader(
        dataset, batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=pin_memory)

    return loader

@torch.no_grad()
def post_train_evaluate(model, val_loader, args):
    model.eval()

    total_predict = []
    total_ground_truth = []
    for data, label in val_loader:
        data = data.to(args.device)
        label = label.to(args.device)

        output = model(data)
        prediction = output.argmax(dim=-1)

        total_predict.extend(prediction.cpu().tolist())
        total_ground_truth.extend(label.cpu().tolist())

    return accuracy_score(total_ground_truth, total_predict), \
           f1_score(total_ground_truth, total_predict, average="macro")


def post_train(model, train_loader, val_loader, criterion, optimizer, args):
    best_f1 = 0
    for epoch in range(args.epochs):
        train_progress_bar = tqdm(
            train_loader, desc=f"Epochs: {epoch + 1}/{args.epochs}")

        model.train()
        for data, label in train_progress_bar:
            data = data.to(args.device)
            label = label.to(args.device)

            # Send data into the model and compute the loss
            output = model(data)
            loss = criterion(output, label)

            # Update the model with back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Compute the accuracy ans save the best model
        eval_acc, eval_f1 = post_train_evaluate(model, val_loader, args)
        print(f"Validation accuracy: {eval_acc:.8f} f1-score: {eval_f1:.8f}")
        if eval_f1 > best_f1:
            best_f1 = eval_f1
            torch.save(model.state_dict(), "best.pt")

if __name__ == '__main__':

    train_loader = make_loader(
        yes_defect=train_yes_defect, no_defect=train_no_defect,
        batch_size=args.batch_size,
        transform=train_transform)
    val_loader = make_loader(
        yes_defect=val_yes_defect, no_defect=val_no_defect,
        batch_size=args.batch_size,
        transform=val_transform, train=False)

    meta_model.load_state_dict(torch.load("cur.pt"))
    train_model = meta_model.clone()
    criterion = nn.CrossEntropyLoss()
    train_optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)

    # Start training
    post_train(train_model, train_loader, val_loader, criterion, train_optimizer, args)
    torch.save(train_model.state_dict(), "last.pt")

  • 写回答

8条回答 默认 最新

  • Python-ZZY 2022-10-04 18:23
    关注

    把代码删掉。
    或者在这段代码的开头加上三个引号,末尾也加上三个引号,就可以把代码注释掉

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(7条)

报告相同问题?

问题事件

  • 系统已结题 10月15日
  • 已采纳回答 10月7日
  • 修改了问题 10月4日
  • 创建了问题 10月4日

悬赏问题

  • ¥20 c语言写的8051单片机存储器mt29的模块程序
  • ¥60 求直线方程 使平面上n个点在直线同侧并且距离总和最小
  • ¥50 java算法,给定试题的难度数量(简单,普通,困难),和试题类型数量(单选,多选,判断),以及题库中各种类型的题有多少道,求能否随机抽题。
  • ¥50 rk3588板端推理
  • ¥250 opencv怎么去掉 数字0中间的斜杠。
  • ¥15 这种情况的伯德图和奈奎斯特曲线怎么分析?
  • ¥250 paddleocr带斜线的0很容易识别成9
  • ¥15 电子档案元素采集(tiff及PDF扫描图片)
  • ¥15 flink-sql-connector-rabbitmq使用
  • ¥15 zynq7015,PCIE读写延时偏大