科研狗没有house 2021-07-07 09:08 采纳率: 0%
浏览 761
已结题

torch模型重新载入准确率很低,基本上不能预测

pytorch训练模型
训练集和验证集准确率都快100,然后用torch.save保存模型,之后再读取模型重新预测,模型基本上准确率不超过50%,有时是0%或者1%。用训练集和验证集测试也是相同的结果,都很低。但如果把预测代码放到train文件的末尾,又能正常预测。
训练代码

import scipy.io as sio
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from tqdm import tqdm

from model.conv_ince_resnet import *
from dataloader import *

batch_size = 64
num_epochs = 20

train_data,train_label,test_data, test_label,c= datagenerator('dataset/train.h5','dataset/trainlabel.txt','dataset/test.h5', 'dataset/testlabel.txt')

num_train_instances = len(train_data)
train_label = np.array(train_label)
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_label = torch.from_numpy(train_label).type(torch.LongTensor)
train_data = train_data.view(num_train_instances, 1, -1)
train_label = train_label.view(num_train_instances, 1)

train_dataset = TensorDataset(train_data, train_label)
train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)



num_test_instances = len(test_data)
test_label = np.array(test_label)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_label = torch.from_numpy(test_label).type(torch.LongTensor)
test_data = test_data.view(num_test_instances, 1, -1)
test_label = test_label.view(num_test_instances, 1)

test_dataset = TensorDataset(test_data, test_label)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)


msresnet = MSResNet(input_channel=13545, layers=[1, 1, 1, 1], num_classes=c)
msresnet = msresnet.cuda()

trace_module = torch.torch.jit.script(msresnet)
print(trace_module.code)

criterion = nn.CrossEntropyLoss().cuda()


optimizer = torch.optim.Adam(msresnet.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[300, 500, 600, 700, 800, 900], gamma=0.1)
train_loss = np.zeros([num_epochs, 1])
test_loss = np.zeros([num_epochs, 1])
train_acc = np.zeros([num_epochs, 1])
test_acc = np.zeros([num_epochs, 1])

for epoch in range(num_epochs):
    print('Epoch:', epoch)
    msresnet.train()
    scheduler.step(epoch)
    # for i, (samples, labels) in enumerate(train_data_loader):
    loss_x = 0
    for (samples, labels) in tqdm(train_data_loader):
        samplesV = Variable(samples.cuda())
        labels = labels.squeeze()
        labelsV = Variable(labels.cuda())

        # Forward + Backward + Optimize

        predict_label = msresnet(samplesV)
        optimizer.zero_grad()
        loss = criterion(predict_label[0], labelsV)

        loss_x += loss.item()

        loss.backward()
        optimizer.step()



    train_loss[epoch] = loss_x / num_train_instances
    print('trainloss:',train_loss[epoch])

    msresnet.eval()
    correct_train = 0
    for i, (samples, labels) in enumerate(train_data_loader):
        with torch.no_grad():
            samplesV = Variable(samples.cuda())
            labels = labels.squeeze()
            labelsV = Variable(labels.cuda())


            predict_label = msresnet(samplesV)
            prediction = predict_label[0].data.max(1)[1]
            correct_train += prediction.eq(labelsV.data.long()).sum()

            loss = criterion(predict_label[0], labelsV)


    print("Training accuracy:", (100*float(correct_train)/num_train_instances))

    train_acc[epoch] = 100*float(correct_train)/num_train_instances

    trainacc = str(100*float(correct_train)/num_train_instances)[0:6]


    loss_x = 0
    correct_test = 0
    for i, (samples, labels) in enumerate(test_data_loader):
        with torch.no_grad():
            samplesV = Variable(samples.cuda())
            labels = labels.squeeze()
            labelsV = Variable(labels.cuda())
            # labelsV = labelsV.view(-1)

        predict_label = msresnet(samplesV)
        prediction = predict_label[0].data.max(1)[1]
        correct_test += prediction.eq(labelsV.data.long()).sum()

        loss = criterion(predict_label[0], labelsV)
        loss_x += loss.item()

    print("Test accuracy:", (100 * float(correct_test) / num_test_instances))

    test_loss[epoch] = loss_x / num_test_instances
    test_acc[epoch] = 100 * float(correct_test) / num_test_instances
    print('valloss:', test_loss[epoch])
    testacc = str(100 * float(correct_test) / num_test_instances)[0:6]

    if epoch == 0:
        temp_test = correct_test
        temp_train = correct_train
    elif correct_test>temp_test:
        torch.save(msresnet, 'weights/changingResnet/ChaningSpeed_Train' + trainacc + 'Test' + testacc + '.pkl')
        temp_test = correct_test
        temp_train = correct_train
    torch.save(msresnet, 'weights/changingResnet/ChaningSpeed.pkl')
