213 lines
7.5 KiB
Python
213 lines
7.5 KiB
Python
import os
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
|
|
class BaseRecLabelDecode(object):
|
|
"""Convert between text-label and text-index"""
|
|
|
|
def __init__(self, character_dict_path=None, use_space_char=False):
|
|
cur_path = os.getcwd()
|
|
scriptDir = Path(__file__).resolve().parent
|
|
os.chdir(scriptDir)
|
|
character_dict_path = str(Path(scriptDir / "ppocr_keys_v1.txt"))
|
|
|
|
self.beg_str = "sos"
|
|
self.end_str = "eos"
|
|
self.reverse = False
|
|
self.character_str = []
|
|
|
|
if character_dict_path is None:
|
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
dict_character = list(self.character_str)
|
|
else:
|
|
with open(character_dict_path, "rb") as fin:
|
|
lines = fin.readlines()
|
|
for line in lines:
|
|
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
|
self.character_str.append(line)
|
|
if use_space_char:
|
|
self.character_str.append(" ")
|
|
dict_character = list(self.character_str)
|
|
if "arabic" in character_dict_path:
|
|
self.reverse = True
|
|
|
|
dict_character = self.add_special_char(dict_character)
|
|
self.dict = {}
|
|
for i, char in enumerate(dict_character):
|
|
self.dict[char] = i
|
|
self.character = dict_character
|
|
os.chdir(cur_path)
|
|
|
|
def pred_reverse(self, pred):
|
|
pred_re = []
|
|
c_current = ""
|
|
for c in pred:
|
|
if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
|
|
if c_current != "":
|
|
pred_re.append(c_current)
|
|
pred_re.append(c)
|
|
c_current = ""
|
|
else:
|
|
c_current += c
|
|
if c_current != "":
|
|
pred_re.append(c_current)
|
|
|
|
return "".join(pred_re[::-1])
|
|
|
|
def add_special_char(self, dict_character):
|
|
return dict_character
|
|
|
|
def get_word_info(self, text, selection):
|
|
"""
|
|
Group the decoded characters and record the corresponding decoded positions.
|
|
|
|
Args:
|
|
text: the decoded text
|
|
selection: the bool array that identifies which columns of features are decoded as non-separated characters
|
|
Returns:
|
|
word_list: list of the grouped words
|
|
word_col_list: list of decoding positions corresponding to each character in the grouped word
|
|
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
|
|
- 'cn': continous chinese characters (e.g., 你好啊)
|
|
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
|
|
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
|
|
"""
|
|
state = None
|
|
word_content = []
|
|
word_col_content = []
|
|
word_list = []
|
|
word_col_list = []
|
|
state_list = []
|
|
valid_col = np.where(selection)[0]
|
|
|
|
for c_i, char in enumerate(text):
|
|
if "\u4e00" <= char <= "\u9fff":
|
|
c_state = "cn"
|
|
elif bool(re.search("[a-zA-Z0-9]", char)):
|
|
c_state = "en&num"
|
|
else:
|
|
c_state = "splitter"
|
|
|
|
if (
|
|
char == "."
|
|
and state == "en&num"
|
|
and c_i + 1 < len(text)
|
|
and bool(re.search("[0-9]", text[c_i + 1]))
|
|
): # grouping floting number
|
|
c_state = "en&num"
|
|
if (
|
|
char == "-" and state == "en&num"
|
|
): # grouping word with '-', such as 'state-of-the-art'
|
|
c_state = "en&num"
|
|
|
|
if state is None:
|
|
state = c_state
|
|
|
|
if state != c_state:
|
|
if len(word_content) != 0:
|
|
word_list.append(word_content)
|
|
word_col_list.append(word_col_content)
|
|
state_list.append(state)
|
|
word_content = []
|
|
word_col_content = []
|
|
state = c_state
|
|
|
|
if state != "splitter":
|
|
word_content.append(char)
|
|
word_col_content.append(valid_col[c_i])
|
|
|
|
if len(word_content) != 0:
|
|
word_list.append(word_content)
|
|
word_col_list.append(word_col_content)
|
|
state_list.append(state)
|
|
|
|
return word_list, word_col_list, state_list
|
|
|
|
def decode(
|
|
self,
|
|
text_index,
|
|
text_prob=None,
|
|
is_remove_duplicate=False,
|
|
return_word_box=False,
|
|
):
|
|
"""convert text-index into text-label."""
|
|
result_list = []
|
|
ignored_tokens = self.get_ignored_tokens()
|
|
batch_size = len(text_index)
|
|
for batch_idx in range(batch_size):
|
|
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
|
if is_remove_duplicate:
|
|
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
|
|
for ignored_token in ignored_tokens:
|
|
selection &= text_index[batch_idx] != ignored_token
|
|
|
|
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
|
|
if text_prob is not None:
|
|
conf_list = text_prob[batch_idx][selection]
|
|
else:
|
|
conf_list = [1] * len(selection)
|
|
if len(conf_list) == 0:
|
|
conf_list = [0]
|
|
|
|
text = "".join(char_list)
|
|
|
|
if self.reverse: # for arabic rec
|
|
text = self.pred_reverse(text)
|
|
|
|
if return_word_box:
|
|
word_list, word_col_list, state_list = self.get_word_info(text, selection)
|
|
result_list.append(
|
|
(
|
|
text,
|
|
np.mean(conf_list).tolist(),
|
|
[
|
|
len(text_index[batch_idx]),
|
|
word_list,
|
|
word_col_list,
|
|
state_list,
|
|
],
|
|
)
|
|
)
|
|
else:
|
|
result_list.append((text, np.mean(conf_list).tolist()))
|
|
return result_list
|
|
|
|
def get_ignored_tokens(self):
|
|
return [0] # for ctc blank
|
|
|
|
|
|
class CTCLabelDecode(BaseRecLabelDecode):
|
|
"""Convert between text-label and text-index"""
|
|
|
|
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
|
|
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
|
|
if isinstance(preds, tuple) or isinstance(preds, list):
|
|
preds = preds[-1]
|
|
assert isinstance(preds, np.ndarray)
|
|
preds_idx = preds.argmax(axis=2)
|
|
preds_prob = preds.max(axis=2)
|
|
text = self.decode(
|
|
preds_idx,
|
|
preds_prob,
|
|
is_remove_duplicate=True,
|
|
return_word_box=return_word_box,
|
|
)
|
|
if return_word_box:
|
|
for rec_idx, rec in enumerate(text):
|
|
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
|
|
max_wh_ratio = kwargs["max_wh_ratio"]
|
|
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
|
|
if label is None:
|
|
return text
|
|
label = self.decode(label)
|
|
return text, label
|
|
|
|
def add_special_char(self, dict_character):
|
|
dict_character = ["blank"] + dict_character
|
|
return dict_character
|