weixin_51497416 2023-06-12 00:07 采纳率: 90%
浏览 8

自己写predict时遇到的问题

现在需要在pytorch环境下预测图片中生菜的种类、直径、干重、鲜重的信息
class文件里是不是应该写种类的分类,在预测时,predict里先预测种类然后再去预测其他信息?
我在训练过程中用到深度图,预测时是不是也应该引入深度图,而不是只读取RGB图像,我看其他predict里只读取了RGB图像

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-06-12 06:02
    关注
    • 这篇博客: 用pytorch搭建简单的语义分割(可训练自己的数据集)中的 4、预测文件predict: 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
    • 把待测图像放入samples文件夹中,输出结果在outputs文件夹中

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KxaWFVZU-1594133488097)(C:\Users\86152\Desktop\博客\4.png)]

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ucszxAAb-1594133488098)(C:\Users\86152\Desktop\博客\5.png)]

      predict的重点是:图像从model输出后得到的predict图像经过【第70行pr=predict.argmax(axis=-1)】压缩成一层,每个像素值为种类概率最高的该层的index,再遍历全部像素点与种类index进行匹配,匹配成功则涂上上对应的颜色。

      from segnet_ import Airplanesnet
      from PIL import Image
      import numpy as np
      import torch
      import argparse
      import cv2
      import copy
      import os
      
      
      parser = argparse.ArgumentParser()
      parser.add_argument('--samples', type=str, default='D:/untitled/.idea/SS_torch/samples', help='samples')
      parser.add_argument('--outputs', type=str, default='D:/untitled/.idea/SS_torch/outputs', help='outputs')
      parser.add_argument('--weights', type=str, default='D:/untitled/.idea/SS_torch/weights/SS_weight_3.pth', help='weights')
      opt = parser.parse_args()
      print(opt)
      
      
      colors = [[0,0,0],[255,0,0]]
      NCLASSES = 2
      BATCH_SIZE=1
      
      img_way=opt.samples
      img_save=opt.outputs
      
      
      device=torch.device("cuda:0"if torch.cuda.is_available() else "cpu")   #检测是否有GPU加速
      
      model=Airplanesnet(NCLASSES,BATCH_SIZE)             #初始化model
      
      model.load_state_dict(torch.load(opt.weights))     #加载权重
      
      model.to(device)     #放入GPU
      
      for jpg in  os.listdir(r"%s" %img_way):
      
          name = jpg[:-4]
          with torch.no_grad():
      
              image=cv2.imread("%s" % img_way + "/" + jpg)
              old_image = copy.deepcopy(image)
              old_image = np.array(old_image)
              orininal_h = image.shape[0]       #读取的图像的高
              orininal_w = image.shape[1]       #读取的图像的宽   方便之后还原大小
      
              image = cv2.resize(image, dsize=(416, 416))         #调整大小
              image = image / 255.0                   #图像归一化
              image = torch.from_numpy(image)
              image = image.permute(2, 0, 1)                #显式的调转维度
      
      
              image = torch.unsqueeze(image, dim=0)           #改变维度,使得符合model input size
              image = image.type(torch.FloatTensor)         #数据转换,否则报错
              image = image.to(device)                      #放入GPU中计算
      
      
              predict = model(image).cpu()
              # print(predict.shape)
      
              predict = torch.squeeze(predict)            #[1,1,416,416]---->[1,416,416]
              predict =predict.permute(1, 2, 0)
              # print(jpg)
      
              predict = predict.numpy()
              # print(predict.shape)
      
              pr=predict.argmax(axis=-1)                     #把class数量的层压缩为一层,Z轴上的值概率最高的返回该层index
      
      
              seg_img = np.zeros((416, 416,3))        #创造三层0矩阵,方便进行涂色匹配
      
              #进行染色
              for c in range(NCLASSES):
      
                  seg_img[:, :, 0] += ((pr[:, :] == c) * (colors[c][0])).astype('uint8')
                  seg_img[:, :, 1] += ((pr[:, :] == c) * (colors[c][1])).astype('uint8')
                  seg_img[:, :, 2] += ((pr[:, :] == c) * (colors[c][2])).astype('uint8')
      
      
              seg_img = cv2.resize(seg_img,(orininal_w,orininal_h))
              seg_img = np.array(seg_img)
      
              # 原图和效果图叠加
              result = cv2.addWeighted(seg_img, 0.3, old_image, 0.7, 0., old_image, cv2.CV_32F)
              cv2.imwrite("%s/%s" % (img_save, name) + ".jpg", result)
              print("%s.jpg  ------>done!!!" % name)
      

      预测结果:

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-anQ4GWJE-1594133488099)(C:\Users\86152\Desktop\博客\6.jpg)]


    评论

报告相同问题?

问题事件

  • 创建了问题 6月12日

悬赏问题

  • ¥15 软件工程用例图的建立(相关搜索:软件工程用例图|画图)
  • ¥15 如何在arcgis中导出拓扑关系表
  • ¥15 处理数据集文本挖掘代码
  • ¥15 matlab2017
  • ¥15 在vxWorks下TCP/IP编程,总是connect()报错,连接服务器失败: errno = 0x41
  • ¥15 AnolisOs7.9如何安装 Qt_5.14.2的运行库
  • ¥20 求:怎么实现qt与pcie通信
  • ¥50 前后端数据顺序不一致问题,如何解决?(相关搜索:数据结构)
  • ¥15 基于蒙特卡罗法的中介效应点估计代码
  • ¥15 罗技G293和UE5.3