遇到报错:
line 57, in forward
input_data = input_data.permute(2, 1, 0)
RuntimeError: number of dims don't match in permute
报错部分代码:
def forward(self, input_data):
# input_data: Point Cloud having shape input_shape.
# output: PointNet features (Batch x emb_dims)
if self.input_shape == "bnc":
num_points = input_data.shape[1]
input_data = input_data.permute(0, 2, 1) #报错地方
else:
num_points = input_data.shape[2]
if input_data.shape[1] != 3:
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
output = input_data
for idx, layer in enumerate(self.layers):
output = layer(output)
return output
查看了input_data的维数,得到torch.Size([1, 3]) 确实是三维的
读取数据部分代码
if args.user_data:
source_path = os.path.join(r'E:\bunny\data\bun000.ply') # The source point cloud is a rotated and offset defect
# source=s3.float()
source_data = o3d.io.read_point_cloud(source_path)
points1 = np.array(source_data.points)
idx1 = np.arange(points1.shape[0])
np.random.shuffle(idx1)
source = points1[idx1[:args.num_points]]
template_path = os.path.join(r'E:\bunny\data\bun045.ply') # Template point cloud is complete
# template=s5.float()
template_data = o3d.io.read_point_cloud(template_path)
points2 = np.array(template_data.points)
idx2 = np.arange(points2.shape[0])
np.random.shuffle(idx2)
template = points2[idx2[:args.num_points]]
testset = UserData(template=template, source=source, tpcc=None, igt=None)
source和template都是三维的(如下所示)
[ 0.0585 0.0808363 0.0858177]
[ 0.0465 0.0650293 0.0904855]
[-0.00725 0.126023 0.0352551]
if __name__ == '__main__':
# Test the code.
x = torch.rand((10,1024,3))
pn = PointNet(use_bn=True)
y = pn(x)
print("Network Architecture: ")
print(pn)
print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)