我用torchvision提供的resnet152作为提取图片特征,结合LSTM,搭建了一个LRCN视频分类模型。在UCF101数据集上对模型进行训练。
torchvision要求的输入范围是[0, 1],我需要重新调整为输入范围是[0, 255]。经过几个epoch,训练集的准确率达到了80%以上,但测试集的acc只有百分之几。请问如何提高测试集的acc?
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torchvision.models import resnet152, resnext50_32x4d
import pdb
##############################
# Encoder
##############################
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
resnet = resnext50_32x4d(pretrained=True)
self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
self.final = nn.Sequential(
nn.Linear(resnet.fc.in_features, latent_dim), nn.BatchNorm1d(latent_dim, momentum=0.01)
)
def forward(self, x):
x = self.feature_extractor(x)
x = x.view(x.size(0), -1)
return self.final(x)
##############################
# LSTM
##############################
class LSTM(nn.Module):
def __init__(self, latent_dim, num_layers, hidden_dim, bidirectional):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(latent_dim, hidden_dim, num_layers, batch_first=True, bidirectional=bidirectional)
self.final = nn.Sequential(
nn.Linear(2 * hidden_dim if bidirectional else hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim, momentum=0.01),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 101),
nn.Softmax(dim=-1),
)
self.hidden_state = None
def reset_hidden_state(self):
self.hidden_state = None
def forward(self, x):
x, self.hidden_state = self.lstm(x, self.hidden_state)
x = x[:, -1]
x = self.final(x)
return x
##############################
# ConvLSTM
##############################
class OrigConvLSTM(nn.Module):
def __init__(
self, num_classes, latent_dim=512, lstm_layers=1, hidden_dim=1024, bidirectional=True):
super(OrigConvLSTM, self).__init__()
self.encoder = Encoder(latent_dim)
self.lstm = LSTM(latent_dim, lstm_layers, hidden_dim, bidirectional)
def forward(self, x):
batch_size, seq_length, c, h, w = x.shape
x = x.view(batch_size * seq_length, c, h, w)
x = self.encoder(x)
x_lstm = x.view(batch_size, seq_length, -1)
x = self.lstm(x_lstm)
return x, x.argmax(1)