拾3 2021-01-30 16:21 采纳率: 0%
浏览 69

c++ 调用python代码 深度学习

model = GetModel(opt)
    parameters_name = 'C:/liaomingwei/Seg_prediction/checkpoints/007_LinHuiMin_FocalWDice2000_2_epoch_3.ckpt'
    model_CKPT = torch.load(parameters_name)
    model.load_state_dict(model_CKPT['state_dict'])
    #tifffile.imsave('E:\\test.tif', image)
    print('loading checkpoint!')
    patch_size = np.array([120, 120, 120])
    overlap = np.array([10, 10, 10])
    patch_indices = compute_patch_indices(image.shape, patch_size, overlap, start=0)
    #return image
    Tdataset = GenerateDataset_ForNew(image, image.shape, patch_size, overlap)
    val_loader = DataLoader(Tdataset, batch_size=3, num_workers=2, shuffle=False)
    #print(val_loader)
    prob_patches = test_model(val_loader, model)

    start = time.time()
    patchindices = compute_patch_indices(image.shape, patch_size, overlap, start=0)
    prob_recon = reconstruct_from_patches(prob_patches, patch_indices, image.shape)
    print('run here3!')
    prob_recon = prob_recon * 255
    prob_recon = prob_recon.astype(np.uint8)
    run_time = time.time() - start
    print(run_time)
    print('ok')
    return prob_recon


def test_model(val_loader,model):
    model.eval()

    pred_Patches = []
    prob_patches =[]
    # image_Patches = []

    soft_max = nn.Softmax(dim=1)

    start_time=time.time()


    for batch_ids, (image_patch) in enumerate(val_loader):
        print('run here2!')
        print(batch_ids)

        if opt.use_cuda:
            image_patch=image_patch.cuda()

            output=model(image_patch)

            with torch.no_grad():
                # just 0 and 1
                _,pred_patch = torch.max(output,dim=1)

                # for prob
                prob_patch = soft_max(output)

                del output

                pred_patch=pred_patch.cpu().numpy()
                prob_patch = prob_patch.cpu().numpy()
                prob_patch = prob_patch[:, 1, ...]

                # image_patch=image_patch.cpu().numpy()


                for image_num in range(pred_patch.shape[0]):
                    # 0 and 1
                    pred1=np.array(pred_patch[image_num, :, :, :], dtype=np.float32)
                    pred_Patches.append( pred1 )

                    # prob
                    prob1 = np.array(prob_patch[image_num, :, :, :], dtype=np.float32)
                    prob_patches.append(prob1)
                    #tifffile.imsave(os.path.join('D:\Data\Osten\\6\Osten' + '\prob', str(batch_ids)+'_'+str(image_num) + '_probtest.tif'), np.uint8(prob1*255))

                    # image1=np.array(image_patch[image_num,0,:,:,:], dtype=np.float32)
                    # image_Patches.append(image1)

    # save the images into npy type
    run_time=time.time()-start_time
    print(run_time)

    return prob_patches

c++调用python运行到test_model函数里面的for batch_ids, (image_patch) in enumerate(val_loader):就不运行了,有大佬知道吗

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-09-07 17:11
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 以帮助更多的人 ^-^
    评论

报告相同问题?

悬赏问题

  • ¥15 is not in the mmseg::model registry。报错,模型注册表找不到自定义模块。
  • ¥15 安装quartus II18.1时弹出此error,怎么解决?
  • ¥15 keil官网下载psn序列号在哪
  • ¥15 想用adb命令做一个通话软件,播放录音
  • ¥30 Pytorch深度学习服务器跑不通问题解决?
  • ¥15 部分客户订单定位有误的问题
  • ¥15 如何在maya程序中利用python编写领子和褶裥的模型的方法
  • ¥15 Bug traq 数据包 大概什么价
  • ¥15 在anaconda上pytorch和paddle paddle下载报错
  • ¥25 自动填写QQ腾讯文档收集表