现在需要在pytorch环境下预测图片中生菜的种类、直径、干重、鲜重的信息
class文件里是不是应该写种类的分类,在预测时,predict里先预测种类然后再去预测其他信息?
我在训练过程中用到深度图,预测时是不是也应该引入深度图,而不是只读取RGB图像,我看其他predict里只读取了RGB图像
自己写predict时遇到的问题
- 写回答
- 好问题 0 提建议
- 追加酬金
- 关注问题
- 邀请回答
-
1条回答 默认 最新
关注 - 这篇博客: 用pytorch搭建简单的语义分割(可训练自己的数据集)中的 4、预测文件predict: 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
把待测图像放入samples文件夹中,输出结果在outputs文件夹中
predict的重点是:图像从model输出后得到的predict图像经过【第70行pr=predict.argmax(axis=-1)】压缩成一层,每个像素值为种类概率最高的该层的index,再遍历全部像素点与种类index进行匹配,匹配成功则涂上上对应的颜色。
from segnet_ import Airplanesnet from PIL import Image import numpy as np import torch import argparse import cv2 import copy import os parser = argparse.ArgumentParser() parser.add_argument('--samples', type=str, default='D:/untitled/.idea/SS_torch/samples', help='samples') parser.add_argument('--outputs', type=str, default='D:/untitled/.idea/SS_torch/outputs', help='outputs') parser.add_argument('--weights', type=str, default='D:/untitled/.idea/SS_torch/weights/SS_weight_3.pth', help='weights') opt = parser.parse_args() print(opt) colors = [[0,0,0],[255,0,0]] NCLASSES = 2 BATCH_SIZE=1 img_way=opt.samples img_save=opt.outputs device=torch.device("cuda:0"if torch.cuda.is_available() else "cpu") #检测是否有GPU加速 model=Airplanesnet(NCLASSES,BATCH_SIZE) #初始化model model.load_state_dict(torch.load(opt.weights)) #加载权重 model.to(device) #放入GPU for jpg in os.listdir(r"%s" %img_way): name = jpg[:-4] with torch.no_grad(): image=cv2.imread("%s" % img_way + "/" + jpg) old_image = copy.deepcopy(image) old_image = np.array(old_image) orininal_h = image.shape[0] #读取的图像的高 orininal_w = image.shape[1] #读取的图像的宽 方便之后还原大小 image = cv2.resize(image, dsize=(416, 416)) #调整大小 image = image / 255.0 #图像归一化 image = torch.from_numpy(image) image = image.permute(2, 0, 1) #显式的调转维度 image = torch.unsqueeze(image, dim=0) #改变维度,使得符合model input size image = image.type(torch.FloatTensor) #数据转换,否则报错 image = image.to(device) #放入GPU中计算 predict = model(image).cpu() # print(predict.shape) predict = torch.squeeze(predict) #[1,1,416,416]---->[1,416,416] predict =predict.permute(1, 2, 0) # print(jpg) predict = predict.numpy() # print(predict.shape) pr=predict.argmax(axis=-1) #把class数量的层压缩为一层,Z轴上的值概率最高的返回该层index seg_img = np.zeros((416, 416,3)) #创造三层0矩阵,方便进行涂色匹配 #进行染色 for c in range(NCLASSES): seg_img[:, :, 0] += ((pr[:, :] == c) * (colors[c][0])).astype('uint8') seg_img[:, :, 1] += ((pr[:, :] == c) * (colors[c][1])).astype('uint8') seg_img[:, :, 2] += ((pr[:, :] == c) * (colors[c][2])).astype('uint8') seg_img = cv2.resize(seg_img,(orininal_w,orininal_h)) seg_img = np.array(seg_img) # 原图和效果图叠加 result = cv2.addWeighted(seg_img, 0.3, old_image, 0.7, 0., old_image, cv2.CV_32F) cv2.imwrite("%s/%s" % (img_save, name) + ".jpg", result) print("%s.jpg ------>done!!!" % name)
预测结果:
解决 无用评论 打赏 举报- 这篇博客: 用pytorch搭建简单的语义分割(可训练自己的数据集)中的 4、预测文件predict: 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
悬赏问题
- ¥15 软件工程用例图的建立(相关搜索:软件工程用例图|画图)
- ¥15 如何在arcgis中导出拓扑关系表
- ¥15 处理数据集文本挖掘代码
- ¥15 matlab2017
- ¥15 在vxWorks下TCP/IP编程,总是connect()报错,连接服务器失败: errno = 0x41
- ¥15 AnolisOs7.9如何安装 Qt_5.14.2的运行库
- ¥20 求:怎么实现qt与pcie通信
- ¥50 前后端数据顺序不一致问题,如何解决?(相关搜索:数据结构)
- ¥15 基于蒙特卡罗法的中介效应点估计代码
- ¥15 罗技G293和UE5.3