commit f057490bdb8c5b291e06390178cc67134504fcb4 Author: 三洋三洋 <1258009915@qq.com> Date: Sun Feb 11 08:06:50 2024 +0000 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..37c7cfd --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +**/__pycache__ +**/.vscode +**/train_result + +**/logs +**/.cache +**/tmp* \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..97be631 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 OleehyO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..41b3b13 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# TexTeller + +## Prerequisites + +python=3.10 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d7ca88e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +transformers +datasets +evaluate +streamlit +opencv-python +ray[serve] +accelerate +tensorboardX +nltk diff --git a/src/client_demo.py b/src/client_demo.py new file mode 100644 index 0000000..5b68e2d --- /dev/null +++ b/src/client_demo.py @@ -0,0 +1,11 @@ +import requests + +url = "http://127.0.0.1:8000/predict" + +img_path = "/your/image/path/" + +data = {"img_path": img_path} + +response = requests.post(url, json=data) + +print(response.text) diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000..c6d6f61 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,36 @@ +import os +import argparse + +from pathlib import Path +from models.ocr_model.utils.inference import inference +from models.ocr_model.model.TexTeller import TexTeller + + +if __name__ == '__main__': + os.chdir(Path(__file__).resolve().parent) + parser = argparse.ArgumentParser() + parser.add_argument( + '-img', + type=str, + required=True, + help='path to the input image' + ) + parser.add_argument( + '-cuda', + default=False, + action='store_true', + help='use cuda or not' + ) + + args = parser.parse_args() + + # You can use your own checkpoint and tokenizer path. + print('Loading model and tokenizer...') + model = TexTeller.from_pretrained() + tokenizer = TexTeller.get_tokenizer() + print('Model and tokenizer loaded.') + + img_path = [args.img] + print('Inference...') + res = inference(model, tokenizer, img_path, args.cuda) + print(res[0]) diff --git a/src/models/globals.py b/src/models/globals.py new file mode 100644 index 0000000..8ec06d9 --- /dev/null +++ b/src/models/globals.py @@ -0,0 +1,26 @@ +# Formula image(grayscale) mean and variance +IMAGE_MEAN = 0.9545467 +IMAGE_STD = 0.15394445 + +# Density value for pdf to image conversion +TEXTELL_DENSITY = 200 + +# Vocabulary size for TexTeller +VOCAB_SIZE = 10000 + +# Fixed size for input image for TexTeller +FIXED_IMG_SIZE = 448 + +# Image channel for TexTeller +IMG_CHANNELS = 1 # grayscale image + +# Max size of token for embedding +MAX_TOKEN_SIZE = 512 + +# Scaling ratio for random resizing when training +MAX_RESIZE_RATIO = 1.15 +MIN_RESIZE_RATIO = 0.75 + +# Minimum height and width for input image for TexTeller +MIN_HEIGHT = 12 +MIN_WIDTH = 30 diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py new file mode 100644 index 0000000..740ccc5 --- /dev/null +++ b/src/models/ocr_model/model/TexTeller.py @@ -0,0 +1,43 @@ +from pathlib import Path + +from models.globals import ( + VOCAB_SIZE, + FIXED_IMG_SIZE, + IMG_CHANNELS, +) + +from transformers import ( + ViTConfig, + ViTModel, + TrOCRConfig, + TrOCRForCausalLM, + RobertaTokenizerFast, + VisionEncoderDecoderModel, +) + + +class TexTeller(VisionEncoderDecoderModel): + REPO_NAME = 'OleehyO/TexTeller' + def __init__(self, decoder_path=None, tokenizer_path=None): + encoder = ViTModel(ViTConfig( + image_size=FIXED_IMG_SIZE, + num_channels=IMG_CHANNELS + )) + decoder = TrOCRForCausalLM(TrOCRConfig( + vocab_size=VOCAB_SIZE, + )) + super().__init__(encoder=encoder, decoder=decoder) + + @classmethod + def from_pretrained(cls, model_path: str = None): + if model_path is None or model_path == cls.REPO_NAME: + return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME) + model_path = Path(model_path).resolve() + return VisionEncoderDecoderModel.from_pretrained(str(model_path)) + + @classmethod + def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast: + if tokenizer_path is None or tokenizer_path == cls.REPO_NAME: + return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME) + tokenizer_path = Path(tokenizer_path).resolve() + return RobertaTokenizerFast.from_pretrained(str(tokenizer_path)) diff --git a/src/models/ocr_model/train/dataset/formulas.jsonl b/src/models/ocr_model/train/dataset/formulas.jsonl new file mode 100644 index 0000000..5a07425 --- /dev/null +++ b/src/models/ocr_model/train/dataset/formulas.jsonl @@ -0,0 +1,35 @@ +{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"} +{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"} +{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"} +{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"} +{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"} +{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"} +{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"} +{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"} +{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"} +{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"} +{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"} +{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"} +{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"} +{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"} +{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"} +{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"} +{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"} +{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"} +{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"} +{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"} +{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"} +{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"} +{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"} +{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"} +{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"} +{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"} +{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"} +{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"} +{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"} +{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"} +{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"} +{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"} +{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"} +{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"} +{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"} diff --git a/src/models/ocr_model/train/dataset/images/0.png b/src/models/ocr_model/train/dataset/images/0.png new file mode 100644 index 0000000..9f27321 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/0.png differ diff --git a/src/models/ocr_model/train/dataset/images/1.png b/src/models/ocr_model/train/dataset/images/1.png new file mode 100644 index 0000000..bc65c5f Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/1.png differ diff --git a/src/models/ocr_model/train/dataset/images/10.png b/src/models/ocr_model/train/dataset/images/10.png new file mode 100644 index 0000000..b2306ab Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/10.png differ diff --git a/src/models/ocr_model/train/dataset/images/11.png b/src/models/ocr_model/train/dataset/images/11.png new file mode 100644 index 0000000..f8b20a1 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/11.png differ diff --git a/src/models/ocr_model/train/dataset/images/12.png b/src/models/ocr_model/train/dataset/images/12.png new file mode 100644 index 0000000..5b3b285 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/12.png differ diff --git a/src/models/ocr_model/train/dataset/images/13.png b/src/models/ocr_model/train/dataset/images/13.png new file mode 100644 index 0000000..692fcc2 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/13.png differ diff --git a/src/models/ocr_model/train/dataset/images/14.png b/src/models/ocr_model/train/dataset/images/14.png new file mode 100644 index 0000000..e7fe2fd Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/14.png differ diff --git a/src/models/ocr_model/train/dataset/images/15.png b/src/models/ocr_model/train/dataset/images/15.png new file mode 100644 index 0000000..fbbeb82 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/15.png differ diff --git a/src/models/ocr_model/train/dataset/images/16.png b/src/models/ocr_model/train/dataset/images/16.png new file mode 100644 index 0000000..be56e99 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/16.png differ diff --git a/src/models/ocr_model/train/dataset/images/17.png b/src/models/ocr_model/train/dataset/images/17.png new file mode 100644 index 0000000..4f30cf1 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/17.png differ diff --git a/src/models/ocr_model/train/dataset/images/18.png b/src/models/ocr_model/train/dataset/images/18.png new file mode 100644 index 0000000..8774d25 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/18.png differ diff --git a/src/models/ocr_model/train/dataset/images/19.png b/src/models/ocr_model/train/dataset/images/19.png new file mode 100644 index 0000000..4d3daa5 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/19.png differ diff --git a/src/models/ocr_model/train/dataset/images/2.png b/src/models/ocr_model/train/dataset/images/2.png new file mode 100644 index 0000000..8fe5dd9 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/2.png differ diff --git a/src/models/ocr_model/train/dataset/images/20.png b/src/models/ocr_model/train/dataset/images/20.png new file mode 100644 index 0000000..45c400d Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/20.png differ diff --git a/src/models/ocr_model/train/dataset/images/21.png b/src/models/ocr_model/train/dataset/images/21.png new file mode 100644 index 0000000..311c1fd Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/21.png differ diff --git a/src/models/ocr_model/train/dataset/images/22.png b/src/models/ocr_model/train/dataset/images/22.png new file mode 100644 index 0000000..6273383 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/22.png differ diff --git a/src/models/ocr_model/train/dataset/images/23.png b/src/models/ocr_model/train/dataset/images/23.png new file mode 100644 index 0000000..06dfcdb Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/23.png differ diff --git a/src/models/ocr_model/train/dataset/images/24.png b/src/models/ocr_model/train/dataset/images/24.png new file mode 100644 index 0000000..c718fd5 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/24.png differ diff --git a/src/models/ocr_model/train/dataset/images/25.png b/src/models/ocr_model/train/dataset/images/25.png new file mode 100644 index 0000000..b90ab45 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/25.png differ diff --git a/src/models/ocr_model/train/dataset/images/26.png b/src/models/ocr_model/train/dataset/images/26.png new file mode 100644 index 0000000..087e6de Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/26.png differ diff --git a/src/models/ocr_model/train/dataset/images/27.png b/src/models/ocr_model/train/dataset/images/27.png new file mode 100644 index 0000000..67f552c Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/27.png differ diff --git a/src/models/ocr_model/train/dataset/images/28.png b/src/models/ocr_model/train/dataset/images/28.png new file mode 100644 index 0000000..3b29359 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/28.png differ diff --git a/src/models/ocr_model/train/dataset/images/29.png b/src/models/ocr_model/train/dataset/images/29.png new file mode 100644 index 0000000..917e0ed Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/29.png differ diff --git a/src/models/ocr_model/train/dataset/images/3.png b/src/models/ocr_model/train/dataset/images/3.png new file mode 100644 index 0000000..0354b68 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/3.png differ diff --git a/src/models/ocr_model/train/dataset/images/30.png b/src/models/ocr_model/train/dataset/images/30.png new file mode 100644 index 0000000..cb38168 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/30.png differ diff --git a/src/models/ocr_model/train/dataset/images/31.png b/src/models/ocr_model/train/dataset/images/31.png new file mode 100644 index 0000000..973f951 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/31.png differ diff --git a/src/models/ocr_model/train/dataset/images/32.png b/src/models/ocr_model/train/dataset/images/32.png new file mode 100644 index 0000000..7c019a5 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/32.png differ diff --git a/src/models/ocr_model/train/dataset/images/33.png b/src/models/ocr_model/train/dataset/images/33.png new file mode 100644 index 0000000..172ff55 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/33.png differ diff --git a/src/models/ocr_model/train/dataset/images/34.png b/src/models/ocr_model/train/dataset/images/34.png new file mode 100644 index 0000000..013c1cc Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/34.png differ diff --git a/src/models/ocr_model/train/dataset/images/4.png b/src/models/ocr_model/train/dataset/images/4.png new file mode 100644 index 0000000..b8b0e39 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/4.png differ diff --git a/src/models/ocr_model/train/dataset/images/5.png b/src/models/ocr_model/train/dataset/images/5.png new file mode 100644 index 0000000..db3af1f Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/5.png differ diff --git a/src/models/ocr_model/train/dataset/images/6.png b/src/models/ocr_model/train/dataset/images/6.png new file mode 100644 index 0000000..c171137 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/6.png differ diff --git a/src/models/ocr_model/train/dataset/images/7.png b/src/models/ocr_model/train/dataset/images/7.png new file mode 100644 index 0000000..9c2f9a6 Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/7.png differ diff --git a/src/models/ocr_model/train/dataset/images/8.png b/src/models/ocr_model/train/dataset/images/8.png new file mode 100644 index 0000000..54e300a Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/8.png differ diff --git a/src/models/ocr_model/train/dataset/images/9.png b/src/models/ocr_model/train/dataset/images/9.png new file mode 100644 index 0000000..9bf24fb Binary files /dev/null and b/src/models/ocr_model/train/dataset/images/9.png differ diff --git a/src/models/ocr_model/train/dataset/loader.py b/src/models/ocr_model/train/dataset/loader.py new file mode 100644 index 0000000..f782f36 --- /dev/null +++ b/src/models/ocr_model/train/dataset/loader.py @@ -0,0 +1,50 @@ +from PIL import Image +from pathlib import Path +import datasets +import json + +DIR_URL = Path('absolute/path/to/dataset/directory') +# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset') + + +class LatexFormulas(datasets.GeneratorBasedBuilder): + BUILDER_CONFIGS = [] + + def _info(self): + return datasets.DatasetInfo( + features=datasets.Features({ + "image": datasets.Image(), + "latex_formula": datasets.Value("string") + }) + ) + + def _split_generators(self, dl_manager: datasets.DownloadManager): + dir_path = Path(dl_manager.download(str(DIR_URL))) + assert dir_path.is_dir() + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + 'dir_path': dir_path, + } + ) + ] + + def _generate_examples(self, dir_path: Path): + images_path = dir_path / 'images' + formulas_path = dir_path / 'formulas.jsonl' + + img2formula = {} + with formulas_path.open('r', encoding='utf-8') as f: + for line in f: + single_json = json.loads(line) + img2formula[single_json['img_name']] = single_json['formula'] + + for img_path in images_path.iterdir(): + if img_path.suffix not in ['.jpg', '.png']: + continue + yield str(img_path), { + "image": Image.open(img_path), + "latex_formula": img2formula[img_path.name] + } diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py new file mode 100644 index 0000000..af0c7a1 --- /dev/null +++ b/src/models/ocr_model/train/train.py @@ -0,0 +1,103 @@ +import os + +from functools import partial +from pathlib import Path + +from datasets import load_dataset +from transformers import ( + Trainer, + TrainingArguments, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + GenerationConfig +) + +from .training_args import CONFIG +from ..model.TexTeller import TexTeller +from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn +from ..utils.metrics import bleu_metric +from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT + + +def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): + training_args = TrainingArguments(**CONFIG) + trainer = Trainer( + model, + training_args, + + train_dataset=train_dataset, + eval_dataset=eval_dataset, + + tokenizer=tokenizer, + data_collator=collate_fn_with_tokenizer, + ) + + trainer.train(resume_from_checkpoint=None) + + +def evaluate(model, tokenizer, eval_dataset, collate_fn): + eval_config = CONFIG.copy() + eval_config['predict_with_generate'] = True + generate_config = GenerationConfig( + max_new_tokens=MAX_TOKEN_SIZE, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + ) + eval_config['generation_config'] = generate_config + seq2seq_config = Seq2SeqTrainingArguments(**eval_config) + + trainer = Seq2SeqTrainer( + model, + seq2seq_config, + + eval_dataset=eval_dataset, + tokenizer=tokenizer, + data_collator=collate_fn, + compute_metrics=partial(bleu_metric, tokenizer=tokenizer) + ) + + eval_res = trainer.evaluate() + print(eval_res) + + +if __name__ == '__main__': + script_dirpath = Path(__file__).resolve().parent + os.chdir(script_dirpath) + + dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train'] + dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH) + dataset = dataset.shuffle(seed=42) + dataset = dataset.flatten_indices() + + tokenizer = TexTeller.get_tokenizer() + # If you want use your own tokenizer, please modify the path to your tokenizer + #+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer') + + map_fn = partial(tokenize_fn, tokenizer=tokenizer) + tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) + tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) + + # Split dataset into train and eval, ratio 9:1 + split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) + + # Train from scratch + model = TexTeller() + # or train from TexTeller pre-trained model: model = TexTeller.from_pretrained() + + # If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint + #+e.g. + #+model = TexTeller.from_pretrained( + #+ '/path/to/your/model_checkpoint' + #+) + + enable_train = True + enable_evaluate = True + if enable_train: + train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) + if enable_evaluate and len(eval_dataset) > 0: + evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) diff --git a/src/models/ocr_model/train/training_args.py b/src/models/ocr_model/train/training_args.py new file mode 100644 index 0000000..07334fa --- /dev/null +++ b/src/models/ocr_model/train/training_args.py @@ -0,0 +1,38 @@ +CONFIG = { + "seed": 42, # Random seed for reproducibility + "use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code) + "learning_rate": 5e-5, # Learning rate + "num_train_epochs": 10, # Total number of training epochs + "per_device_train_batch_size": 4, # Batch size per GPU for training + "per_device_eval_batch_size": 8, # Batch size per GPU for evaluation + + "output_dir": "train_result", # Output directory + "overwrite_output_dir": False, # If the output directory exists, do not delete its content + "report_to": ["tensorboard"], # Report logs to TensorBoard + + "save_strategy": "steps", # Strategy to save checkpoints + "save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000) + "save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded + + "logging_strategy": "steps", # Log every certain number of steps + "logging_steps": 500, # Number of steps between each log + "logging_nan_inf_filter": False, # Record logs for loss=nan or inf + + "optim": "adamw_torch", # Optimizer + "lr_scheduler_type": "cosine", # Learning rate scheduler + "warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr) + "max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0) + "fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode) + "bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it) + "gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large + "jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors) + "torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance) + + "dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU + "dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used + + "evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch" + "eval_steps": 500, # If evaluation_strategy="step" + + "remove_unused_columns": False, # Don't change this unless you really know what you are doing. +} diff --git a/src/models/ocr_model/utils/functional.py b/src/models/ocr_model/utils/functional.py new file mode 100644 index 0000000..7d3947e --- /dev/null +++ b/src/models/ocr_model/utils/functional.py @@ -0,0 +1,46 @@ +import torch +import numpy as np + +from transformers import DataCollatorForLanguageModeling +from typing import List, Dict, Any +from .transforms import train_transform + + +def left_move(x: torch.Tensor, pad_val): + assert len(x.shape) == 2, 'x should be 2-dimensional' + lefted_x = torch.ones_like(x) + lefted_x[:, :-1] = x[:, 1:] + lefted_x[:, -1] = pad_val + return lefted_x + + +def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]: + assert tokenizer is not None, 'tokenizer should not be None' + tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True) + tokenized_formula['pixel_values'] = samples['image'] + return tokenized_formula + + +def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]: + assert tokenizer is not None, 'tokenizer should not be None' + pixel_values = [dic.pop('pixel_values') for dic in samples] + + clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + batch = clm_collator(samples) + batch['pixel_values'] = pixel_values + batch['decoder_input_ids'] = batch.pop('input_ids') + batch['decoder_attention_mask'] = batch.pop('attention_mask') + + # left shift labels and decoder_attention_mask, padding with -100 + batch['labels'] = left_move(batch['labels'], -100) + + # convert list of Image to tensor with (B, C, H, W) + batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) + return batch + + +def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + processed_img = train_transform(samples['pixel_values']) + samples['pixel_values'] = processed_img + return samples diff --git a/src/models/ocr_model/utils/helpers.py b/src/models/ocr_model/utils/helpers.py new file mode 100644 index 0000000..d650556 --- /dev/null +++ b/src/models/ocr_model/utils/helpers.py @@ -0,0 +1,26 @@ +import cv2 +import numpy as np +from typing import List + + +def convert2rgb(image_paths: List[str]) -> List[np.ndarray]: + processed_images = [] + for path in image_paths: + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if image is None: + print(f"Image at {path} could not be read.") + continue + if image.dtype == np.uint16: + print(f'Converting {path} to 8-bit, image may be lossy.') + image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0)) + + channels = 1 if len(image.shape) == 2 else image.shape[2] + if channels == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + elif channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif channels == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + processed_images.append(image) + + return processed_images diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py new file mode 100644 index 0000000..eaa4b52 --- /dev/null +++ b/src/models/ocr_model/utils/inference.py @@ -0,0 +1,38 @@ +import torch + +from transformers import RobertaTokenizerFast, GenerationConfig +from typing import List + +from models.ocr_model.model.TexTeller import TexTeller +from models.ocr_model.utils.transforms import inference_transform +from models.ocr_model.utils.helpers import convert2rgb +from models.globals import MAX_TOKEN_SIZE + + +def inference( + model: TexTeller, + tokenizer: RobertaTokenizerFast, + imgs_path: List[str], + use_cuda: bool, + num_beams: int = 1, +) -> List[str]: + model.eval() + imgs = convert2rgb(imgs_path) + imgs = inference_transform(imgs) + pixel_values = torch.stack(imgs) + + if use_cuda: + model = model.to('cuda') + pixel_values = pixel_values.to('cuda') + + generate_config = GenerationConfig( + max_new_tokens=MAX_TOKEN_SIZE, + num_beams=num_beams, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + ) + pred = model.generate(pixel_values, generation_config=generate_config) + res = tokenizer.batch_decode(pred, skip_special_tokens=True) + return res diff --git a/src/models/ocr_model/utils/metrics.py b/src/models/ocr_model/utils/metrics.py new file mode 100644 index 0000000..1dd0702 --- /dev/null +++ b/src/models/ocr_model/utils/metrics.py @@ -0,0 +1,23 @@ +import evaluate +import numpy as np +import os + +from pathlib import Path +from typing import Dict +from transformers import EvalPrediction, RobertaTokenizer + + +def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict: + cur_dir = Path(os.getcwd()) + os.chdir(Path(__file__).resolve().parent) + metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded + os.chdir(cur_dir) + + logits, labels = eval_preds.predictions, eval_preds.label_ids + preds = logits + + labels = np.where(labels == -100, 1, labels) + + preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + return metric.compute(predictions=preds, references=labels) diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py new file mode 100644 index 0000000..ba923a2 --- /dev/null +++ b/src/models/ocr_model/utils/transforms.py @@ -0,0 +1,90 @@ +import torch +import random +import numpy as np +import cv2 + +from torchvision.transforms import v2 +from typing import List +from PIL import Image + +from models.globals import ( + FIXED_IMG_SIZE, + IMAGE_MEAN, IMAGE_STD, + MAX_RESIZE_RATIO, MIN_RESIZE_RATIO +) + +general_transform_pipeline = v2.Compose([ + v2.ToImage(), + v2.ToDtype(torch.uint8, scale=True), + v2.Grayscale(), + v2.Resize( + size=FIXED_IMG_SIZE - 1, + interpolation=v2.InterpolationMode.BICUBIC, + max_size=FIXED_IMG_SIZE, + antialias=True + ), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]), +]) + + +def trim_white_border(image: np.ndarray): + if len(image.shape) != 3 or image.shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") + + if image.dtype != np.uint8: + raise ValueError(f"Image should stored in uint8") + + h, w = image.shape[:2] + bg = np.full((h, w, 3), 255, dtype=np.uint8) + diff = cv2.absdiff(image, bg) + + _, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY) + gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) + x, y, w, h = cv2.boundingRect(gray_diff) + + trimmed_image = image[y:y+h, x:x+w] + return trimmed_image + + +def padding(images: List[torch.Tensor], required_size: int): + images = [ + v2.functional.pad( + img, + padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]] + ) + for img in images + ] + return images + + +def random_resize( + images: List[np.ndarray], + minr: float, + maxr: float +) -> List[np.ndarray]: + if len(images[0].shape) != 3 or images[0].shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") + + ratios = [random.uniform(minr, maxr) for _ in range(len(images))] + return [ + cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿 + for img, r in zip(images, ratios) + ] + + +def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]: + images = [trim_white_border(image) for image in images] + images = general_transform_pipeline(images) + images = padding(images, FIXED_IMG_SIZE) + return images + + +def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: + images = [np.array(img.convert('RGB')) for img in images] + images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) + return general_transform(images) + + +def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: + return general_transform(images) diff --git a/src/models/tokenizer/train.py b/src/models/tokenizer/train.py new file mode 100644 index 0000000..aa44521 --- /dev/null +++ b/src/models/tokenizer/train.py @@ -0,0 +1,25 @@ +import os +from pathlib import Path +from datasets import load_dataset +from ..ocr_model.model.TexTeller import TexTeller +from ..globals import VOCAB_SIZE + + +if __name__ == '__main__': + script_dirpath = Path(__file__).resolve().parent + os.chdir(script_dirpath) + + tokenizer = TexTeller.get_tokenizer() + + # Don't forget to config your dataset path in loader.py + dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train'] + + new_tokenizer = tokenizer.train_new_from_iterator( + text_iterator=dataset['latex_formula'], + + # If you want to use a different vocab size, **change VOCAB_SIZE from globals.py** + vocab_size=VOCAB_SIZE + ) + + # Save the new tokenizer for later training and inference + new_tokenizer.save_pretrained('./your_dir_name') diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000..e838f27 --- /dev/null +++ b/src/server.py @@ -0,0 +1,81 @@ +import argparse +import time + +from starlette.requests import Request +from ray import serve +from ray.serve.handle import DeploymentHandle + +from models.ocr_model.utils.inference import inference +from models.ocr_model.model.TexTeller import TexTeller + + +parser = argparse.ArgumentParser() +parser.add_argument( + '-ckpt', '--checkpoint_dir', type=str +) +parser.add_argument( + '-tknz', '--tokenizer_dir', type=str +) +parser.add_argument('-port', '--server_port', type=int, default=8000) +parser.add_argument('--num_replicas', type=int, default=1) +parser.add_argument('--ncpu_per_replica', type=float, default=1.0) +parser.add_argument('--ngpu_per_replica', type=float, default=0.0) + +parser.add_argument('--use_cuda', action='store_true', default=False) +parser.add_argument('--num_beam', type=int, default=1) + +args = parser.parse_args() +if args.ngpu_per_replica > 0 and not args.use_cuda: + raise ValueError("use_cuda must be True if ngpu_per_replica > 0") + + +@serve.deployment( + num_replicas=args.num_replicas, + ray_actor_options={ + "num_cpus": args.ncpu_per_replica, + "num_gpus": args.ngpu_per_replica + } +) +class TexTellerServer: + def __init__( + self, + checkpoint_path: str, + tokenizer_path: str, + use_cuda: bool = False, + num_beam: int = 1 + ) -> None: + self.model = TexTeller.from_pretrained(checkpoint_path) + self.tokenizer = TexTeller.get_tokenizer(tokenizer_path) + self.use_cuda = use_cuda + self.num_beam = num_beam + + self.model = self.model.to('cuda') if use_cuda else self.model + + def predict(self, image_path: str) -> str: + return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0] + + +@serve.deployment() +class Ingress: + def __init__(self, texteller_server: DeploymentHandle) -> None: + self.texteller_server = texteller_server + + async def __call__(self, request: Request) -> str: + msg = await request.json() + img_path: str = msg['img_path'] + pred = await self.texteller_server.predict.remote(img_path) + return pred + + +if __name__ == '__main__': + ckpt_dir = args.checkpoint_dir + tknz_dir = args.tokenizer_dir + + serve.start(http_options={"port": args.server_port}) + texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam) + ingress = Ingress.bind(texteller_server) + + ingress_handle = serve.run(ingress, route_prefix="/predict") + + while True: + time.sleep(1) diff --git a/src/start_web.sh b/src/start_web.sh new file mode 100755 index 0000000..2a9538d --- /dev/null +++ b/src/start_web.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +set -exu + +export CHECKPOINT_DIR="OleehyO/TexTeller" +export TOKENIZER_DIR="OleehyO/TexTeller" +export USE_CUDA=False # True or False (case-sensitive) +export NUM_BEAM=1 + +streamlit run web.py diff --git a/src/web.py b/src/web.py new file mode 100644 index 0000000..1ef6fe5 --- /dev/null +++ b/src/web.py @@ -0,0 +1,93 @@ +import os +import io +import base64 +import tempfile +import streamlit as st + +from PIL import Image +from models.ocr_model.utils.inference import inference +from models.ocr_model.model.TexTeller import TexTeller + + +@st.cache_resource +def get_model(): + return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR']) + + +@st.cache_resource +def get_tokenizer(): + return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) + + +model = get_model() +tokenizer = get_tokenizer() + + +# ============================ pages =============================== # +html_string = ''' +
Input image ({img.height}✖️{img.width})
+