Initial commit
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
**/__pycache__
|
||||
**/.vscode
|
||||
**/train_result
|
||||
|
||||
**/logs
|
||||
**/.cache
|
||||
**/tmp*
|
||||
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 OleehyO
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
9
requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
transformers
|
||||
datasets
|
||||
evaluate
|
||||
streamlit
|
||||
opencv-python
|
||||
ray[serve]
|
||||
accelerate
|
||||
tensorboardX
|
||||
nltk
|
||||
11
src/client_demo.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/predict"
|
||||
|
||||
img_path = "/your/image/path/"
|
||||
|
||||
data = {"img_path": img_path}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
|
||||
print(response.text)
|
||||
36
src/inference.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-img',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to the input image'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-cuda',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='use cuda or not'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# You can use your own checkpoint and tokenizer path.
|
||||
print('Loading model and tokenizer...')
|
||||
model = TexTeller.from_pretrained()
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
print('Model and tokenizer loaded.')
|
||||
|
||||
img_path = [args.img]
|
||||
print('Inference...')
|
||||
res = inference(model, tokenizer, img_path, args.cuda)
|
||||
print(res[0])
|
||||
26
src/models/globals.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Formula image(grayscale) mean and variance
|
||||
IMAGE_MEAN = 0.9545467
|
||||
IMAGE_STD = 0.15394445
|
||||
|
||||
# Density value for pdf to image conversion
|
||||
TEXTELL_DENSITY = 200
|
||||
|
||||
# Vocabulary size for TexTeller
|
||||
VOCAB_SIZE = 10000
|
||||
|
||||
# Fixed size for input image for TexTeller
|
||||
FIXED_IMG_SIZE = 448
|
||||
|
||||
# Image channel for TexTeller
|
||||
IMG_CHANNELS = 1 # grayscale image
|
||||
|
||||
# Max size of token for embedding
|
||||
MAX_TOKEN_SIZE = 512
|
||||
|
||||
# Scaling ratio for random resizing when training
|
||||
MAX_RESIZE_RATIO = 1.15
|
||||
MIN_RESIZE_RATIO = 0.75
|
||||
|
||||
# Minimum height and width for input image for TexTeller
|
||||
MIN_HEIGHT = 12
|
||||
MIN_WIDTH = 30
|
||||
43
src/models/ocr_model/model/TexTeller.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pathlib import Path
|
||||
|
||||
from models.globals import (
|
||||
VOCAB_SIZE,
|
||||
FIXED_IMG_SIZE,
|
||||
IMG_CHANNELS,
|
||||
)
|
||||
|
||||
from transformers import (
|
||||
ViTConfig,
|
||||
ViTModel,
|
||||
TrOCRConfig,
|
||||
TrOCRForCausalLM,
|
||||
RobertaTokenizerFast,
|
||||
VisionEncoderDecoderModel,
|
||||
)
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
REPO_NAME = 'OleehyO/TexTeller'
|
||||
def __init__(self, decoder_path=None, tokenizer_path=None):
|
||||
encoder = ViTModel(ViTConfig(
|
||||
image_size=FIXED_IMG_SIZE,
|
||||
num_channels=IMG_CHANNELS
|
||||
))
|
||||
decoder = TrOCRForCausalLM(TrOCRConfig(
|
||||
vocab_size=VOCAB_SIZE,
|
||||
))
|
||||
super().__init__(encoder=encoder, decoder=decoder)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None):
|
||||
if model_path is None or model_path == cls.REPO_NAME:
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_path is None or tokenizer_path == cls.REPO_NAME:
|
||||
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
35
src/models/ocr_model/train/dataset/formulas.jsonl
Normal file
@@ -0,0 +1,35 @@
|
||||
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
BIN
src/models/ocr_model/train/dataset/images/0.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
src/models/ocr_model/train/dataset/images/1.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
src/models/ocr_model/train/dataset/images/10.png
Normal file
|
After Width: | Height: | Size: 6.8 KiB |
BIN
src/models/ocr_model/train/dataset/images/11.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
src/models/ocr_model/train/dataset/images/12.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
BIN
src/models/ocr_model/train/dataset/images/13.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
src/models/ocr_model/train/dataset/images/14.png
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
BIN
src/models/ocr_model/train/dataset/images/15.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
src/models/ocr_model/train/dataset/images/16.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
src/models/ocr_model/train/dataset/images/17.png
Normal file
|
After Width: | Height: | Size: 2.6 KiB |
BIN
src/models/ocr_model/train/dataset/images/18.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
src/models/ocr_model/train/dataset/images/19.png
Normal file
|
After Width: | Height: | Size: 2.7 KiB |
BIN
src/models/ocr_model/train/dataset/images/2.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/20.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/21.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/22.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
src/models/ocr_model/train/dataset/images/23.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
src/models/ocr_model/train/dataset/images/24.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
src/models/ocr_model/train/dataset/images/25.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
src/models/ocr_model/train/dataset/images/26.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
src/models/ocr_model/train/dataset/images/27.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
src/models/ocr_model/train/dataset/images/28.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/29.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
src/models/ocr_model/train/dataset/images/3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
src/models/ocr_model/train/dataset/images/30.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/31.png
Normal file
|
After Width: | Height: | Size: 4.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/32.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
src/models/ocr_model/train/dataset/images/33.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
src/models/ocr_model/train/dataset/images/34.png
Normal file
|
After Width: | Height: | Size: 3.2 KiB |
BIN
src/models/ocr_model/train/dataset/images/4.png
Normal file
|
After Width: | Height: | Size: 5.7 KiB |
BIN
src/models/ocr_model/train/dataset/images/5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
src/models/ocr_model/train/dataset/images/6.png
Normal file
|
After Width: | Height: | Size: 4.8 KiB |
BIN
src/models/ocr_model/train/dataset/images/7.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
BIN
src/models/ocr_model/train/dataset/images/8.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
src/models/ocr_model/train/dataset/images/9.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
50
src/models/ocr_model/train/dataset/loader.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import datasets
|
||||
import json
|
||||
|
||||
DIR_URL = Path('absolute/path/to/dataset/directory')
|
||||
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
|
||||
|
||||
|
||||
class LatexFormulas(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = []
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
features=datasets.Features({
|
||||
"image": datasets.Image(),
|
||||
"latex_formula": datasets.Value("string")
|
||||
})
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
dir_path = Path(dl_manager.download(str(DIR_URL)))
|
||||
assert dir_path.is_dir()
|
||||
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
'dir_path': dir_path,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, dir_path: Path):
|
||||
images_path = dir_path / 'images'
|
||||
formulas_path = dir_path / 'formulas.jsonl'
|
||||
|
||||
img2formula = {}
|
||||
with formulas_path.open('r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
single_json = json.loads(line)
|
||||
img2formula[single_json['img_name']] = single_json['formula']
|
||||
|
||||
for img_path in images_path.iterdir():
|
||||
if img_path.suffix not in ['.jpg', '.png']:
|
||||
continue
|
||||
yield str(img_path), {
|
||||
"image": Image.open(img_path),
|
||||
"latex_formula": img2formula[img_path.name]
|
||||
}
|
||||
103
src/models/ocr_model/train/train.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig
|
||||
)
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eval_config = CONFIG.copy()
|
||||
eval_config['predict_with_generate'] = True
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
eval_config['generation_config'] = generate_config
|
||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model,
|
||||
seq2seq_config,
|
||||
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||
)
|
||||
|
||||
eval_res = trainer.evaluate()
|
||||
print(eval_res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
model = TexTeller()
|
||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
#+e.g.
|
||||
#+model = TexTeller.from_pretrained(
|
||||
#+ '/path/to/your/model_checkpoint'
|
||||
#+)
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = True
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate and len(eval_dataset) > 0:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
38
src/models/ocr_model/train/training_args.py
Normal file
@@ -0,0 +1,38 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||
46
src/models/ocr_model/utils/functional.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from .transforms import train_transform
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||
tokenized_formula['pixel_values'] = samples['image']
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch['pixel_values'] = pixel_values
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||
|
||||
# left shift labels and decoder_attention_mask, padding with -100
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
|
||||
# convert list of Image to tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
26
src/models/ocr_model/utils/helpers.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
processed_images = []
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
if image.dtype == np.uint16:
|
||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0))
|
||||
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
38
src/models/ocr_model/utils/inference.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.transforms import inference_transform
|
||||
from models.ocr_model.utils.helpers import convert2rgb
|
||||
from models.globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs_path: List[str],
|
||||
use_cuda: bool,
|
||||
num_beams: int = 1,
|
||||
) -> List[str]:
|
||||
model.eval()
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if use_cuda:
|
||||
model = model.to('cuda')
|
||||
pixel_values = pixel_values.to('cuda')
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
23
src/models/ocr_model/utils/metrics.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from transformers import EvalPrediction, RobertaTokenizer
|
||||
|
||||
|
||||
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
|
||||
cur_dir = Path(os.getcwd())
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded
|
||||
os.chdir(cur_dir)
|
||||
|
||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
preds = logits
|
||||
|
||||
labels = np.where(labels == -100, 1, labels)
|
||||
|
||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
90
src/models/ocr_model/utils/transforms.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
from models.globals import (
|
||||
FIXED_IMG_SIZE,
|
||||
IMAGE_MEAN, IMAGE_STD,
|
||||
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||
)
|
||||
|
||||
general_transform_pipeline = v2.Compose([
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True),
|
||||
v2.Grayscale(),
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True
|
||||
),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
])
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), 255, dtype=np.uint8)
|
||||
diff = cv2.absdiff(image, bg)
|
||||
|
||||
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
|
||||
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
|
||||
x, y, w, h = cv2.boundingRect(gray_diff)
|
||||
|
||||
trimmed_image = image[y:y+h, x:x+w]
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int):
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img,
|
||||
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(
|
||||
images: List[np.ndarray],
|
||||
minr: float,
|
||||
maxr: float
|
||||
) -> List[np.ndarray]:
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
|
||||
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
images = [trim_white_border(image) for image in images]
|
||||
images = general_transform_pipeline(images)
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
return images
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
return general_transform(images)
|
||||
|
||||
|
||||
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
return general_transform(images)
|
||||
25
src/models/tokenizer/train.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
from ..ocr_model.model.TexTeller import TexTeller
|
||||
from ..globals import VOCAB_SIZE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
|
||||
# Don't forget to config your dataset path in loader.py
|
||||
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
|
||||
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
text_iterator=dataset['latex_formula'],
|
||||
|
||||
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
|
||||
vocab_size=VOCAB_SIZE
|
||||
)
|
||||
|
||||
# Save the new tokenizer for later training and inference
|
||||
new_tokenizer.save_pretrained('./your_dir_name')
|
||||
81
src/server.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from starlette.requests import Request
|
||||
from ray import serve
|
||||
from ray.serve.handle import DeploymentHandle
|
||||
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-ckpt', '--checkpoint_dir', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
'-tknz', '--tokenizer_dir', type=str
|
||||
)
|
||||
parser.add_argument('-port', '--server_port', type=int, default=8000)
|
||||
parser.add_argument('--num_replicas', type=int, default=1)
|
||||
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
|
||||
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
|
||||
|
||||
parser.add_argument('--use_cuda', action='store_true', default=False)
|
||||
parser.add_argument('--num_beam', type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.ngpu_per_replica > 0 and not args.use_cuda:
|
||||
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
|
||||
|
||||
|
||||
@serve.deployment(
|
||||
num_replicas=args.num_replicas,
|
||||
ray_actor_options={
|
||||
"num_cpus": args.ncpu_per_replica,
|
||||
"num_gpus": args.ngpu_per_replica
|
||||
}
|
||||
)
|
||||
class TexTellerServer:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
tokenizer_path: str,
|
||||
use_cuda: bool = False,
|
||||
num_beam: int = 1
|
||||
) -> None:
|
||||
self.model = TexTeller.from_pretrained(checkpoint_path)
|
||||
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
|
||||
self.use_cuda = use_cuda
|
||||
self.num_beam = num_beam
|
||||
|
||||
self.model = self.model.to('cuda') if use_cuda else self.model
|
||||
|
||||
def predict(self, image_path: str) -> str:
|
||||
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0]
|
||||
|
||||
|
||||
@serve.deployment()
|
||||
class Ingress:
|
||||
def __init__(self, texteller_server: DeploymentHandle) -> None:
|
||||
self.texteller_server = texteller_server
|
||||
|
||||
async def __call__(self, request: Request) -> str:
|
||||
msg = await request.json()
|
||||
img_path: str = msg['img_path']
|
||||
pred = await self.texteller_server.predict.remote(img_path)
|
||||
return pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ckpt_dir = args.checkpoint_dir
|
||||
tknz_dir = args.tokenizer_dir
|
||||
|
||||
serve.start(http_options={"port": args.server_port})
|
||||
texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
|
||||
ingress = Ingress.bind(texteller_server)
|
||||
|
||||
ingress_handle = serve.run(ingress, route_prefix="/predict")
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
9
src/start_web.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
set -exu
|
||||
|
||||
export CHECKPOINT_DIR="OleehyO/TexTeller"
|
||||
export TOKENIZER_DIR="OleehyO/TexTeller"
|
||||
export USE_CUDA=False # True or False (case-sensitive)
|
||||
export NUM_BEAM=1
|
||||
|
||||
streamlit run web.py
|
||||
93
src/web.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import tempfile
|
||||
import streamlit as st
|
||||
|
||||
from PIL import Image
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_model():
|
||||
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_tokenizer():
|
||||
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
||||
|
||||
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
||||
# ============================ pages =============================== #
|
||||
html_string = '''
|
||||
<h1 style="color: orange; text-align: center;">
|
||||
✨ TexTeller ✨
|
||||
</h1>
|
||||
'''
|
||||
st.markdown(html_string, unsafe_allow_html=True)
|
||||
|
||||
if "start" not in st.session_state:
|
||||
st.balloons()
|
||||
st.session_state["start"] = 1
|
||||
|
||||
uploaded_file = st.file_uploader("",type=['jpg', 'png'])
|
||||
|
||||
if uploaded_file:
|
||||
img = Image.open(uploaded_file)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
png_file_path = os.path.join(temp_dir, 'image.png')
|
||||
img.save(png_file_path, 'PNG')
|
||||
|
||||
def get_image_base64(img_file):
|
||||
buffered = io.BytesIO()
|
||||
img_file.seek(0)
|
||||
img = Image.open(img_file)
|
||||
img.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
img_base64 = get_image_base64(uploaded_file)
|
||||
|
||||
st.markdown(f"""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
.centered-image {{
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 700px;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
|
||||
<p style="color:gray;">Input image ({img.height}✖️{img.width})</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
st.write("")
|
||||
st.write("")
|
||||
|
||||
with st.spinner("Predicting..."):
|
||||
uploaded_file.seek(0)
|
||||
TeXTeller_result = inference(
|
||||
model,
|
||||
tokenizer,
|
||||
[png_file_path],
|
||||
True if os.environ['USE_CUDA'] == 'True' else False,
|
||||
int(os.environ['NUM_BEAM'])
|
||||
)[0]
|
||||
|
||||
# st.subheader(':rainbow[Predict] :sunglasses:', divider='rainbow')
|
||||
st.subheader(':sunglasses:', divider='gray')
|
||||
st.latex(TeXTeller_result)
|
||||
st.code(TeXTeller_result, language='latex')
|
||||
st.success('Done!')
|
||||
|
||||
# ============================ pages =============================== #
|
||||