问题遇到的现象和发生背景
监视器显示在20行,出现问题,求问该如何改写
问题相关代码,请勿粘贴截图
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import norm
from matplotlib import cm
class Net(nn.Module):
# NL: the number of hidden layers
# NN: the number of vertices in each layer
def __init__(self, NL, NN):
super(Net, self).__init__()
self.input_layer = nn.Linear(3, NN)
self.hidden_layers = nn.ModuleList([nn.Linear(NN, NN) for i in range(NL)])
self.output_layer = nn.Linear(NN, 1)
def forward(self, x):
o = self.act(self.input_layer(x))
for i, li in enumerate(self.hidden_layers):
o = self.act(li(o))
out = self.output_layer(o)
return out
def act(self, x):
return x * torch.sigmoid(x)
运行结果及报错内容
runfile('C:/Users/Tian Yi/Desktop/HEAT2.py', wdir='C:/Users/Tian Yi/Desktop')
File "C:\Users\Tian Yi\Desktop\HEAT2.py", line 20
class Net(nn.Module):
^
SyntaxError: import * only allowed at module level