inference.py支持katex
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import cv2 as cv
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from models.ocr_model.utils.inference import inference
|
from utils import to_katex
|
||||||
|
from models.ocr_model.utils.inference import inference as latex_inference
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
|
|
||||||
|
|
||||||
@@ -21,16 +23,16 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='use cuda or not'
|
help='use cuda or not'
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# You can use your own checkpoint and tokenizer path.
|
# You can use your own checkpoint and tokenizer path.
|
||||||
print('Loading model and tokenizer...')
|
print('Loading model and tokenizer...')
|
||||||
model = TexTeller.from_pretrained()
|
latex_rec_model = TexTeller.from_pretrained()
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
tokenizer = TexTeller.get_tokenizer()
|
||||||
print('Model and tokenizer loaded.')
|
print('Model and tokenizer loaded.')
|
||||||
|
|
||||||
img_path = [args.img]
|
img = cv.imread(args.img)
|
||||||
print('Inference...')
|
print('Inference...')
|
||||||
res = inference(model, tokenizer, img_path, args.cuda)
|
res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
|
||||||
print(res[0])
|
res = to_katex(res[0])
|
||||||
|
print(res)
|
||||||
|
|||||||
15
src/utils.py
Normal file
15
src/utils.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def to_katex(formula: str) -> str:
|
||||||
|
res = formula
|
||||||
|
res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
|
||||||
|
res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
|
||||||
|
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
|
||||||
|
|
||||||
|
pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
|
||||||
|
replacement = r'\1\2'
|
||||||
|
res = re.sub(pattern, replacement, res)
|
||||||
|
if res.endswith(r'\newline'):
|
||||||
|
res = res[:-8]
|
||||||
|
return res
|
||||||
15
src/web.py
15
src/web.py
@@ -4,11 +4,11 @@ import base64
|
|||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import re
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from models.ocr_model.utils.inference import inference
|
from models.ocr_model.utils.inference import inference
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
|
from utils import to_katex
|
||||||
|
|
||||||
|
|
||||||
html_string = '''
|
html_string = '''
|
||||||
@@ -66,19 +66,6 @@ def get_model():
|
|||||||
def get_tokenizer():
|
def get_tokenizer():
|
||||||
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
||||||
|
|
||||||
def to_katex(formula: str) -> str:
|
|
||||||
res = formula
|
|
||||||
res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
|
|
||||||
res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
|
|
||||||
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
|
|
||||||
|
|
||||||
pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
|
|
||||||
replacement = r'\1\2'
|
|
||||||
res = re.sub(pattern, replacement, res)
|
|
||||||
if res.endswith(r'\newline'):
|
|
||||||
res = res[:-8]
|
|
||||||
return res
|
|
||||||
|
|
||||||
def get_image_base64(img_file):
|
def get_image_base64(img_file):
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
img_file.seek(0)
|
img_file.seek(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user