sio.savemat('result/changingResnet/TrainLoss_' + 'ChangingSpeed_Train' + str(100*float(temp_train)/num_train_instances)[0:6] + 'Test' + str(100*float(temp_test)/num_test_instances)[0:6] + '.mat', {'train_loss': train_loss})
sio.savemat('result/changingResnet/TestLoss_' + 'ChangingSpeed_Train' + str(100*float(temp_train)/num_train_instances)[0:6] + 'Test' + str(100*float(temp_test)/num_test_instances)[0:6] + '.mat', {'test_loss': test_loss})
sio.savemat('result/changingResnet/TrainAccuracy_' + 'ChangingSpeed_Train' + str(100*float(temp_train)/num_train_instances)[0:6] + 'Test' + str(100*float(temp_test)/num_test_instances)[0:6] + '.mat', {'train_acc': train_acc})
sio.savemat('result/changingResnet/TestAccuracy_' + 'ChangingSpeed_Train' + str(100*float(temp_train)/num_train_instances)[0:6] + 'Test' + str(100*float(temp_test)/num_test_instances)[0:6] + '.mat', {'test_acc': test_acc})
print(str(100*float(temp_test)/num_test_instances)[0:6])
plt.figure('1')
plt.plot(train_loss)
plt.show()
plt.figure('2')
plt.plot(test_loss)
plt.show()
plt.figure('3')
plt.plot(train_acc)
plt.show()
plt.figure('4')
plt.plot(test_acc)
plt.show()


测试代码

import scipy.io as sio
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import time
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import time
from tqdm import tqdm
from dataloader import *
train_data,train_label,test_data, test_label,c= datagenerator('dataset/train.h5','dataset/trainlabel.txt','dataset/test.h5', 'dataset/testlabel.txt')

batch_size = 100

num_test_instances = len(test_data)

test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_label = np.array(test_label)
test_label = torch.from_numpy(test_label).type(torch.LongTensor)
test_data = test_data.view(num_test_instances, 1, -1)
test_label = test_label.view(num_test_instances, 1)

test_dataset = TensorDataset(test_data, test_label)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

msresnet = torch.load('weights/changingResnet/ChaningSpeed.pkl')
a = msresnet.state_dict()

msresnet = msresnet.cuda()
msresnet.eval()

correct_test = 0
for i, (samples, labels) in enumerate(test_data_loader):
    with torch.no_grad():
        samplesV = Variable(samples.cuda())
        labels = labels.squeeze()
        labelsV = Variable(labels.cuda())
        # labelsV = labelsV.view(-1)
    predict_label = msresnet(samplesV)
    prediction = predict_label[0].data.max(1)[1]
    correct_test += prediction.eq(labelsV.data.long()).sum()

print("Test accuracy:", (100 * float(correct_test) / num_test_instances))

  • 写回答

1条回答 默认 最新

  • 影醉阏轩窗 2021-07-09 09:39
    关注

    放那么大段代码,是让人给你debug嘛?这是很简单的问题,直接对比训练集和测试集的模型参数即可。

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 7月9日

悬赏问题

  • ¥15 需要手写数字信号处理Dsp三个简单题 不用太复杂
  • ¥15 数字信号处理考试111
  • ¥100 关于#audobe audition#的问题,如何解决?
  • ¥15 allegro17.2生成bom表是空白的
  • ¥15 请问一下怎么打通CAN通讯
  • ¥20 如何在 rocky9.4 部署 CDH6.3.2?
  • ¥35 navicat将excel中的数据导入mysql出错
  • ¥15 rt-thread线程切换的问题
  • ¥15 高通uboot 打印ubi init err 22
  • ¥15 R语言中lasso回归报错