checkpoint
This commit is contained in:
75
src/gradio_web.py
Normal file
75
src/gradio_web.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from utils import to_katex
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
|
||||
# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
||||
|
||||
|
||||
css = """
|
||||
<style>
|
||||
.container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 20px;
|
||||
font-family: 'Arial';
|
||||
}
|
||||
.container img {
|
||||
height: auto;
|
||||
}
|
||||
.text {
|
||||
margin: 0 15px;
|
||||
}
|
||||
h1 {
|
||||
text-align: center;
|
||||
font-size: 50px !important;
|
||||
}
|
||||
|
||||
.markdown-style {
|
||||
color: #333; /* 调整颜色 */
|
||||
line-height: 1.6; /* 行间距 */
|
||||
font-size: 50px;
|
||||
}
|
||||
.markdown-style h1, .markdown-style h2, .markdown-style h3 {
|
||||
color: #007BFF; /* 为标题元素指定颜色 */
|
||||
}
|
||||
.markdown-style p {
|
||||
margin-bottom: 1em; /* 段落间距 */
|
||||
}
|
||||
</style>
|
||||
"""
|
||||
|
||||
theme=gr.themes.Default(),
|
||||
|
||||
def fn(img):
|
||||
return img
|
||||
|
||||
with gr.Blocks(
|
||||
theme=theme,
|
||||
css=css
|
||||
) as demo:
|
||||
gr.HTML(f'''
|
||||
{css}
|
||||
<div class="container">
|
||||
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
|
||||
<h1> 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 </h1>
|
||||
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
|
||||
</div>
|
||||
''')
|
||||
|
||||
with gr.Row(equal_height=True):
|
||||
input_img = gr.Image(type="pil", label="Input Image")
|
||||
latex_img = gr.Image(label="Predicted Latex", show_label=False)
|
||||
input_img.upload(fn, input_img, latex_img)
|
||||
|
||||
gr.Markdown(r'$$\fcxrac{7}{10349}$$')
|
||||
gr.Markdown('fooooooooooooooooooooooooooooo')
|
||||
|
||||
|
||||
demo.launch()
|
||||
@@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
from models.globals import (
|
||||
from ...globals import (
|
||||
VOCAB_SIZE,
|
||||
FIXED_IMG_SIZE,
|
||||
IMG_CHANNELS,
|
||||
|
||||
@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
||||
)
|
||||
|
||||
# trainer.train(resume_from_checkpoint=None)
|
||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
|
||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
@@ -96,7 +96,7 @@ if __name__ == '__main__':
|
||||
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
|
||||
|
||||
# ================= debug =======================
|
||||
# foo = train_dataset[:50]
|
||||
|
||||
21
src/web.py
21
src/web.py
@@ -35,27 +35,6 @@ fail_gif_html = '''
|
||||
</h1>
|
||||
'''
|
||||
|
||||
tex = r'''
|
||||
\documentclass{{article}}
|
||||
\usepackage[
|
||||
left=1in, % 左边距
|
||||
right=1in, % 右边距
|
||||
top=1in, % 上边距
|
||||
bottom=1in,% 下边距
|
||||
paperwidth=40cm, % 页面宽度
|
||||
paperheight=40cm % 页面高度,这里以A4纸为例
|
||||
]{{geometry}}
|
||||
|
||||
\usepackage[utf8]{{inputenc}}
|
||||
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
|
||||
|
||||
\begin{{document}}
|
||||
|
||||
{formula}
|
||||
|
||||
\pagenumbering{{gobble}}
|
||||
\end{{document}}
|
||||
'''
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
Reference in New Issue
Block a user