checkpoint

This commit is contained in:
三洋三洋
2024-04-16 13:56:56 +00:00
parent 7d1d8ddd77
commit f81a31a8c9
6 changed files with 238 additions and 24 deletions

157
assets/css.css Normal file
View File

@@ -0,0 +1,157 @@
html {
font-family: Inter;
font-size: 16px;
font-weight: 400;
line-height: 1.5;
-webkit-text-size-adjust: 100%;
background: #fff;
color: #323232;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
}
:root {
--space: 1;
--vspace: calc(var(--space) * 1rem);
--vspace-0: calc(3 * var(--space) * 1rem);
--vspace-1: calc(2 * var(--space) * 1rem);
--vspace-2: calc(1.5 * var(--space) * 1rem);
--vspace-3: calc(0.5 * var(--space) * 1rem);
}
.app {
max-width: 748px !important;
}
.prose p {
margin: var(--vspace) 0;
line-height: var(--vspace * 2);
font-size: 1rem;
}
code {
font-family: "inconsolata", sans-serif;
font-size: 16px;
}
h1,
h1 code {
font-weight: 400;
line-height: calc(2.5 / var(--space) * var(--vspace));
}
h1 code {
background: none;
border: none;
letter-spacing: 0.05em;
padding-bottom: 5px;
position: relative;
padding: 0;
}
h2 {
margin: var(--vspace-1) 0 var(--vspace-2) 0;
line-height: 1em;
}
h3,
h3 code {
margin: var(--vspace-1) 0 var(--vspace-2) 0;
line-height: 1em;
}
h4,
h5,
h6 {
margin: var(--vspace-3) 0 var(--vspace-3) 0;
line-height: var(--vspace);
}
.bigtitle,
h1,
h1 code {
font-size: calc(8px * 4.5);
word-break: break-word;
}
.title,
h2,
h2 code {
font-size: calc(8px * 3.375);
font-weight: lighter;
word-break: break-word;
border: none;
background: none;
}
.subheading1,
h3,
h3 code {
font-size: calc(8px * 1.8);
font-weight: 600;
border: none;
background: none;
letter-spacing: 0.1em;
text-transform: uppercase;
}
h2 code {
padding: 0;
position: relative;
letter-spacing: 0.05em;
}
blockquote {
font-size: calc(8px * 1.1667);
font-style: italic;
line-height: calc(1.1667 * var(--vspace));
margin: var(--vspace-2) var(--vspace-2);
}
.subheading2,
h4 {
font-size: calc(8px * 1.4292);
text-transform: uppercase;
font-weight: 600;
}
.subheading3,
h5 {
font-size: calc(8px * 1.2917);
line-height: calc(1.2917 * var(--vspace));
font-weight: lighter;
text-transform: uppercase;
letter-spacing: 0.15em;
}
h6 {
font-size: calc(8px * 1.1667);
font-size: 1.1667em;
font-weight: normal;
font-style: italic;
font-family: "le-monde-livre-classic-byol", serif !important;
letter-spacing: 0px !important;
}
#start .md > *:first-child {
margin-top: 0;
}
h2 + h3 {
margin-top: 0;
}
.md hr {
border: none;
border-top: 1px solid var(--block-border-color);
margin: var(--vspace-2) 0 var(--vspace-2) 0;
}
.prose ul {
margin: var(--vspace-2) 0 var(--vspace-1) 0;
}
.gap {
gap: 0;
}

View File

@@ -2,6 +2,9 @@ transformers
datasets
evaluate
streamlit
gradio
opencv-python
ray[serve]
accelerate

75
src/gradio_web.py Normal file
View File

@@ -0,0 +1,75 @@
import os
import gradio as gr
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from pathlib import Path
# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
css = """
<style>
.container {
display: flex;
align-items: center;
justify-content: center;
font-size: 20px;
font-family: 'Arial';
}
.container img {
height: auto;
}
.text {
margin: 0 15px;
}
h1 {
text-align: center;
font-size: 50px !important;
}
.markdown-style {
color: #333; /* 调整颜色 */
line-height: 1.6; /* 行间距 */
font-size: 50px;
}
.markdown-style h1, .markdown-style h2, .markdown-style h3 {
color: #007BFF; /* 为标题元素指定颜色 */
}
.markdown-style p {
margin-bottom: 1em; /* 段落间距 */
}
</style>
"""
theme=gr.themes.Default(),
def fn(img):
return img
with gr.Blocks(
theme=theme,
css=css
) as demo:
gr.HTML(f'''
{css}
<div class="container">
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
<h1> 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 </h1>
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
</div>
''')
with gr.Row(equal_height=True):
input_img = gr.Image(type="pil", label="Input Image")
latex_img = gr.Image(label="Predicted Latex", show_label=False)
input_img.upload(fn, input_img, latex_img)
gr.Markdown(r'$$\fcxrac{7}{10349}$$')
gr.Markdown('fooooooooooooooooooooooooooooo')
demo.launch()

View File

@@ -1,6 +1,6 @@
from pathlib import Path
from models.globals import (
from ...globals import (
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,

View File

@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
)
# trainer.train(resume_from_checkpoint=None)
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
def evaluate(model, tokenizer, eval_dataset, collate_fn):
@@ -96,7 +96,7 @@ if __name__ == '__main__':
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
# ================= debug =======================
# foo = train_dataset[:50]

View File

@@ -35,27 +35,6 @@ fail_gif_html = '''
</h1>
'''
tex = r'''
\documentclass{{article}}
\usepackage[
left=1in, % 左边距
right=1in, % 右边距
top=1in, % 上边距
bottom=1in,% 下边距
paperwidth=40cm, % 页面宽度
paperheight=40cm % 页面高度这里以A4纸为例
]{{geometry}}
\usepackage[utf8]{{inputenc}}
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
\begin{{document}}
{formula}
\pagenumbering{{gobble}}
\end{{document}}
'''
@st.cache_resource