import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
class MLP(nn.Module):
def __init__(self):
super(MLP,self).__init__()
self.model=nn.Sequential(#sequential串联起来
nn.Linear(784,200),
nn.LeakyReLU(inplace=True),
nn.Linear(200, 200),
nn.LeakyReLU(inplace=True),
nn.Linear(200,10),
nn.LeakyReLU(inplace=True),
)
def forward(self,x):
x = self.model(x)
return x
device=torch.device('cuda:0')
net = MLP().to(device) # 网络结构 就是foward函数
optimizer=optim.SGD(net.parameters(),lr=learning_rate) # 使用nn.Module可以直接代替之前[w1,b1,w2,b2.。。]
criteon=nn.CrossEntropyLoss().to(device)
for data,target in test_loader:
data=data.view(-1,28*28)
data, target = data.to(device), target.to(device)
logits = net(data)
logits.data和logits有什么区别呀 这个data有啥用啊 我用torch.all(torch.eq(logits,logits.data))
发现结果是true 好像没啥区别