27 Commits
prev ... v1.0.1

Author SHA1 Message Date
OleehyO
a329453873 [deps] Grouped deps & setup vcs 2025-04-21 04:48:29 +00:00
OleehyO
42db737ae6 [chore] Update README.md 2025-04-21 04:47:53 +00:00
OleehyO
29f6f8960d [CI/CD] Setup complete workflow 2025-04-21 03:00:06 +00:00
OleehyO
6a67017962 [chore] Setup vcs and deps 2025-04-21 02:41:46 +00:00
OleehyO
d4a32094a7 [chore] Ignore images 2025-04-21 02:41:06 +00:00
OleehyO
3eceedd96b [docs] Set up documentation structure with API reference 2025-04-21 02:38:36 +00:00
OleehyO
8aad53e079 Upload logo.svg 2025-04-21 02:37:28 +00:00
OleehyO
df03525fd7 [feat] Support dynamic package vcs 2025-04-21 02:36:13 +00:00
OleehyO
8a000edb7b [docs] Add comprehensive function documentation 2025-04-21 02:34:56 +00:00
OleehyO
375ad4a4cb Add globals test 2025-04-21 02:32:05 +00:00
OleehyO
59df576c85 [test] Init 2025-04-19 16:36:48 +00:00
OleehyO
d4cef5135f [feat] Add texteller training script 2025-04-19 16:36:43 +00:00
OleehyO
25142c34f1 [CI] Update ruff hook 2025-04-19 14:32:32 +00:00
OleehyO
0cba17d9ce [refactor] Init 2025-04-19 14:32:28 +00:00
OleehyO
e0cbf2c99f [chore] Cleanup 2025-04-17 07:08:47 +00:00
OleehyO
4e2740ada0 [feat] Support n-gram stop criteria 2025-04-02 03:23:27 +00:00
OleehyO
509cb75dfa [deps] Change onnx-gpu to manually install 2025-04-02 02:48:23 +00:00
三洋三洋
5673adecff [feat][formatter] Integrate LaTeX formatter for improved formula readability
- Add latex_formatter.py based on tex-fmt (https://github.com/WGUNDERWOOD/tex-fmt)
- Update to_katex.py to use the new formatter
- Enhance LaTeX formula output with better formatting and readability

This integration helps make generated LaTeX formulas more readable and
maintainable by applying consistent formatting rules.
2025-03-01 00:55:41 +08:00
三洋三洋
5cda58e8fc [chore] Ignore ruff lint E741 2025-03-01 00:54:57 +08:00
三洋三洋
c35b0e9f53 [fix] Add project prefix 2025-02-28 23:38:12 +08:00
三洋三洋
f023c6741e [feat] Remove bold style 2025-02-28 23:38:12 +08:00
三洋三洋
be9e32b439 [deps] Add ray serve & python-multipart 2025-02-28 23:37:53 +08:00
三洋三洋
cd9e4146e0 [chore] Add build system and pakage location 2025-02-28 23:18:06 +08:00
三洋三洋
3e9c6c00b8 [chore] Add python related rules 2025-02-28 23:18:03 +08:00
三洋三洋
0e5c5fd706 [chore] Remove unsed files 2025-02-28 20:54:51 +08:00
三洋三洋
4d3714bb4b [chore] exclude paddleocr directory from pre-commit hooks 2025-02-28 20:01:54 +08:00
三洋三洋
3296077461 [chore] Setup project infrastructure 2025-02-28 20:01:52 +08:00
133 changed files with 4222 additions and 3108 deletions

36
.github/workflows/deploy-doc.yml vendored Normal file
View File

@@ -0,0 +1,36 @@
name: "Sphinx: Render docs"
on: push
jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: pip install uv
- name: Install docs dependencies
run: uv pip install --system -e ".[docs]"
- name: Build HTML
run: |
cd docs
make html
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: html-docs
path: docs/build/html/
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
if: github.ref == 'refs/heads/main'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/build/html/

37
.github/workflows/pr-welcome.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
name: PR Welcome Bot
on:
pull_request:
types: [opened]
jobs:
welcome:
runs-on: ubuntu-latest
steps:
- name: Post Welcome Comment
uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const prNumber = context.issue.number;
const prAuthor = context.payload.pull_request.user.login;
const welcomeMessage = `
👋 Hello @${prAuthor}, thank you for contributing to this project! 🎉
We've received your Pull Request and the team will review it as soon as possible.
In the meantime, please ensure:
- [ ] Your code follows the project's coding style
- [ ] Relevant tests have been added and are passing
- [ ] Documentation has been updated if needed
If you have any questions, feel free to ask here. Happy coding! 😊
`;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: welcomeMessage
});

36
.github/workflows/publish.yml vendored Normal file
View File

@@ -0,0 +1,36 @@
name: Publish to PyPI
on:
push:
branches:
- 'main'
tags:
- 'v*'
jobs:
publish:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Build package with uv
run: |
uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
uv publish --token $UV_PUBLISH_TOKEN

27
.github/workflows/python-lint.yml vendored Normal file
View File

@@ -0,0 +1,27 @@
name: Python Linting
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit
- name: Run pre-commit
run: pre-commit run --all-files

35
.github/workflows/test.yaml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Run Tests with Pytest
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install dependencies
run: |
uv sync --group test
- name: Run tests with pytest
run: |
uv run pytest tests/

369
.gitignore vendored
View File

@@ -1,28 +1,349 @@
**/.DS_Store
**/__pycache__
# Created by https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm,python
# Edit at https://www.toptal.com/developers/gitignore?templates=macos,visualstudiocode,pycharm,python
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### macOS Patch ###
# iCloud generated files
*.icloud
### PyCharm ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### PyCharm Patch ###
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
# *.iml
# modules.xml
# .idea/misc.xml
# *.ipr
# Sonarlint plugin
# https://plugins.jetbrains.com/plugin/7973-sonarlint
.idea/**/sonarlint/
# SonarQube Plugin
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
.idea/**/sonarIssues.xml
# Markdown Navigator plugin
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
.idea/**/markdown-navigator.xml
.idea/**/markdown-navigator-enh.xml
.idea/**/markdown-navigator/
# Cache file creation bug
# See https://youtrack.jetbrains.com/issue/JBR-2257
.idea/$CACHE_FILE$
# CodeStream plugin
# https://plugins.jetbrains.com/plugin/12206-codestream
.idea/codestream.xml
# Azure Toolkit for IntelliJ plugin
# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
.idea/**/azureSettings.xml
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### VisualStudioCode ###
**/.vscode
**/pyrightconfig.json
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
**/dist
**/build
*.egg-info
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
# End of https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm,python
uv.lock
**/train_result
**/ckpt
**/ckpts
**/*.safetensor
**/trocr-*
**/large*.onnx
**/rtdetr_r50vd_6x_coco.onnx
**/*cache
**/.cache
**/tmp
**/tmp*
**/log
**/logs
**/data
**/*.bin
**/*.onnx
**/*.png
**/*.jpg
**/augraphy_cache

22
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.6
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
exclude: ^texteller/models/thrid_party/paddleocr/
- id: ruff-format
args: [--config=pyproject.toml]
exclude: ^texteller/models/thrid_party/paddleocr/
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.10

View File

@@ -1 +0,0 @@
include README.md

213
README.md
View File

