[chore] exclude paddleocr directory from pre-commit hooks
43
texteller/models/ocr_model/model/TexTeller.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
|
||||
|
||||
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
REPO_NAME = 'OleehyO/TexTeller'
|
||||
|
||||
def __init__(self):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(
|
||||
Path(__file__).resolve().parent / "config.json"
|
||||
)
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||
|
||||
super().__init__(config=config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
||||
if model_path is None or model_path == 'default':
|
||||
if not use_onnx:
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
else:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
use_gpu = True if onnx_provider == 'cuda' else False
|
||||
return ORTModelForVision2Seq.from_pretrained(
|
||||
cls.REPO_NAME,
|
||||
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
|
||||
)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_path is None or tokenizer_path == 'default':
|
||||
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
168
texteller/models/ocr_model/model/config.json
Normal file
@@ -0,0 +1,168 @@
|
||||
{
|
||||
"_name_or_path": "OleehyO/TexTeller",
|
||||
"architectures": [
|
||||
"VisionEncoderDecoderModel"
|
||||
],
|
||||
"decoder": {
|
||||
"_name_or_path": "",
|
||||
"activation_dropout": 0.0,
|
||||
"activation_function": "gelu",
|
||||
"add_cross_attention": true,
|
||||
"architectures": null,
|
||||
"attention_dropout": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"begin_suppress_tokens": null,
|
||||
"bos_token_id": 0,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"classifier_dropout": 0.0,
|
||||
"cross_attention_hidden_size": 768,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_ffn_dim": 4096,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 12,
|
||||
"decoder_start_token_id": 2,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"dropout": 0.1,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"eos_token_id": 2,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_decoder": true,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layernorm_embedding": true,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"max_position_embeddings": 1024,
|
||||
"min_length": 0,
|
||||
"model_type": "trocr",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": 1,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"scale_embedding": false,
|
||||
"sep_token_id": null,
|
||||
"suppress_tokens": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tf_legacy_loss": false,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false,
|
||||
"use_cache": false,
|
||||
"use_learned_position_embeddings": true,
|
||||
"vocab_size": 15000
|
||||
},
|
||||
"encoder": {
|
||||
"_name_or_path": "",
|
||||
"add_cross_attention": false,
|
||||
"architectures": null,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"begin_suppress_tokens": null,
|
||||
"bos_token_id": null,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"cross_attention_hidden_size": null,
|
||||
"decoder_start_token_id": null,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_stride": 16,
|
||||
"eos_token_id": null,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"image_size": 448,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"is_decoder": false,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"min_length": 0,
|
||||
"model_type": "vit",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_attention_heads": 12,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_channels": 1,
|
||||
"num_hidden_layers": 12,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": null,
|
||||
"patch_size": 16,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"qkv_bias": false,
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"sep_token_id": null,
|
||||
"suppress_tokens": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tf_legacy_loss": false,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false
|
||||
},
|
||||
"is_encoder_decoder": true,
|
||||
"model_type": "vision-encoder-decoder",
|
||||
"tie_word_embeddings": false,
|
||||
"transformers_version": "4.41.2",
|
||||
"use_cache": true
|
||||
}
|
||||
BIN
texteller/models/ocr_model/train/augraphy_cache/image_0.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_1.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_10.png
Normal file
|
After Width: | Height: | Size: 4.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_11.png
Normal file
|
After Width: | Height: | Size: 8.5 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_12.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_13.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_14.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_15.png
Normal file
|
After Width: | Height: | Size: 7.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_16.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_17.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_18.png
Normal file
|
After Width: | Height: | Size: 5.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_19.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_2.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_20.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_21.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_22.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_23.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_24.png
Normal file
|
After Width: | Height: | Size: 10 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_25.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_26.png
Normal file
|
After Width: | Height: | Size: 9.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_27.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_28.png
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_29.png
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_4.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_6.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_7.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_8.png
Normal file
|
After Width: | Height: | Size: 6.0 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_9.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/0.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/1.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/10.png
Normal file
|
After Width: | Height: | Size: 6.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/11.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/12.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/13.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/14.png
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/15.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/16.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/17.png
Normal file
|
After Width: | Height: | Size: 2.6 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/18.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/19.png
Normal file
|
After Width: | Height: | Size: 2.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/2.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/20.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/21.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/22.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/23.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/24.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/25.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/26.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/27.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/28.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/29.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/30.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/31.png
Normal file
|
After Width: | Height: | Size: 4.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/32.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/33.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/34.png
Normal file
|
After Width: | Height: | Size: 3.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/4.png
Normal file
|
After Width: | Height: | Size: 5.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/6.png
Normal file
|
After Width: | Height: | Size: 4.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/7.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/8.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/9.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
@@ -0,0 +1,35 @@
|
||||
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
114
texteller/models/ocr_model/train/train.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import (
|
||||
tokenize_fn,
|
||||
collate_fn,
|
||||
img_train_transform,
|
||||
img_inf_transform,
|
||||
filter_fn,
|
||||
)
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eval_config = CONFIG.copy()
|
||||
eval_config['predict_with_generate'] = True
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
eval_config['generation_config'] = generate_config
|
||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model,
|
||||
seq2seq_config,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
eval_res = trainer.evaluate()
|
||||
print(eval_res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
|
||||
dataset = dataset.filter(
|
||||
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
|
||||
)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(
|
||||
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||
)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
model = TexTeller()
|
||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
# +e.g.
|
||||
# +model = TexTeller.from_pretrained(
|
||||
# + '/path/to/your/model_checkpoint'
|
||||
# +)
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate and len(eval_dataset) > 0:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
31
texteller/models/ocr_model/train/training_args.py
Normal file
@@ -0,0 +1,31 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||
60
texteller/models/ocr_model/utils/functional.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from .transforms import train_transform, inference_transform
|
||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||
tokenized_formula['pixel_values'] = samples['image']
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch['pixel_values'] = pixel_values
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = inference_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT
|
||||
and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
26
texteller/models/ocr_model/utils/helpers.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
processed_images = []
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
if image.dtype == np.uint16:
|
||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
||||
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
49
texteller/models/ocr_model/utils/inference.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List, Union
|
||||
|
||||
from .transforms import inference_transform
|
||||
from .helpers import convert2rgb
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ...globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs: Union[List[str], List[np.ndarray]],
|
||||
accelerator: str = 'cpu',
|
||||
num_beams: int = 1,
|
||||
max_tokens=None,
|
||||
) -> List[str]:
|
||||
if imgs == []:
|
||||
return []
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, turn model.eval()
|
||||
model.eval()
|
||||
if isinstance(imgs[0], str):
|
||||
imgs = convert2rgb(imgs)
|
||||
else: # already numpy array(rgb format)
|
||||
assert isinstance(imgs[0], np.ndarray)
|
||||
imgs = imgs
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, move weights to device
|
||||
model = model.to(accelerator)
|
||||
pixel_values = pixel_values.to(accelerator)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values.to(model.device), generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
25
texteller/models/ocr_model/utils/metrics.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from transformers import EvalPrediction, RobertaTokenizer
|
||||
|
||||
|
||||
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
|
||||
cur_dir = Path(os.getcwd())
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
metric = evaluate.load(
|
||||
'google_bleu'
|
||||
) # Will download the metric from huggingface if not already downloaded
|
||||
os.chdir(cur_dir)
|
||||
|
||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
preds = logits
|
||||
|
||||
labels = np.where(labels == -100, 1, labels)
|
||||
|
||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
152
texteller/models/ocr_model/utils/ocr_aug.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from augraphy import *
|
||||
import random
|
||||
|
||||
|
||||
def ocr_augmentation_pipeline():
|
||||
pre_phase = []
|
||||
|
||||
ink_phase = [
|
||||
InkColorSwap(
|
||||
ink_swap_color="random",
|
||||
ink_swap_sequence_number_range=(5, 10),
|
||||
ink_swap_min_width_range=(2, 3),
|
||||
ink_swap_max_width_range=(100, 120),
|
||||
ink_swap_min_height_range=(2, 3),
|
||||
ink_swap_max_height_range=(100, 120),
|
||||
ink_swap_min_area_range=(10, 20),
|
||||
ink_swap_max_area_range=(400, 500),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
LinesDegradation(
|
||||
line_roi=(0.0, 0.0, 1.0, 1.0),
|
||||
line_gradient_range=(32, 255),
|
||||
line_gradient_direction=(0, 2),
|
||||
line_split_probability=(0.2, 0.4),
|
||||
line_replacement_value=(250, 255),
|
||||
line_min_length=(30, 40),
|
||||
line_long_to_short_ratio=(5, 7),
|
||||
line_replacement_probability=(0.4, 0.5),
|
||||
line_replacement_thickness=(1, 3),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# ============================
|
||||
OneOf(
|
||||
[
|
||||
Dithering(
|
||||
dither="floyd-steinberg",
|
||||
order=(3, 5),
|
||||
),
|
||||
InkBleed(
|
||||
intensity_range=(0.1, 0.2),
|
||||
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
|
||||
severity=(0.4, 0.6),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# ============================
|
||||
# ============================
|
||||
InkShifter(
|
||||
text_shift_scale_range=(18, 27),
|
||||
text_shift_factor_range=(1, 4),
|
||||
text_fade_range=(0, 2),
|
||||
blur_kernel_size=(5, 5),
|
||||
blur_sigma=0,
|
||||
noise_type="perlin",
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# ============================
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
NoiseTexturize( # tested
|
||||
sigma_range=(3, 10),
|
||||
turbulence_range=(2, 5),
|
||||
texture_width_range=(300, 500),
|
||||
texture_height_range=(300, 500),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
BrightnessTexturize( # tested
|
||||
texturize_range=(0.9, 0.99),
|
||||
deviation=0.03,
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
]
|
||||
|
||||
post_phase = [
|
||||
ColorShift( # tested
|
||||
color_shift_offset_x_range=(3, 5),
|
||||
color_shift_offset_y_range=(3, 5),
|
||||
color_shift_iterations=(2, 3),
|
||||
color_shift_brightness_range=(0.9, 1.1),
|
||||
color_shift_gaussian_kernel_range=(3, 3),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
DirtyDrum( # tested
|
||||
line_width_range=(1, 6),
|
||||
line_concentration=random.uniform(0.05, 0.15),
|
||||
direction=random.randint(0, 2),
|
||||
noise_intensity=random.uniform(0.6, 0.95),
|
||||
noise_value=(64, 224),
|
||||
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
||||
sigmaX=0,
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
LightingGradient(
|
||||
light_position=None,
|
||||
direction=None,
|
||||
max_brightness=255,
|
||||
min_brightness=0,
|
||||
mode="gaussian",
|
||||
linear_decay_rate=None,
|
||||
transparency=None,
|
||||
),
|
||||
Brightness(
|
||||
brightness_range=(0.9, 1.1),
|
||||
min_brightness=0,
|
||||
min_brightness_value=(120, 150),
|
||||
),
|
||||
Gamma(
|
||||
gamma_range=(0.9, 1.1),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# =====================================
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
SubtleNoise(
|
||||
subtle_range=random.randint(5, 10),
|
||||
),
|
||||
Jpeg(
|
||||
quality_range=(70, 95),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# =====================================
|
||||
]
|
||||
|
||||
pipeline = AugraphyPipeline(
|
||||
ink_phase=ink_phase,
|
||||
paper_phase=paper_phase,
|
||||
post_phase=post_phase,
|
||||
pre_phase=pre_phase,
|
||||
log=False,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
184
texteller/models/ocr_model/utils/to_katex.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import re
|
||||
|
||||
|
||||
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
result = ""
|
||||
i = 0
|
||||
n = len(input_str)
|
||||
|
||||
while i < n:
|
||||
if input_str[i : i + len(old_inst)] == old_inst:
|
||||
# check if the old_inst is followed by old_surr_l
|
||||
start = i + len(old_inst)
|
||||
else:
|
||||
result += input_str[i]
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if start < n and input_str[start] == old_surr_l:
|
||||
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
|
||||
count = 1
|
||||
j = start + 1
|
||||
escaped = False
|
||||
while j < n and count > 0:
|
||||
if input_str[j] == '\\' and not escaped:
|
||||
escaped = True
|
||||
j += 1
|
||||
continue
|
||||
if input_str[j] == old_surr_r and not escaped:
|
||||
count -= 1
|
||||
if count == 0:
|
||||
break
|
||||
elif input_str[j] == old_surr_l and not escaped:
|
||||
count += 1
|
||||
escaped = False
|
||||
j += 1
|
||||
|
||||
if count == 0:
|
||||
assert j < n
|
||||
assert input_str[start] == old_surr_l
|
||||
assert input_str[j] == old_surr_r
|
||||
inner_content = input_str[start + 1 : j]
|
||||
# Replace the content with new pattern
|
||||
result += new_inst + new_surr_l + inner_content + new_surr_r
|
||||
i = j + 1
|
||||
continue
|
||||
else:
|
||||
assert count >= 1
|
||||
assert j == n
|
||||
print("Warning: unbalanced surrogate pair in input string")
|
||||
result += new_inst + new_surr_l
|
||||
i = start + 1
|
||||
continue
|
||||
else:
|
||||
result += input_str[i:start]
|
||||
i = start
|
||||
|
||||
if old_inst != new_inst and (old_inst + old_surr_l) in result:
|
||||
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def find_substring_positions(string, substring):
|
||||
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
|
||||
return positions
|
||||
|
||||
|
||||
def rm_dollar_surr(content):
|
||||
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
|
||||
matches = pattern.findall(content)
|
||||
|
||||
for match in matches:
|
||||
if not re.match(r'\\[a-zA-Z]+', match):
|
||||
new_match = match.strip('$')
|
||||
content = content.replace(match, ' ' + new_match + ' ')
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
pos = find_substring_positions(input_str, old_inst + old_surr_l)
|
||||
res = list(input_str)
|
||||
for p in pos[::-1]:
|
||||
res[p:] = list(
|
||||
change(
|
||||
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
|
||||
)
|
||||
)
|
||||
res = ''.join(res)
|
||||
return res
|
||||
|
||||
|
||||
def to_katex(formula: str) -> str:
|
||||
res = formula
|
||||
# remove mbox surrounding
|
||||
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
|
||||
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
|
||||
# remove hbox surrounding
|
||||
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
|
||||
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove raise surrounding
|
||||
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
|
||||
# remove makebox
|
||||
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
|
||||
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove vbox surrounding, scalebox surrounding
|
||||
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
|
||||
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
|
||||
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
origin_instructions = [
|
||||
r'\Huge',
|
||||
r'\huge',
|
||||
r'\LARGE',
|
||||
r'\Large',
|
||||
r'\large',
|
||||
r'\normalsize',
|
||||
r'\small',
|
||||
r'\footnotesize',
|
||||
r'\tiny',
|
||||
]
|
||||
for old_ins, new_ins in zip(origin_instructions, origin_instructions):
|
||||
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
|
||||
|
||||
origin_instructions = [
|
||||
r'\left',
|
||||
r'\middle',
|
||||
r'\right',
|
||||
r'\big',
|
||||
r'\Big',
|
||||
r'\bigg',
|
||||
r'\Bigg',
|
||||
r'\bigl',
|
||||
r'\Bigl',
|
||||
r'\biggl',
|
||||
r'\Biggl',
|
||||
r'\bigm',
|
||||
r'\Bigm',
|
||||
r'\biggm',
|
||||
r'\Biggm',
|
||||
r'\bigr',
|
||||
r'\Bigr',
|
||||
r'\biggr',
|
||||
r'\Biggr',
|
||||
]
|
||||
for origin_ins in origin_instructions:
|
||||
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
|
||||
|
||||
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
|
||||
|
||||
if res.endswith(r'\newline'):
|
||||
res = res[:-8]
|
||||
|
||||
# remove multiple spaces
|
||||
res = re.sub(r'(\\,){1,}', ' ', res)
|
||||
res = re.sub(r'(\\!){1,}', ' ', res)
|
||||
res = re.sub(r'(\\;){1,}', ' ', res)
|
||||
res = re.sub(r'(\\:){1,}', ' ', res)
|
||||
res = re.sub(r'\\vspace\{.*?}', '', res)
|
||||
|
||||
# merge consecutive text
|
||||
def merge_texts(match):
|
||||
texts = match.group(0)
|
||||
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
|
||||
return f'\\text{{{merged_content}}}'
|
||||
|
||||
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
|
||||
|
||||
res = res.replace(r'\bf ', '')
|
||||
res = rm_dollar_surr(res)
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
res = re.sub(r' +', ' ', res)
|
||||
|
||||
return res.strip()
|
||||
177
texteller/models/ocr_model/utils/transforms.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
from collections import Counter
|
||||
|
||||
from ...globals import (
|
||||
IMG_CHANNELS,
|
||||
FIXED_IMG_SIZE,
|
||||
IMAGE_MEAN,
|
||||
IMAGE_STD,
|
||||
MAX_RESIZE_RATIO,
|
||||
MIN_RESIZE_RATIO,
|
||||
)
|
||||
from .ocr_aug import ocr_augmentation_pipeline
|
||||
|
||||
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
||||
train_pipeline = ocr_augmentation_pipeline()
|
||||
|
||||
general_transform_pipeline = v2.Compose(
|
||||
[
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(),
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True,
|
||||
),
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
# v2.ToPILImage()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||
|
||||
diff = cv2.absdiff(image, bg)
|
||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
threshold = 15
|
||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
trimmed_image = image[y : y + h, x : x + w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
pad_width_size = randi[0] + randi[2]
|
||||
if pad_height_size + image.shape[0] < 30:
|
||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||
randi[1] += compensate_height
|
||||
randi[3] += compensate_height
|
||||
if pad_width_size + image.shape[1] < 30:
|
||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||
randi[0] += compensate_width
|
||||
randi[2] += compensate_width
|
||||
return v2.functional.pad(
|
||||
torch.from_numpy(image).permute(2, 0, 1),
|
||||
padding=randi,
|
||||
padding_mode='constant',
|
||||
fill=(255, 255, 255),
|
||||
)
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
cv2.resize(
|
||||
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
||||
) # 抗锯齿
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
|
||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
# Get the center of the image to define the point of rotation
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
|
||||
# Generate a random angle within the specified range
|
||||
angle = random.randint(min_angle, max_angle)
|
||||
|
||||
# Get the rotation matrix for rotating the image around its center
|
||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
|
||||
# Determine the size of the rotated image
|
||||
cos = np.abs(rotation_mat[0, 0])
|
||||
sin = np.abs(rotation_mat[0, 1])
|
||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||
|
||||
# Adjust the rotation matrix to take into account translation
|
||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(
|
||||
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
||||
)
|
||||
|
||||
return rotated_image
|
||||
|
||||
|
||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
image = train_pipeline(image)
|
||||
return image
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
# random resize first
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
images = [trim_white_border(image) for image in images]
|
||||
|
||||
# OCR augmentation
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
return images
|
||||
|
||||
|
||||
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
images = [
|
||||
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
|
||||
]
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
|
||||
return images
|
||||