RicardoM.Lu1 2023-04-13 13:17 采纳率: 80.6%
浏览 164
已结题

在给模型传入数据时出现问题(经过调试,数据已经传入了模型处理,但是到了某一步的时候变成了None)

在给模型传入数据时出现问题(经过调试,数据已经传入了模型处理,但是到了某一步的时候变成了None):
代码如下:


import torch
import json
import base64
from seal_bak.seal_recognition import work
import os
import numpy as np
import cv2

import matplotlib.pyplot as plt

# root_path = os.getcwd()

# model_path = root_path + '/models/seal_detect_best.pt'
model_path = 'D:/test11/yolov5/models/seal_detect_best.pt'


class Infer_main:
    def __init__(self, yolo_model_path):
        self.model_path = yolo_model_path
        self.seal_infer = self.yolo_model_load()

    def yolo_model_load(self):
        model = torch.hub.load('D:/test11/yolov5',
                               'custom',
                               path='D:/test11/yolov5/models/seal_detect_best.pt',
                               source='local',
                               force_reload=True)  # local repo

        model.conf = 0.4
        model.eval()
        return model

    def predict(self, images):
        result = {
            "base64": images,
            "debug": False  # debug模式将可视化各环节,否则只输出结果
        }
        # for item, i in enumerate(images):
        yolo_res = self.seal_infer(images)
        print(f'yolo_res:{yolo_res}')
        yolo_res_list = yolo_res.pandas().xyxy[0].values.tolist()
        if yolo_res_list:
            for i in range(len(yolo_res_list)):
                # print(yolo_res_list[i])
                yolo_res_liststr = json.dumps(yolo_res_list[i])
                # print(type(yolo_res_liststr))
                base64str = base64.b64encode(yolo_res_liststr.encode('utf-8'))
                # img_array = np.fromstring(base64str, np.uint8)
                # print(type(base64str))
                result['base64'] = base64str
                result_rec = work(result)
                print(result_rec)
                return result_rec


if __name__ == '__main__':
    from PIL import Image
#
# model_path = 'D:/test11/yolov5/models/seal_detect_best.pt'
# image = plt.imread('D:/test11/yolov5/8.png')
# image = cv2.imread('D:/test11/yolov5/8.png')
image = Image.open('D:/test11/yolov5/1.png')
a = Infer_main(model_path)
# print()
# a.seal_infer(image).show()
a.predict(image)
#     model = yolo_model_load(model_path)
# yolo_res = model(image, size=640)
#     yolo_res.show()
#     yolo_res_list = yolo_res.pandas().xyxy[0].values.tolist()
#     print(type(yolo_res_list))
#     print('印章识别结果:', yolo_res_list)

报错信息如下:


Traceback (most recent call last):
  File "D:\test11\yolov5\test1.py", line 66, in <module>
    a.predict(image)
  File "D:\test11\yolov5\test1.py", line 51, in predict
    result_rec = work(result)
  File "D:\test11\seal_bak\seal_recognition.py", line 116, in work
    category, img = model.predict(cfg['base64'])
  File "D:\test11\seal_bak\utils\Model.py", line 38, in predict
    Tensor, img = self.__img2tensor(ImgPath)
  File "D:\test11\seal_bak\utils\Model.py", line 58, in __img2tensor
    Img = Transfm(OImg)
  File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
    img = t(img)
  File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\transforms.py", line 135, in __call__
    return F.to_tensor(pic)
  File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\functional.py", line 137, in to_tensor
    raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
TypeError: pic should be PIL Image or ndarray. Got <class 'NoneType'>
Process finished with exit code 1
该报错为代码逻辑,数据类型不对或没有则返回这段

模型处理代码如下:

