RicardoM.Lu1 2023-03-08 09:40 采纳率: 82.1%
浏览 43
已结题

OCR的学习,在探索代码的时候遇到问题

在学习OCR识别时,我拿到一串代码,但是看不出来要怎么让它跑起来拿到结果(只是一个类,原本是通过接口访问的代码)

import re

import numpy as np
from PIL import Image

from utils import order_point, crop_image, calc_distance


class BusinessCerts:

    def __init__(self, img='', dbnet_det='', ocr_recognition=''):
        self.img = img
        self.dbnet_det = dbnet_det
        self.ocr_recognition = ocr_recognition
        self.det_result = []
        self.box_result = []
        self.left_uppers = []
        self.line_space = -1
        self.line_height = -1

        self.box_detect()
        self.ocr_recognize()

        self.left_uppers = np.array([[point['pts'][0], point['pts'][1]] for point in self.box_result])
        self.left_lowers = np.array([[point['pts'][6], point['pts'][7]] for point in self.box_result])
        self.results = {'名称': '', '统一社会信用代码': '', '住所': ''}

        self.ocr_pipline()

    # DBNet文字检测
    def box_detect(self):
        det_result = self.dbnet_det.predict(self.img)
        det_result = det_result.reshape(-1, 8)
        self.det_result = det_result[det_result[:, 1].argsort()]
        return det_result[det_result[:, 1].argsort()]

    # 识别文字
    def ocr_recognize(self):
        box_result = []

        for i in range(self.det_result.shape[0]):
            box_dict = {}
            pts = order_point(self.det_result[i])
            image_crop = crop_image(self.img, pts)
            image_crop = Image.fromarray(image_crop)
            result = self.ocr_recognition.rec(image_crop)
            pts = pts.reshape(-1).astype(int)
            box_dict['pts'] = pts
            box_dict['text'] = result
            box_result.append(box_dict)

        self.box_result = box_result

    # 根据标签索引,通过计算距离确定对应值
    def get_item_value(self, key_idx):
        curr_pts = self.box_result[key_idx]['pts']
        curr_right_upper = np.array([curr_pts[2], curr_pts[3]])

        dists = calc_distance(curr_right_upper, self.left_uppers)
        dists[key_idx] = float('inf')

        value_idx = np.argmin(dists)
        return value_idx, self.box_result[value_idx]['text']

    # 获取名称对应Box索引
    def get_name_index(self):
        index = -1
        for i, temp in enumerate(self.box_result):
            if '名称' == temp['text'] or '称' == temp['text']:
                index = i
        return index

    def get_credit_code_index(self):
        index = -1
        for i, temp in enumerate(self.box_result):
            if '统一社会信用代码' == temp['text'] or '信用代码' == temp['text'] or '代码' == temp['text']:
                index = i
        return index

    def get_address_index(self):
        index = -1
        for i, temp in enumerate(self.box_result):
            if '住所' == temp['text'] or '经营场所' == temp['text']:
                index = i
        return index

    def ocr_pipline(self):
        address_index = self.get_address_index()
        if address_index != -1:
            address_value_idx, address_result = self.get_item_value(address_index)

            address_value_pts = self.box_result[address_value_idx]['pts']
            self.line_height = address_value_pts[7] - address_value_pts[1]

            address_value_left_lower = np.array([address_value_pts[6], address_value_pts[7]])
            # 计算两行之间的距离,住所Box左下角点和其他Box左上角点的距离
            line_space_dists = calc_distance(address_value_left_lower, self.left_uppers)
            line_space_dists[address_value_idx] = float('inf')

            # 计算行距
            left_lowers = [[point['pts'][6], point['pts'][7]] for point in self.box_result]
            left_upper = np.array([address_value_pts[0], address_value_pts[1]])
            line_height_dists = calc_distance(left_upper, left_lowers)
            line_height_dists[address_value_idx] = float('inf')
            self.line_space = min(line_height_dists)

            # 如果最小行距小于正常行距的 1/2 ,认为有换行
            if min(line_space_dists) < (self.line_space / 2):
                min_idx = np.argmin(line_space_dists)
                next_line = self.box_result[min_idx]['text']
                address_result += next_line

            self.results['住所'] = address_result

        name_index = self.get_name_index()

        if name_index != -1:
            name_value_idx, name_result = self.get_item_value(name_index)

            name_value_pts = self.box_result[name_value_idx]['pts']
            name_value_left_lower = np.array([name_value_pts[6], name_value_pts[7]])

            # 计算两行之间的距离,名称Box左下角点和其他Box左上角点的距离
            line_space_dists = calc_distance(name_value_left_lower, self.left_uppers)
            line_space_dists[name_value_idx] = float('inf')

            # 如果最小行距小于正常行距的 1/2 ,认为有换行
            if min(line_space_dists) < (self.line_space / 2):
                min_idx = np.argmin(line_space_dists)
                next_line = self.box_result[min_idx]['text']
                address_result += next_line

            self.results['名称'] = name_result
        credit_code_index = self.get_credit_code_index()
        if credit_code_index != -1:
            _, credit_code_result = self.get_item_value(credit_code_index)
            re_result = re.findall(r'[0-9A-Z]{15,}', credit_code_result)
            if re_result:
                self.results['统一社会信用代码'] = re_result[0]
            else:
                for box in self.box_result:
                    re_result = re.findall(r'[0-9A-Z]{15,}', box['text'])
                    if re_result:
                        self.results['统一社会信用代码'] = re_result[0]
                        break
if __name__ == "__main__":
    

我的IF后要怎么写才能单独把这串代码运行起来?

我想到达成的目的就是把这个类给它实例化,然后运行拿到结果

  • 写回答

5条回答 默认 最新

  • CSDN专家-sinJack 2023-03-08 09:48
    关注

    new BusinessCerts()对象,并传递这三个参数

    if __name__ == "__main__":
      img=''
      ocr_recognition = OCRRecognition()
      dbnet_det = DbNetInfer()
      certs = BusinessCerts(img,dbnet_det,ocr_recognition)
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(4条)

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 3月9日
  • 已采纳回答 3月8日
  • 创建了问题 3月8日

悬赏问题

  • ¥20 java在应用程序里获取不到扬声器设备
  • ¥15 echarts动画效果的问题,请帮我添加一个动画。不要机器人回答。
  • ¥60 许可证msc licensing软件报错显示已有相同版本软件,但是下一步显示无法读取日志目录。
  • ¥15 Attention is all you need 的代码运行
  • ¥15 一个服务器已经有一个系统了如果用usb再装一个系统,原来的系统会被覆盖掉吗
  • ¥15 使用esm_msa1_t12_100M_UR50S蛋白质语言模型进行零样本预测时,终端显示出了sequence handled的进度条,但是并不出结果就自动终止回到命令提示行了是怎么回事:
  • ¥15 前置放大电路与功率放大电路相连放大倍数出现问题
  • ¥30 关于<main>标签页面跳转的问题
  • ¥80 部署运行web自动化项目
  • ¥15 腾讯云如何建立同一个项目中物模型之间的联系