@@ -2,51 +2,27 @@
<div align="center">
<h1>
<img src="./assets/fire.svg" width=30, height=30>
<img src="./assets/fire.svg" width=30, height=30>
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
<img src="./assets/fire.svg" width=30, height=30>
</h1>
<!-- <p align="center">
🤗 <a href="https://huggingface.co/OleehyO/TexTeller"> Hugging Face </a>
</p> -->
[![](https://img.shields.io/badge/License-Apache_2.0-blue.svg?logo=github)](https://opensource.org/licenses/Apache-2.0)
[![](https://img.shields.io/badge/API-Docs-orange.svg?logo=read-the-docs)](https://oleehyo.github.io/TexTeller/)
[![](https://img.shields.io/badge/docker-pull-green.svg?logo=docker)](https://hub.docker.com/r/oleehyo/texteller)
[![](https://img.shields.io/badge/Data-Texteller1.0-brightgreen.svg?logo=huggingface)](https://huggingface.co/datasets/OleehyO/latex-formulas)
[![](https://img.shields.io/badge/Weights-Texteller3.0-yellow.svg?logo=huggingface)](https://huggingface.co/OleehyO/TexTeller)
[![](https://img.shields.io/badge/License-Apache_2.0-blue.svg?logo=github)](https://opensource.org/licenses/Apache-2.0)
</div>
<!-- <p align="center">
<a href="https://opensource.org/licenses/Apache-2.0">
<img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License">
</a>
<a href="https://github.com/OleehyO/TexTeller/issues">
<img src="https://img.shields.io/badge/Maintained%3F-yes-green.svg" alt="Maintenance">
</a>
<a href="https://github.com/OleehyO/TexTeller/pulls">
<img src="https://img.shields.io/badge/Contributions-welcome-brightgreen.svg?style=flat" alt="Contributions welcome">
</a>
<a href="https://huggingface.co/datasets/OleehyO/latex-formulas">
<img src="https://img.shields.io/badge/Data-Texteller1.0-brightgreen.svg" alt="Data">
</a>
<a href="https://huggingface.co/OleehyO/TexTeller">
<img src="https://img.shields.io/badge/Weights-Texteller3.0-yellow.svg" alt="Weights">
</a>
</p> -->
https://github.com/OleehyO/TexTeller/assets/56267907/532d1471-a72e-4960-9677-ec6c19db289f
TexTeller is an end-to-end formula recognition model based on [TrOCR](https://arxiv.org/abs/2109.10282), capable of converting images into corresponding LaTeX formulas.
TexTeller is an end-to-end formula recognition model, capable of converting images into corresponding LaTeX formulas.
TexTeller was trained with **80M image-formula pairs** (previous dataset can be obtained [here](https://huggingface.co/datasets/OleehyO/latex-formulas)), compared to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which used a 100K dataset, TexTeller has **stronger generalization abilities** and **higher accuracy**, covering most use cases.
>[!NOTE]
> If you would like to provide feedback or suggestions for this project, feel free to start a discussion in the [Discussions section](https://github.com/OleehyO/TexTeller/discussions).
>
> Additionally, if you find this project helpful, please don't forget to give it a star⭐🙏
---
@@ -55,15 +31,12 @@ TexTeller was trained with **80M image-formula pairs** (previous dataset can be
<td>
## 🔖 Table of Contents
- [Change Log](#-change-log)
- [Getting Started](#-getting-started)
- [Web Demo](#-web-demo)
- [Server](#-server)
- [Python API](#-python-api)
- [Formula Detection](#-formula-detection)
- [API Usage](#-api-usage)
- [Training](#-training)
- [Plans](#-plans)
- [Stargazers over time](#-stargazers-over-time)
- [Contributors](#-contributors)
</td>
<td>
@@ -76,18 +49,9 @@ TexTeller was trained with **80M image-formula pairs** (previous dataset can be
</figcaption>
</figure>
<div>
<p>
Thanks to the
<i>
Super Computing Platform of Beijing University of Posts and Telecommunications
</i>
for supporting this work😘
</p>
<!-- <img src="assets/scss.png" width="200"> -->
</div>
</div>
</td>
</tr>
</table>
@@ -110,153 +74,118 @@ TexTeller was trained with **80M image-formula pairs** (previous dataset can be
## 🚀 Getting Started
1. Clone the repository:
```bash
git clone https://github.com/OleehyO/TexTeller
```
2. Install the project's dependencies:
1. Install the project's dependencies:
```bash
pip install texteller
```
3. Enter the `src/` directory and run the following command in the terminal to start inference:
2. If your are using CUDA backend, you may need to install `onnxruntime-gpu`:
```bash
python inference.py -img "/path/to/image.{jpg,png}"
# use --inference-mode option to enable GPU(cuda or mps) inference
#+e.g. python inference.py -img "img.jpg" --inference-mode cuda
pip install texteller[onnxruntime-gpu]
```
> The first time you run it, the required checkpoints will be downloaded from Hugging Face.
### Paragraph Recognition
As demonstrated in the video, TexTeller is also capable of recognizing entire text paragraphs. Although TexTeller has general text OCR capabilities, we still recommend using paragraph recognition for better results:
1. [Download the weights](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true) of the formula detection model to the`src/models/det_model/model/`directory
2. Run `inference.py` in the `src/` directory and add the `-mix` option, the results will be output in markdown format.
3. Run the following command to start inference:
```bash
python inference.py -img "/path/to/image.{jpg,png}" -mix
texteller inference "/path/to/image.{jpg,png}"
```
TexTeller uses the lightweight [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) model by default for recognizing both Chinese and English text. You can try using a larger model to achieve better recognition results for both Chinese and English:
| Checkpoints | Model Description | Size |
|-------------|-------------------| ---- |
| [ch_PP-OCRv4_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx?download=true) | **Default detection model**, supports Chinese-English text detection | 4.70M |
| [ch_PP-OCRv4_server_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_det.onnx?download=true) | High accuracy model, supports Chinese-English text detection | 115M |
| [ch_PP-OCRv4_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_rec.onnx?download=true) | **Default recoginition model**, supports Chinese-English text recognition | 10.80M |
| [ch_PP-OCRv4_server_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx?download=true) | High accuracy model, supports Chinese-English text recognition | 90.60M |
Place the weights of the recognition/detection model in the `det/` or `rec/` directories within `src/models/third_party/paddleocr/checkpoints/`, and rename them to `default_model.onnx`.
> [!NOTE]
> Paragraph recognition cannot restore the structure of a document, it can only recognize its content.
> See `texteller inference --help` for more details
## 🌐 Web Demo
Go to the `src/` directory and run the following command:
Run the following command:
```bash
./start_web.sh
texteller web
```
Enter `http://localhost:8501` in a browser to view the web demo.
> [!NOTE]
> 1. For Windows users, please run the `start_web.bat` file.
> 2. When using onnxruntime + GPU for inference, you need to install onnxruntime-gpu.
> Paragraph recognition cannot restore the structure of a document, it can only recognize its content.
## 🔍 Formula Detection
## 🖥️ Server
TexTellers formula detection model is trained on 3,415 images of Chinese educational materials (with over 130 layouts) and 8,272 images from the [IBEM dataset](https://zenodo.org/records/4757865), and it supports formula detection across entire images.
<div align="center">
<img src="./assets/det_rec.png" width=250>
</div>
1. Download the model weights and place them in `src/models/det_model/model/` [[link](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true)].
2. Run the following command in the `src/` directory, and the results will be saved in `src/subimages/`
<details>
<summary>Advanced: batch formula recognition</summary>
After **formula detection**, run the following command in the `src/` directory:
```shell
python rec_infer_from_crop_imgs.py
```
This will use the results of the previous formula detection to perform batch recognition on all cropped formulas, saving the recognition results as txt files in `src/results/`.
</details>
## 📡 API Usage
We use [ray serve](https://github.com/ray-project/ray) to provide an API interface for TexTeller, allowing you to integrate TexTeller into your own projects. To start the server, you first need to enter the `src/` directory and then run the following command:
We use [ray serve](https://github.com/ray-project/ray) to provide an API server for TexTeller. To start the server, run the following command:
```bash
python server.py
texteller launch
```
| Parameter | Description |
| --------- | -------- |
| `-ckpt` | The path to the weights file,*default is TexTeller's pretrained weights*. |
| `-tknz` | The path to the tokenizer,*default is TexTeller's tokenizer*. |
| `-port` | The server's service port,*default is 8000*. |
| `--inference-mode` | Whether to use "cuda" or "mps" for inference,*default is "cpu"*. |
| `--num_beams` | The number of beams for beam search,*default is 1*. |
| `--num_replicas` | The number of service replicas to run on the server,*default is 1 replica*. You can use more replicas to achieve greater throughput.|
| `--ncpu_per_replica` | The number of CPU cores used per service replica,*default is 1*.|
| `--ngpu_per_replica` | The number of GPUs used per service replica,*default is 1*. You can set this value between 0 and 1 to run multiple service replicas on one GPU to share the GPU, thereby improving GPU utilization. (Note, if --num_replicas is 2, --ngpu_per_replica is 0.7, then 2 GPUs must be available) |
| `-onnx` | Perform inference using Onnx Runtime, *disabled by default* |
| `-p` | The server's service port,*default is 8000*. |
| `--num-replicas` | The number of service replicas to run on the server,*default is 1 replica*. You can use more replicas to achieve greater throughput.|
| `--ncpu-per-replica` | The number of CPU cores used per service replica,*default is 1*.|
| `--ngpu-per-replica` | The number of GPUs used per service replica,*default is 1*. You can set this value between 0 and 1 to run multiple service replicas on one GPU to share the GPU, thereby improving GPU utilization. (Note, if --num_replicas is 2, --ngpu_per_replica is 0.7, then 2 GPUs must be available) |
| `--num-beams` | The number of beams for beam search,*default is 1*. |
| `--use-onnx` | Perform inference using Onnx Runtime, *disabled by default* |
> [!NOTE]
> A client demo can be found at `src/client/demo.py`, you can refer to `demo.py` to send requests to the server
To send requests to the server:
```python
# client_demo.py
import requests
server_url = "http://127.0.0.1:8000/predict"
img_path = "/path/to/your/image"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(server_url, files=files)
print(response.text)
```
## 🐍 Python API
We provide several easy-to-use Python APIs for formula OCR scenarios. Please refer to our [documentation](https://oleehyo.github.io/TexTeller/) to learn about the corresponding API interfaces and usage.
## 🔍 Formula Detection
TexTeller's formula detection model is trained on 3,415 images of Chinese materials and 8,272 images from the [IBEM dataset](https://zenodo.org/records/4757865).
<div align="center">
<img src="./assets/det_rec.png" width=250>
</div>
We provide a formula detection interface in the Python API. Please refer to our [API documentation](https://oleehyo.github.io/TexTeller/) for more details.
## 🏋️‍♂️ Training
### Dataset
Please setup your environment before training:
We provide an example dataset in the `src/models/ocr_model/train/dataset/` directory, you can place your own images in the `images/` directory and annotate each image with its corresponding formula in `formulas.jsonl`.
After preparing your dataset, you need to **change the `DIR_URL` variable to your own dataset's path** in `**/train/dataset/loader.py`
### Retraining the Tokenizer
If you are using a different dataset, you might need to retrain the tokenizer to obtain a different vocabulary. After configuring your dataset, you can train your own tokenizer with the following command:
1. In `src/models/tokenizer/train.py`, change `new_tokenizer.save_pretrained('./your_dir_name')` to your custom output directory
> If you want to use a different vocabulary size (default 15K), you need to change the `VOCAB_SIZE` variable in `src/models/globals.py`
>
2. **In the `src/` directory**, run the following command:
1. Install the dependencies for training:
```bash
python -m models.tokenizer.train
pip install texteller[train]
```
2. Clone the repository:
```bash
git clone https://github.com/OleehyO/TexTeller.git
```
### Dataset
We provide an example dataset in the `examples/train_texteller/dataset/train` directory, you can place your own training data according to the format of the example dataset.
### Training the Model
1. Modify `num_processes` in `src/train_config.yaml` to match the number of GPUs available for training (default is 1).
2. In the `src/` directory, run the following command:
In the `examples/train_texteller/` directory, run the following command:
```bash
accelerate launch --config_file ./train_config.yaml -m models.ocr_model.train.train
accelerate launch train.py
```
You can set your own tokenizer and checkpoint paths in `src/models/ocr_model/train/train.py` (refer to `train.py` for more information). If you are using the same architecture and vocabulary as TexTeller, you can also fine-tune TexTeller's default weights with your own dataset.
In `src/globals.py` and `src/models/ocr_model/train/train_args.py`, you can change the model's architecture and training hyperparameters.
> [!NOTE]
> Our training scripts use the [Hugging Face Transformers](https://github.com/huggingface/transformers) library, so you can refer to their [documentation](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments) for more details and configurations on training parameters.
Training arguments can be adjusted in [`train_config.yaml`](./examples/train_texteller/train_config.yaml).
## 📅 Plans
@@ -266,13 +195,11 @@ In `src/globals.py` and `src/models/ocr_model/train/train_args.py`, you can chan
- [X] ~~Handwritten formulas support~~
- [ ] PDF document recognition
- [ ] Inference acceleration
- [ ] ...
## ⭐️ Stargazers over time
[![Stargazers over time](https://starchart.cc/OleehyO/TexTeller.svg?variant=adaptive)](https://starchart.cc/OleehyO/TexTeller)
## 👥 Contributors
<a href="https://github.com/OleehyO/TexTeller/graphs/contributors">

View File

@@ -1,52 +1,28 @@
📄 <a href="../README.md">English</a> | 中文
📄 中文 | [English](./README.md)
<div align="center">
<h1>
<img src="./fire.svg" width=30, height=30>
<img src="./fire.svg" width=30, height=30>
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
<img src="./fire.svg" width=30, height=30>
</h1>
<!-- <p align="center">
🤗 <a href="https://huggingface.co/OleehyO/TexTeller"> Hugging Face </a>
</p> -->
[![](https://img.shields.io/badge/License-Apache_2.0-blue.svg?logo=github)](https://opensource.org/licenses/Apache-2.0)
[![](https://img.shields.io/badge/docker-pull-green.svg?logo=docker)](https://hub.docker.com/r/oleehyo/texteller)
[![](https://img.shields.io/badge/Data-Texteller1.0-brightgreen.svg?logo=huggingface)](https://huggingface.co/datasets/OleehyO/latex-formulas)
[![](https://img.shields.io/badge/Weights-Texteller3.0-yellow.svg?logo=huggingface)](https://huggingface.co/OleehyO/TexTeller)
[![](https://img.shields.io/badge/API-文档-orange.svg?logo=read-the-docs)](https://oleehyo.github.io/TexTeller/)
[![](https://img.shields.io/badge/docker-镜像-green.svg?logo=docker)](https://hub.docker.com/r/oleehyo/texteller)
[![](https://img.shields.io/badge/数据-Texteller1.0-brightgreen.svg?logo=huggingface)](https://huggingface.co/datasets/OleehyO/latex-formulas)
[![](https://img.shields.io/badge/权重-Texteller3.0-yellow.svg?logo=huggingface)](https://huggingface.co/OleehyO/TexTeller)
[![](https://img.shields.io/badge/协议-Apache_2.0-blue.svg?logo=github)](https://opensource.org/licenses/Apache-2.0)
</div>
<!-- <p align="center">
<a href="https://opensource.org/licenses/Apache-2.0">
<img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License">
</a>
<a href="https://github.com/OleehyO/TexTeller/issues">
<img src="https://img.shields.io/badge/Maintained%3F-yes-green.svg" alt="Maintenance">
</a>
<a href="https://github.com/OleehyO/TexTeller/pulls">
<img src="https://img.shields.io/badge/Contributions-welcome-brightgreen.svg?style=flat" alt="Contributions welcome">
</a>
<a href="https://huggingface.co/datasets/OleehyO/latex-formulas">
<img src="https://img.shields.io/badge/Data-Texteller1.0-brightgreen.svg" alt="Data">
</a>
<a href="https://huggingface.co/OleehyO/TexTeller">
<img src="https://img.shields.io/badge/Weights-Texteller3.0-yellow.svg" alt="Weights">
</a>
</p> -->
https://github.com/OleehyO/TexTeller/assets/56267907/532d1471-a72e-4960-9677-ec6c19db289f
TexTeller是一个基于[TrOCR](https://arxiv.org/abs/2109.10282)的端到端公式识别模型,可以把图片转换为对应的latex公式
TexTeller 是一个端到端公式识别模型,能够将图像转换为对应的 LaTeX 公式
TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以在[这里](https://huggingface.co/datasets/OleehyO/latex-formulas)获取),相比于[LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)(使用了一个100K的数据集)TexTeller具有**更强的泛化能力**以及**更高的准确率**可以覆盖大部分的使用场景。
TexTeller 使用 **8千万图像-公式对** 进行训练(前代数据集可在此[获取](https://huggingface.co/datasets/OleehyO/latex-formulas)),相较 [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) 使用的 10 万量级数据集TexTeller 具有**更强的泛化能力****更高的准确率**覆盖绝大多数使用场景。
> [!NOTE]
> 如果您想本项目提供一些反馈建议,欢迎在[Discussions版块](https://github.com/OleehyO/TexTeller/discussions)发起讨论。
>
> 另外如果您觉得这个项目对您有帮助请不要忘记点亮上方的Star⭐🙏
>[!NOTE]
> 如果您想本项目提反馈建议,欢迎前往 [讨论区](https://github.com/OleehyO/TexTeller/discussions) 发起讨论。
---
@@ -55,17 +31,12 @@ TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以
<td>
## 🔖 目录
- [变更信息](#-变更信息)
- [开搞](#-开搞)
- [常见问题无法连接到Hugging Face](#-常见问题无法连接到hugging-face)
- [快速开始](#-快速开始)
- [网页演示](#-网页演示)
- [服务部署](#-服务部署)
- [Python接口](#-python接口)
- [公式检测](#-公式检测)
- [API调用](#-api调用)
- [训练](#-训练)
- [计划](#-计划)
- [观星曲线](#-观星曲线)
- [贡献者](#-贡献者)
- [模型训练](#-模型训练)
</td>
<td>
@@ -74,17 +45,10 @@ TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以
<figure>
<img src="cover.png" width="800">
<figcaption>
<p>可以被TexTeller识别的图</p>
<p>TexTeller识别的图像示例</p>
</figcaption>
</figure>
<div>
<p>
感谢
<i>
北京邮电大学超算平台
</i>
为本项工作提供支持😘
</p>
</div>
</div>
@@ -92,221 +56,149 @@ TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以
</tr>
</table>
## 🔄 变更信息
## 📮 更新日志
- 📮[2024-06-06] **TexTeller3.0**发布! 训练数据集增加到了**80M**(相较于TexTeller2.0增加了**10倍**,并且改善了数据多样性)。新版的TexTeller具有以下新的特性:
- 支持扫描图片、手写公式以及中英文混合的公式。
- 在打印图片上具有通用的中英文识别能力。
- [2024-06-06] **TexTeller3.0 发布!** 训练数据增至 **8千万**(是 TexTeller2.0**10倍** 并提升了数据多样性)。TexTeller3.0 新特性:
- 📮[2024-05-02] 支持**段落识别**。
- 支持扫描件、手写公式、中英文混合公式识别
- 📮[2024-04-12] **公式检测模型**发布!
- 支持印刷体中英文混排公式的OCR识别
- 📮[2024-03-25] TexTeller2.0发布TexTeller2.0的训练数据增大到了7.5M(相较于TexTeller1.0增加了~15倍并且数据质量也有所改善)。训练后的TexTeller2.0在测试集中展现出了更加优越的性能,尤其在生僻符号、复杂多行、矩阵的识别场景中。
- [2024-05-02] 支持**段落识别**功能
> 在[这里](./test.pdf)有更多的测试图片以及各家识别模型的横向对比。
- [2024-04-12] **公式检测模型**发布!
## 🚀 开搞
- [2024-03-25] TexTeller2.0 发布TexTeller2.0 的训练数据增至750万是前代的15倍并提升了数据质量。训练后的 TexTeller2.0 在测试集中展现了**更优性能**,特别是在识别罕见符号、复杂多行公式和矩阵方面表现突出。
1. 克隆本仓库:
> [此处](./assets/test.pdf) 展示了更多测试图像及各类识别模型的横向对比。
```bash
git clone https://github.com/OleehyO/TexTeller
```
## 🚀 快速开始
2. 安装项目依赖包:
1. 安装项目依赖
```bash
pip install texteller
```
3. 进入`src/`目录,在终端运行以下命令进行推理:
2. 若使用 CUDA 后端,可能需要安装 `onnxruntime-gpu`
```bash
python inference.py -img "/path/to/image.{jpg,png}"
# use --inference-mode option to enable GPU(cuda or mps) inference
#+e.g. python inference.py -img "img.jpg" --inference-mode cuda
pip install texteller[onnxruntime-gpu]
```
> 第一次运行时会在Hugging Face上下载所需要的权重
### 段落识别
如演示视频所示TexTeller还可以识别整个文本段落。尽管TexTeller具备通用的文本OCR能力但我们仍然建议使用段落识别来获得更好的效果
1. 下载公式检测模型的权重到`src/models/det_model/model/`目录 [[链接](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true)]
2. `src/`目录下运行`inference.py`并添加`-mix`选项结果会以markdown的格式进行输出。
3. 运行以下命令开始推理:
```bash
python inference.py -img "/path/to/image.{jpg,png}" -mix
texteller inference "/path/to/image.{jpg,png}"
```
TexTeller默认使用轻量的[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)模型来识别中英文,可以尝试使用更大的模型来获取更好的中英文识别效果:
| 权重 | 描述 | 尺寸 |
|-------------|-------------------| ---- |
| [ch_PP-OCRv4_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx?download=true) | **默认的检测模型**,支持中英文检测 | 4.70M |
| [ch_PP-OCRv4_server_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_det.onnx?download=true) | 高精度模型,支持中英文检测 | 115M |
| [ch_PP-OCRv4_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_rec.onnx?download=true) | **默认的识别模型**,支持中英文识别 | 10.80M |
| [ch_PP-OCRv4_server_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx?download=true) | 高精度模型,支持中英文识别 | 90.60M |
把识别/检测模型的权重放在`src/models/third_party/paddleocr/checkpoints/`
下的`det/`或`rec/`目录中,然后重命名为`default_model.onnx`。
> [!NOTE]
> 段落识别只能识别文档内容,无法还原文档的结构。
## ❓ 常见问题无法连接到Hugging Face
默认情况下会在Hugging Face中下载模型权重**如果你的远端服务器无法连接到Hugging Face**,你可以通过以下命令进行加载:
1. 安装huggingface hub包
```bash
pip install -U "huggingface_hub[cli]"
```
2. 在能连接Hugging Face的机器上下载模型权重:
```bash
huggingface-cli download \
OleehyO/TexTeller \
--repo-type model \
--local-dir "your/dir/path" \
--local-dir-use-symlinks False
```
3. 把包含权重的目录上传远端服务器,然后把 `src/models/ocr_model/model/TexTeller.py`中的 `REPO_NAME = 'OleehyO/TexTeller'`修改为 `REPO_NAME = 'your/dir/path'`
<!-- 如果你还想在训练模型时开启evaluate你需要提前下载metric脚本并上传远端服务器
1. 在能连接Hugging Face的机器上下载metric脚本
```bash
huggingface-cli download \
evaluate-metric/google_bleu \
--repo-type space \
--local-dir "your/dir/path" \
--local-dir-use-symlinks False
```
2. 把这个目录上传远端服务器,并在 `TexTeller/src/models/ocr_model/utils/metrics.py`中把 `evaluate.load('google_bleu')`改为 `evaluate.load('your/dir/path/google_bleu.py')` -->
> 更多参数请查看 `texteller inference --help`
## 🌐 网页演示
进入 `src/` 目录,运行以下命令
运行命令
```bash
./start_web.sh
texteller web
```
在浏览器输入 `http://localhost:8501`就可以看到web demo
在浏览器输入 `http://localhost:8501` 查看网页演示。
> [!NOTE]
> 1. 对于Windows用户, 请运行 `start_web.bat`文件
> 2. 使用onnxruntime + gpu 推理时需要安装onnxruntime-gpu
> 段落识别无法还原文档结构,仅能识别其内容
## 🖥️ 服务部署
我们使用 [ray serve](https://github.com/ray-project/ray) 为 TexTeller 提供 API 服务。启动服务:
```bash
texteller launch
```
| 参数 | 说明 |
| --------- | -------- |
| `-ckpt` | 权重文件路径,*默认为 TexTeller 预训练权重* |
| `-tknz` | 分词器路径,*默认为 TexTeller 分词器* |
| `-p` | 服务端口,*默认 8000* |
| `--num-replicas` | 服务副本数,*默认 1*。可使用更多副本来提升吞吐量 |
| `--ncpu-per-replica` | 单个副本使用的CPU核数*默认 1* |
| `--ngpu-per-replica` | 单个副本使用的GPU数*默认 1*。可设置为0~1之间的值来在单卡上运行多个服务副本共享GPU提升GPU利用率注意若--num_replicas为2--ngpu_per_replica为0.7则需有2块可用GPU |
| `--num-beams` | beam search的束宽*默认 1* |
| `--use-onnx` | 使用Onnx Runtime进行推理*默认关闭* |
向服务发送请求:
```python
# client_demo.py
import requests
server_url = "http://127.0.0.1:8000/predict"
img_path = "/path/to/your/image"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(server_url, files=files)
print(response.text)
```
## 🐍 Python接口
我们为公式OCR场景提供了多个易用的Python API接口请参考[接口文档](https://oleehyo.github.io/TexTeller/)了解对应的API接口及使用方法。
## 🔍 公式检测
TexTeller的公式检测模型在3415张中文教材数据(130+版式)和8272张[IBEM数据集](https://zenodo.org/records/4757865)上训练得到,支持对整张图片进行**公式检测**
TexTeller的公式检测模型在3415张中文资料图像和8272张[IBEM数据集](https://zenodo.org/records/4757865)图像上训练。
<div align="center">
<img src="det_rec.png" width=250>
<img src="./assets/det_rec.png" width=250>
</div>
1. 下载公式检测模型的权重到`src/models/det_model/model/`目录 [[链接](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true)]
我们在Python接口中提供了公式检测接口详见[接口文档](https://oleehyo.github.io/TexTeller/)。
2. `src/`目录下运行以下命令,结果保存在`src/subimages/`
## 🏋️‍♂️ 模型训练
请按以下步骤配置训练环境:
1. 安装训练依赖:
```bash
python infer_det.py
pip install texteller[train]
```
<details>
<summary>更进一步:公式批识别</summary>
在进行**公式检测后**`src/`目录下运行以下命令
```shell
python rec_infer_from_crop_imgs.py
```
会基于上一步公式检测的结果,对裁剪出的所有公式进行批量识别,将识别结果在 `src/results/`中保存为txt文件。
</details>
## 📡 API调用
我们使用[ray serve](https://github.com/ray-project/ray)来对外提供一个TexTeller的API接口通过使用这个接口你可以把TexTeller整合到自己的项目里。要想启动server你需要先进入 `src/`目录然后运行以下命令:
```bash
python server.py
```
| 参数 | 描述 |
| --- | --- |
| `-ckpt` | 权重文件的路径,*默认为TexTeller的预训练权重*。|
| `-tknz` | 分词器的路径,*默认为TexTeller的分词器*。|
| `-port` | 服务器的服务端口,*默认是8000*。|
| `--inference-mode` | 使用"cuda"或"mps"推理,*默认为"cpu"*。|
| `--num_beams` | beam search的beam数量*默认是1*。|
| `--num_replicas` | 在服务器上运行的服务副本数量,*默认1个副本*。你可以使用更多的副本来获取更大的吞吐量。|
| `--ncpu_per_replica` | 每个服务副本所用的CPU核心数*默认为1*。|
| `--ngpu_per_replica` | 每个服务副本所用的GPU数量*默认为1*。你可以把这个值设置成 0~1之间的数这样会在一个GPU上运行多个服务副本来共享GPU从而提高GPU的利用率。(注意,如果 --num_replicas 2, --ngpu_per_replica 0.7, 那么就必须要有2个GPU可用) |
| `-onnx` | 使用Onnx Runtime进行推理*默认不使用*。|
> [!NOTE]
> 一个客户端demo可以在 `TexTeller/client/demo.py`找到,你可以参考 `demo.py`来给server发送请求
## 🏋️‍♂️ 训练
### 数据集
我们在 `src/models/ocr_model/train/dataset/`目录中提供了一个数据集的例子,你可以把自己的图片放在 `images`目录然后在 `formulas.jsonl`中为每张图片标注对应的公式。
准备好数据集后,你需要在 `**/train/dataset/loader.py`中把 **`DIR_URL`变量改成你自己数据集的路径**
### 重新训练分词器
如果你使用了不一样的数据集你可能需要重新训练tokenizer来得到一个不一样的词典。配置好数据集后可以通过以下命令来训练自己的tokenizer
1. 在`src/models/tokenizer/train.py`中,修改`new_tokenizer.save_pretrained('./your_dir_name')`为你自定义的输出目录
> 注意:如果要用一个不一样大小的词典(默认1.5W个token),你需要在`src/models/globals.py`中修改`VOCAB_SIZE`变量
2. **在`src/`目录下**运行以下命令:
2. 克隆仓库:
```bash
python -m models.tokenizer.train
git clone https://github.com/OleehyO/TexTeller.git
```
### 训练模型
### 数据集准备
1. 修改`src/train_config.yaml`中的`num_processes`为训练用的显卡数(默认为1)
我们在`examples/train_texteller/dataset/train`目录中提供了示例数据集,您可按照示例数据集的格式放置自己的训练数据。
2. 在`src/`目录下运行以下命令:
### 开始训练
在`examples/train_texteller/`目录下运行:
```bash
accelerate launch --config_file ./train_config.yaml -m models.ocr_model.train.train
accelerate launch train.py
```
你可以在`src/models/ocr_model/train/train.py`中设置自己的tokenizer和checkpoint路径请参考`train.py`。如果你使用了与TexTeller一样的架构和相同的词典你还可以用自己的数据集来微调TexTeller的默认权重
训练参数可通过[`train_config.yaml`](./examples/train_texteller/train_config.yaml)调整
> [!NOTE]
> 我们的训练脚本使用了[Hugging Face Transformers](https://github.com/huggingface/transformers)库, 所以你可以参考他们提供的[文档](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments)来获取更多训练参数的细节以及配置。
## 📅 计划列表
## 📅 计划
- [X] ~~使用更大的数据集来训练模型~~
- [X] ~~扫描图片识别~~
- [X] ~~使用更大规模数据集训练模型~~
- [X] ~~扫描件识别支持~~
- [X] ~~中英文场景支持~~
- [X] ~~手写公式识别~~
- [X] ~~手写公式支持~~
- [ ] PDF文档识别
- [ ] 推理加速
## ⭐️ 观星曲线
## ⭐️ 项目星标
[![Stargazers over time](https://starchart.cc/OleehyO/TexTeller.svg?variant=adaptive)](https://starchart.cc/OleehyO/TexTeller)
[![Star增长曲线](https://starchart.cc/OleehyO/TexTeller.svg?variant=adaptive)](https://starchart.cc/OleehyO/TexTeller)
## 👥 贡献者

13
assets/logo.svg Normal file
View File

@@ -0,0 +1,13 @@
<svg xmlns="http://www.w3.org/2000/svg" width="354" height="100" viewBox="0 0 354 100">
<text
x="50%"
y="50%"
font-family="Arial, sans-serif"
font-size="55"
text-anchor="middle"
dominant-baseline="middle">
<tspan fill="#F37726">{</tspan><tspan fill="#616161">Tex</tspan><tspan fill="#F37726">}</tspan><tspan fill="#616161">Teller</tspan>
</text>
</svg>

After

Width:  |  Height:  |  Size: 389 B

20
docs/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

0
docs/requirements.txt Normal file
View File

39
docs/source/api.rst Normal file
View File

@@ -0,0 +1,39 @@
API Reference
=============
This section provides detailed API documentation for the TexTeller package. TexTeller is a tool for detecting and recognizing LaTeX formulas in images and converting mixed text and formula images to markdown.
.. contents:: Table of Contents
:local:
:depth: 2
Image to LaTeX Conversion
-------------------------
.. autofunction:: texteller.api.img2latex
Paragraph to Markdown Conversion
------------------------------
.. autofunction:: texteller.api.paragraph2md
LaTeX Detection
---------------
.. autofunction:: texteller.api.detection.latex_detect
Model Loading
-------------
.. autofunction:: texteller.api.load_model
.. autofunction:: texteller.api.load_tokenizer
.. autofunction:: texteller.api.load_latexdet_model
.. autofunction:: texteller.api.load_textdet_model
.. autofunction:: texteller.api.load_textrec_model
KaTeX Conversion
----------------
.. autofunction:: texteller.api.to_katex

75
docs/source/conf.py Normal file
View File

@@ -0,0 +1,75 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute.
import os
import sys
sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'TexTeller'
copyright = '2025, TexTeller Team'
author = 'TexTeller Team'
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
'myst_parser',
'sphinx.ext.duration',
'sphinx.ext.intersphinx',
'sphinx.ext.autosectionlabel',
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'sphinx.ext.autosummary',
'sphinx_copybutton',
# 'sphinx.ext.linkcode',
# 'sphinxarg.ext',
'sphinx_design',
'nbsphinx',
]
templates_path = ['_templates']
exclude_patterns = []
# Autodoc settings
autodoc_member_order = 'bysource'
add_module_names = False
autoclass_content = 'both'
autodoc_default_options = {
'members': True,
'member-order': 'bysource',
'undoc-members': True,
'show-inheritance': True,
'imported-members': True,
}
# Intersphinx settings
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable', None),
'torch': ('https://pytorch.org/docs/stable', None),
'transformers': ('https://huggingface.co/docs/transformers/main/en', None),
}
html_theme = 'sphinx_book_theme'
html_theme_options = {
'repository_url': 'https://github.com/OleehyO/TexTeller',
'use_repository_button': True,
'use_issues_button': True,
'use_edit_page_button': True,
'use_download_button': True,
}
html_logo = "../../assets/logo.svg"

76
docs/source/index.rst Normal file
View File

@@ -0,0 +1,76 @@
.. TexTeller documentation master file, created by
sphinx-quickstart on Sun Apr 20 13:05:53 2025.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
TexTeller Documentation
===========================================
Features
--------
- **Image to LaTeX Conversion**: Convert images containing LaTeX formulas to LaTeX code
- **LaTeX Detection**: Detect and locate LaTeX formulas in mixed text/formula images
- **Paragraph to Markdown**: Convert mixed text and formula images to Markdown format
Installation
-----------
You can install TexTeller using pip:
.. code-block:: bash
pip install texteller
Quick Start
----------
Converting an image to LaTeX:
.. code-block:: python
from texteller import load_model, load_tokenizer, img2latex
# Load models
model = load_model(use_onnx=False)
tokenizer = load_tokenizer()
# Convert image to LaTeX
latex = img2latex(model, tokenizer, ["path/to/image.png"])[0]
Processing a mixed text/formula image:
.. code-block::python
from texteller import (
load_model, load_tokenizer, load_latexdet_model,
load_textdet_model, load_textrec_model, paragraph2md
)
# Load all required models
latex_model = load_model()
tokenizer = load_tokenizer()
latex_detector = load_latexdet_model()
text_detector = load_textdet_model()
text_recognizer = load_textrec_model()
# Convert to markdown
markdown = paragraph2md(
"path/to/mixed_image.png",
latex_detector,
text_detector,
text_recognizer,
latex_model,
tokenizer
)
API Documentation
----------------
For detailed API documentation, please see :doc:`./api`.
.. toctree::
:maxdepth: 2
:hidden:
api

10
examples/client_demo.py Normal file
View File

@@ -0,0 +1,10 @@
import requests
server_url = "http://127.0.0.1:8000/predict"
img_path = "/path/to/your/image"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(server_url, files=files)
print(response.text)

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 8.7 KiB

After

Width:  |  Height:  |  Size: 8.7 KiB

View File

Before

Width:  |  Height:  |  Size: 6.8 KiB

After

Width:  |  Height:  |  Size: 6.8 KiB

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 5.2 KiB

After

Width:  |  Height:  |  Size: 5.2 KiB

View File

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

Before

Width:  |  Height:  |  Size: 2.8 KiB

After

Width:  |  Height:  |  Size: 2.8 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 2.6 KiB

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 2.7 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 3.7 KiB

View File

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

Before

Width:  |  Height:  |  Size: 5.3 KiB

After

Width:  |  Height:  |  Size: 5.3 KiB

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

Before

Width:  |  Height:  |  Size: 4.9 KiB

After

Width:  |  Height:  |  Size: 4.9 KiB

View File

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

Before

Width:  |  Height:  |  Size: 1.8 KiB

After

Width:  |  Height:  |  Size: 1.8 KiB

View File

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

View File

Before

Width:  |  Height:  |  Size: 5.7 KiB

After

Width:  |  Height:  |  Size: 5.7 KiB

View File

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

View File

Before

Width:  |  Height:  |  Size: 4.8 KiB

After

Width:  |  Height:  |  Size: 4.8 KiB

View File

Before

Width:  |  Height:  |  Size: 4.5 KiB

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

Before

Width:  |  Height:  |  Size: 5.2 KiB

After

Width:  |  Height:  |  Size: 5.2 KiB

View File

@@ -0,0 +1,35 @@
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

View File

@@ -0,0 +1,71 @@
from functools import partial
import yaml
from datasets import load_dataset
from transformers import (
Trainer,
TrainingArguments,
)
from texteller import load_model, load_tokenizer
from texteller.constants import MIN_HEIGHT, MIN_WIDTH
from examples.train_texteller.utils import (
collate_fn,
filter_fn,
img_inf_transform,
img_train_transform,
tokenize_fn,
)
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**training_config)
trainer = Trainer(
model,
training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn_with_tokenizer,
)
trainer.train(resume_from_checkpoint=None)
if __name__ == "__main__":
dataset = load_dataset("imagefolder", data_dir="dataset")["train"]
dataset = dataset.filter(
lambda x: x["image"].height > MIN_HEIGHT and x["image"].width > MIN_WIDTH
)
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()
tokenizer = load_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer
# tokenizer = load_tokenizer("/path/to/your/tokenizer")
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
)
# Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset["train"], split_dataset["test"]
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch
model = load_model()
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
# model = load_model("/path/to/your/model_checkpoint")
enable_train = True
training_config = yaml.safe_load(open("train_config.yaml"))
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -0,0 +1,32 @@
# For more information, please refer to the official documentation: https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
seed: 42 # Random seed for reproducibility
use_cpu: false # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
learning_rate: 5.0e-5 # Learning rate
num_train_epochs: 10 # Total number of training epochs
per_device_train_batch_size: 4 # Batch size per GPU for training
per_device_eval_batch_size: 8 # Batch size per GPU for evaluation
output_dir: "train_result" # Output directory
overwrite_output_dir: false # If the output directory exists, do not delete its content
report_to:
- tensorboard # Report logs to TensorBoard
save_strategy: "steps" # Strategy to save checkpoints
save_steps: 500 # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
save_total_limit: 5 # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
logging_strategy: "steps" # Log every certain number of steps
logging_steps: 500 # Number of steps between each log
logging_nan_inf_filter: false # Record logs for loss=nan or inf
optim: "adamw_torch" # Optimizer
lr_scheduler_type: "cosine" # Learning rate scheduler
warmup_ratio: 0.1 # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
max_grad_norm: 1.0 # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
fp16: false # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
bf16: false # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
gradient_accumulation_steps: 1 # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
jit_mode_eval: false # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
torch_compile: false # Whether to use torch.compile to compile the model (for better training and inference performance)
dataloader_pin_memory: true # Can speed up data transfer between CPU and GPU
dataloader_num_workers: 1 # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
evaluation_strategy: "steps" # Evaluation strategy, can be "steps" or "epoch"
eval_steps: 500 # If evaluation_strategy="step"
remove_unused_columns: false # Don't change this unless you really know what you are doing.

View File

@@ -0,0 +1,17 @@
from .functional import (
collate_fn,
filter_fn,
tokenize_fn,
)
from .transforms import (
img_train_transform,
img_inf_transform,
)
__all__ = [
"collate_fn",
"filter_fn",
"tokenize_fn",
"img_train_transform",
"img_inf_transform",
]

View File

@@ -1,9 +1,36 @@
from augraphy import *
"""
Custom augraphy pipeline for training
This file implements a custom augraphy data augmentation pipeline. We found that using augraphy's
default pipeline can cause significant degradation to formula images, potentially losing semantic
information. Therefore, we carefully selected several common augmentation effects,
adjusting their parameters and combination methods to preserve the original semantic information
of the images as much as possible.
"""
from augraphy import (
InkColorSwap,
LinesDegradation,
OneOf,
Dithering,
InkBleed,
InkShifter,
NoiseTexturize,
BrightnessTexturize,
ColorShift,
DirtyDrum,
LightingGradient,
Brightness,
Gamma,
SubtleNoise,
Jpeg,
AugraphyPipeline,
)
import random
def ocr_augmentation_pipeline():
pre_phase = [
]
def get_custom_augraphy():
pre_phase = []
ink_phase = [
InkColorSwap(
@@ -15,8 +42,7 @@ def ocr_augmentation_pipeline():
ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500),
# p=0.2
p=0.4
p=0.2,
),
LinesDegradation(
line_roi=(0.0, 0.0, 1.0, 1.0),
@@ -28,10 +54,8 @@ def ocr_augmentation_pipeline():
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
# p=0.2
p=0.4
p=0.2,
),
# ============================
OneOf(
[
@@ -45,11 +69,9 @@ def ocr_augmentation_pipeline():
severity=(0.4, 0.6),
),
],
# p=0.2
p=0.4
p=0.2,
),
# ============================
# ============================
InkShifter(
text_shift_scale_range=(18, 27),
@@ -58,42 +80,32 @@ def ocr_augmentation_pipeline():
blur_kernel_size=(5, 5),
blur_sigma=0,
noise_type="perlin",
# p=0.2
p=0.4
p=0.2,
),
# ============================
]
paper_phase = [
NoiseTexturize( # tested
NoiseTexturize(
sigma_range=(3, 10),
turbulence_range=(2, 5),
texture_width_range=(300, 500),
texture_height_range=(300, 500),
# p=0.2
p=0.4
p=0.2,
),
BrightnessTexturize( # tested
texturize_range=(0.9, 0.99),
deviation=0.03,
# p=0.2
p=0.4
)
BrightnessTexturize(texturize_range=(0.9, 0.99), deviation=0.03, p=0.2),
]
post_phase = [
ColorShift( # tested
ColorShift(
color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3),
# p=0.2
p=0.4
p=0.2,
),
DirtyDrum( # tested
DirtyDrum(
line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2),
@@ -101,10 +113,8 @@ def ocr_augmentation_pipeline():
noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0,
# p=0.2
p=0.4
p=0.2,
),
# =====================================
OneOf(
[
@@ -126,11 +136,9 @@ def ocr_augmentation_pipeline():
gamma_range=(0.9, 1.1),
),
],
# p=0.2
p=0.4
p=0.2,
),
# =====================================
# =====================================
OneOf(
[
@@ -141,8 +149,7 @@ def ocr_augmentation_pipeline():
quality_range=(70, 95),
),
],
# p=0.2
p=0.4
p=0.2,
),
# =====================================
]
@@ -152,7 +159,7 @@ def ocr_augmentation_pipeline():
paper_phase=paper_phase,
post_phase=post_phase,
pre_phase=pre_phase,
log=False
log=False,
)
return pipeline

View File

@@ -0,0 +1,47 @@
from typing import Any
import torch
from transformers import DataCollatorForLanguageModeling
from texteller.constants import MAX_TOKEN_SIZE, MIN_HEIGHT, MIN_WIDTH
def _left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, "x should be 2-dimensional"
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: dict[str, list[Any]], tokenizer=None) -> dict[str, list[Any]]:
assert tokenizer is not None, "tokenizer should not be None"
tokenized_formula = tokenizer(samples["latex_formula"], return_special_tokens_mask=True)
tokenized_formula["pixel_values"] = samples["image"]
return tokenized_formula
def collate_fn(samples: list[dict[str, Any]], tokenizer=None) -> dict[str, list[Any]]:
assert tokenizer is not None, "tokenizer should not be None"
pixel_values = [dic.pop("pixel_values") for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch["pixel_values"] = pixel_values
batch["decoder_input_ids"] = batch.pop("input_ids")
batch["decoder_attention_mask"] = batch.pop("attention_mask")
batch["labels"] = _left_move(batch["labels"], -100)
# convert list of Image to a tensor with (B, C, H, W)
batch["pixel_values"] = torch.stack(batch["pixel_values"], dim=0)
return batch
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample["image"].height > MIN_HEIGHT
and sample["image"].width > MIN_WIDTH
and len(tokenizer(sample["latex_formula"])["input_ids"]) < MAX_TOKEN_SIZE - 10
)

View File

@@ -4,38 +4,19 @@ import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List, Union
from typing import Any
from PIL import Image
from collections import Counter
from ...globals import (
from texteller.constants import (
IMG_CHANNELS,
FIXED_IMG_SIZE,
IMAGE_MEAN, IMAGE_STD,
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
MAX_RESIZE_RATIO,
MIN_RESIZE_RATIO,
)
from .ocr_aug import ocr_augmentation_pipeline
from texteller.utils import transform as inference_transform
from .augraphy_pipe import get_custom_augraphy
# train_pipeline = default_augraphy_pipeline(scan_only=True)
train_pipeline = ocr_augmentation_pipeline()
general_transform_pipeline = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage()
])
augraphy_pipeline = get_custom_augraphy()
def trim_white_border(image: np.ndarray):
@@ -45,11 +26,10 @@ def trim_white_border(image: np.ndarray):
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
tuple(image[-1, 0]), tuple(image[-1, -1])]
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
@@ -59,9 +39,9 @@ def trim_white_border(image: np.ndarray):
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y:y+h, x:x+w]
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
@@ -69,45 +49,43 @@ def trim_white_border(image: np.ndarray):
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if (pad_height_size + image.shape[0] < 30):
pad_width_size = randi[0] + randi[2]
if pad_height_size + image.shape[0] < 30:
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if (pad_width_size + image.shape[1] < 30):
if pad_width_size + image.shape[1] < 30:
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode='constant',
fill=(255, 255, 255)
padding_mode="constant",
fill=(255, 255, 255),
)
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
def padding(images: list[torch.Tensor], required_size: int) -> list[torch.Tensor]:
images = [
v2.functional.pad(
img,
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(
images: List[np.ndarray],
minr: float,
maxr: float
) -> List[np.ndarray]:
def random_resize(images: list[np.ndarray], minr: float, maxr: float) -> list[np.ndarray]:
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
# Anti-aliasing
cv2.resize(
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
)
for img, r in zip(images, ratios)
]
@@ -133,7 +111,9 @@ def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
rotated_image = cv2.warpAffine(
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
)
return rotated_image
@@ -142,14 +122,14 @@ def ocr_aug(image: np.ndarray) -> np.ndarray:
if random.random() < 0.2:
image = rotate(image, -5, 5)
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
image = train_pipeline(image)
image = augraphy_pipeline(image)
return image
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
def train_transform(images: list[Image.Image]) -> list[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
images = [np.array(img.convert("RGB")) for img in images]
# random resize first
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
images = [trim_white_border(image) for image in images]
@@ -158,19 +138,17 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
images = [ocr_aug(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
images = inference_transform(images)
return images
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images]
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
def img_train_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
processed_img = train_transform(samples["pixel_values"])
samples["pixel_values"] = processed_img
return samples
return images
def img_inf_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
processed_img = inference_transform(samples["pixel_values"])
samples["pixel_values"] = processed_img
return samples

87
pyproject.toml Normal file
View File

@@ -0,0 +1,87 @@
[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
[project]
name = "texteller"
authors = [
{ name="OleehyO", email="leehy0357@gmail.com" }
]
dynamic = ["version"]
description = "Texteller is a tool for converting rendered image to original latex code"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
dependencies = [
"click>=8.1.8",
"colorama>=0.4.6",
"opencv-python-headless>=4.11.0.86",
"pyclipper>=1.3.0.post6",
"shapely>=2.1.0",
"streamlit>=1.44.1",
"streamlit-paste-button>=0.1.2",
"torch>=2.6.0",
"torchvision>=0.21.0",
"transformers==4.45.2",
"wget>=3.2",
"optimum[onnxruntime]>=1.24.0",
"python-multipart>=0.0.20",
"ray[serve]>=2.44.1",
]
[tool.hatch.version]
source = "vcs"
[tool.ruff]
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
target-version = "py310"
line-length = 100
[tool.ruff.format]
line-ending = "lf"
quote-style = "double"
[tool.ruff.lint]
select = ["E", "W"]
ignore = [
"E999",
"EXE001",
"UP009",
"F401",
"TID252",
"F403",
"F841",
"E501",
"W291",
"W293",
"E741",
"E712",
]
[tool.hatch.build.targets.wheel]
packages = ["texteller"]
[project.scripts]
texteller = "texteller.cli:cli"
[project.optional-dependencies]
onnxruntime-gpu = [
"onnxruntime-gpu>=1.21.0",
]
test = [
"pytest>=8.3.5",
]
train = [
"accelerate>=1.6.0",
"augraphy>=8.2.6",
"datasets>=3.5.0",
"tensorboardx>=2.6.2.2",
]
docs = [
"myst-parser>=4.0.1",
"nbsphinx>=0.9.7",
"sphinx>=8.1.3",
"sphinx-book-theme>=1.1.4",
"sphinx-copybutton>=0.5.2",
"sphinx-design>=0.6.1",
]

View File

@@ -1,19 +0,0 @@
transformers==4.45.2
sentence-transformers==3.1.1
datasets
evaluate
opencv-python
ray[serve]
accelerate
tensorboardX
nltk
python-multipart
augraphy
streamlit==1.30
streamlit-paste-button
shapely
pyclipper
onnxruntime-gpu

View File

@@ -1,42 +0,0 @@
from setuptools import setup, find_packages
# Define the base dependencies
install_requires = [
"torch",
"torchvision",
"transformers",
"datasets",
"evaluate",
"opencv-python",
"ray[serve]",
"accelerate",
"tensorboardX",
"nltk",
"python-multipart",
"augraphy",
"streamlit==1.30",
"streamlit-paste-button",
"shapely",
"pyclipper",
"optimum[exporters]",
]
setup(
name="texteller",
version="0.1.2",
author="OleehyO",
author_email="1258009915@qq.com",
description="A meta-package for installing dependencies",
long_description=open('README.md').read(),
long_description_content_type="text/markdown",
url="https://github.com/OleehyO/TexTeller",
packages=find_packages(),
install_requires=install_requires,
classifiers=[
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
],
python_requires='>=3.10',
)

View File

@@ -1,12 +0,0 @@
import requests
rec_server_url = "http://127.0.0.1:8000/frec"
det_server_url = "http://127.0.0.1:8000/fdet"
img_path = "/your/image/path/"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(rec_server_url, files=files)
# response = requests.post(det_server_url, files=files)
print(response.text)

View File

@@ -1,85 +0,0 @@
import os
import argparse
import glob
import subprocess
import onnxruntime
from pathlib import Path
from models.det_model.inference import PredictConfig, predict_image
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
default="./models/det_model/model/infer_cfg.yml")
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
parser.add_argument("--image_dir", type=str, default='./testImgs')
parser.add_argument("--image_file", type=str)
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True)
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--image_file or --image_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
print("Found {} inference images in total.".format(len(images)))
return images
def download_file(url, filename):
print(f"Downloading {filename}...")
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
print("Download complete.")
if __name__ == '__main__':
cur_path = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
FLAGS = parser.parse_args()
if not os.path.exists(FLAGS.infer_cfg):
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
download_file(infer_cfg_url, FLAGS.infer_cfg)
if not os.path.exists(FLAGS.onnx_file):
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
download_file(onnx_file_url, FLAGS.onnx_file)
# load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
if FLAGS.use_gpu:
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider'])
else:
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider'])
# load infer config
infer_config = PredictConfig(FLAGS.infer_cfg)
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
os.chdir(cur_path)

View File

@@ -1,85 +0,0 @@
import os
import argparse
import cv2 as cv
from pathlib import Path
from onnxruntime import InferenceSession
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
from models.thrid_party.paddleocr.infer import utility
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument(
'-img',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps'
)
parser.add_argument(
'--num-beam',
type=int,
default=1,
help='number of beam search for decoding'
)
parser.add_argument(
'-mix',
action='store_true',
help='use mix mode'
)
args = parser.parse_args()
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
img_path = args.img
img = cv.imread(img_path)
print('Inference...')
if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
res = to_katex(res[0])
print(res)
else:
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
use_gpu = args.inference_mode == 'cuda'
SIZE_LIMIT = 20 * 1024 * 1024
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_use_gpu = False
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
paddleocr_args = utility.parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = det_model_dir
paddleocr_args.rec_model_dir = rec_model_dir
paddleocr_args.use_gpu = det_use_gpu
detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args)
lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
print(res)

View File

@@ -1,195 +0,0 @@
import os
import time
import yaml
import numpy as np
import cv2
from tqdm import tqdm
from typing import List
from .preprocess import Compose
from .Bbox import Bbox
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
'DETR'
}
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image
def predict_image(imgsave_dir, infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
first_image_skipped = False
total_time = 0
num_images = 0
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
# Start timing
start_time = time.time()
outputs = predictor.run(output_names=None, input_feed=inputs)
# Stop timing
end_time = time.time()
inference_time = end_time - start_time
if not first_image_skipped:
first_image_skipped = True
else:
total_time += inference_time
num_images += 1
print(f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds")
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} "
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
if len(subimg) == 0:
continue
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_mask = img.copy()
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
cv2.rectangle(img_with_mask, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 255), -1) # 盖白
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir
os.makedirs(output_dir, exist_ok=True)
draw_box_dir = os.path.join(output_dir, 'draw_box')
mask_white_dir = os.path.join(output_dir, 'mask_white')
os.makedirs(draw_box_dir, exist_ok=True)
os.makedirs(mask_white_dir, exist_ok=True)
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
cv2.imwrite(output_file_mask, img_with_mask)
cv2.imwrite(output_file_bbox, img_with_bbox)
avg_time_per_image = total_time / num_images if num_images > 0 else 0
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
print("ErrorImgs:")
print(errImgList)
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
transforms = Compose(infer_config.preprocess_infos)
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = infer_config.label_list[int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > infer_config.draw_threshold:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -1,27 +0,0 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: DETR
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 1600
- 1600
type: Resize
- mean:
- 0.0
- 0.0
- 0.0
norm_type: none
std:
- 1.0
- 1.0
- 1.0
type: NormalizeImage
- type: Permute
label_list:
- isolated
- embedding

View File

@@ -1,499 +0,0 @@
import numpy as np
import cv2
import copy
def decode_image(img_path):
if isinstance(img_path, str):
with open(img_path, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(
im.shape[:2], dtype=np.float32),
"scale_factor": np.array(
[1., 1.], dtype=np.float32)
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
class Pad(object):
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.fill_value = fill_value
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
h, w = self.size
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
im = canvas
return im, im_info
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def get_affine_transform(center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
class WarpAffine(object):
"""Warp affine the image
"""
def __init__(self,
keep_res=False,
pad=31,
input_h=512,
input_w=512,
scale=0.4,
shift=0.1):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
return inp, im_info
# keypoint preprocess
def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
np.sin(theta) + 0.5 * size_target[0])
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
np.cos(theta) + 0.5 * size_target[1])
return matrix
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize, use_udp=False):
self.trainsize = trainsize
self.use_udp = use_udp
def __call__(self, image, im_info):
rot = 0
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.
scale = im_info['scale'] if 'scale' in im_info else imshape
if self.use_udp:
trans = get_warp_matrix(
rot, center * 2.0,
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
else:
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
return image, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs['image'] = img
return inputs

View File

@@ -1,45 +0,0 @@
from pathlib import Path
from ...globals import (
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE
)
from transformers import (
RobertaTokenizerFast,
VisionEncoderDecoderModel,
VisionEncoderDecoderConfig
)
class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
if model_path is None or model_path == 'default':
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
if tokenizer_path is None or tokenizer_path == 'default':
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))

View File

@@ -1,168 +0,0 @@
{
"_name_or_path": "OleehyO/TexTeller",
"architectures": [
"VisionEncoderDecoderModel"
],
"decoder": {
"_name_or_path": "",
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": 768,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.1,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"init_std": 0.02,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layernorm_embedding": true,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 1024,
"min_length": 0,
"model_type": "trocr",
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"scale_embedding": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": false,
"use_learned_position_embeddings": true,
"vocab_size": 15000
},
"encoder": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_probs_dropout_prob": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"encoder_stride": 16,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 448,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-12,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "vit",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 16,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"qkv_bias": false,
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"typical_p": 1.0,
"use_bfloat16": false
},
"is_encoder_decoder": true,
"model_type": "vision-encoder-decoder",
"tie_word_embeddings": false,
"transformers_version": "4.41.2",
"use_cache": true
}

View File

@@ -1,35 +0,0 @@
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

View File

@@ -1,50 +0,0 @@
from PIL import Image
from pathlib import Path
import datasets
import json
DIR_URL = Path('absolute/path/to/dataset/directory')
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
class LatexFormulas(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = []
def _info(self):
return datasets.DatasetInfo(
features=datasets.Features({
"image": datasets.Image(),
"latex_formula": datasets.Value("string")
})
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
dir_path = Path(dl_manager.download(str(DIR_URL)))
assert dir_path.is_dir()
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
'dir_path': dir_path,
}
)
]
def _generate_examples(self, dir_path: Path):
images_path = dir_path / 'images'
formulas_path = dir_path / 'formulas.jsonl'
img2formula = {}
with formulas_path.open('r', encoding='utf-8') as f:
for line in f:
single_json = json.loads(line)
img2formula[single_json['img_name']] = single_json['formula']
for img_path in images_path.iterdir():
if img_path.suffix not in ['.jpg', '.png']:
continue
yield str(img_path), {
"image": Image.open(img_path),
"latex_formula": img2formula[img_path.name]
}

View File

@@ -1,109 +0,0 @@
import os
from functools import partial
from pathlib import Path
from datasets import load_dataset
from transformers import (
Trainer,
TrainingArguments,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
GenerationConfig
)
from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG)
trainer = Trainer(
model,
training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn_with_tokenizer,
)
trainer.train(resume_from_checkpoint=None)
def evaluate(model, tokenizer, eval_dataset, collate_fn):
eval_config = CONFIG.copy()
eval_config['predict_with_generate'] = True
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
eval_config['generation_config'] = generate_config
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
trainer = Seq2SeqTrainer(
model,
seq2seq_config,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
)
eval_res = trainer.evaluate()
print(eval_res)
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()
tokenizer = TexTeller.get_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
dataset = dataset.filter(
filter_fn_with_tokenizer,
num_proc=8
)
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
# Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch
model = TexTeller()
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
#+e.g.
#+model = TexTeller.from_pretrained(
#+ '/path/to/your/model_checkpoint'
#+)
enable_train = True
enable_evaluate = False
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate and len(eval_dataset) > 0:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -1,38 +0,0 @@
CONFIG = {
"seed": 42, # Random seed for reproducibility
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
"learning_rate": 5e-5, # Learning rate
"num_train_epochs": 10, # Total number of training epochs
"per_device_train_batch_size": 4, # Batch size per GPU for training
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
"output_dir": "train_result", # Output directory
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
"report_to": ["tensorboard"], # Report logs to TensorBoard
"save_strategy": "steps", # Strategy to save checkpoints
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
"logging_strategy": "steps", # Log every certain number of steps
"logging_steps": 500, # Number of steps between each log
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
"optim": "adamw_torch", # Optimizer
"lr_scheduler_type": "cosine", # Learning rate scheduler
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
"eval_steps": 500, # If evaluation_strategy="step"
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
}

View File

@@ -1,59 +0,0 @@
import torch
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from .transforms import train_transform, inference_transform
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
def left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, 'x should be 2-dimensional'
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image']
return tokenized_formula
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids')
batch['decoder_attention_mask'] = batch.pop('attention_mask')
# 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100)
# 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = inference_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
)

View File

@@ -1,26 +0,0 @@
import cv2
import numpy as np
from typing import List
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
if image.dtype == np.uint16:
print(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images

View File

@@ -1,49 +0,0 @@
import torch
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
accelerator: str = 'cpu',
num_beams: int = 1,
max_tokens = None
) -> List[str]:
if imgs == []:
return []
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -1,23 +0,0 @@
import evaluate
import numpy as np
import os
from pathlib import Path
from typing import Dict
from transformers import EvalPrediction, RobertaTokenizer
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded
os.chdir(cur_dir)
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)

View File

@@ -1,25 +0,0 @@
import os
from pathlib import Path
from datasets import load_dataset
from ..ocr_model.model.TexTeller import TexTeller
from ..globals import VOCAB_SIZE
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
tokenizer = TexTeller.get_tokenizer()
# Don't forget to config your dataset path in loader.py
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
new_tokenizer = tokenizer.train_new_from_iterator(
text_iterator=dataset['latex_formula'],
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
vocab_size=VOCAB_SIZE
)
# Save the new tokenizer for later training and inference
new_tokenizer.save_pretrained('./your_dir_name')

View File

@@ -1 +0,0 @@
from .mix_inference import mix_inference

View File

@@ -1,264 +0,0 @@
import re
import heapq
import cv2
import time
import numpy as np
from collections import Counter
from typing import List
from PIL import Image
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex, change_all
MAXV = 999999999
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
if (len(sorted_bboxes) == 0):
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
# log results
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while (len(bboxes) > 0):
idx += 1
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None
)
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
# log results
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y:y+h, x:x+w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def mix_inference(
img_path: str,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
accelerator="cpu",
num_beams=1
) -> str:
'''
Input a mixed image of formula text and output str (in markdown syntax)
'''
global img
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
end_time = time.time()
print(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, rec_model = lang_ocr_models
start_time = time.time()
det_prediction, _ = det_model(masked_img)
end_time = time.time()
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
label="text",
confidence=None,
content=None
)
for p in det_prediction
]
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = rec_model(sliced_imgs)
end_time = time.time()
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs =[]
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800)
end_time = time.time()
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if (
prev.label == "isolated"
and curr.label == "text"
and prev.same_row(curr)
):
curr.content = curr.content.strip()
if curr.content.startswith('(') and curr.content.endswith(')'):
curr.content = curr.content[1:-1]
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
else:
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
# change split environment into aligned
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
# remove extra spaces (keeping only one)
curr.content = re.sub(r' +', ' ', curr.content)
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
md += curr.content
prev = curr
return md.strip()

View File

@@ -1,65 +0,0 @@
import os
import argparse
import cv2 as cv
from pathlib import Path
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument(
'-img_dir',
type=str,
help='path to the input image',
default='./detect_results/subimages'
)
parser.add_argument(
'-output_dir',
type=str,
help='path to the output dir',
default='./rec_results'
)
parser.add_argument(
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps'
)
parser.add_argument(
'--num-beam',
type=int,
default=1,
help='number of beam search for decoding'
)
args = parser.parse_args()
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
# Create the output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Loop through all images in the input directory
for filename in os.listdir(args.img_dir):
img_path = os.path.join(args.img_dir, filename)
img = cv.imread(img_path)
if img is not None:
print(f'Inference for {filename}...')
res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0])
# Save the recognition result to a text file
output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt')
with open(output_file, 'w') as f:
f.write(res)
print(f'Result saved to {output_file}')
else:
print(f"Warning: Could not read image {img_path}. Skipping...")

View File

@@ -1,157 +0,0 @@
import sys
import argparse
import tempfile
import time
import numpy as np
import cv2
from pathlib import Path
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from onnxruntime import InferenceSession
from models.ocr_model.utils.inference import inference as rec_inference
from models.det_model.inference import predict as det_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from models.ocr_model.utils.to_katex import to_katex
PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages'
CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib'
parser = argparse.ArgumentParser()
parser.add_argument(
'-ckpt', '--checkpoint_dir', type=str
)
parser.add_argument(
'-tknz', '--tokenizer_dir', type=str
)
parser.add_argument('-port', '--server_port', type=int, default=8000)
parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--inference-mode', type=str, default='cpu')
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('-onnx', action='store_true', help='using onnx runtime')
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica * 1.0 / 2
}
)
class TexTellerRecServer:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
inf_mode: str = 'cpu',
use_onnx: bool = False,
num_beams: int = 1
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.inf_mode = inf_mode
self.num_beams = num_beams
if not use_onnx:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
return to_katex(rec_inference(
self.model, self.tokenizer, [image_nparray],
accelerator=self.inf_mode, num_beams=self.num_beams
)[0])
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
"runtime_env": {
"env_vars": {
"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"
}
}
},
)
class TexTellerDetServer:
def __init__(
self,
inf_mode='cpu'
):
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
self.latex_det_model = InferenceSession(
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider']
)
async def predict(self, image_nparray) -> str:
with tempfile.TemporaryDirectory() as temp_dir:
img_path = f"{temp_dir}/temp_image.jpg"
cv2.imwrite(img_path, image_nparray)
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
return latex_bboxes
@serve.deployment()
class Ingress:
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
self.det_server = det_server
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str:
request_path = request.url.path
form = await request.form()
img_rb = await form['img'].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
if request_path.startswith("/fdet"):
if self.det_server == None:
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
pred = await self.det_server.predict.remote(img_nparray)
return pred
elif request_path.startswith("/frec"):
pred = await self.texteller_server.predict.remote(img_nparray)
return pred
else:
return "[ERROR] Invalid request path"
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
rec_server = TexTellerRecServer.bind(
ckpt_dir, tknz_dir,
inf_mode=args.inference_mode,
use_onnx=args.onnx,
num_beams=args.num_beams
)
det_server = None
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
det_server = TexTellerDetServer.bind(args.inference_mode)
ingress = Ingress.bind(det_server, rec_server)
# ingress_handle = serve.run(ingress, route_prefix="/predict")
ingress_handle = serve.run(ingress, route_prefix="/")
while True:
time.sleep(1)

View File

@@ -1,9 +0,0 @@
@echo off
SETLOCAL ENABLEEXTENSIONS
set CHECKPOINT_DIR=default
set TOKENIZER_DIR=default
streamlit run web.py
ENDLOCAL

View File

@@ -1,7 +0,0 @@
#!/usr/bin/env bash
set -exu
export CHECKPOINT_DIR="default"
export TOKENIZER_DIR="default"
streamlit run web.py

View File

@@ -1,14 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
gpu_ids: all
num_processes: 1
machine_rank: 0
main_training_function: main
num_machines: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,272 +0,0 @@
import os
import io
import re
import base64
import tempfile
import shutil
import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from onnxruntime import InferenceSession
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
from models.thrid_party.paddleocr.infer import utility
from models.utils import mix_inference
from models.det_model.inference import PredictConfig
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.inference import inference as latex_recognition
from models.ocr_model.utils.to_katex import to_katex
st.set_page_config(
page_title="TexTeller",
page_icon="🧮"
)
html_string = '''
<h1 style="color: black; text-align: center;">
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
</h1>
'''
suc_gif_html = '''
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
</h1>
'''
fail_gif_html = '''
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
</h1>
'''
@st.cache_resource
def get_texteller(use_onnx, accelerator):
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'], use_onnx=use_onnx, onnx_provider=accelerator)
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
@st.cache_resource
def get_det_models(accelerator):
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession(
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
providers=['CUDAExecutionProvider'] if accelerator == 'cuda' else ['CPUExecutionProvider']
)
return infer_config, latex_det_model
@st.cache_resource()
def get_ocr_models(accelerator):
use_gpu = accelerator == 'cuda'
SIZE_LIMIT = 20 * 1024 * 1024
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_use_gpu = False
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
paddleocr_args = utility.parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = det_model_dir
paddleocr_args.rec_model_dir = rec_model_dir
paddleocr_args.use_gpu = det_use_gpu
detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args)
return [detector, recognizer]
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def on_file_upload():
st.session_state["UPLOADED_FILE_CHANGED"] = True
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
if "UPLOADED_FILE_CHANGED" not in st.session_state:
st.session_state["UPLOADED_FILE_CHANGED"] = False
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula recognition"
############################## <sidebar> ##############################
with st.sidebar:
num_beams = 1
st.markdown("# 🔨️ Config")
st.markdown("")
inf_mode = st.selectbox(
"Inference mode",
("Formula recognition", "Paragraph recognition"),
on_change=change_side_bar
)
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
accelerator = st.radio(
"Accelerator",
("cpu", "cuda", "mps"),
on_change=change_side_bar
)
st.markdown("## Seedup")
use_onnx = st.toggle("ONNX Runtime ")
############################## </sidebar> ##############################
################################ <page> ################################
texteller = get_texteller(use_onnx, accelerator)
tokenizer = get_tokenizer()
latex_rec_models = [texteller, tokenizer]
if inf_mode == "Paragraph recognition":
infer_config, latex_det_model = get_det_models(accelerator)
lang_ocr_models = get_ocr_models(accelerator)
st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader(
" ",
type=['jpg', 'png'],
on_change=on_file_upload
)
paste_result = pbutton(
label="📋 Paste an image",
background_color="#5BBCFF",
hover_background_color="#3498db",
)
st.write("")
if st.session_state["CHANGE_SIDEBAR_FLAG"] == True:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
elif uploaded_file or paste_result.image_data is not None:
if st.session_state["UPLOADED_FILE_CHANGED"] == False and paste_result.image_data is not None:
uploaded_file = io.BytesIO()
paste_result.image_data.save(uploaded_file, format='PNG')
uploaded_file.seek(0)
if st.session_state["UPLOADED_FILE_CHANGED"] == True:
st.session_state["UPLOADED_FILE_CHANGED"] = False
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
with st.container(height=300):
img_base64 = get_image_base64(uploaded_file)
st.markdown(f"""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-height: 350px;
max-width: 100%;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
</div>
""", unsafe_allow_html=True)
st.markdown(f"""
<style>
.centered-container {{
text-align: center;
}}
</style>
<div class="centered-container">
<p style="color:gray;">Input image ({img.height}✖️{img.width})</p>
</div>
""", unsafe_allow_html=True)
st.write("")
with st.spinner("Predicting..."):
if inf_mode == "Formula recognition":
TexTeller_result = latex_recognition(
texteller,
tokenizer,
[png_file_path],
accelerator=accelerator,
num_beams=num_beams
)[0]
katex_res = to_katex(TexTeller_result)
else:
katex_res = mix_inference(png_file_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
st.success('Completed!', icon="")
st.markdown(suc_gif_html, unsafe_allow_html=True)
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
if inf_mode == "Formula recognition":
st.latex(katex_res)
elif inf_mode == "Paragraph recognition":
mixed_res = re.split(r'(\$\$.*?\$\$)', katex_res)
for text in mixed_res:
if text.startswith('$$') and text.endswith('$$'):
st.latex(text[2:-2])
else:
st.markdown(text)
st.write("")
st.write("")
with st.expander(":star2: :gray[Tips for better results]"):
st.markdown('''
* :mag_right: Use a clear and high-resolution image.
* :scissors: Crop images as accurately as possible.
* :jigsaw: Split large multi line formulas into smaller ones.
* :page_facing_up: Use images with **white background and black text** as much as possible.
* :book: Use a font with good readability.
''')
shutil.rmtree(temp_dir)
paste_result.image_data = None
################################ </page> ################################

67
tests/test_globals.py Normal file
View File

@@ -0,0 +1,67 @@
import logging
from texteller.globals import Globals
def test_singleton_pattern():
"""Test that Globals uses the singleton pattern correctly."""
# Create two instances
globals1 = Globals()
globals2 = Globals()
# Both variables should reference the same object
assert globals1 is globals2
# Modifying one should affect the other
globals1.test_attr = "test_value"
assert globals2.test_attr == "test_value"
# Clean up after test
delattr(globals1, "test_attr")
def test_predefined_attributes():
"""Test predefined attributes have correct default values."""
globals_instance = Globals()
assert globals_instance.repo_name == "OleehyO/TexTeller"
assert globals_instance.logging_level == logging.INFO
def test_attribute_modification():
"""Test that attributes can be modified."""
globals_instance = Globals()
# Modify existing attribute
original_repo_name = globals_instance.repo_name
globals_instance.repo_name = "NewRepo/NewName"
assert globals_instance.repo_name == "NewRepo/NewName"
assert Globals().logging_level == logging.INFO
Globals().logging_level = logging.DEBUG
assert Globals().logging_level == logging.DEBUG
# Reset for other tests
globals_instance.repo_name = original_repo_name
globals_instance.logging_level = logging.INFO
def test_dynamic_attributes():
"""Test that new attributes can be added dynamically."""
globals_instance = Globals()
# Add new attribute
globals_instance.new_attribute = "new_value"
assert globals_instance.new_attribute == "new_value"
# Clean up after test
delattr(globals_instance, "new_attribute")
def test_representation():
"""Test the string representation of Globals."""
globals_instance = Globals()
repr_string = repr(globals_instance)
# Check that repr contains class name and is formatted as expected
assert repr_string.startswith("<Globals:")
assert "repo_name" in repr_string
assert "logging_level" in repr_string

8
tests/test_to_katex.py Normal file
View File

@@ -0,0 +1,8 @@
from texteller import to_katex
def test_to_katex():
# Test complex mathematical equations with vectors and symbols
complex_input = "\\[\\begin{split}&\\mathbb{E}_{\\bm{\\omega}}[h(\\mathbf{x})h(\\mathbf{y})^{ *}]\\\\ &=\\mathbb{E}_{\\bm{\\omega}}[\\exp(i\\bm{\\omega}^{\\top}\\mathbf{x})\\exp (-i\\bm{\\omega}^{\\top}\\mathbf{y}))]\\\\ &=\\mathbb{E}_{\\bm{\\omega}}[\\exp(i\\bm{\\omega}^{\\top}\\bm{\\delta})] \\\\ &=\\int_{\\mathbb{R}^{D}}p(\\bm{\\omega})\\exp(i\\bm{\\omega}^{\\top}\\bm{ \\delta})\\mathrm{d}\\bm{\\omega}\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\bm{ \\omega}^{\\top}\\bm{\\omega}\\big{)}\\exp(i\\bm{\\omega}^{\\top}\\bm{\\delta})\\mathrm{d} \\bm{\\omega}\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\bm{ \\omega}^{\\top}\\bm{\\omega}-i\\bm{\\omega}^{\\top}\\bm{\\delta}\\big{)}\\mathrm{d}\\bm{ \\omega}\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\big{(} \\bm{\\omega}^{\\top}\\bm{\\omega}-2i\\bm{\\omega}^{\\top}\\bm{\\delta}-\\bm{\\delta}^{ \\top}\\bm{\\delta}\\big{)}-\\frac{1}{2}\\bm{\\delta}^{\\top}\\bm{\\delta}\\big{)} \\mathrm{d}\\bm{\\omega}\\\\ &=(2\\pi)^{-D/2}\\exp\\!\\big{(}-\\frac{1}{2}\\bm{\\delta}^{\\top}\\bm{ \\delta}\\big{)}\\!\\underbrace{\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\big{(} \\bm{\\omega}-i\\bm{\\delta}\\big{)}^{\\top}\\big{(}\\bm{\\omega}-i\\bm{\\delta}\\big{)} \\big{)}\\mathrm{d}\\bm{\\omega}}_{(2\\pi)^{D/2}}\\\\ &=\\exp\\!\\big{(}-\\frac{1}{2}\\bm{\\delta}^{\\top}\\bm{\\delta}\\big{)} \\\\ &=k(\\bm{\\delta}).\\end{split}\\]"
expected_output = "\\begin{split}&\\mathbb{E}_{ \\omega}[h( x)h( y)^{ *}]\\\\ &=\\mathbb{E}_{ \\omega}[\\exp(i \\omega^{\\top} x)\\exp (-i \\omega^{\\top} y))]\\\\ &=\\mathbb{E}_{ \\omega}[\\exp(i \\omega^{\\top} \\delta)] \\\\ &=\\int_{\\mathbb{R}^{D}}p( \\omega)\\exp(i \\omega^{\\top} \\delta)\\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2} \\omega^{\\top} \\omega\\big)\\exp(i \\omega^{\\top} \\delta)\\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2} \\omega^{\\top} \\omega-i \\omega^{\\top} \\delta\\big)\\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2}\\big( \\omega^{\\top} \\omega-2i \\omega^{\\top} \\delta- \\delta^{ \\top} \\delta\\big)-\\frac{1}{2} \\delta^{\\top} \\delta\\big) \\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\exp \\big(-\\frac{1}{2} \\delta^{\\top} \\delta\\big) \\underbrace{\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2}\\big( \\omega-i \\delta\\big)^{\\top}\\big( \\omega-i \\delta\\big) \\big)\\mathrm{d} \\omega}_{(2\\pi)^{D/2}}\\\\ &=\\exp \\big(-\\frac{1}{2} \\delta^{\\top} \\delta\\big) \\\\ &=k( \\delta).\n\\end{split}\n"
assert to_katex(complex_input).strip() == expected_output.strip()

5
texteller/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from importlib.metadata import version
from texteller.api import *
__version__ = version("texteller")

24
texteller/api/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
from .detection import latex_detect
from .format import format_latex
from .inference import img2latex, paragraph2md
from .katex import to_katex
from .load import (
load_latexdet_model,
load_model,
load_textdet_model,
load_textrec_model,
load_tokenizer,
)
__all__ = [
"to_katex",
"format_latex",
"img2latex",
"paragraph2md",
"load_model",
"load_tokenizer",
"load_latexdet_model",
"load_textrec_model",
"load_textdet_model",
"latex_detect",
]

View File

@@ -0,0 +1,4 @@
from .ngram import DetectRepeatingNgramCriteria
__all__ = ["DetectRepeatingNgramCriteria"]

View File

@@ -0,0 +1,63 @@
import torch
from transformers import StoppingCriteria
class DetectRepeatingNgramCriteria(StoppingCriteria):
"""
Stops generation efficiently if any n-gram repeats.
This criteria maintains a set of encountered n-grams.
At each step, it checks if the *latest* n-gram is already in the set.
If yes, it stops generation. If no, it adds the n-gram to the set.
"""
def __init__(self, n: int):
"""
Args:
n (int): The size of the n-gram to check for repetition.
"""
if n <= 0:
raise ValueError("n-gram size 'n' must be positive.")
self.n = n
# Stores tuples of token IDs representing seen n-grams
self.seen_ngrams = set()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores.
Return:
`bool`: `True` if generation should stop, `False` otherwise.
"""
batch_size, seq_length = input_ids.shape
# Need at least n tokens to form the first n-gram
if seq_length < self.n:
return False
# --- Efficient Check ---
# Consider only the first sequence in the batch for simplicity
if batch_size > 1:
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
# For now, we'll focus on the first sequence.
pass # No warning needed every step, maybe once in __init__ if needed.
sequence = input_ids[0] # Get the first sequence
# Get the latest n-gram (the one ending at the last token)
last_ngram_tensor = sequence[-self.n :]
# Convert to a hashable tuple for set storage and lookup
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
# Check if this n-gram has been seen before *at any prior step*
if last_ngram_tuple in self.seen_ngrams:
return True # Stop generation
else:
# It's a new n-gram, add it to the set and continue
self.seen_ngrams.add(last_ngram_tuple)
return False # Continue generation

View File

@@ -0,0 +1,3 @@
from .detect import latex_detect
__all__ = ["latex_detect"]

View File

@@ -0,0 +1,69 @@
from typing import List
from onnxruntime import InferenceSession
from texteller.types import Bbox
from .preprocess import Compose
_config = {
"mode": "paddle",
"draw_threshold": 0.5,
"metric": "COCO",
"use_dynamic_shape": False,
"arch": "DETR",
"min_subgraph_size": 3,
"preprocess": [
{"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"},
{
"mean": [0.0, 0.0, 0.0],
"norm_type": "none",
"std": [1.0, 1.0, 1.0],
"type": "NormalizeImage",
},
{"type": "Permute"},
],
"label_list": ["isolated", "embedding"],
}
def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]:
"""
Detect LaTeX formulas in an image and classify them as isolated or embedded.
This function uses an ONNX model to detect LaTeX formulas in images. The model
identifies two types of LaTeX formulas:
- 'isolated': Standalone LaTeX formulas (typically displayed equations)
- 'embedding': Inline LaTeX formulas embedded within text
Args:
img_path: Path to the input image file
predictor: ONNX InferenceSession model for LaTeX detection
Returns:
List of Bbox objects representing the detected LaTeX formulas with their
positions, classifications, and confidence scores
Example:
>>> from texteller.api import load_latexdet_model, latex_detect
>>> model = load_latexdet_model()
>>> bboxes = latex_detect("path/to/image.png", model)
"""
transforms = Compose(_config["preprocess"])
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = _config["label_list"][int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > 0.5:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -0,0 +1,161 @@
import copy
import cv2
import numpy as np
def decode_image(img_path):
if isinstance(img_path, str):
with open(img_path, "rb") as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype="uint8")
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(im.shape[:2], dtype=np.float32),
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im_info["im_shape"] = np.array(im.shape[:2]).astype("float32")
im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32")
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type="mean_std"):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == "mean_std":
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(
self,
):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop("type")
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs["image"] = img
return inputs

Some files were not shown because too many files have changed in this diff Show More