class Model():
    def __init__(self, choice=2):
        if choice == 0:
            model = network.SPP_Net1()
            Dict = SPPDict1
        elif choice == 1:
            model = network.SPP_Net2()
            Dict = SPPDict2
        else:
            model = network.SPP_Net3()
            Dict = SPPDict3
        if self.Device() == "cuda":
            model.load_state_dict(torch.load(Dict))
        else:
            model.load_state_dict(torch.load(Dict, map_location=torch.device('cpu')))
        # print("Model loaded!")
        self.model = model

    def Device(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        return device

    def predict(self, ImgPath):
        Tensor, img = self.__img2tensor(ImgPath)
        with torch.no_grad():
            pred = self.model(Tensor)
        result = self.__arg2label(pred.argmax())
        # print(result)
        return result, img

    #     def __img2tensor(self,ImgPath):
    #         OImg = cv2.imread(ImgPath) #origin图片
    #         Transfm = Compose([transforms.ToTensor()])
    #         Img = Transfm(OImg)
    #         Img = Img.unsqueeze(0)
    #         return Img,OImg
    def __img2tensor(self, base64_data):
        img_b64decode = base64.b64decode(base64_data)  # base64解码
        # 将base64转换numpy array
        img_array = np.fromstring(img_b64decode, np.uint8)
        # 转换cv2
        OImg = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)
        Transfm = Compose([transforms.ToTensor()])
        Img = Transfm(OImg)
        Img = Img.unsqueeze(0)
        return Img, OImg
    def __arg2label(self, arg):
        return LabelDict[str(arg.item())]
根据debug调试,是到了__img2tensor函数中出现的问题,在OImg = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)这步时 OImg得到的是None

模型逻辑代码和所需数据如下:

import re
import time
import traceback
import base64
from seal_bak.utils.ocr import ellipse_ocr, circle_ocr, rectangle_ocr
from seal_bak.utils.detect import *
from seal_bak.utils.Model import Model
import logging
import copy
import numpy as np

model = Model(0)

CODE_SCORE = 0  # 0.98 0.97
NAME_SCORE = 0  # 0.9


def get_result(res, category):
    en_or_num = re.compile(r"[a-zA-Z0-9]", re.I)  # [a-z]|\d
    not_en_nor_num = re.compile(r"[^a-zA-Z0-9]", re.I)  # 非英文数字

    null_text = {"text": "", "confidence": 0}
    result = {
        "code": null_text,  # 印章编号
        "name": null_text,  # 印章名称
        "SealType": "99",  # 印章类型代码
        "strSealType": "其他类型印章"  # 印章类型名称
    }
    seal_type = {
        "法定名称章": "01",
        "财务专用章": "02",
        # "发票专用章": "03",
        "合同专用章": "04",
        "未找到印章": "06",
        # "法定代表人名章": "05",
        "其他类型印章": "99"

    }

    #     logging.info(f"<{[str(item['text']) + '|' + str(item['confidence']) for item in res]}>")

    if category == "正方形":
        result["SealType"] = "05"
        result["strSealType"] = "法定代表人名章"
        result["code"] = {
            "text": res[0]["text"], "confidence": res[0]["confidence"]
        } if res[0]["confidence"] >= CODE_SCORE else null_text  # 设定编号
        result["name"] = {
            "text": res[1]["text"], "confidence": res[1]["confidence"]
        } if res[1]["confidence"] >= NAME_SCORE else null_text  # 设定name
    elif category == "椭圆":  # todo 调参
        result["SealType"] = "03"
        result["strSealType"] = "发票专用章"
        for item in res:
            text = item["text"]
            score = item["confidence"]
            from_ = item["from"]
            if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
                result["code"] = {
                    "text": text, "confidence": score
                } if score >= CODE_SCORE and from_ == "center" else null_text  # 设定编号
            else:
                result["name"] = {
                    "text": text, "confidence": score
                } if score >= NAME_SCORE and from_ == "side" else null_text  # 设定name
    elif len(res) == 2:
        result["SealType"] = "01"
        result["strSealType"] = "法定名称章"
        for item in res:
            text = item["text"]
            score = item["confidence"]
            from_ = item["from"]
            if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
                result["code"] = {
                    "text": text, "confidence": score
                } if score >= CODE_SCORE and from_ == "side" else null_text  # 设定编号
            else:
                result["name"] = {
                    "text": text, "confidence": score
                } if score >= NAME_SCORE and from_ == "side" else null_text  # 设定name
    elif category == False:
        result["SealType"] = "06"
        result["strSealType"] = "未找到印章"
    else:
        set_type = True
        for item in res:
            text = item["text"]
            score = item["confidence"]
            from_ = item["from"]
            if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
                result["code"] = {
                    "text": text, "confidence": score
                } if score >= CODE_SCORE and from_ == "side" else null_text  # 设定编号
            else:
                if set_type and from_ == "center":
                    # 设定type
                    if text not in seal_type.keys():
                        result["SealType"] = "99"
                        result["strSealType"] = "其他类型印章"  # "其他类型印章"
                    else:
                        result["SealType"] = seal_type[text]
                        result["strSealType"] = text
                    set_type = False
                else:
                    # 设定name
                    result["name"] = {
                        "text": text, "confidence": score
                    } if score >= NAME_SCORE and from_ == "side" else null_text  # 设定name
    return result


