於菟601 2022-05-12 16:02 采纳率: 50%
浏览 289
已结题

训练模型时遇到以下问题:[WinError 10061] 由于目标计算机积极拒绝,无法连接

所有的数据集都在本地,不明白为什么会有连接问题

问题遇到的现象和发生背景

img

问题相关代码,请勿粘贴截图

from findplate.config import opt
import os
import torch as t
from findplate import models
from findplate.data.dataset import MyDataset
from torch.utils.data import DataLoader
from torchnet import meter
from findplate.utils.visualize import Visualizer
from tqdm import tqdm
from torchvision import transforms as T

def write_csv(results,file_name,col1_name,col2_name):
import csv
with open(file_name,'w',newline='') as f:
writer = csv.writer(f)
writer.writerow([col1_name,col2_name])
writer.writerows(results)

def train(**kwargs):
opt._parse(kwargs)
vis = Visualizer(opt.env,port = opt.vis_port)

# step1: configure model
model = getattr(models, opt.model)()
if opt.load_model_path:
    model.load(opt.load_model_path)
model.to(opt.device)

# step2: data
train_data = MyDataset(opt.train_data_root,train=True)
val_data = MyDataset(opt.train_data_root,train=False)
train_dataloader = DataLoader(train_data,opt.batch_size,
                    shuffle=True,num_workers=opt.num_workers)
val_dataloader = DataLoader(val_data,opt.batch_size,
                    shuffle=False,num_workers=opt.num_workers)
# write id and classes into csv file
data_id_to_class = []
label_idx = 0
for label_name in train_data.data_classes:
    data_id_to_class.append([label_idx, label_name])
    label_idx += 1
print(data_id_to_class)
id_file_name = opt.id_file
write_csv(data_id_to_class,id_file_name,'label_idx','label_name')

# step3: criterion and optimizer
criterion = t.nn.CrossEntropyLoss()
lr = opt.lr
optimizer = model.get_optimizer(lr, opt.weight_decay)
    
# step4: meters
loss_meter = meter.AverageValueMeter()
confusion_matrix = meter.ConfusionMeter(opt.classifier_num)
previous_loss = 1e10

# train
for epoch in range(opt.max_epoch):
    
    loss_meter.reset()
    confusion_matrix.reset()

    for ii,(data,label) in tqdm(enumerate(train_dataloader)):

        # train model 
        input = data.to(opt.device)
        target = label.to(opt.device)


        optimizer.zero_grad()
        score = model(input)
        loss = criterion(score,target)
        loss.backward()
        optimizer.step()
        
        
        # meters update and visualize
        loss_meter.add(loss.item())
        # detach 一下更安全保险
        confusion_matrix.add(score.detach(), target.detach()) 

        if (ii + 1)%opt.print_freq == 0:
            vis.plot('loss', loss_meter.value()[0])
            
            # 进入debug模式
            if os.path.exists(opt.debug_file):
                import ipdb;
                ipdb.set_trace()


    model.save()

    # validate and visualize
    val_cm,val_accuracy = val(model,val_dataloader)

    vis.plot('val_accuracy',val_accuracy)
    vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
                epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))
    
    # update learning rate
    if loss_meter.value()[0] > previous_loss:          
        lr = lr * opt.lr_decay
        # 第二种降低学习率的方法:不会有moment等信息的丢失
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    

    previous_loss = loss_meter.value()[0]

@t.no_grad()
def val(model,dataloader):
"""
计算模型在验证集上的准确率等信息
"""
model = model.eval()

confusion_matrix = meter.ConfusionMeter(opt.classifier_num)
for ii, (val_input, label) in tqdm(enumerate(dataloader)):
    val_input = val_input.to(opt.device)
    score = model(val_input)
    confusion_matrix.add(score.detach().squeeze(), label.type(t.LongTensor))

model.train()
cm_value = confusion_matrix.value()
cm_value_sum = 0
for i in range(opt.classifier_num):
    cm_value_sum += cm_value[i][i]
accuracy = 100. * (cm_value_sum) / (cm_value.sum())
return confusion_matrix, accuracy

def help():
"""
打印帮助的信息: python file.py help
"""

print("""
usage : python file.py <function> [--args=value]
<function> := train | test | help
example: 
        python {0} train --env='env0701' --lr=0.01
        python {0} test --dataset='path/to/dataset/root/'
        python {0} help
avaiable args:""".format(__file__))

from inspect import getsource
source = (getsource(opt.__class__))
print(source)

if name=='main':
import fire
fire.Fire()

以下是配置信息:

class DefaultConfig(object):
env = 'default' # visdom 环境
vis_port =8097 # visdom 端口
model = 'SqueezeNet' # 使用的模型,名字必须与models/init.py中的名字一致
classifier_num = 2 # 分类器最终的分类数量
gray = False # 读取图片是否为灰度图

train_data_root = './imgs/images/cnn_plate_train/'  # 训练集存放路径
test_data_root = './data/test/plate/'  # 测试集存放路径
load_model_path = None  # 加载预训练的模型的路径,为None代表不加载

batch_size = 16  # batch size
use_gpu = True  # user GPU or not
num_workers = 0  # how many workers for loading data
print_freq = 20  # print info every N batch

debug_file = '/tmp/debug'  # if os.path.exists(debug_file): enter ipdb
result_file = 'result.csv'
id_file = './findplate/plate.csv'

max_epoch = 100
lr = 0.001  # initial learning rate
lr_decay = 0.5  # when val_loss increase, lr = lr*lr_decay
weight_decay = 0e-5  # 损失函数

opt = DefaultConfig()

  • 写回答

3条回答 默认 最新

  • 不会长胖的斜杠 后端领域新星创作者 2022-05-12 16:28
    关注

    这种一般都是在本地训练,然后是在web中给你可视化操作的,我也踩过坑,你开启visdom.sevser了吗?
    https://blog.csdn.net/Dummy_/article/details/106873857

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 5月20日
  • 已采纳回答 5月12日
  • 创建了问题 5月12日

悬赏问题

  • ¥66 定制开发肯德基自动化网站下单软件
  • ¥20 vscode虚拟环境依赖包未安装
  • ¥15 odoo17关于owl开发js代码问题
  • ¥15 光纤中多普勒频移公式的推导
  • ¥15 怎么制作一个人脸识别门禁系统
  • ¥20 大华dss监控平台网络关闭登不进去
  • ¥15 请使用蚁群算法解决下列问题,并给出我完整的代码
  • ¥20 关于php录入完成后,批量更新数据库
  • ¥15 请教往复密封润滑问题
  • ¥15 cocos creator发布ios包