Initial commit
This commit is contained in:
75
requirements.txt
Normal file
75
requirements.txt
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
absl-py==2.0.0
|
||||||
|
accelerate==0.26.0
|
||||||
|
aiohttp==3.9.1
|
||||||
|
aiosignal==1.3.1
|
||||||
|
async-timeout==4.0.3
|
||||||
|
attrs==23.2.0
|
||||||
|
cachetools==5.3.2
|
||||||
|
certifi==2023.11.17
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
datasets==2.16.1
|
||||||
|
dill==0.3.7
|
||||||
|
filelock==3.13.1
|
||||||
|
frozenlist==1.4.1
|
||||||
|
fsspec==2023.10.0
|
||||||
|
google-auth==2.26.2
|
||||||
|
google-auth-oauthlib==1.2.0
|
||||||
|
grpcio==1.60.0
|
||||||
|
huggingface-hub==0.20.2
|
||||||
|
idna==3.6
|
||||||
|
Jinja2==3.1.2
|
||||||
|
Markdown==3.5.2
|
||||||
|
MarkupSafe==2.1.3
|
||||||
|
mpmath==1.3.0
|
||||||
|
multidict==6.0.4
|
||||||
|
multiprocess==0.70.15
|
||||||
|
networkx==3.2.1
|
||||||
|
numpy==1.26.3
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
nvidia-nccl-cu12==2.18.1
|
||||||
|
nvidia-nvjitlink-cu12==12.3.101
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
oauthlib==3.2.2
|
||||||
|
packaging==23.2
|
||||||
|
pandas==2.1.4
|
||||||
|
pillow==10.2.0
|
||||||
|
protobuf==4.23.4
|
||||||
|
psutil==5.9.7
|
||||||
|
pyarrow==14.0.2
|
||||||
|
pyarrow-hotfix==0.6
|
||||||
|
pyasn1==0.5.1
|
||||||
|
pyasn1-modules==0.3.0
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
pytz==2023.3.post1
|
||||||
|
PyYAML==6.0.1
|
||||||
|
regex==2023.12.25
|
||||||
|
requests==2.31.0
|
||||||
|
requests-oauthlib==1.3.1
|
||||||
|
rsa==4.9
|
||||||
|
safetensors==0.4.1
|
||||||
|
six==1.16.0
|
||||||
|
sympy==1.12
|
||||||
|
tensorboard==2.15.1
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
tensorboardX==2.6.2.2
|
||||||
|
tokenizers==0.15.0
|
||||||
|
torch==2.1.2
|
||||||
|
torchaudio==2.1.2
|
||||||
|
torchvision==0.16.2
|
||||||
|
tqdm==4.66.1
|
||||||
|
transformers==4.36.2
|
||||||
|
triton==2.1.0
|
||||||
|
typing_extensions==4.9.0
|
||||||
|
tzdata==2023.4
|
||||||
|
urllib3==2.1.0
|
||||||
|
Werkzeug==3.0.1
|
||||||
|
xxhash==3.4.1
|
||||||
|
yarl==1.9.4
|
||||||
7
run.sh
Executable file
7
run.sh
Executable file
@@ -0,0 +1,7 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# 设置 CUDA 设备
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,4
|
||||||
|
|
||||||
|
# 运行 Python 脚本并将输出重定向到日志文件
|
||||||
|
nohup python -m src.models.resizer.train.train > train_result_pred_height_v3.log 2>&1 &
|
||||||
33
src/globals.py
Normal file
33
src/globals.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# 公式图片(灰度化后)的均值和方差
|
||||||
|
IMAGE_MEAN = 0.9545467
|
||||||
|
IMAGE_STD = 0.15394445
|
||||||
|
|
||||||
|
|
||||||
|
# ========================= TeXify模型用的参数 ============================= #
|
||||||
|
|
||||||
|
# 输入图片的最大最小的宽和高
|
||||||
|
MIN_HEIGHT = 32
|
||||||
|
MAX_HEIGHT = 512
|
||||||
|
MIN_WIDTH = 32
|
||||||
|
MAX_WIDTH = 1280
|
||||||
|
# LaTex-OCR中分别是 32、192、32、672
|
||||||
|
|
||||||
|
# TeXify模型所用数据集中,图片所用的Density渲染值
|
||||||
|
TEXIFY_INPUT_DENSITY = 80
|
||||||
|
|
||||||
|
# ============================================================================= #
|
||||||
|
|
||||||
|
|
||||||
|
# ========================= Resizer模型用的参数 ============================= #
|
||||||
|
|
||||||
|
# Resizer模型所用数据集中,图片所用的Density渲染值
|
||||||
|
RESIZER_INPUT_DENSITY = 200
|
||||||
|
|
||||||
|
LABEL_RATIO = 1.0 * TEXIFY_INPUT_DENSITY / RESIZER_INPUT_DENSITY
|
||||||
|
|
||||||
|
NUM_CLASSES = 1 # 模型使用回归预测(最后会接一个sigmoid,预测0~1)
|
||||||
|
NUM_CHANNELS = 1 # 输入单通道图片(灰度图)
|
||||||
|
|
||||||
|
# Resizer在训练时,图片所固定的的大小
|
||||||
|
RESIZER_IMG_SIZE = 448
|
||||||
|
# ============================================================================= #
|
||||||
0
src/inference.py
Normal file
0
src/inference.py
Normal file
74
src/models/resizer/inference.py
Normal file
74
src/models/resizer/inference.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from .model.Resizer import Resizer
|
||||||
|
from .utils import preprocess_fn
|
||||||
|
|
||||||
|
from munch import Munch
|
||||||
|
|
||||||
|
|
||||||
|
def load_resizer():
|
||||||
|
model = Resizer.from_pretrained('/home/lhy/code/TeXify/src/models/resizer/train/res_wo_sigmoid_train_result_v2/checkpoint-96000')
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_teller():
|
||||||
|
arguments = Munch(
|
||||||
|
{
|
||||||
|
'config': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex/config.yaml',
|
||||||
|
'checkpoint': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex_v1/pix2tex_v1_e30_step4265.pth',
|
||||||
|
'no_cuda': False,
|
||||||
|
'no_resize': False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def inference_v2(img: Image):
|
||||||
|
# img = img.convert('RGB') if img.format == 'PNG' else img
|
||||||
|
# processed_img = preprocess_fn({"pixel_values": [img]})
|
||||||
|
|
||||||
|
# resizer = load_resizer(resizer_path)
|
||||||
|
# inpu = torch.stack(processed_img['pixel_values'])
|
||||||
|
# pred_size = resizer(inpu)
|
||||||
|
|
||||||
|
# teller = load_teller(teller_path)
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def inference(args):
|
||||||
|
img = Image.open(args.image)
|
||||||
|
img = img.convert('RGB') if img.format == 'PNG' else img
|
||||||
|
processed_img = preprocess_fn({"pixel_values": [img]})
|
||||||
|
|
||||||
|
ckt_path = Path(args.checkpoint).resolve()
|
||||||
|
model = Resizer.from_pretrained(ckt_path)
|
||||||
|
model.eval()
|
||||||
|
inpu = torch.stack(processed_img['pixel_values'])
|
||||||
|
pred = model(inpu)
|
||||||
|
print(pred)
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cur_dirpath = os.getcwd()
|
||||||
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-img', '--image', type=str, required=True)
|
||||||
|
parser.add_argument('-ckt', '--checkpoint', type=str, required=True)
|
||||||
|
|
||||||
|
args = parser.parse_args([
|
||||||
|
'-img', '/home/lhy/code/TeXify/src/models/resizer/foo5_140h.jpg',
|
||||||
|
'-ckt', '/home/lhy/code/TeXify/src/models/resizer/train/train_result_pred_height_v5'
|
||||||
|
])
|
||||||
|
inference(args)
|
||||||
|
|
||||||
|
os.chdir(cur_dirpath)
|
||||||
5
src/models/resizer/model/Resizer.py
Normal file
5
src/models/resizer/model/Resizer.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from transformers import ResNetForImageClassification
|
||||||
|
|
||||||
|
class Resizer(ResNetForImageClassification):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
122
src/models/resizer/train/train.py
Normal file
122
src/models/resizer/train/train.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import os
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from transformers import (
|
||||||
|
ResNetConfig,
|
||||||
|
TrainingArguments,
|
||||||
|
Trainer
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..utils import preprocess_fn
|
||||||
|
from ..model.Resizer import Resizer
|
||||||
|
from ....globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
cur_dirpath = os.getcwd()
|
||||||
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
data = datasets.load_dataset("./dataset").shuffle(seed=42)
|
||||||
|
data = data.rename_column("images", "pixel_values")
|
||||||
|
data.flatten_indices()
|
||||||
|
data = data.with_transform(preprocess_fn)
|
||||||
|
train_data, test_data = data['train'], data['test']
|
||||||
|
|
||||||
|
config = ResNetConfig(
|
||||||
|
num_channels=NUM_CHANNELS,
|
||||||
|
num_labels=NUM_CLASSES,
|
||||||
|
img_size=RESIZER_IMG_SIZE
|
||||||
|
)
|
||||||
|
model = Resizer(config)
|
||||||
|
model = Resizer.from_pretrained("/home/lhy/code/TeXify/src/models/resizer/train/train_result_pred_height_v4/checkpoint-213000")
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
# resume_from_checkpoint="/home/lhy/code/TeXify/src/models/resizer/train/train_result_pred_height_v3/checkpoint-94500",
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
# use_cpu=True,
|
||||||
|
seed=42, # 随机种子,用于确保实验的可重复性
|
||||||
|
# data_seed=42, # data sampler的采样也固定
|
||||||
|
# full_determinism=True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
||||||
|
|
||||||
|
output_dir='./train_result_pred_height_v5', # 输出目录
|
||||||
|
overwrite_output_dir=False, # 如果输出目录存在,不删除原先的内容
|
||||||
|
report_to=["tensorboard"], # 输出日志到TensorBoard,
|
||||||
|
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
||||||
|
|
||||||
|
logging_dir=None, # TensorBoard日志文件的存储目录
|
||||||
|
log_level="info",
|
||||||
|
logging_strategy="steps", # 每隔一定步数记录一次日志
|
||||||
|
logging_steps=500, # 记录日志的步数间隔
|
||||||
|
logging_nan_inf_filter=False, # 对loss=nan或inf进行记录
|
||||||
|
|
||||||
|
num_train_epochs=50, # 总的训练轮数
|
||||||
|
# max_steps=3, # 训练的最大步骤数。如果设置了这个参数,
|
||||||
|
#+那么num_train_epochs将被忽略(通常用于调试)
|
||||||
|
|
||||||
|
# label_names = ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
||||||
|
|
||||||
|
per_device_train_batch_size=55, # 每个GPU的batch size
|
||||||
|
per_device_eval_batch_size=48*2, # 每个GPU的evaluation batch size
|
||||||
|
auto_find_batch_size=False, # 自动搜索合适的batch size(指数decay)
|
||||||
|
|
||||||
|
optim = 'adamw_torch', # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
||||||
|
#+当设置了optim后,就不需要在Trainer中传入optimizer
|
||||||
|
lr_scheduler_type="cosine", # 设置lr_scheduler
|
||||||
|
warmup_ratio=0.1, # warmup占整个训练steps的比例
|
||||||
|
# warmup_steps=500, # 预热步数
|
||||||
|
weight_decay=0, # 权重衰减
|
||||||
|
learning_rate=5e-5, # 学习率
|
||||||
|
fp16=False, # 是否使用16位浮点数进行训练
|
||||||
|
gradient_accumulation_steps=1, # 梯度累积步数,当batch size无法开很大时,可以考虑这个参数来实现大batch size的效果
|
||||||
|
gradient_checkpointing=False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间)
|
||||||
|
label_smoothing_factor=0.0, # softlabel,等于0时表示未开启
|
||||||
|
# debug='underflow_overflow', # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug)
|
||||||
|
torch_compile=True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能)
|
||||||
|
#+ 要求torch > 2.0,并且这个功能现在还不是很稳定
|
||||||
|
# deepspeed='your_json_path', # 使用deepspeed来训练,需要指定ds_config.json的路径
|
||||||
|
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等)
|
||||||
|
#+ 如果不一致,会出现很奇怪的bug(而且一般还很难发现)
|
||||||
|
|
||||||
|
dataloader_pin_memory=True, # 可以加快数据在cpu和gpu之间转移的速度
|
||||||
|
dataloader_num_workers=16, # 默认不会使用多进程来加载数据
|
||||||
|
dataloader_drop_last=True, # 丢掉最后一个minibatch
|
||||||
|
|
||||||
|
evaluation_strategy="steps", # 评估策略,可以是"steps"或"epoch"
|
||||||
|
eval_steps=500, # if evaluation_strategy="step"
|
||||||
|
# eval_steps=10, # if evaluation_strategy="step"
|
||||||
|
|
||||||
|
save_strategy="steps", # 保存checkpoint的策略
|
||||||
|
save_steps=1500, # 模型保存的步数间隔
|
||||||
|
save_total_limit=5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
|
||||||
|
|
||||||
|
load_best_model_at_end=True, # 训练结束时是否加载最佳模型
|
||||||
|
metric_for_best_model="eval_loss", # 用于选择最佳模型的指标
|
||||||
|
greater_is_better=False, # 指标值越小越好
|
||||||
|
|
||||||
|
do_train=True, # 是否进行训练,通常用于调试
|
||||||
|
do_eval=True, # 是否进行评估,通常用于调试
|
||||||
|
|
||||||
|
remove_unused_columns=True, # 是否删除没有用到的列(特征),默认为True
|
||||||
|
#+当删除了没用到的列后,making it easier to unpack inputs into the model’s call function
|
||||||
|
|
||||||
|
push_to_hub=False, # 是否训练完后上传hub,需要先在命令行:huggingface-cli login进行登录认证的配置,配置完后,认证信息会存到cache文件夹里
|
||||||
|
hub_model_id="a_different_name", # 模型的名字
|
||||||
|
#+每次保存模型时,都会上传到hub,
|
||||||
|
#+训练完后,记得trainer.push_to_hub(),会将模型使用的参数以及验证集上的结果传到hub上
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
training_args,
|
||||||
|
train_dataset=train_data,
|
||||||
|
eval_dataset=test_data,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
os.chdir(cur_dirpath)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train()
|
||||||
1
src/models/resizer/utils/__init__.py
Normal file
1
src/models/resizer/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .preprocess import preprocess_fn
|
||||||
73
src/models/resizer/utils/preprocess.py
Normal file
73
src/models/resizer/utils/preprocess.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import torch
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
from PIL import Image, ImageChops
|
||||||
|
from ....globals import (
|
||||||
|
IMAGE_MEAN, IMAGE_STD,
|
||||||
|
LABEL_RATIO,
|
||||||
|
RESIZER_IMG_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_white_border(image: Image):
|
||||||
|
if image.mode == 'RGB':
|
||||||
|
bg_color = (255, 255, 255)
|
||||||
|
elif image.mode == 'RGBA':
|
||||||
|
bg_color = (255, 255, 255, 255)
|
||||||
|
elif image.mode == 'L':
|
||||||
|
bg_color = 255
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported image mode")
|
||||||
|
bg = Image.new(image.mode, image.size, bg_color)
|
||||||
|
diff = ImageChops.difference(image, bg)
|
||||||
|
diff = ImageChops.add(diff, diff, 2.0, -100)
|
||||||
|
bbox = diff.getbbox()
|
||||||
|
if bbox:
|
||||||
|
return image.crop(bbox)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
imgs = samples['pixel_values']
|
||||||
|
imgs = [trim_white_border(img) for img in imgs]
|
||||||
|
labels = [float(img.height * LABEL_RATIO) for img in imgs]
|
||||||
|
|
||||||
|
transform = v2.Compose([
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.uint8, scale=True),
|
||||||
|
v2.Grayscale(),
|
||||||
|
v2.Resize(
|
||||||
|
size=RESIZER_IMG_SIZE - 1, # size必须小于max_size
|
||||||
|
interpolation=v2.InterpolationMode.BICUBIC,
|
||||||
|
max_size=RESIZER_IMG_SIZE,
|
||||||
|
antialias=True
|
||||||
|
),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||||
|
])
|
||||||
|
imgs = transform(imgs)
|
||||||
|
imgs = [
|
||||||
|
v2.functional.pad(
|
||||||
|
img,
|
||||||
|
padding=[0, 0, RESIZER_IMG_SIZE - img.shape[2], RESIZER_IMG_SIZE - img.shape[1]]
|
||||||
|
)
|
||||||
|
for img in imgs
|
||||||
|
]
|
||||||
|
|
||||||
|
res = {'pixel_values': imgs, 'labels': labels}
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": # unit test
|
||||||
|
import datasets
|
||||||
|
data = datasets.load_dataset("/home/lhy/code/TeXify/src/models/resizer/train/dataset/dataset.py").shuffle(seed=42)
|
||||||
|
data = data.with_transform(preprocess_fn)
|
||||||
|
train_data, test_data = data['train'], data['test']
|
||||||
|
|
||||||
|
inpu = train_data[:10]
|
||||||
|
pause = 1
|
||||||
33
src/web.py
Normal file
33
src/web.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import time
|
||||||
|
|
||||||
|
from stqdm import stqdm
|
||||||
|
|
||||||
|
# 使用 Markdown 和 HTML 将标题居中
|
||||||
|
with st.columns(3)[1]:
|
||||||
|
st.title(":rainbow[TexTeller] :sparkles:")
|
||||||
|
|
||||||
|
if "start" not in st.session_state:
|
||||||
|
st.balloons()
|
||||||
|
st.session_state["start"] = 1
|
||||||
|
|
||||||
|
uploaded_file = st.file_uploader("",type=['jpg', 'png'])
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
if uploaded_file:
|
||||||
|
st.image(uploaded_file, caption="Input image")
|
||||||
|
|
||||||
|
for _ in stqdm(range(10), st_container=st.sidebar):
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
with st.spinner('Wait for it...'):
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
st.success('Done!')
|
||||||
|
|
||||||
|
|
||||||
|
with st.empty():
|
||||||
|
for seconds in range(60):
|
||||||
|
st.write(f"⏳ {seconds} seconds have passed")
|
||||||
|
time.sleep(1)
|
||||||
|
st.write("✔️ 1 minute over!")
|
||||||
Reference in New Issue
Block a user