checkpoint

This commit is contained in:
三洋三洋
2024-04-16 13:56:56 +00:00
parent 7d1d8ddd77
commit f81a31a8c9
6 changed files with 238 additions and 24 deletions

75
src/gradio_web.py Normal file
View File

@@ -0,0 +1,75 @@
import os
import gradio as gr
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from pathlib import Path
# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
css = """
<style>
.container {
display: flex;
align-items: center;
justify-content: center;
font-size: 20px;
font-family: 'Arial';
}
.container img {
height: auto;
}
.text {
margin: 0 15px;
}
h1 {
text-align: center;
font-size: 50px !important;
}
.markdown-style {
color: #333; /* 调整颜色 */
line-height: 1.6; /* 行间距 */
font-size: 50px;
}
.markdown-style h1, .markdown-style h2, .markdown-style h3 {
color: #007BFF; /* 为标题元素指定颜色 */
}
.markdown-style p {
margin-bottom: 1em; /* 段落间距 */
}
</style>
"""
theme=gr.themes.Default(),
def fn(img):
return img
with gr.Blocks(
theme=theme,
css=css
) as demo:
gr.HTML(f'''
{css}
<div class="container">
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
<h1> 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 </h1>
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
</div>
''')
with gr.Row(equal_height=True):
input_img = gr.Image(type="pil", label="Input Image")
latex_img = gr.Image(label="Predicted Latex", show_label=False)
input_img.upload(fn, input_img, latex_img)
gr.Markdown(r'$$\fcxrac{7}{10349}$$')
gr.Markdown('fooooooooooooooooooooooooooooo')
demo.launch()

View File

@@ -1,6 +1,6 @@
from pathlib import Path
from models.globals import (
from ...globals import (
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,

View File

@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
)
# trainer.train(resume_from_checkpoint=None)
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
def evaluate(model, tokenizer, eval_dataset, collate_fn):
@@ -96,7 +96,7 @@ if __name__ == '__main__':
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
# ================= debug =======================
# foo = train_dataset[:50]

View File

@@ -35,27 +35,6 @@ fail_gif_html = '''
</h1>
'''
tex = r'''
\documentclass{{article}}
\usepackage[
left=1in, % 左边距
right=1in, % 右边距
top=1in, % 上边距
bottom=1in,% 下边距
paperwidth=40cm, % 页面宽度
paperheight=40cm % 页面高度这里以A4纸为例
]{{geometry}}
\usepackage[utf8]{{inputenc}}
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
\begin{{document}}
{formula}
\pagenumbering{{gobble}}
\end{{document}}
'''
@st.cache_resource