在给模型传入数据时出现问题(经过调试,数据已经传入了模型处理,但是到了某一步的时候变成了None):
代码如下:
import torch
import json
import base64
from seal_bak.seal_recognition import work
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
# root_path = os.getcwd()
# model_path = root_path + '/models/seal_detect_best.pt'
model_path = 'D:/test11/yolov5/models/seal_detect_best.pt'
class Infer_main:
def __init__(self, yolo_model_path):
self.model_path = yolo_model_path
self.seal_infer = self.yolo_model_load()
def yolo_model_load(self):
model = torch.hub.load('D:/test11/yolov5',
'custom',
path='D:/test11/yolov5/models/seal_detect_best.pt',
source='local',
force_reload=True) # local repo
model.conf = 0.4
model.eval()
return model
def predict(self, images):
result = {
"base64": images,
"debug": False # debug模式将可视化各环节,否则只输出结果
}
# for item, i in enumerate(images):
yolo_res = self.seal_infer(images)
print(f'yolo_res:{yolo_res}')
yolo_res_list = yolo_res.pandas().xyxy[0].values.tolist()
if yolo_res_list:
for i in range(len(yolo_res_list)):
# print(yolo_res_list[i])
yolo_res_liststr = json.dumps(yolo_res_list[i])
# print(type(yolo_res_liststr))
base64str = base64.b64encode(yolo_res_liststr.encode('utf-8'))
# img_array = np.fromstring(base64str, np.uint8)
# print(type(base64str))
result['base64'] = base64str
result_rec = work(result)
print(result_rec)
return result_rec
if __name__ == '__main__':
from PIL import Image
#
# model_path = 'D:/test11/yolov5/models/seal_detect_best.pt'
# image = plt.imread('D:/test11/yolov5/8.png')
# image = cv2.imread('D:/test11/yolov5/8.png')
image = Image.open('D:/test11/yolov5/1.png')
a = Infer_main(model_path)
# print()
# a.seal_infer(image).show()
a.predict(image)
# model = yolo_model_load(model_path)
# yolo_res = model(image, size=640)
# yolo_res.show()
# yolo_res_list = yolo_res.pandas().xyxy[0].values.tolist()
# print(type(yolo_res_list))
# print('印章识别结果:', yolo_res_list)
报错信息如下:
Traceback (most recent call last):
File "D:\test11\yolov5\test1.py", line 66, in <module>
a.predict(image)
File "D:\test11\yolov5\test1.py", line 51, in predict
result_rec = work(result)
File "D:\test11\seal_bak\seal_recognition.py", line 116, in work
category, img = model.predict(cfg['base64'])
File "D:\test11\seal_bak\utils\Model.py", line 38, in predict
Tensor, img = self.__img2tensor(ImgPath)
File "D:\test11\seal_bak\utils\Model.py", line 58, in __img2tensor
Img = Transfm(OImg)
File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
img = t(img)
File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\transforms.py", line 135, in __call__
return F.to_tensor(pic)
File "D:\Anaconda3\envs\cj-env\lib\site-packages\torchvision\transforms\functional.py", line 137, in to_tensor
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
TypeError: pic should be PIL Image or ndarray. Got <class 'NoneType'>
Process finished with exit code 1
该报错为代码逻辑,数据类型不对或没有则返回这段
模型处理代码如下:
class Model():
def __init__(self, choice=2):
if choice == 0:
model = network.SPP_Net1()
Dict = SPPDict1
elif choice == 1:
model = network.SPP_Net2()
Dict = SPPDict2
else:
model = network.SPP_Net3()
Dict = SPPDict3
if self.Device() == "cuda":
model.load_state_dict(torch.load(Dict))
else:
model.load_state_dict(torch.load(Dict, map_location=torch.device('cpu')))
# print("Model loaded!")
self.model = model
def Device(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
return device
def predict(self, ImgPath):
Tensor, img = self.__img2tensor(ImgPath)
with torch.no_grad():
pred = self.model(Tensor)
result = self.__arg2label(pred.argmax())
# print(result)
return result, img
# def __img2tensor(self,ImgPath):
# OImg = cv2.imread(ImgPath) #origin图片
# Transfm = Compose([transforms.ToTensor()])
# Img = Transfm(OImg)
# Img = Img.unsqueeze(0)
# return Img,OImg
def __img2tensor(self, base64_data):
img_b64decode = base64.b64decode(base64_data) # base64解码
# 将base64转换numpy array
img_array = np.fromstring(img_b64decode, np.uint8)
# 转换cv2
OImg = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)
Transfm = Compose([transforms.ToTensor()])
Img = Transfm(OImg)
Img = Img.unsqueeze(0)
return Img, OImg
def __arg2label(self, arg):
return LabelDict[str(arg.item())]
根据debug调试,是到了__img2tensor函数中出现的问题,在OImg = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)这步时 OImg得到的是None
模型逻辑代码和所需数据如下:
import re
import time
import traceback
import base64
from seal_bak.utils.ocr import ellipse_ocr, circle_ocr, rectangle_ocr
from seal_bak.utils.detect import *
from seal_bak.utils.Model import Model
import logging
import copy
import numpy as np
model = Model(0)
CODE_SCORE = 0 # 0.98 0.97
NAME_SCORE = 0 # 0.9
def get_result(res, category):
en_or_num = re.compile(r"[a-zA-Z0-9]", re.I) # [a-z]|\d
not_en_nor_num = re.compile(r"[^a-zA-Z0-9]", re.I) # 非英文数字
null_text = {"text": "", "confidence": 0}
result = {
"code": null_text, # 印章编号
"name": null_text, # 印章名称
"SealType": "99", # 印章类型代码
"strSealType": "其他类型印章" # 印章类型名称
}
seal_type = {
"法定名称章": "01",
"财务专用章": "02",
# "发票专用章": "03",
"合同专用章": "04",
"未找到印章": "06",
# "法定代表人名章": "05",
"其他类型印章": "99"
}
# logging.info(f"<{[str(item['text']) + '|' + str(item['confidence']) for item in res]}>")
if category == "正方形":
result["SealType"] = "05"
result["strSealType"] = "法定代表人名章"
result["code"] = {
"text": res[0]["text"], "confidence": res[0]["confidence"]
} if res[0]["confidence"] >= CODE_SCORE else null_text # 设定编号
result["name"] = {
"text": res[1]["text"], "confidence": res[1]["confidence"]
} if res[1]["confidence"] >= NAME_SCORE else null_text # 设定name
elif category == "椭圆": # todo 调参
result["SealType"] = "03"
result["strSealType"] = "发票专用章"
for item in res:
text = item["text"]
score = item["confidence"]
from_ = item["from"]
if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
result["code"] = {
"text": text, "confidence": score
} if score >= CODE_SCORE and from_ == "center" else null_text # 设定编号
else:
result["name"] = {
"text": text, "confidence": score
} if score >= NAME_SCORE and from_ == "side" else null_text # 设定name
elif len(res) == 2:
result["SealType"] = "01"
result["strSealType"] = "法定名称章"
for item in res:
text = item["text"]
score = item["confidence"]
from_ = item["from"]
if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
result["code"] = {
"text": text, "confidence": score
} if score >= CODE_SCORE and from_ == "side" else null_text # 设定编号
else:
result["name"] = {
"text": text, "confidence": score
} if score >= NAME_SCORE and from_ == "side" else null_text # 设定name
elif category == False:
result["SealType"] = "06"
result["strSealType"] = "未找到印章"
else:
set_type = True
for item in res:
text = item["text"]
score = item["confidence"]
from_ = item["from"]
if len(en_or_num.findall(text)) > len(not_en_nor_num.findall(text)):
result["code"] = {
"text": text, "confidence": score
} if score >= CODE_SCORE and from_ == "side" else null_text # 设定编号
else:
if set_type and from_ == "center":
# 设定type
if text not in seal_type.keys():
result["SealType"] = "99"
result["strSealType"] = "其他类型印章" # "其他类型印章"
else:
result["SealType"] = seal_type[text]
result["strSealType"] = text
set_type = False
else:
# 设定name
result["name"] = {
"text": text, "confidence": score
} if score >= NAME_SCORE and from_ == "side" else null_text # 设定name
return result
def work(cfg):
t0 = time.time()
# 分类
# category, img = model.predict(cfg['img_path'])
category, img = model.predict(cfg['base64'])
img = enlarge_img(img)
img_ = np.copy(img)
# 提取红色
t1 = time.time()
img_, img_bw = Lazyfilter(img_, cfg)
# 开运算去噪填充
t2 = time.time()
img_ = erode_dilate(img_, category, cfg=cfg)
# 查找最大轮廓
t3 = time.time()
contours, max_idx = find_max(img_)
# 完成检测框
t4 = time.time()
det = get_area(img, contours, max_idx, category, cfg=cfg) # 不进行分类
# 截取目标区域 img:原图 img_bw:二值
t5 = time.time()
img = rotate_cut(img, det, cfg=cfg)
# ============分类处理目标区域
t6 = time.time()
# if category == False:
# empty_ocr()
try:
if category == '圆形':
res = circle_ocr(img, cfg=cfg) # 使用二值图预测
elif category == "正方形":
res = rectangle_ocr(img, cfg=cfg)
else:
res = ellipse_ocr(img, cfg=cfg)
except Exception:
print("未找到印章")
# 将文本列表处理成标准json格式
try:
result = get_result(res, category)
except Exception:
pass
try:
result_string = json.dumps(result, ensure_ascii=False)
except Exception:
pass
if not cfg["debug"]:
# with open(os.path.join(cfg["to_path"], cfg["file_name"] + ".json"), mode="w", encoding="utf-8") as f:
# json.dump(result, f, ensure_ascii=False)
pass
t7 = time.time()
# detail = f"{cfg['file_name']} done in {round(t7 - t0, 2)}s. " \
# f"class={det['class']}, text={result_string}"
# logging.info(detail)
print(t7 - t1)
try:
return result_string
except Exception:
pass
Debug调试的结果如下(其余步骤正常我就省略了):