模型代码
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: LR_Model.py
@time: 2022/09/05
@desc:PyTorch实现LR模型
"""
import torch
import numpy as np
import torch.nn as nn
class LogisticRegression(nn.Module):
def __init__(self, field_dims, emb_size):
"""
:param field_dims: 特征数量列表,其和为总特征数量
:param emb_size: embedding的维度
"""
super(LogisticRegression, self).__init__()
# embedding层
self.emb = nn.Embedding(sum(field_dims), emb_size)
# 模型初始化,针对激活函数:饱和函数,如Sigmoid,Tanh
nn.init.xavier_uniform_(self.emb.weight.data)
# 偏置项
self.offset = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
# 可梯度更新
self.bias = nn.Parameter(torch.zeros((1,)))
def forward(self, x):
"""
前向传播
:param x: 输入数据,(batch,seq_len)
:return:
"""
x = x + x.new_tensor(self.offset)
# (batch,seq_len) => (batch,seq_len,1) => (batch,1)
x = self.emb(x).sum(1) + self.bias
x = torch.sigmoid(x)
return x
训练及预测代码:
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: liujie
@file: main.py
@time: 2022/09/05
@desc:
"""
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from torch import optim
from LR_Model import LogisticRegression
import matplotlib.pyplot as plt
from dataSet import My_DataSet
from torch.utils.data import DataLoader
from dataProcess import DataProcess
from sklearn.metrics import f1_score, recall_score, roc_auc_score
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criteo_file = "criteo-100k.txt"
nrows = 100000
sizes = [0.75, 0.25]
embedding_size = 1
batch_size = 1024
num_epochs = 200
learning_rate = 1e-5
weight_decay = 1e-6
def train_and_test(train_dataloader, test_dataloader, model):
# 损失函数
criterion = nn.BCELoss()
# 优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# 记录训练与测试过程的损失,用于绘图
train_loss, test_loss, train_acc, test_acc = [], [], [], []
for epoch in range(num_epochs):
train_loss_sum = 0.0
train_len = 0
train_correct = 0
# 显示训练进度
train_dataloader = tqdm.tqdm(train_dataloader)
train_dataloader.set_description('[%s%04d/%04d]' % ('Epoch:', epoch + 1, num_epochs))
# 训练模式
model.train()
model.to(device)
for i, data_ in enumerate(train_dataloader):
x, y = data_[0].to(device), data_[1].to(device)
# 开始当前批次训练时,优化器的梯度置零,否则,梯度会累加
optimizer.zero_grad()
# output size = (batch,)
output = model(x)
loss = criterion(output.squeeze(1), y)
# 反向传播
loss.backward()
# 利用优化器更新参数
optimizer.step()
# 默认reduction="mean",因此需要乘以个数
train_loss_sum += loss.detach() * len(y)
train_len += len(y)
_, predicted = torch.max(output, 1)
train_correct += (predicted == y).sum().item()
# print("train_correct=\n", train_correct)
# print("train_acc=\n", train_correct / train_len)
F1 = f1_score(y.cpu(), predicted.cpu(), average="weighted")
Recall = recall_score(y.cpu(), predicted.cpu(), average="micro")
# 设置日志
postfic = {"train_loss: {:.5f},train_acc:{:.3f}%,F1: {:.3f}%,Recall:{:.3f}%".
format(train_loss_sum / train_len, 100 * train_correct / train_len, 100 * F1, 100 * Recall)}
train_dataloader.set_postfix(log=postfic)
train_loss.append((train_loss_sum / train_len).item())
train_acc.append(round(train_correct / train_len, 4))
# 测试
test_dataloader = tqdm.tqdm(test_dataloader)
test_dataloader.set_description('[%s%04d/%04d]' % ('Epoch:', epoch + 1, num_epochs))
model.eval()
model.to(device)
with torch.no_grad():
test_loss_sum = 0.0
test_len = 0
test_correct = 0
for i, data_ in enumerate(test_dataloader):
x, y = data_[0].to(device), data_[1].to(device)
output = model(x)
loss = criterion(output.squeeze(1), y)
test_loss_sum += loss.detach() * len(x)
test_len += len(y)
_, predicted = torch.max(output, 1)
test_correct += (predicted == y).sum().item()
F1 = f1_score(y.cpu(), predicted.cpu(), average="weighted")
Recall = recall_score(y.cpu(), predicted.cpu(), average="micro")
# 设置日志
postfic = {"test_loss: {:.5f},test_acc:{:.3f}%,F1: {:.3f}%,Recall:{:.3f}%".
format(test_loss_sum / test_len, 100 * test_correct / test_len, 100 * F1, 100 * Recall)}
test_dataloader.set_postfix(log=postfic)
test_loss.append((test_loss_sum / test_len).item())
test_acc.append(round(test_correct / test_len, 4))
return train_loss, test_loss, train_acc, test_acc
def main():
"""
主函数
:return:
"""
dataProcess = DataProcess(criteo_file, nrows, sizes, device)
field_dims, (x_train, y_train), (x_test, y_test) \
= dataProcess.train_valid_test_split(sizes)
# 构造数据集
trainDataset = My_DataSet(x_train, y_train)
train_dataloader = DataLoader(trainDataset, batch_size=batch_size, shuffle=True)
testDataset = My_DataSet(x_test, y_test)
test_dataloader = DataLoader(testDataset, batch_size=batch_size)
# 模型实例化
model = LogisticRegression(field_dims, embedding_size)
# 训练与测试
train_loss, test_loss, train_acc, test_acc = train_and_test(train_dataloader, test_dataloader, model)
# 绘图,展示损失变化
epochs = np.arange(num_epochs)
plt.plot(epochs, train_loss, 'b-', label='Training loss')
plt.plot(epochs, test_loss, 'r--', label='Validation loss')
plt.title('Training And Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
epochs = np.arange(num_epochs)
plt.plot(epochs, train_acc, 'b-', label='Training acc')
plt.plot(epochs, test_acc, 'r--', label='Validation acc')
plt.title('Training And Validation acc')
plt.xlabel('Epochs')
plt.ylabel('acc')
plt.legend()
plt.show()
if __name__ == '__main__':
main()
结果展示
D:\softwares\anaconda3\envs\tfpt368\python.exe D:/PycharmProjects/sxlj/Recommendation/CF/LR/main.py
[Epoch:0001/0020]: 100%|██████████| 37/37 [00:00<00:00, 40.72it/s, log={'train_loss: 0.67913,train_acc:77.339%,F1: 63.968%,Recall:74.764%'}]
[Epoch:0001/0020]: 100%|██████████| 13/13 [00:00<00:00, 53.43it/s, log={'test_loss: 0.66259,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0002/0020]: 100%|██████████| 37/37 [00:00<00:00, 44.81it/s, log={'train_loss: 0.64821,train_acc:77.339%,F1: 68.450%,Recall:78.066%'}]
[Epoch:0002/0020]: 100%|██████████| 13/13 [00:00<00:00, 55.70it/s, log={'test_loss: 0.63518,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0003/0020]: 100%|██████████| 37/37 [00:00<00:00, 47.80it/s, log={'train_loss: 0.62221,train_acc:77.339%,F1: 65.347%,Recall:75.786%'}]
[Epoch:0003/0020]: 100%|██████████| 13/13 [00:00<00:00, 55.48it/s, log={'test_loss: 0.61244,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0004/0020]: 100%|██████████| 37/37 [00:00<00:00, 48.78it/s, log={'train_loss: 0.60060,train_acc:77.339%,F1: 69.637%,Recall:78.931%'}]
[Epoch:0004/0020]: 100%|██████████| 13/13 [00:00<00:00, 54.09it/s, log={'test_loss: 0.59368,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0005/0020]: 100%|██████████| 37/37 [00:00<00:00, 49.18it/s, log={'train_loss: 0.58267,train_acc:77.339%,F1: 69.529%,Recall:78.852%'}]
[Epoch:0005/0020]: 100%|██████████| 13/13 [00:00<00:00, 60.35it/s, log={'test_loss: 0.57828,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0006/0020]: 100%|██████████| 37/37 [00:00<00:00, 42.04it/s, log={'train_loss: 0.56785,train_acc:77.339%,F1: 65.773%,Recall:76.101%'}]
[Epoch:0006/0020]: 100%|██████████| 13/13 [00:00<00:00, 56.91it/s, log={'test_loss: 0.56569,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0007/0020]: 100%|██████████| 37/37 [00:01<00:00, 36.04it/s, log={'train_loss: 0.55563,train_acc:77.339%,F1: 65.666%,Recall:76.022%'}]
[Epoch:0007/0020]: 100%|██████████| 13/13 [00:00<00:00, 54.54it/s, log={'test_loss: 0.55540,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0008/0020]: 100%|██████████| 37/37 [00:00<00:00, 38.80it/s, log={'train_loss: 0.54552,train_acc:77.339%,F1: 66.092%,Recall:76.336%'}]
[Epoch:0008/0020]: 100%|██████████| 13/13 [00:00<00:00, 52.56it/s, log={'test_loss: 0.54695,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0009/0020]: 100%|██████████| 37/37 [00:01<00:00, 33.45it/s, log={'train_loss: 0.53709,train_acc:77.339%,F1: 67.698%,Recall:77.516%'}]
[Epoch:0009/0020]: 100%|██████████| 13/13 [00:00<00:00, 49.37it/s, log={'test_loss: 0.54005,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0010/0020]: 100%|██████████| 37/37 [00:00<00:00, 39.34it/s, log={'train_loss: 0.53003,train_acc:77.339%,F1: 66.413%,Recall:76.572%'}]
[Epoch:0010/0020]: 100%|██████████| 13/13 [00:00<00:00, 47.92it/s, log={'test_loss: 0.53433,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0011/0020]: 100%|██████████| 37/37 [00:00<00:00, 37.54it/s, log={'train_loss: 0.52403,train_acc:77.339%,F1: 66.840%,Recall:76.887%'}]
[Epoch:0011/0020]: 100%|██████████| 13/13 [00:00<00:00, 49.94it/s, log={'test_loss: 0.52961,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0012/0020]: 100%|██████████| 37/37 [00:00<00:00, 39.26it/s, log={'train_loss: 0.51891,train_acc:77.339%,F1: 66.840%,Recall:76.887%'}]
[Epoch:0012/0020]: 100%|██████████| 13/13 [00:00<00:00, 47.93it/s, log={'test_loss: 0.52559,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0013/0020]: 100%|██████████| 37/37 [00:01<00:00, 35.10it/s, log={'train_loss: 0.51445,train_acc:77.339%,F1: 66.733%,Recall:76.808%'}]
[Epoch:0013/0020]: 100%|██████████| 13/13 [00:00<00:00, 50.14it/s, log={'test_loss: 0.52217,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0014/0020]: 100%|██████████| 37/37 [00:00<00:00, 39.56it/s, log={'train_loss: 0.51053,train_acc:77.339%,F1: 66.626%,Recall:76.730%'}]
[Epoch:0014/0020]: 100%|██████████| 13/13 [00:00<00:00, 49.19it/s, log={'test_loss: 0.51921,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0015/0020]: 100%|██████████| 37/37 [00:01<00:00, 34.69it/s, log={'train_loss: 0.50701,train_acc:77.339%,F1: 66.840%,Recall:76.887%'}]
[Epoch:0015/0020]: 100%|██████████| 13/13 [00:00<00:00, 50.13it/s, log={'test_loss: 0.51664,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0016/0020]: 100%|██████████| 37/37 [00:00<00:00, 39.56it/s, log={'train_loss: 0.50383,train_acc:77.339%,F1: 68.342%,Recall:77.987%'}]
[Epoch:0016/0020]: 100%|██████████| 13/13 [00:00<00:00, 50.52it/s, log={'test_loss: 0.51434,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0017/0020]: 100%|██████████| 37/37 [00:01<00:00, 35.54it/s, log={'train_loss: 0.50090,train_acc:77.339%,F1: 69.313%,Recall:78.695%'}]
[Epoch:0017/0020]: 100%|██████████| 13/13 [00:00<00:00, 48.46it/s, log={'test_loss: 0.51225,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0018/0020]: 100%|██████████| 37/37 [00:00<00:00, 39.15it/s, log={'train_loss: 0.49816,train_acc:77.339%,F1: 67.698%,Recall:77.516%'}]
[Epoch:0018/0020]: 100%|██████████| 13/13 [00:00<00:00, 48.65it/s, log={'test_loss: 0.51036,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0019/0020]: 100%|██████████| 37/37 [00:01<00:00, 32.66it/s, log={'train_loss: 0.49559,train_acc:77.339%,F1: 68.020%,Recall:77.752%'}]
[Epoch:0019/0020]: 100%|██████████| 13/13 [00:00<00:00, 49.38it/s, log={'test_loss: 0.50860,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
[Epoch:0020/0020]: 100%|██████████| 37/37 [00:00<00:00, 39.33it/s, log={'train_loss: 0.49315,train_acc:77.339%,F1: 67.590%,Recall:77.437%'}]
[Epoch:0020/0020]: 100%|██████████| 13/13 [00:00<00:00, 50.32it/s, log={'test_loss: 0.50695,test_acc:77.332%,F1: 68.127%,Recall:77.830%'}]
Process finished with exit code 0
尝试调整batch,学习率等都不起作用