import os import re from pathlib import Path import numpy as np class BaseRecLabelDecode(object): """Convert between text-label and text-index""" def __init__(self, character_dict_path=None, use_space_char=False): cur_path = os.getcwd() scriptDir = Path(__file__).resolve().parent os.chdir(scriptDir) character_dict_path = str(Path(scriptDir / "ppocr_keys_v1.txt")) self.beg_str = "sos" self.end_str = "eos" self.reverse = False self.character_str = [] if character_dict_path is None: self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) else: with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode("utf-8").strip("\n").strip("\r\n") self.character_str.append(line) if use_space_char: self.character_str.append(" ") dict_character = list(self.character_str) if "arabic" in character_dict_path: self.reverse = True dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character os.chdir(cur_path) def pred_reverse(self, pred): pred_re = [] c_current = "" for c in pred: if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)): if c_current != "": pred_re.append(c_current) pred_re.append(c) c_current = "" else: c_current += c if c_current != "": pred_re.append(c_current) return "".join(pred_re[::-1]) def add_special_char(self, dict_character): return dict_character def get_word_info(self, text, selection): """ Group the decoded characters and record the corresponding decoded positions. Args: text: the decoded text selection: the bool array that identifies which columns of features are decoded as non-separated characters Returns: word_list: list of the grouped words word_col_list: list of decoding positions corresponding to each character in the grouped word state_list: list of marker to identify the type of grouping words, including two types of grouping words: - 'cn': continous chinese characters (e.g., 你好啊) - 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16) The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.). """ state = None word_content = [] word_col_content = [] word_list = [] word_col_list = [] state_list = [] valid_col = np.where(selection)[0] for c_i, char in enumerate(text): if "\u4e00" <= char <= "\u9fff": c_state = "cn" elif bool(re.search("[a-zA-Z0-9]", char)): c_state = "en&num" else: c_state = "splitter" if ( char == "." and state == "en&num" and c_i + 1 < len(text) and bool(re.search("[0-9]", text[c_i + 1])) ): # grouping floting number c_state = "en&num" if ( char == "-" and state == "en&num" ): # grouping word with '-', such as 'state-of-the-art' c_state = "en&num" if state is None: state = c_state if state != c_state: if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) state_list.append(state) word_content = [] word_col_content = [] state = c_state if state != "splitter": word_content.append(char) word_col_content.append(valid_col[c_i]) if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) state_list.append(state) return word_list, word_col_list, state_list def decode( self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False, ): """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]] if text_prob is not None: conf_list = text_prob[batch_idx][selection] else: conf_list = [1] * len(selection) if len(conf_list) == 0: conf_list = [0] text = "".join(char_list) if self.reverse: # for arabic rec text = self.pred_reverse(text) if return_word_box: word_list, word_col_list, state_list = self.get_word_info(text, selection) result_list.append( ( text, np.mean(conf_list).tolist(), [ len(text_index[batch_idx]), word_list, word_col_list, state_list, ], ) ) else: result_list.append((text, np.mean(conf_list).tolist())) return result_list def get_ignored_tokens(self): return [0] # for ctc blank class CTCLabelDecode(BaseRecLabelDecode): """Convert between text-label and text-index""" def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char) def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs): if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] assert isinstance(preds, np.ndarray) preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode( preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box, ) if return_word_box: for rec_idx, rec in enumerate(text): wh_ratio = kwargs["wh_ratio_list"][rec_idx] max_wh_ratio = kwargs["max_wh_ratio"] rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) if label is None: return text label = self.decode(label) return text, label def add_special_char(self, dict_character): dict_character = ["blank"] + dict_character return dict_character