试验LSTM时遇到一个问题,下面的代码报错,TypeError: view() takes from 0 to 2 positional arguments but 3 were given。查了很多网页,都是相同的代码。是我的版本问题吗?

试验LSTM时遇到一个问题,下面的代码报错,TypeError: view() takes from 0 to 2 positional arguments but 3 were given。查了很多网页,都是相同的代码。是我的版本问题吗?

关注从报错信息来看,并不是 PyTorch 自身的 view() 方法不支持多维形状参数——PyTorch 的 tensor.view(x, y, z, …) 一直都是允许的。通常出现 view() takes from 0 to 2 positional arguments but 3 were given 这样的报错,多半是因为:
input_seq 不是一个 PyTorch Tensor,而是一个 Python 列表或 NumPy 数组; view 的函数或变量),覆盖了原本的 tensor.view()。下面给出一些排查思路和解决方案:
input_seq 的类型在调用 input_seq.view(...) 之前,先打印或断言一下类型:
print(type(input_seq))
# 或者
assert isinstance(input_seq, torch.Tensor)
input_seq 的类型是 <class 'numpy.ndarray'>,那就说明你拿的是一个 NumPy 数组。在 NumPy 中,.view() 的含义是“以不同的 dtype/内存视图”来查看数组,而不是用来改 shape。因此会报参数不匹配的错误。 <class 'list'>,那就更不支持 .view() 了。 解决方法:将 input_seq 转为 PyTorch Tensor,再调用 .view()。例如:
import torch
if not isinstance(input_seq, torch.Tensor):
input_seq = torch.tensor(input_seq, dtype=torch.float32)
lstm_out, self.hidden_cell = self.lstm(
input_seq.view(len(input_seq), 1, -1),
self.hidden_cell
)
有时,如果你在某个作用域里写了类似:
view = some_function # 或者
from some_module import view
这样就会导致 input_seq.view(...) 实际上是调用别的 view(),从而产生参数错误。可以用 print(input_seq.view) 看看它到底指向什么对象,或者留意自己的命名空间中是否做了类似导入/变量赋值。
PyTorch 很早就支持 tensor.view(dim1, dim2, ...) 这种写法了,并没有出现 “只能支持一两个维度” 的旧版本。即使是 PyTorch 0.4 时代,也已经是这种 view 的用法。所以“版本过老”导致此问题的可能性并不大(除非你用的是非常非常老的 PyTorch 0.3 甚至之前,但那时候的报错也一般不会是 view() takes from 0 to 2 positional arguments... 这种形式)。
input_seq 并不是一个 PyTorch Tensor,而是 NumPy 数组或 Python 列表,导致 .view(...) 语义不对。 .view() 之前先 torch.tensor(...) 转为 Tensor,或者如果是 NumPy 数据可以用 .reshape(...)(NumPy 的方法)来调整维度。 view 是否被其他同名符号遮蔽,也能避免一些低级的命名冲突。