RicardoM.Lu1 2023-04-13 13:17 采纳率: 80.6%
浏览 164
已结题

在给模型传入数据时出现问题(经过调试,数据已经传入了模型处理,但是到了某一步的时候变成了None)

在给模型传入数据时出现问题(经过调试,数据已经传入了模型处理,但是到了某一步的时候变成了None):
代码如下:


import torch
import json
import base64
from seal_bak.seal_recognition import work
import os
import numpy as np
import cv2

import matplotlib.pyplot as plt

# root_path = os.getcwd()

# model_path = root_path + '/models/seal_detect_best.pt'
model_path = 'D:/test11/yolov5/models/seal_detect_best.pt'


class Infer_main:
    def __init__(self, yolo_model_path):
        self.model_path = yolo_model_path
        self.seal_infer = self.yolo_model_load()

    def yolo_model_load(self):
        model = torch.hub.load('D:/test11/yolov5',
                               'custom',
                               path='D:/test11/yolov5/models/seal_detect_best.pt',
                               source='local',
                               force_reload=True)  # local repo

        model.conf = 0.4
        model.eval()
        return model

    def predict(self, images):
        result = {
            "base64": images,
            "debug": False  # debug模式将可视化各环节,否则只输出结果
        }
        # for item, i in enumerate(images):
        yolo_res = self.seal_infer(images)
        print(f'yolo_res:{yolo_res}')
        yolo_res_list = yolo_res.pandas().xyxy[0].values.tolist()
        if yolo_res_list:
            for i in range(len(yolo_res_list)):
                # print(yolo_res_list[i])
                yolo_res_liststr = json.dumps(yolo_res_list[i])
                # print(type(yolo_res_liststr))
                base64str = base64.b64encode(yolo_res_liststr.encode('utf-8'))
                # img_array = np.fromstring(base64str, np.uint8)
                # print(type(base64str))
                result['base64'] = base64str
                result_rec = work(result)
                print(result_rec)
                return result_rec


if __name__ == '__main__':
    from PIL import Image
#
# model_path = 'D:/test11/yolov5/models/seal_detect_best.pt'
# image = plt.imread('D:/test11/yolov5/8.png')
# image = cv2.imread('D:/test11/yolov5/8.png')
image = Image.open('D:/test11/yolov5/1.png')
a = Infer_main(model_path)
# print()
# a.seal_infer(image).show()
a.predict(image)
#     model = yolo_model_load(model_path)
# yolo_res = model(image, size=640)
#     yolo_res.show()
#     yolo_res_list = yolo_res.pandas().xyxy[0].values.tolist()
#     print(type(yolo_res_list))
#     print('印章识别结果:', yolo_res_list)

报错信息如下:


Traceback (most recent call last):
  File "D:\test11\yolov5\test1.py", line 66, in <module>
    a.predict(image)
  File "D:\test11\yolov5\test1.py", line 51, in predict
    result_rec = work(result)
  File "D:\test11\seal_bak\seal_recognition.py", line 116, in work
    category, img = model.predict(cfg['base64'])
  File "D:\test11\seal_bak\utils\Model.py", line 38, in predict
    Tensor, img = self.__img2tensor(ImgPath)
  File "D:\test11\seal_bak\utils\Model.py", line 58, in __img2tensor
    Img = Transfm(OImg)
  File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
    img = t(img)
  File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\transforms.py", line 135, in __call__
    return F.to_tensor(pic)
  File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\functional.py", line 137, in to_tensor
    raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
TypeError: pic should be PIL Image or ndarray. Got <class 'NoneType'>
Process finished with exit code 1
该报错为代码逻辑,数据类型不对或没有则返回这段

模型处理代码如下:

