Source code for pororo.tasks.optical_character_recognition

"""OCR related modeling class"""

from typing import Optional

from pororo.tasks import download_or_load
from pororo.tasks.utils.base import PororoFactoryBase, PororoSimpleBase

[docs]class PororoOcrFactory(PororoFactoryBase): """ Recognize optical characters in image file Currently support Korean language English + Korean (`brainocr`) - dataset: Internal data + AI hub Font Image dataset - metric: TBU - ref: Examples: >>> ocr = Pororo(task="ocr", lang="ko") >>> ocr(IMAGE_PATH) ["사이렌'(' 신마'", "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"] >>> ocr = Pororo(task="ocr", lang="ko") >>> ocr(IMAGE_PATH, detail=True) { 'description': ["사이렌'(' 신마', "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"], 'bounding_poly': [ { 'description': "사이렌'(' 신마'", 'vertices': [ {'x': 93, 'y': 7}, {'x': 164, 'y': 7}, {'x': 164, 'y': 21}, {'x': 93, 'y': 21} ] }, { 'description': "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고", 'vertices': [ {'x': 0, 'y': 30}, {'x': 259, 'y': 30}, {'x': 259, 'y': 194}, {'x': 0, 'y': 194}]} ] } } """ def __init__(self, task: str, lang: str, model: Optional[str]): super().__init__(task, lang, model) self.detect_model = "craft" self.ocr_opt = "ocr-opt"
[docs] @staticmethod def get_available_langs(): return ["en", "ko"]
[docs] @staticmethod def get_available_models(): return { "en": ["brainocr"], "ko": ["brainocr"], }
[docs] def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ if self.config.n_model == "brainocr": from pororo.models.brainOCR import brainocr if self.config.lang not in self.get_available_langs(): raise ValueError( f"Unsupported Language : {self.config.lang}", 'Support Languages : ["en", "ko"]', ) det_model_path = download_or_load( f"misc/{self.detect_model}.pt", self.config.lang, ) rec_model_path = download_or_load( f"misc/{self.config.n_model}.pt", self.config.lang, ) opt_fp = download_or_load( f"misc/{self.ocr_opt}.txt", self.config.lang, ) model = brainocr.Reader( self.config.lang, det_model_ckpt_fp=det_model_path, rec_model_ckpt_fp=rec_model_path, opt_fp=opt_fp, device=device, ) return PororoOCR(model, self.config)
[docs]class PororoOCR(PororoSimpleBase): def __init__(self, model, config): super().__init__(config) self._model = model def _postprocess(self, ocr_results, detail: bool = False): """ Post-process for OCR result Args: ocr_results (list): list contains result of OCR detail (bool): if True, returned to include details. (bounding poly, vertices, etc) """ sorted_ocr_results = sorted( ocr_results, key=lambda x: ( x[0][0][1], x[0][0][0], ), ) if not detail: return [ sorted_ocr_results[i][-1] for i in range(len(sorted_ocr_results)) ] result_dict = { "description": list(), "bounding_poly": list(), } for ocr_result in sorted_ocr_results: vertices = list() for vertice in ocr_result[0]: vertices.append({ "x": vertice[0], "y": vertice[1], }) result_dict["description"].append(ocr_result[1]) result_dict["bounding_poly"].append({ "description": ocr_result[1], "vertices": vertices }) return result_dict
[docs] def predict(self, image_path: str, **kwargs): """ Conduct Optical Character Recognition (OCR) Args: image_path (str): the image file path detail (bool): if True, returned to include details. (bounding poly, vertices, etc) """ detail = kwargs.get("detail", False) return self._postprocess( self._model( image_path, skip_details=False, batch_size=1, paragraph=True, ), detail, )