[chore] exclude paddleocr directory from pre-commit hooks
BIN
texteller/models/ocr_model/train/augraphy_cache/image_0.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_1.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_10.png
Normal file
|
After Width: | Height: | Size: 4.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_11.png
Normal file
|
After Width: | Height: | Size: 8.5 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_12.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_13.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_14.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_15.png
Normal file
|
After Width: | Height: | Size: 7.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_16.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_17.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_18.png
Normal file
|
After Width: | Height: | Size: 5.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_19.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_2.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_20.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_21.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_22.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_23.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_24.png
Normal file
|
After Width: | Height: | Size: 10 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_25.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_26.png
Normal file
|
After Width: | Height: | Size: 9.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_27.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_28.png
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_29.png
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_4.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_6.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_7.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_8.png
Normal file
|
After Width: | Height: | Size: 6.0 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_9.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/0.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/1.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/10.png
Normal file
|
After Width: | Height: | Size: 6.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/11.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/12.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/13.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/14.png
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/15.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/16.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/17.png
Normal file
|
After Width: | Height: | Size: 2.6 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/18.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/19.png
Normal file
|
After Width: | Height: | Size: 2.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/2.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/20.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/21.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/22.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/23.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/24.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/25.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/26.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/27.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/28.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/29.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/30.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/31.png
Normal file
|
After Width: | Height: | Size: 4.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/32.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/33.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/34.png
Normal file
|
After Width: | Height: | Size: 3.2 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/4.png
Normal file
|
After Width: | Height: | Size: 5.7 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/6.png
Normal file
|
After Width: | Height: | Size: 4.8 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/7.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/8.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
texteller/models/ocr_model/train/dataset/train/9.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
@@ -0,0 +1,35 @@
|
||||
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
114
texteller/models/ocr_model/train/train.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import (
|
||||
tokenize_fn,
|
||||
collate_fn,
|
||||
img_train_transform,
|
||||
img_inf_transform,
|
||||
filter_fn,
|
||||
)
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eval_config = CONFIG.copy()
|
||||
eval_config['predict_with_generate'] = True
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
eval_config['generation_config'] = generate_config
|
||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model,
|
||||
seq2seq_config,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
eval_res = trainer.evaluate()
|
||||
print(eval_res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
|
||||
dataset = dataset.filter(
|
||||
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
|
||||
)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(
|
||||
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||
)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
model = TexTeller()
|
||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
# +e.g.
|
||||
# +model = TexTeller.from_pretrained(
|
||||
# + '/path/to/your/model_checkpoint'
|
||||
# +)
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate and len(eval_dataset) > 0:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
31
texteller/models/ocr_model/train/training_args.py
Normal file
@@ -0,0 +1,31 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||