class Model():
    def __init__(self, choice=2):
        if choice == 0:
            model = network.SPP_Net1()
            Dict = SPPDict1
        elif choice == 1:
            model = network.SPP_Net2()
            Dict = SPPDict2
        else:
            model = network.SPP_Net3()
            Dict = SPPDict3
        if self.Device() == "cuda":
            model.load_state_dict(torch.load(Dict))
        else:
            model.load_state_dict(torch.load(Dict, map_location=torch.device('cpu')))
        # print("Model loaded!")
        self.model = model

    def Device(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        return device

    def predict(self, ImgPath):
        Tensor, img = self.__img2tensor(ImgPath)
        with torch.no_grad():
            pred = self.model(Tensor)
        result = self.__arg2label(pred.argmax())
        # print(result)
        return result, img

    #     def __img2tensor(self,ImgPath):
    #         OImg = cv2.imread(ImgPath) #origin图片
    #         Transfm = Compose([transforms.ToTensor()])
    #         Img = Transfm(OImg)
    #         Img = Img.unsqueeze(0)
    #         return Img,OImg
    def __img2tensor(self, base64_data):
        img_b64decode = base64.b64decode(base64_data)  # base64解码
        # 将base64转换numpy array
        img_array = np.fromstring(img_b64decode, np.uint8)
        # 转换cv2
        OImg = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)
        Transfm = Compose([transforms.ToTensor()])
        Img = Transfm(OImg)
        Img = Img.unsqueeze(0)
        return Img, OImg
    def __arg2label(self, arg):
        return LabelDict[str(arg.item())]
根据debug调试,是到了__img2tensor函数中出现的问题,在OImg = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)这步时 OImg得到的是None

模型逻辑代码和所需数据如下:

import re
import time
import traceback
import base64
from seal_bak.utils.ocr import ellipse_ocr, circle_ocr, rectangle_ocr
from seal_bak.utils.detect import *
from seal_bak.utils.Model import Model
import logging
import copy
import numpy as np

model = Model(0)

CODE_SCORE = 0  # 0.98 0.97
NAME_SCORE = 0  # 0.9


def get_result(res, category):
    en_or_num = re.compile(r"[a-zA-Z0-9]", re.I)  # [a-z]|\d
    not_en_nor_num = re.compile(r"[^a-zA-Z0-9]", re.I)  # 非英文数字

    null_text = {"text": "", "confidence": 0}
    result = {
        "code": null_text,  # 印章编号
        "name": null_text,  # 印章名称
        "SealType": "99",  # 印章类型代码
        "strSealType": "其他类型印章"  # 印章类型名称
    }
    seal_type = {
        "法定名称章": "01",
        "财务专用章": "02",
        # "发票专用章": "03",
        "合同专用章": "04",
        "未找到印章": "06",
        # "法定代表人名章": "05",
        "其他类型印章": "99"

    }

    #     logging.info(f"<{[str(item['text']) + '|' + str(item['confidence']) for item in res]}>")

    if category == "正方形":
        result["SealType"] = "05"
        result["strSealType"] = "法定代表人名章"
        result["code"] = {
            "text": res[0]["text"], "confidence": res[0]["confidence"]
        } if res[0]["confidence"] >= CODE_SCORE else null_text  # 设定编号
        result["name"] = {
            "text": res[1]["text"], "confidence": res[1]["confidence"]
        } if res[1]["confidence"] >= NAME_SCORE else null_text  # 设定name
    elif category == "椭圆":  # todo 调参
        result["SealType"] = "03"
        result["strSealType"] = "发票专用章"
        for item in res:
            text = item["text"]
            score = item["confidence"]
            from_ = item["from"]
            if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
                result["code"] = {
                    "text": text, "confidence": score
                } if score >= CODE_SCORE and from_ == "center" else null_text  # 设定编号
            else:
                result["name"] = {
                    "text": text, "confidence": score
                } if score >= NAME_SCORE and from_ == "side" else null_text  # 设定name
    elif len(res) == 2:
        result["SealType"] = "01"
        result["strSealType"] = "法定名称章"
        for item in res:
            text = item["text"]
            score = item["confidence"]
            from_ = item["from"]
            if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
                result["code"] = {
                    "text": text, "confidence": score
                } if score >= CODE_SCORE and from_ == "side" else null_text  # 设定编号
            else:
                result["name"] = {
                    "text": text, "confidence": score
                } if score >= NAME_SCORE and from_ == "side" else null_text  # 设定name
    elif category == False:
        result["SealType"] = "06"
        result["strSealType"] = "未找到印章"
    else:
        set_type = True
        for item in res:
            text = item["text"]
            score = item["confidence"]
            from_ = item["from"]
            if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
                result["code"] = {
                    "text": text, "confidence": score
                } if score >= CODE_SCORE and from_ == "side" else null_text  # 设定编号
            else:
                if set_type and from_ == "center":
                    # 设定type
                    if text not in seal_type.keys():
                        result["SealType"] = "99"
                        result["strSealType"] = "其他类型印章"  # "其他类型印章"
                    else:
                        result["SealType"] = seal_type[text]
                        result["strSealType"] = text
                    set_type = False
                else:
                    # 设定name
                    result["name"] = {
                        "text": text, "confidence": score
                    } if score >= NAME_SCORE and from_ == "side" else null_text  # 设定name
    return result


