盛季,夏开 2020-02-02 16:20 采纳率: 0%
浏览 1419

mxnet中'DataLoader' object is not callable是什么情况,我是按书这么写的

这是一个高位线性回归实验的代码,具体在动手学深度学习mxnet版的p66-p68页
%matplotlib inline
import d2lzh as d2l
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn

n_train, n_test, num_inputs = 20, 100, 200
true_w, true_b = nd.ones((num_inputs, 1)) * 0.01, 0.05

features = nd.random.normal(shape=(n_train + n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

def init_params():
w = nd.random.normal(scale=1, shape=(num_inputs, 1))
b = nd.zeros(shape=(1,))
w.attach_grad()
b.attach_grad()
return [w,b]

def l2_penalty(w):
return (w**2).sun() / 2

batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_loss
train_iter = gdata.DataLoader(gdata.ArrayDataset(train_features, train_labels), batch_size, shuffle=True, num_workers=0)

def fit_and_plot(lambd):
w, b = init_params()
train_ls, test_ls = [], []
for _ in range(num_epochs):
for X, y in train_iter():
with autograd.record():
#添加了L2范数惩罚项
l = loss(net(X, w, b), y) + lambd * l2_penalty(w)
l.backward()
d2l.sgd([w, b], lr, batch_size)
train_ls.append(loss(net(train_features, w, b), train_labels).mean().asscalar())
test_ls.append(loss(net(test_features, w, b), test_labels).mean().asscalar())
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', range(1, num_epochs + 1), test_ls, ['train', ' test'])
print('L2 norm of w:', w.norm().asscalar())

fit_and_plot(lambd=0)

这一些代码编译之后

出现是这个

TypeError Traceback (most recent call last)
in
----> 1 fit_and_plot(lambd=0)

in fit_and_plot(lambd)
7 train_ls, test_ls = [], []
8 for _ in range(num_epochs):
----> 9 for X, y in train_iter():
10 with autograd.record():
11 #添加了L2范数惩罚项

TypeError: 'DataLoader' object is not callable

想问问 各位大佬,谢谢了

  • 写回答

1条回答 默认 最新

  • dabocaiqq 2020-02-02 23:32
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 如何在scanpy上做差异基因和通路富集?
  • ¥20 关于#硬件工程#的问题,请各位专家解答!
  • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
  • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
  • ¥30 截图中的mathematics程序转换成matlab
  • ¥15 动力学代码报错,维度不匹配
  • ¥15 Power query添加列问题
  • ¥50 Kubernetes&Fission&Eleasticsearch
  • ¥15 報錯:Person is not mapped,如何解決?
  • ¥15 c++头文件不能识别CDialog