·炭烤小肥 2022-05-20 10:41 采纳率: 100%
浏览 43
已结题

DGCNN.pytorch在S3DIS上运行错误,如何解决?

报错如下:

> Namespace(batch_size=32, dataset='S3DIS', dropout=0.5, emb_dims=1024, epochs=100, eval=True, exp_name='semseg_eval', k=20, lr=0.001, model='dgcnn', model_root='outputs/semseg_6/models/', momentum=0.9, no_cuda=False, num_points=4096, scheduler='cos',
 seed=1, test_area='all', test_batch_size=16, use_sgd=True, visu='', visu_format='ply')
Using GPU : 0 from 1 devices
Test :: test area: 1, test acc: 0.893883, test avg acc: 0.805912, test iou: 0.702250
Overall Test :: test acc: 0.893883, test avg acc: 0.805912, test iou: 0.702250
Test :: test area: 2, test acc: 0.833375, test avg acc: 0.557771, test iou: 0.444156
Traceback (most recent call last):
  File "main_semseg.py", line 454, in <module>
    test(args, io)
  File "main_semseg.py", line 369, in test
    all_true_cls.append(test_true_cls)
AttributeError: 'numpy.ndarray' object has no attribute 'append'

代码如下:

def test(args, io):
    all_true_cls = []
    all_pred_cls = []
    all_true_seg = []
    all_pred_seg = []
    for test_area in range(1, 7):
        visual_file_index = 0
        test_area = str(test_area)
        if os.path.exists("data/indoor3d_sem_seg_hdf5_data_test/room_filelist.txt"):
            with open("data/indoor3d_sem_seg_hdf5_data_test/room_filelist.txt") as f:
                for line in f:
                    if (line[5]) == test_area:
                        break
                    visual_file_index = visual_file_index + 1
        if (args.test_area == 'all') or (test_area == args.test_area):
            test_loader = DataLoader(S3DIS(partition='test', num_points=args.num_points, test_area=test_area),
                                     batch_size=args.test_batch_size, shuffle=False, drop_last=False)

            device = torch.device("cuda" if args.cuda else "cpu")

            # Try to load models
            semseg_colors = test_loader.dataset.semseg_colors
            if args.model == 'dgcnn':
                model = DGCNN_semseg(args).to(device)
            else:
                raise Exception("Not implemented")

            model = nn.DataParallel(model)
            model.load_state_dict(torch.load(os.path.join(args.model_root, 'model_%s.t7' % test_area)))
            model = model.eval()
            test_acc = 0.0
            count = 0.0
            test_true_cls = []
            test_pred_cls = []
            test_true_seg = []
            test_pred_seg = []
            with torch.no_grad():
                for data, seg in test_loader:
                    data, seg = data.to(device), seg.to(device)
                    data = data.permute(0, 2, 1)
                    batch_size = data.size()[0]
                    seg_pred = model(data)
                    seg_pred = seg_pred.permute(0, 2, 1).contiguous()
                    pred = seg_pred.max(dim=2)[1]
                    seg_np = seg.cpu().numpy()
                    pred_np = pred.detach().cpu().numpy()
                    test_true_cls.append(seg_np.reshape(-1))
                    test_pred_cls.append(pred_np.reshape(-1))
                    test_true_seg.append(seg_np)
                    test_pred_seg.append(pred_np)
                    # visiualization
                    visualization(args.visu, args.visu_format, args.test_area, data, seg, pred, visual_file_index,
                                  semseg_colors)
                    visual_file_index = visual_file_index + data.shape[0]
                if visual_warning and args.visu != '':
                    print(
                        'Visualization Failed: You can only choose a room to visualize within the scope of the test area')
                test_true_cls = np.concatenate(test_true_cls)
                test_pred_cls = np.concatenate(test_pred_cls)
                test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
                avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls)
                test_true_seg = np.concatenate(test_true_seg, axis=0)
                test_pred_seg = np.concatenate(test_pred_seg, axis=0)
                test_ious = calculate_sem_IoU(test_pred_seg, test_true_seg)
                outstr = 'Test :: test area: %s, test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (test_area,
                                                                                                        test_acc,
                                                                                                        avg_per_class_acc,
                                                                                                        np.mean(
                                                                                                            test_ious))
                io.cprint(outstr)
                all_true_cls.append(test_true_cls)
                all_pred_cls.append(test_pred_cls)
                all_true_seg.append(test_true_seg)
                all_pred_seg.append(test_pred_seg)

        if args.test_area == 'all':
            all_true_cls = np.concatenate(all_true_cls)
            all_pred_cls = np.concatenate(all_pred_cls)
            all_acc = metrics.accuracy_score(all_true_cls, all_pred_cls)
            avg_per_class_acc = metrics.balanced_accuracy_score(all_true_cls, all_pred_cls)
            all_true_seg = np.concatenate(all_true_seg, axis=0)
            all_pred_seg = np.concatenate(all_pred_seg, axis=0)
            all_ious = calculate_sem_IoU(all_pred_seg, all_true_seg)
            outstr = 'Overall Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (all_acc,
                                                                                             avg_per_class_acc,
                                                                                             np.mean(all_ious))
            io.cprint(outstr)

请问该如何修改呢?

  • 写回答

3条回答

      报告相同问题?

      相关推荐 更多相似问题

      问题事件

      • 系统已结题 6月4日
      • 已采纳回答 5月27日
      • 创建了问题 5月20日

      悬赏问题

      • ¥15 postman测试正常,在代码运行报错
      • ¥15 关于#C语言#的问题,如何解决?
      • ¥20 Vs2017 Help Viewer2.3 问题
      • ¥35 基于嵌入式linux的日程管理软件
      • ¥50 如何将list字符串添加到CSV文件表头?
      • ¥15 关于#javascript#的问题:通过ajax实现的局部刷新 如何将项目打包
      • ¥15 海思uboot USB3.0无法识别
      • ¥15 无法调用库文件,自己可以找到,但编译时显示没有
      • ¥15 安装PyQt5的时候这里创建虚拟环境是哪里?具体是怎么的?能录个视频吗
      • ¥20 php程序设计题不会!求解答!