def work(cfg):
    t0 = time.time()
    # 分类
    #     category, img = model.predict(cfg['img_path'])
    category, img = model.predict(cfg['base64'])
    img = enlarge_img(img)
    img_ = np.copy(img)
    # 提取红色
    t1 = time.time()
    img_, img_bw = Lazyfilter(img_, cfg)
    # 开运算去噪填充
    t2 = time.time()
    img_ = erode_dilate(img_, category, cfg=cfg)
    # 查找最大轮廓
    t3 = time.time()
    contours, max_idx = find_max(img_)
    # 完成检测框
    t4 = time.time()
    det = get_area(img, contours, max_idx, category, cfg=cfg)  # 不进行分类
    # 截取目标区域  img:原图 img_bw:二值
    t5 = time.time()
    img = rotate_cut(img, det, cfg=cfg)
    # ============分类处理目标区域
    t6 = time.time()
    # if category == False:
    #     empty_ocr()
    try:
        if category == '圆形':
            res = circle_ocr(img, cfg=cfg)  # 使用二值图预测
        elif category == "正方形":
            res = rectangle_ocr(img, cfg=cfg)
        else:
            res = ellipse_ocr(img, cfg=cfg)
    except Exception:
        print("未找到印章")


    # 将文本列表处理成标准json格式
    try:
        result = get_result(res, category)
    except Exception:
        pass
    try:
        result_string = json.dumps(result, ensure_ascii=False)
    except Exception:
        pass
    if not cfg["debug"]:
        # with open(os.path.join(cfg["to_path"], cfg["file_name"] + ".json"), mode="w", encoding="utf-8") as f:
        #     json.dump(result, f, ensure_ascii=False)
        pass
    t7 = time.time()

    #     detail = f"{cfg['file_name']} done in {round(t7 - t0, 2)}s. " \
    #              f"class={det['class']}, text={result_string}"
    #     logging.info(detail)
    print(t7 - t1)
    try:
        return result_string
    except Exception:
        pass

Debug调试的结果如下(其余步骤正常我就省略了):

img

  • 写回答

6条回答 默认 最新

  • 爱晚乏客游 2023-04-13 13:28
    关注

    你说有问题的这句换成下面的看下

    OImg = cv2.imdecode(img_array,flags=cv2.IMREAD_COLOR)
    OImg =cv2.cvtColor(OImg,cv2.COLOR_BGR2RGB)
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(5条)

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 4月14日
  • 已采纳回答 4月14日
  • 修改了问题 4月13日
  • 修改了问题 4月13日
  • 展开全部

悬赏问题

  • ¥15 请问Ubuntu要怎么安装chrome呀?
  • ¥15 视频编码 十六进制问题
  • ¥15 Xsheii7我安装这个文件的时候跳出来另一个文件已锁定文件的无一部分进程无法访问。这个该怎么解决
  • ¥15 unity terrain打包后地形错位,跟建筑不在同一个位置,怎么办
  • ¥15 FileNotFoundError 解决方案
  • ¥15 uniapp实现如下图的图表功能
  • ¥15 u-subsection如何修改相邻两个节点样式
  • ¥30 vs2010开发 WFP(windows filtering platform)
  • ¥15 服务端控制goose报文控制块的发布问题
  • ¥15 学习指导与未来导向啊