def work(cfg):
    t0 = time.time()
    # 分类
    #     category, img = model.predict(cfg['img_path'])
    category, img = model.predict(cfg['base64'])
    img = enlarge_img(img)
    img_ = np.copy(img)
    # 提取红色
    t1 = time.time()
    img_, img_bw = Lazyfilter(img_, cfg)
    # 开运算去噪填充
    t2 = time.time()
    img_ = erode_dilate(img_, category, cfg=cfg)
    # 查找最大轮廓
    t3 = time.time()
    contours, max_idx = find_max(img_)
    # 完成检测框
    t4 = time.time()
    det = get_area(img, contours, max_idx, category, cfg=cfg)  # 不进行分类
    # 截取目标区域  img:原图 img_bw:二值
    t5 = time.time()
    img = rotate_cut(img, det, cfg=cfg)
    # ============分类处理目标区域
    t6 = time.time()
    # if category == False:
    #     empty_ocr()
    try:
        if category == '圆形':
            res = circle_ocr(img, cfg=cfg)  # 使用二值图预测
        elif category == "正方形":
            res = rectangle_ocr(img, cfg=cfg)
        else:
            res = ellipse_ocr(img, cfg=cfg)
    except Exception:
        print("未找到印章")


    # 将文本列表处理成标准json格式
    try:
        result = get_result(res, category)
    except Exception:
        pass
    try:
        result_string = json.dumps(result, ensure_ascii=False)
    except Exception:
        pass
    if not cfg["debug"]:
        # with open(os.path.join(cfg["to_path"], cfg["file_name"] + ".json"), mode="w", encoding="utf-8") as f:
        #     json.dump(result, f, ensure_ascii=False)
        pass
    t7 = time.time()

    #     detail = f"{cfg['file_name']} done in {round(t7 - t0, 2)}s. " \
    #              f"class={det['class']}, text={result_string}"
    #     logging.info(detail)
    print(t7 - t1)
    try:
        return result_string
    except Exception:
        pass

Debug调试的结果如下(其余步骤正常我就省略了):

img

  • 写回答

6条回答 默认 最新

  • 爱晚乏客游 2023-04-13 13:28
    关注

    你说有问题的这句换成下面的看下

    OImg = cv2.imdecode(img_array,flags=cv2.IMREAD_COLOR)
    OImg =cv2.cvtColor(OImg,cv2.COLOR_BGR2RGB)
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(5条)

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 4月14日
  • 已采纳回答 4月14日
  • 修改了问题 4月13日
  • 修改了问题 4月13日
  • 展开全部

悬赏问题

  • ¥66 定制开发肯德基自动化网站下单软件
  • ¥20 vscode虚拟环境依赖包未安装
  • ¥15 odoo17关于owl开发js代码问题
  • ¥15 光纤中多普勒频移公式的推导
  • ¥15 怎么制作一个人脸识别门禁系统
  • ¥20 大华dss监控平台网络关闭登不进去
  • ¥15 请使用蚁群算法解决下列问题,并给出我完整的代码
  • ¥20 关于php录入完成后,批量更新数据库
  • ¥15 请教往复密封润滑问题
  • ¥15 cocos creator发布ios包