Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12e6bb4312 | ||
|
|
4292be86f2 | ||
|
|
3e0d236f6b | ||
|
|
b653b9e784 | ||
|
|
b85979b258 | ||
|
|
05e494af4b | ||
|
|
4b1b8d10de | ||
|
|
c8e08a22aa | ||
|
|
4d3be22956 | ||
|
|
4e92a38682 | ||
|
|
3e5272a476 | ||
|
|
a329453873 | ||
|
|
42db737ae6 | ||
|
|
29f6f8960d | ||
|
|
6a67017962 | ||
|
|
d4a32094a7 | ||
|
|
3eceedd96b | ||
|
|
8aad53e079 | ||
|
|
df03525fd7 | ||
|
|
8a000edb7b | ||
|
|
375ad4a4cb | ||
|
|
59df576c85 | ||
|
|
d4cef5135f | ||
|
|
25142c34f1 | ||
|
|
0cba17d9ce | ||
|
|
e0cbf2c99f | ||
|
|
4e2740ada0 | ||
|
|
509cb75dfa | ||
|
|
5673adecff | ||
|
|
5cda58e8fc | ||
|
|
c35b0e9f53 | ||
|
|
f023c6741e | ||
|
|
be9e32b439 | ||
|
|
cd9e4146e0 | ||
|
|
3e9c6c00b8 | ||
|
|
0e5c5fd706 | ||
|
|
4d3714bb4b | ||
|
|
3296077461 |
31
.github/workflows/deploy-doc.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name: "Sphinx: Render docs"
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Build HTML
|
||||
uses: ammaraskar/sphinx-action@7.0.0
|
||||
with:
|
||||
pre-build-command: |
|
||||
apt-get update && apt-get install -y git
|
||||
pip install uv
|
||||
uv pip install --system . .[docs]
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: html-docs
|
||||
path: docs/build/html/
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.ref == 'refs/heads/main'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/build/html
|
||||
41
.github/workflows/pr-welcome.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: PR Welcome Bot
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
welcome:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Post Welcome Comment
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const prNumber = context.issue.number;
|
||||
const prAuthor = context.payload.pull_request.user.login;
|
||||
|
||||
const welcomeMessage = `
|
||||
👋 Hello @${prAuthor}, thank you for contributing to this project! 🎉
|
||||
|
||||
We've received your Pull Request and the team will review it as soon as possible.
|
||||
|
||||
In the meantime, please ensure:
|
||||
- [ ] Your code follows the project's coding style
|
||||
- [ ] Relevant tests have been added and are passing
|
||||
- [ ] Documentation has been updated if needed
|
||||
|
||||
If you have any questions, feel free to ask here. Happy coding! 😊
|
||||
`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: welcomeMessage
|
||||
});
|
||||
34
.github/workflows/publish.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Publish to PyPI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Build package with uv
|
||||
run: |
|
||||
uv build
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
uv publish --token $UV_PUBLISH_TOKEN
|
||||
27
.github/workflows/python-lint.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Python Linting
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pre-commit
|
||||
|
||||
- name: Run pre-commit
|
||||
run: pre-commit run --all-files
|
||||
35
.github/workflows/test.yaml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Run Tests with Pytest
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --extra test
|
||||
|
||||
- name: Run tests with pytest
|
||||
run: |
|
||||
uv run pytest -v tests/
|
||||
369
.gitignore
vendored
@@ -1,28 +1,349 @@
|
||||
**/.DS_Store
|
||||
**/__pycache__
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm,python
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=macos,visualstudiocode,pycharm,python
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### macOS Patch ###
|
||||
# iCloud generated files
|
||||
*.icloud
|
||||
|
||||
### PyCharm ###
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# AWS User-specific
|
||||
.idea/**/aws.xml
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# SonarLint plugin
|
||||
.idea/sonarlint/
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### PyCharm Patch ###
|
||||
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
|
||||
|
||||
# *.iml
|
||||
# modules.xml
|
||||
# .idea/misc.xml
|
||||
# *.ipr
|
||||
|
||||
# Sonarlint plugin
|
||||
# https://plugins.jetbrains.com/plugin/7973-sonarlint
|
||||
.idea/**/sonarlint/
|
||||
|
||||
# SonarQube Plugin
|
||||
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
|
||||
.idea/**/sonarIssues.xml
|
||||
|
||||
# Markdown Navigator plugin
|
||||
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
|
||||
.idea/**/markdown-navigator.xml
|
||||
.idea/**/markdown-navigator-enh.xml
|
||||
.idea/**/markdown-navigator/
|
||||
|
||||
# Cache file creation bug
|
||||
# See https://youtrack.jetbrains.com/issue/JBR-2257
|
||||
.idea/$CACHE_FILE$
|
||||
|
||||
# CodeStream plugin
|
||||
# https://plugins.jetbrains.com/plugin/12206-codestream
|
||||
.idea/codestream.xml
|
||||
|
||||
# Azure Toolkit for IntelliJ plugin
|
||||
# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
|
||||
.idea/**/azureSettings.xml
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
### VisualStudioCode ###
|
||||
**/.vscode
|
||||
**/pyrightconfig.json
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
|
||||
**/dist
|
||||
**/build
|
||||
*.egg-info
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
### VisualStudioCode Patch ###
|
||||
# Ignore all local history of files
|
||||
.history
|
||||
.ionide
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm,python
|
||||
|
||||
uv.lock
|
||||
**/train_result
|
||||
**/ckpt
|
||||
**/ckpts
|
||||
**/*.safetensor
|
||||
**/trocr-*
|
||||
**/large*.onnx
|
||||
**/rtdetr_r50vd_6x_coco.onnx
|
||||
|
||||
**/*cache
|
||||
**/.cache
|
||||
|
||||
**/tmp
|
||||
**/tmp*
|
||||
**/log
|
||||
**/logs
|
||||
|
||||
**/data
|
||||
|
||||
**/*.bin
|
||||
**/*.onnx
|
||||
**/*.png
|
||||
**/*.jpg
|
||||
**/augraphy_cache
|
||||
|
||||
22
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.6
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
||||
exclude: ^texteller/models/thrid_party/paddleocr/
|
||||
- id: ruff-format
|
||||
args: [--config=pyproject.toml]
|
||||
exclude: ^texteller/models/thrid_party/paddleocr/
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-added-large-files
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
@@ -1 +0,0 @@
|
||||
include README.md
|
||||
219
README.md
@@ -6,47 +6,23 @@
|
||||
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
|
||||
<img src="./assets/fire.svg" width=30, height=30>
|
||||
</h1>
|
||||
<!-- <p align="center">
|
||||
🤗 <a href="https://huggingface.co/OleehyO/TexTeller"> Hugging Face </a>
|
||||
</p> -->
|
||||
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
[](https://oleehyo.github.io/TexTeller/)
|
||||
[](https://hub.docker.com/r/oleehyo/texteller)
|
||||
[](https://huggingface.co/datasets/OleehyO/latex-formulas)
|
||||
[](https://huggingface.co/OleehyO/TexTeller)
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
|
||||
</div>
|
||||
|
||||
<!-- <p align="center">
|
||||
|
||||
<a href="https://opensource.org/licenses/Apache-2.0">
|
||||
<img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License">
|
||||
</a>
|
||||
<a href="https://github.com/OleehyO/TexTeller/issues">
|
||||
<img src="https://img.shields.io/badge/Maintained%3F-yes-green.svg" alt="Maintenance">
|
||||
</a>
|
||||
<a href="https://github.com/OleehyO/TexTeller/pulls">
|
||||
<img src="https://img.shields.io/badge/Contributions-welcome-brightgreen.svg?style=flat" alt="Contributions welcome">
|
||||
</a>
|
||||
<a href="https://huggingface.co/datasets/OleehyO/latex-formulas">
|
||||
<img src="https://img.shields.io/badge/Data-Texteller1.0-brightgreen.svg" alt="Data">
|
||||
</a>
|
||||
<a href="https://huggingface.co/OleehyO/TexTeller">
|
||||
<img src="https://img.shields.io/badge/Weights-Texteller3.0-yellow.svg" alt="Weights">
|
||||
</a>
|
||||
|
||||
</p> -->
|
||||
|
||||
https://github.com/OleehyO/TexTeller/assets/56267907/532d1471-a72e-4960-9677-ec6c19db289f
|
||||
|
||||
TexTeller is an end-to-end formula recognition model based on [TrOCR](https://arxiv.org/abs/2109.10282), capable of converting images into corresponding LaTeX formulas.
|
||||
TexTeller is an end-to-end formula recognition model, capable of converting images into corresponding LaTeX formulas.
|
||||
|
||||
TexTeller was trained with **80M image-formula pairs** (previous dataset can be obtained [here](https://huggingface.co/datasets/OleehyO/latex-formulas)), compared to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which used a 100K dataset, TexTeller has **stronger generalization abilities** and **higher accuracy**, covering most use cases.
|
||||
|
||||
>[!NOTE]
|
||||
> If you would like to provide feedback or suggestions for this project, feel free to start a discussion in the [Discussions section](https://github.com/OleehyO/TexTeller/discussions).
|
||||
>
|
||||
> Additionally, if you find this project helpful, please don't forget to give it a star⭐️🙏️
|
||||
|
||||
---
|
||||
|
||||
@@ -55,15 +31,12 @@ TexTeller was trained with **80M image-formula pairs** (previous dataset can be
|
||||
<td>
|
||||
|
||||
## 🔖 Table of Contents
|
||||
- [Change Log](#-change-log)
|
||||
- [Getting Started](#-getting-started)
|
||||
- [Web Demo](#-web-demo)
|
||||
- [Server](#-server)
|
||||
- [Python API](#-python-api)
|
||||
- [Formula Detection](#-formula-detection)
|
||||
- [API Usage](#-api-usage)
|
||||
- [Training](#️️-training)
|
||||
- [Plans](#-plans)
|
||||
- [Stargazers over time](#️-stargazers-over-time)
|
||||
- [Contributors](#-contributors)
|
||||
|
||||
</td>
|
||||
<td>
|
||||
@@ -76,187 +49,149 @@ TexTeller was trained with **80M image-formula pairs** (previous dataset can be
|
||||
</figcaption>
|
||||
</figure>
|
||||
<div>
|
||||
<p>
|
||||
Thanks to the
|
||||
<i>
|
||||
Super Computing Platform of Beijing University of Posts and Telecommunications
|
||||
</i>
|
||||
for supporting this work😘
|
||||
</p>
|
||||
<!-- <img src="assets/scss.png" width="200"> -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 🔄 Change Log
|
||||
## 📮 Change Log
|
||||
|
||||
- 📮[2024-06-06] **TexTeller3.0 released!** The training data has been increased to **80M** (**10x more than** TexTeller2.0 and also improved in data diversity). TexTeller3.0's new features:
|
||||
- [2024-06-06] **TexTeller3.0 released!** The training data has been increased to **80M** (**10x more than** TexTeller2.0 and also improved in data diversity). TexTeller3.0's new features:
|
||||
|
||||
- Support scanned image, handwritten formulas, English(Chinese) mixed formulas.
|
||||
|
||||
- OCR abilities in both Chinese and English for printed images.
|
||||
|
||||
- 📮[2024-05-02] Support **paragraph recognition**.
|
||||
- [2024-05-02] Support **paragraph recognition**.
|
||||
|
||||
- 📮[2024-04-12] **Formula detection model** released!
|
||||
- [2024-04-12] **Formula detection model** released!
|
||||
|
||||
- 📮[2024-03-25] TexTeller2.0 released! The training data for TexTeller2.0 has been increased to 7.5M (15x more than TexTeller1.0 and also improved in data quality). The trained TexTeller2.0 demonstrated **superior performance** in the test set, especially in recognizing rare symbols, complex multi-line formulas, and matrices.
|
||||
- [2024-03-25] TexTeller2.0 released! The training data for TexTeller2.0 has been increased to 7.5M (15x more than TexTeller1.0 and also improved in data quality). The trained TexTeller2.0 demonstrated **superior performance** in the test set, especially in recognizing rare symbols, complex multi-line formulas, and matrices.
|
||||
|
||||
> [Here](./assets/test.pdf) are more test images and a horizontal comparison of various recognition models.
|
||||
|
||||
## 🚀 Getting Started
|
||||
|
||||
1. Clone the repository:
|
||||
1. Install uv:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OleehyO/TexTeller
|
||||
pip install uv
|
||||
```
|
||||
|
||||
2. Install the project's dependencies:
|
||||
|
||||
```bash
|
||||
pip install texteller
|
||||
uv pip install texteller
|
||||
```
|
||||
|
||||
3. Enter the `src/` directory and run the following command in the terminal to start inference:
|
||||
3. If your are using CUDA backend, you may need to install `onnxruntime-gpu`:
|
||||
|
||||
```bash
|
||||
python inference.py -img "/path/to/image.{jpg,png}"
|
||||
# use --inference-mode option to enable GPU(cuda or mps) inference
|
||||
#+e.g. python inference.py -img "img.jpg" --inference-mode cuda
|
||||
uv pip install texteller[onnxruntime-gpu]
|
||||
```
|
||||
|
||||
> The first time you run it, the required checkpoints will be downloaded from Hugging Face.
|
||||
|
||||
### Paragraph Recognition
|
||||
|
||||
As demonstrated in the video, TexTeller is also capable of recognizing entire text paragraphs. Although TexTeller has general text OCR capabilities, we still recommend using paragraph recognition for better results:
|
||||
|
||||
1. [Download the weights](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true) of the formula detection model to the`src/models/det_model/model/`directory
|
||||
|
||||
2. Run `inference.py` in the `src/` directory and add the `-mix` option, the results will be output in markdown format.
|
||||
4. Run the following command to start inference:
|
||||
|
||||
```bash
|
||||
python inference.py -img "/path/to/image.{jpg,png}" -mix
|
||||
texteller inference "/path/to/image.{jpg,png}"
|
||||
```
|
||||
|
||||
TexTeller uses the lightweight [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) model by default for recognizing both Chinese and English text. You can try using a larger model to achieve better recognition results for both Chinese and English:
|
||||
|
||||
| Checkpoints | Model Description | Size |
|
||||
|-------------|-------------------| ---- |
|
||||
| [ch_PP-OCRv4_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx?download=true) | **Default detection model**, supports Chinese-English text detection | 4.70M |
|
||||
| [ch_PP-OCRv4_server_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_det.onnx?download=true) | High accuracy model, supports Chinese-English text detection | 115M |
|
||||
| [ch_PP-OCRv4_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_rec.onnx?download=true) | **Default recoginition model**, supports Chinese-English text recognition | 10.80M |
|
||||
| [ch_PP-OCRv4_server_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx?download=true) | High accuracy model, supports Chinese-English text recognition | 90.60M |
|
||||
|
||||
Place the weights of the recognition/detection model in the `det/` or `rec/` directories within `src/models/third_party/paddleocr/checkpoints/`, and rename them to `default_model.onnx`.
|
||||
|
||||
> [!NOTE]
|
||||
> Paragraph recognition cannot restore the structure of a document, it can only recognize its content.
|
||||
> See `texteller inference --help` for more details
|
||||
|
||||
## 🌐 Web Demo
|
||||
|
||||
Go to the `src/` directory and run the following command:
|
||||
Run the following command:
|
||||
|
||||
```bash
|
||||
./start_web.sh
|
||||
texteller web
|
||||
```
|
||||
|
||||
Enter `http://localhost:8501` in a browser to view the web demo.
|
||||
|
||||
> [!NOTE]
|
||||
> 1. For Windows users, please run the `start_web.bat` file.
|
||||
> 2. When using onnxruntime + GPU for inference, you need to install onnxruntime-gpu.
|
||||
> Paragraph recognition cannot restore the structure of a document, it can only recognize its content.
|
||||
|
||||
## 🔍 Formula Detection
|
||||
## 🖥️ Server
|
||||
|
||||
TexTeller’s formula detection model is trained on 3,415 images of Chinese educational materials (with over 130 layouts) and 8,272 images from the [IBEM dataset](https://zenodo.org/records/4757865), and it supports formula detection across entire images.
|
||||
|
||||
<div align="center">
|
||||
<img src="./assets/det_rec.png" width=250>
|
||||
</div>
|
||||
|
||||
1. Download the model weights and place them in `src/models/det_model/model/` [[link](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true)].
|
||||
|
||||
2. Run the following command in the `src/` directory, and the results will be saved in `src/subimages/`
|
||||
|
||||
<details>
|
||||
<summary>Advanced: batch formula recognition</summary>
|
||||
|
||||
After **formula detection**, run the following command in the `src/` directory:
|
||||
|
||||
```shell
|
||||
python rec_infer_from_crop_imgs.py
|
||||
```
|
||||
|
||||
This will use the results of the previous formula detection to perform batch recognition on all cropped formulas, saving the recognition results as txt files in `src/results/`.
|
||||
|
||||
</details>
|
||||
|
||||
## 📡 API Usage
|
||||
|
||||
We use [ray serve](https://github.com/ray-project/ray) to provide an API interface for TexTeller, allowing you to integrate TexTeller into your own projects. To start the server, you first need to enter the `src/` directory and then run the following command:
|
||||
We use [ray serve](https://github.com/ray-project/ray) to provide an API server for TexTeller. To start the server, run the following command:
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
texteller launch
|
||||
```
|
||||
|
||||
| Parameter | Description |
|
||||
| --------- | -------- |
|
||||
| `-ckpt` | The path to the weights file,*default is TexTeller's pretrained weights*. |
|
||||
| `-tknz` | The path to the tokenizer,*default is TexTeller's tokenizer*. |
|
||||
| `-port` | The server's service port,*default is 8000*. |
|
||||
| `--inference-mode` | Whether to use "cuda" or "mps" for inference,*default is "cpu"*. |
|
||||
| `--num_beams` | The number of beams for beam search,*default is 1*. |
|
||||
| `--num_replicas` | The number of service replicas to run on the server,*default is 1 replica*. You can use more replicas to achieve greater throughput.|
|
||||
| `--ncpu_per_replica` | The number of CPU cores used per service replica,*default is 1*.|
|
||||
| `--ngpu_per_replica` | The number of GPUs used per service replica,*default is 1*. You can set this value between 0 and 1 to run multiple service replicas on one GPU to share the GPU, thereby improving GPU utilization. (Note, if --num_replicas is 2, --ngpu_per_replica is 0.7, then 2 GPUs must be available) |
|
||||
| `-onnx` | Perform inference using Onnx Runtime, *disabled by default* |
|
||||
| `-p` | The server's service port,*default is 8000*. |
|
||||
| `--num-replicas` | The number of service replicas to run on the server,*default is 1 replica*. You can use more replicas to achieve greater throughput.|
|
||||
| `--ncpu-per-replica` | The number of CPU cores used per service replica,*default is 1*.|
|
||||
| `--ngpu-per-replica` | The number of GPUs used per service replica,*default is 1*. You can set this value between 0 and 1 to run multiple service replicas on one GPU to share the GPU, thereby improving GPU utilization. (Note, if --num_replicas is 2, --ngpu_per_replica is 0.7, then 2 GPUs must be available) |
|
||||
| `--num-beams` | The number of beams for beam search,*default is 1*. |
|
||||
| `--use-onnx` | Perform inference using Onnx Runtime, *disabled by default* |
|
||||
|
||||
> [!NOTE]
|
||||
> A client demo can be found at `src/client/demo.py`, you can refer to `demo.py` to send requests to the server
|
||||
To send requests to the server:
|
||||
|
||||
```python
|
||||
# client_demo.py
|
||||
|
||||
import requests
|
||||
|
||||
server_url = "http://127.0.0.1:8000/predict"
|
||||
|
||||
img_path = "/path/to/your/image"
|
||||
with open(img_path, 'rb') as img:
|
||||
files = {'img': img}
|
||||
response = requests.post(server_url, files=files)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
## 🐍 Python API
|
||||
|
||||
We provide several easy-to-use Python APIs for formula OCR scenarios. Please refer to our [documentation](https://oleehyo.github.io/TexTeller/) to learn about the corresponding API interfaces and usage.
|
||||
|
||||
## 🔍 Formula Detection
|
||||
|
||||
TexTeller's formula detection model is trained on 3,415 images of Chinese materials and 8,272 images from the [IBEM dataset](https://zenodo.org/records/4757865).
|
||||
|
||||
<div align="center">
|
||||
<img src="./assets/det_rec.png" width=250>
|
||||
</div>
|
||||
|
||||
We provide a formula detection interface in the Python API. Please refer to our [API documentation](https://oleehyo.github.io/TexTeller/) for more details.
|
||||
|
||||
## 🏋️♂️ Training
|
||||
|
||||
### Dataset
|
||||
Please setup your environment before training:
|
||||
|
||||
We provide an example dataset in the `src/models/ocr_model/train/dataset/` directory, you can place your own images in the `images/` directory and annotate each image with its corresponding formula in `formulas.jsonl`.
|
||||
|
||||
After preparing your dataset, you need to **change the `DIR_URL` variable to your own dataset's path** in `**/train/dataset/loader.py`
|
||||
|
||||
### Retraining the Tokenizer
|
||||
|
||||
If you are using a different dataset, you might need to retrain the tokenizer to obtain a different vocabulary. After configuring your dataset, you can train your own tokenizer with the following command:
|
||||
|
||||
1. In `src/models/tokenizer/train.py`, change `new_tokenizer.save_pretrained('./your_dir_name')` to your custom output directory
|
||||
|
||||
> If you want to use a different vocabulary size (default 15K), you need to change the `VOCAB_SIZE` variable in `src/models/globals.py`
|
||||
>
|
||||
2. **In the `src/` directory**, run the following command:
|
||||
1. Install the dependencies for training:
|
||||
|
||||
```bash
|
||||
python -m models.tokenizer.train
|
||||
uv pip install texteller[train]
|
||||
```
|
||||
|
||||
2. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OleehyO/TexTeller.git
|
||||
```
|
||||
|
||||
### Dataset
|
||||
|
||||
We provide an example dataset in the `examples/train_texteller/dataset/train` directory, you can place your own training data according to the format of the example dataset.
|
||||
|
||||
### Training the Model
|
||||
|
||||
1. Modify `num_processes` in `src/train_config.yaml` to match the number of GPUs available for training (default is 1).
|
||||
2. In the `src/` directory, run the following command:
|
||||
In the `examples/train_texteller/` directory, run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file ./train_config.yaml -m models.ocr_model.train.train
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
You can set your own tokenizer and checkpoint paths in `src/models/ocr_model/train/train.py` (refer to `train.py` for more information). If you are using the same architecture and vocabulary as TexTeller, you can also fine-tune TexTeller's default weights with your own dataset.
|
||||
|
||||
In `src/globals.py` and `src/models/ocr_model/train/train_args.py`, you can change the model's architecture and training hyperparameters.
|
||||
|
||||
> [!NOTE]
|
||||
> Our training scripts use the [Hugging Face Transformers](https://github.com/huggingface/transformers) library, so you can refer to their [documentation](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments) for more details and configurations on training parameters.
|
||||
Training arguments can be adjusted in [`train_config.yaml`](./examples/train_texteller/train_config.yaml).
|
||||
|
||||
## 📅 Plans
|
||||
|
||||
@@ -266,13 +201,11 @@ In `src/globals.py` and `src/models/ocr_model/train/train_args.py`, you can chan
|
||||
- [X] ~~Handwritten formulas support~~
|
||||
- [ ] PDF document recognition
|
||||
- [ ] Inference acceleration
|
||||
- [ ] ...
|
||||
|
||||
## ⭐️ Stargazers over time
|
||||
|
||||
[](https://starchart.cc/OleehyO/TexTeller)
|
||||
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
<a href="https://github.com/OleehyO/TexTeller/graphs/contributors">
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
📄 <a href="../README.md">English</a> | 中文
|
||||
📄 中文 | [English](./README.md)
|
||||
|
||||
<div align="center">
|
||||
<h1>
|
||||
@@ -6,47 +6,23 @@
|
||||
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
|
||||
<img src="./fire.svg" width=30, height=30>
|
||||
</h1>
|
||||
<!-- <p align="center">
|
||||
🤗 <a href="https://huggingface.co/OleehyO/TexTeller"> Hugging Face </a>
|
||||
</p> -->
|
||||
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
[](https://hub.docker.com/r/oleehyo/texteller)
|
||||
[](https://huggingface.co/datasets/OleehyO/latex-formulas)
|
||||
[](https://huggingface.co/OleehyO/TexTeller)
|
||||
[](https://oleehyo.github.io/TexTeller/)
|
||||
[](https://hub.docker.com/r/oleehyo/texteller)
|
||||
[](https://huggingface.co/datasets/OleehyO/latex-formulas)
|
||||
[](https://huggingface.co/OleehyO/TexTeller)
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
|
||||
</div>
|
||||
|
||||
<!-- <p align="center">
|
||||
|
||||
<a href="https://opensource.org/licenses/Apache-2.0">
|
||||
<img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License">
|
||||
</a>
|
||||
<a href="https://github.com/OleehyO/TexTeller/issues">
|
||||
<img src="https://img.shields.io/badge/Maintained%3F-yes-green.svg" alt="Maintenance">
|
||||
</a>
|
||||
<a href="https://github.com/OleehyO/TexTeller/pulls">
|
||||
<img src="https://img.shields.io/badge/Contributions-welcome-brightgreen.svg?style=flat" alt="Contributions welcome">
|
||||
</a>
|
||||
<a href="https://huggingface.co/datasets/OleehyO/latex-formulas">
|
||||
<img src="https://img.shields.io/badge/Data-Texteller1.0-brightgreen.svg" alt="Data">
|
||||
</a>
|
||||
<a href="https://huggingface.co/OleehyO/TexTeller">
|
||||
<img src="https://img.shields.io/badge/Weights-Texteller3.0-yellow.svg" alt="Weights">
|
||||
</a>
|
||||
|
||||
</p> -->
|
||||
|
||||
https://github.com/OleehyO/TexTeller/assets/56267907/532d1471-a72e-4960-9677-ec6c19db289f
|
||||
|
||||
TexTeller是一个基于[TrOCR](https://arxiv.org/abs/2109.10282)的端到端公式识别模型,可以把图片转换为对应的latex公式
|
||||
TexTeller 是一个端到端的公式识别模型,能够将图像转换为对应的 LaTeX 公式。
|
||||
|
||||
TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以在[这里](https://huggingface.co/datasets/OleehyO/latex-formulas)获取),相比于[LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)(使用了一个100K的数据集),TexTeller具有**更强的泛化能力**以及**更高的准确率**,可以覆盖大部分的使用场景。
|
||||
TexTeller 使用 **8千万图像-公式对** 进行训练(前代数据集可在此[获取](https://huggingface.co/datasets/OleehyO/latex-formulas)),相较 [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) 使用的 10 万量级数据集,TexTeller 具有**更强的泛化能力**和**更高的准确率**,覆盖绝大多数使用场景。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果您想为本项目提供一些反馈、建议等,欢迎在[Discussions版块](https://github.com/OleehyO/TexTeller/discussions)发起讨论。
|
||||
>
|
||||
> 另外,如果您觉得这个项目对您有帮助,请不要忘记点亮上方的Star⭐️🙏
|
||||
>[!NOTE]
|
||||
> 如果您想对本项目提出反馈或建议,欢迎前往 [讨论区](https://github.com/OleehyO/TexTeller/discussions) 发起讨论。
|
||||
|
||||
---
|
||||
|
||||
@@ -55,17 +31,12 @@ TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以
|
||||
<td>
|
||||
|
||||
## 🔖 目录
|
||||
|
||||
- [变更信息](#-变更信息)
|
||||
- [开搞](#-开搞)
|
||||
- [常见问题:无法连接到Hugging Face](#-常见问题无法连接到hugging-face)
|
||||
- [快速开始](#-快速开始)
|
||||
- [网页演示](#-网页演示)
|
||||
- [服务部署](#-服务部署)
|
||||
- [Python接口](#-python接口)
|
||||
- [公式检测](#-公式检测)
|
||||
- [API调用](#-api调用)
|
||||
- [训练](#️️-训练)
|
||||
- [计划](#-计划)
|
||||
- [观星曲线](#️-观星曲线)
|
||||
- [贡献者](#-贡献者)
|
||||
- [模型训练](#️️-模型训练)
|
||||
|
||||
</td>
|
||||
<td>
|
||||
@@ -74,17 +45,10 @@ TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以
|
||||
<figure>
|
||||
<img src="cover.png" width="800">
|
||||
<figcaption>
|
||||
<p>可以被TexTeller识别出的图片</p>
|
||||
<p>TexTeller 可识别的图像示例</p>
|
||||
</figcaption>
|
||||
</figure>
|
||||
<div>
|
||||
<p>
|
||||
感谢
|
||||
<i>
|
||||
北京邮电大学超算平台
|
||||
</i>
|
||||
为本项工作提供支持😘
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -92,221 +56,155 @@ TexTeller用了**80M**个图片-公式对进行训练(过去的数据集可以
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 🔄 变更信息
|
||||
## 📮 更新日志
|
||||
|
||||
- 📮[2024-06-06] **TexTeller3.0**发布! 训练数据集增加到了**80M**(相较于TexTeller2.0增加了**10倍**,并且改善了数据的多样性)。新版的TexTeller具有以下新的特性:
|
||||
- 支持扫描图片、手写公式以及中英文混合的公式。
|
||||
- 在打印图片上具有通用的中英文识别能力。
|
||||
- [2024-06-06] **TexTeller3.0 发布!** 训练数据增至 **8千万**(是 TexTeller2.0 的 **10倍** 并提升了数据多样性)。TexTeller3.0 新特性:
|
||||
|
||||
- 📮[2024-05-02] 支持**段落识别**。
|
||||
- 支持扫描件、手写公式、中英文混合公式识别
|
||||
|
||||
- 📮[2024-04-12] **公式检测模型**发布!
|
||||
- 支持印刷体中英文混排公式的OCR识别
|
||||
|
||||
- 📮[2024-03-25] TexTeller2.0发布!TexTeller2.0的训练数据增大到了7.5M(相较于TexTeller1.0增加了~15倍并且数据质量也有所改善)。训练后的TexTeller2.0在测试集中展现出了更加优越的性能,尤其在生僻符号、复杂多行、矩阵的识别场景中。
|
||||
- [2024-05-02] 支持**段落识别**功能
|
||||
|
||||
> 在[这里](./test.pdf)有更多的测试图片以及各家识别模型的横向对比。
|
||||
- [2024-04-12] **公式检测模型**发布!
|
||||
|
||||
## 🚀 开搞
|
||||
- [2024-03-25] TexTeller2.0 发布!TexTeller2.0 的训练数据增至750万(是前代的15倍并提升了数据质量)。训练后的 TexTeller2.0 在测试集中展现了**更优性能**,特别是在识别罕见符号、复杂多行公式和矩阵方面表现突出。
|
||||
|
||||
1. 克隆本仓库:
|
||||
> [此处](./assets/test.pdf) 展示了更多测试图像及各类识别模型的横向对比。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
1. 安装uv:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OleehyO/TexTeller
|
||||
pip install uv
|
||||
```
|
||||
|
||||
2. 安装本项目的依赖包:
|
||||
2. 安装项目依赖:
|
||||
|
||||
```bash
|
||||
pip install texteller
|
||||
uv pip install texteller
|
||||
```
|
||||
|
||||
3. 进入`src/`目录,在终端运行以下命令进行推理:
|
||||
3. 若使用 CUDA 后端,可能需要安装 `onnxruntime-gpu`:
|
||||
|
||||
```bash
|
||||
python inference.py -img "/path/to/image.{jpg,png}"
|
||||
# use --inference-mode option to enable GPU(cuda or mps) inference
|
||||
#+e.g. python inference.py -img "img.jpg" --inference-mode cuda
|
||||
uv pip install texteller[onnxruntime-gpu]
|
||||
```
|
||||
|
||||
> 第一次运行时会在Hugging Face上下载所需要的权重
|
||||
|
||||
### 段落识别
|
||||
|
||||
如演示视频所示,TexTeller还可以识别整个文本段落。尽管TexTeller具备通用的文本OCR能力,但我们仍然建议使用段落识别来获得更好的效果:
|
||||
|
||||
1. 下载公式检测模型的权重到`src/models/det_model/model/`目录 [[链接](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true)]
|
||||
|
||||
2. `src/`目录下运行`inference.py`并添加`-mix`选项,结果会以markdown的格式进行输出。
|
||||
4. 运行以下命令开始推理:
|
||||
|
||||
```bash
|
||||
python inference.py -img "/path/to/image.{jpg,png}" -mix
|
||||
texteller inference "/path/to/image.{jpg,png}"
|
||||
```
|
||||
|
||||
TexTeller默认使用轻量的[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)模型来识别中英文,可以尝试使用更大的模型来获取更好的中英文识别效果:
|
||||
|
||||
| 权重 | 描述 | 尺寸 |
|
||||
|-------------|-------------------| ---- |
|
||||
| [ch_PP-OCRv4_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx?download=true) | **默认的检测模型**,支持中英文检测 | 4.70M |
|
||||
| [ch_PP-OCRv4_server_det.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_det.onnx?download=true) | 高精度模型,支持中英文检测 | 115M |
|
||||
| [ch_PP-OCRv4_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_rec.onnx?download=true) | **默认的识别模型**,支持中英文识别 | 10.80M |
|
||||
| [ch_PP-OCRv4_server_rec.onnx](https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx?download=true) | 高精度模型,支持中英文识别 | 90.60M |
|
||||
|
||||
把识别/检测模型的权重放在`src/models/third_party/paddleocr/checkpoints/`
|
||||
下的`det/`或`rec/`目录中,然后重命名为`default_model.onnx`。
|
||||
|
||||
> [!NOTE]
|
||||
> 段落识别只能识别文档内容,无法还原文档的结构。
|
||||
|
||||
## ❓ 常见问题:无法连接到Hugging Face
|
||||
|
||||
默认情况下,会在Hugging Face中下载模型权重,**如果你的远端服务器无法连接到Hugging Face**,你可以通过以下命令进行加载:
|
||||
|
||||
1. 安装huggingface hub包
|
||||
|
||||
```bash
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
```
|
||||
|
||||
2. 在能连接Hugging Face的机器上下载模型权重:
|
||||
|
||||
```bash
|
||||
huggingface-cli download \
|
||||
OleehyO/TexTeller \
|
||||
--repo-type model \
|
||||
--local-dir "your/dir/path" \
|
||||
--local-dir-use-symlinks False
|
||||
```
|
||||
|
||||
3. 把包含权重的目录上传远端服务器,然后把 `src/models/ocr_model/model/TexTeller.py`中的 `REPO_NAME = 'OleehyO/TexTeller'`修改为 `REPO_NAME = 'your/dir/path'`
|
||||
|
||||
<!-- 如果你还想在训练模型时开启evaluate,你需要提前下载metric脚本并上传远端服务器:
|
||||
|
||||
1. 在能连接Hugging Face的机器上下载metric脚本
|
||||
|
||||
```bash
|
||||
huggingface-cli download \
|
||||
evaluate-metric/google_bleu \
|
||||
--repo-type space \
|
||||
--local-dir "your/dir/path" \
|
||||
--local-dir-use-symlinks False
|
||||
```
|
||||
|
||||
2. 把这个目录上传远端服务器,并在 `TexTeller/src/models/ocr_model/utils/metrics.py`中把 `evaluate.load('google_bleu')`改为 `evaluate.load('your/dir/path/google_bleu.py')` -->
|
||||
> 更多参数请查看 `texteller inference --help`
|
||||
|
||||
## 🌐 网页演示
|
||||
|
||||
进入 `src/` 目录,运行以下命令
|
||||
命令行运行:
|
||||
|
||||
```bash
|
||||
./start_web.sh
|
||||
texteller web
|
||||
```
|
||||
|
||||
在浏览器里输入 `http://localhost:8501`就可以看到web demo
|
||||
在浏览器中输入 `http://localhost:8501` 查看网页演示。
|
||||
|
||||
> [!NOTE]
|
||||
> 1. 对于Windows用户, 请运行 `start_web.bat`文件。
|
||||
> 2. 使用onnxruntime + gpu 推理时,需要安装onnxruntime-gpu
|
||||
> 段落识别无法还原文档结构,仅能识别其内容。
|
||||
|
||||
## 🖥️ 服务部署
|
||||
|
||||
我们使用 [ray serve](https://github.com/ray-project/ray) 为 TexTeller 提供 API 服务。启动服务:
|
||||
|
||||
```bash
|
||||
texteller launch
|
||||
```
|
||||
|
||||
| 参数 | 说明 |
|
||||
| --------- | -------- |
|
||||
| `-ckpt` | 权重文件路径,*默认为 TexTeller 预训练权重* |
|
||||
| `-tknz` | 分词器路径,*默认为 TexTeller 分词器* |
|
||||
| `-p` | 服务端口,*默认 8000* |
|
||||
| `--num-replicas` | 服务副本数,*默认 1*。可使用更多副本来提升吞吐量 |
|
||||
| `--ncpu-per-replica` | 单个副本使用的CPU核数,*默认 1* |
|
||||
| `--ngpu-per-replica` | 单个副本使用的GPU数,*默认 1*。可设置为0~1之间的值来在单卡上运行多个服务副本共享GPU,提升GPU利用率(注意,若--num_replicas为2,--ngpu_per_replica为0.7,则需有2块可用GPU) |
|
||||
| `--num-beams` | beam search的束宽,*默认 1* |
|
||||
| `--use-onnx` | 使用Onnx Runtime进行推理,*默认关闭* |
|
||||
|
||||
向服务发送请求:
|
||||
|
||||
```python
|
||||
# client_demo.py
|
||||
|
||||
import requests
|
||||
|
||||
server_url = "http://127.0.0.1:8000/predict"
|
||||
|
||||
img_path = "/path/to/your/image"
|
||||
with open(img_path, 'rb') as img:
|
||||
files = {'img': img}
|
||||
response = requests.post(server_url, files=files)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
## 🐍 Python接口
|
||||
|
||||
我们为公式OCR场景提供了多个易用的Python API接口,请参考[接口文档](https://oleehyo.github.io/TexTeller/)了解对应的API接口及使用方法。
|
||||
|
||||
## 🔍 公式检测
|
||||
|
||||
TexTeller的公式检测模型在3415张中文教材数据(130+版式)和8272张[IBEM数据集](https://zenodo.org/records/4757865)上训练得到,支持对整张图片进行**公式检测**。
|
||||
TexTeller的公式检测模型在3415张中文资料图像和8272张[IBEM数据集](https://zenodo.org/records/4757865)图像上训练。
|
||||
|
||||
<div align="center">
|
||||
<img src="det_rec.png" width=250>
|
||||
<img src="./det_rec.png" width=250>
|
||||
</div>
|
||||
|
||||
1. 下载公式检测模型的权重到`src/models/det_model/model/`目录 [[链接](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true)]
|
||||
我们在Python接口中提供了公式检测接口,详见[接口文档](https://oleehyo.github.io/TexTeller/)。
|
||||
|
||||
2. `src/`目录下运行以下命令,结果保存在`src/subimages/`
|
||||
## 🏋️♂️ 模型训练
|
||||
|
||||
请按以下步骤配置训练环境:
|
||||
|
||||
1. 安装训练依赖:
|
||||
|
||||
```bash
|
||||
python infer_det.py
|
||||
uv pip install texteller[train]
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>更进一步:公式批识别</summary>
|
||||
|
||||
在进行**公式检测后**,`src/`目录下运行以下命令
|
||||
|
||||
```shell
|
||||
python rec_infer_from_crop_imgs.py
|
||||
```
|
||||
|
||||
会基于上一步公式检测的结果,对裁剪出的所有公式进行批量识别,将识别结果在 `src/results/`中保存为txt文件。
|
||||
</details>
|
||||
|
||||
## 📡 API调用
|
||||
|
||||
我们使用[ray serve](https://github.com/ray-project/ray)来对外提供一个TexTeller的API接口,通过使用这个接口,你可以把TexTeller整合到自己的项目里。要想启动server,你需要先进入 `src/`目录然后运行以下命令:
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
| 参数 | 描述 |
|
||||
| --- | --- |
|
||||
| `-ckpt` | 权重文件的路径,*默认为TexTeller的预训练权重*。|
|
||||
| `-tknz` | 分词器的路径,*默认为TexTeller的分词器*。|
|
||||
| `-port` | 服务器的服务端口,*默认是8000*。|
|
||||
| `--inference-mode` | 使用"cuda"或"mps"推理,*默认为"cpu"*。|
|
||||
| `--num_beams` | beam search的beam数量,*默认是1*。|
|
||||
| `--num_replicas` | 在服务器上运行的服务副本数量,*默认1个副本*。你可以使用更多的副本来获取更大的吞吐量。|
|
||||
| `--ncpu_per_replica` | 每个服务副本所用的CPU核心数,*默认为1*。|
|
||||
| `--ngpu_per_replica` | 每个服务副本所用的GPU数量,*默认为1*。你可以把这个值设置成 0~1之间的数,这样会在一个GPU上运行多个服务副本来共享GPU,从而提高GPU的利用率。(注意,如果 --num_replicas 2, --ngpu_per_replica 0.7, 那么就必须要有2个GPU可用) |
|
||||
| `-onnx` | 使用Onnx Runtime进行推理,*默认不使用*。|
|
||||
|
||||
> [!NOTE]
|
||||
> 一个客户端demo可以在 `TexTeller/client/demo.py`找到,你可以参考 `demo.py`来给server发送请求
|
||||
|
||||
## 🏋️♂️ 训练
|
||||
|
||||
### 数据集
|
||||
|
||||
我们在 `src/models/ocr_model/train/dataset/`目录中提供了一个数据集的例子,你可以把自己的图片放在 `images`目录然后在 `formulas.jsonl`中为每张图片标注对应的公式。
|
||||
|
||||
准备好数据集后,你需要在 `**/train/dataset/loader.py`中把 **`DIR_URL`变量改成你自己数据集的路径**
|
||||
|
||||
### 重新训练分词器
|
||||
|
||||
如果你使用了不一样的数据集,你可能需要重新训练tokenizer来得到一个不一样的词典。配置好数据集后,可以通过以下命令来训练自己的tokenizer:
|
||||
|
||||
1. 在`src/models/tokenizer/train.py`中,修改`new_tokenizer.save_pretrained('./your_dir_name')`为你自定义的输出目录
|
||||
|
||||
> 注意:如果要用一个不一样大小的词典(默认1.5W个token),你需要在`src/models/globals.py`中修改`VOCAB_SIZE`变量
|
||||
|
||||
2. **在`src/`目录下**运行以下命令:
|
||||
2. 克隆仓库:
|
||||
|
||||
```bash
|
||||
python -m models.tokenizer.train
|
||||
git clone https://github.com/OleehyO/TexTeller.git
|
||||
```
|
||||
|
||||
### 训练模型
|
||||
### 数据集准备
|
||||
|
||||
1. 修改`src/train_config.yaml`中的`num_processes`为训练用的显卡数(默认为1)
|
||||
我们在`examples/train_texteller/dataset/train`目录中提供了示例数据集,您可按照示例数据集的格式放置自己的训练数据。
|
||||
|
||||
2. 在`src/`目录下运行以下命令:
|
||||
### 开始训练
|
||||
|
||||
在`examples/train_texteller/`目录下运行:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file ./train_config.yaml -m models.ocr_model.train.train
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
你可以在`src/models/ocr_model/train/train.py`中设置自己的tokenizer和checkpoint路径(请参考`train.py`)。如果你使用了与TexTeller一样的架构和相同的词典,你还可以用自己的数据集来微调TexTeller的默认权重。
|
||||
训练参数可通过[`train_config.yaml`](./examples/train_texteller/train_config.yaml)调整。
|
||||
|
||||
> [!NOTE]
|
||||
> 我们的训练脚本使用了[Hugging Face Transformers](https://github.com/huggingface/transformers)库, 所以你可以参考他们提供的[文档](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments)来获取更多训练参数的细节以及配置。
|
||||
## 📅 计划列表
|
||||
|
||||
## 📅 计划
|
||||
|
||||
- [X] ~~使用更大的数据集来训练模型~~
|
||||
- [X] ~~扫描图片识别~~
|
||||
- [X] ~~使用更大规模数据集训练模型~~
|
||||
- [X] ~~扫描件识别支持~~
|
||||
- [X] ~~中英文场景支持~~
|
||||
- [X] ~~手写公式识别~~
|
||||
- [X] ~~手写公式支持~~
|
||||
- [ ] PDF文档识别
|
||||
- [ ] 推理加速
|
||||
|
||||
## ⭐️ 观星曲线
|
||||
## ⭐️ 项目星标
|
||||
|
||||
[](https://starchart.cc/OleehyO/TexTeller)
|
||||
[](https://starchart.cc/OleehyO/TexTeller)
|
||||
|
||||
## 👥 贡献者
|
||||
|
||||
|
||||
14
assets/logo.svg
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="430" height="80" viewBox="0 0 430 80">
|
||||
|
||||
<text
|
||||
x="50%"
|
||||
y="50%"
|
||||
font-family="monaco"
|
||||
font-size="55"
|
||||
text-anchor="middle"
|
||||
dominant-baseline="middle">
|
||||
<tspan fill="#F37726">{</tspan><tspan fill="#616161">Tex</tspan><tspan fill="#F37726">}</tspan><tspan fill="#616161">Teller</tspan>
|
||||
</text>
|
||||
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 377 B |
20
docs/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
35
docs/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
0
docs/requirements.txt
Normal file
39
docs/source/api.rst
Normal file
@@ -0,0 +1,39 @@
|
||||
API Reference
|
||||
=============
|
||||
|
||||
This section provides detailed API documentation for the TexTeller package. TexTeller is a tool for detecting and recognizing LaTeX formulas in images and converting mixed text and formula images to markdown.
|
||||
|
||||
.. contents:: Table of Contents
|
||||
:local:
|
||||
:depth: 2
|
||||
|
||||
|
||||
Image to LaTeX Conversion
|
||||
-------------------------
|
||||
|
||||
.. autofunction:: texteller.api.img2latex
|
||||
|
||||
Paragraph to Markdown Conversion
|
||||
------------------------------
|
||||
|
||||
.. autofunction:: texteller.api.paragraph2md
|
||||
|
||||
LaTeX Detection
|
||||
---------------
|
||||
|
||||
.. autofunction:: texteller.api.detection.latex_detect
|
||||
|
||||
Model Loading
|
||||
-------------
|
||||
|
||||
.. autofunction:: texteller.api.load_model
|
||||
.. autofunction:: texteller.api.load_tokenizer
|
||||
.. autofunction:: texteller.api.load_latexdet_model
|
||||
.. autofunction:: texteller.api.load_textdet_model
|
||||
.. autofunction:: texteller.api.load_textrec_model
|
||||
|
||||
|
||||
KaTeX Conversion
|
||||
----------------
|
||||
|
||||
.. autofunction:: texteller.api.to_katex
|
||||
75
docs/source/conf.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute.
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "TexTeller"
|
||||
copyright = "2025, TexTeller Team"
|
||||
author = "TexTeller Team"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"myst_parser",
|
||||
"sphinx.ext.duration",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.autosectionlabel",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx_copybutton",
|
||||
# 'sphinx.ext.linkcode',
|
||||
# 'sphinxarg.ext',
|
||||
"sphinx_design",
|
||||
"nbsphinx",
|
||||
]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = []
|
||||
|
||||
# Autodoc settings
|
||||
autodoc_member_order = "bysource"
|
||||
add_module_names = False
|
||||
autoclass_content = "both"
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"member-order": "bysource",
|
||||
"undoc-members": True,
|
||||
"show-inheritance": True,
|
||||
"imported-members": True,
|
||||
}
|
||||
|
||||
# Intersphinx settings
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"numpy": ("https://numpy.org/doc/stable", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
"transformers": ("https://huggingface.co/docs/transformers/main/en", None),
|
||||
}
|
||||
|
||||
html_theme = "sphinx_book_theme"
|
||||
|
||||
html_theme_options = {
|
||||
"repository_url": "https://github.com/OleehyO/TexTeller",
|
||||
"use_repository_button": True,
|
||||
"use_issues_button": True,
|
||||
"use_edit_page_button": True,
|
||||
"use_download_button": True,
|
||||
}
|
||||
|
||||
html_logo = "../../assets/logo.svg"
|
||||
76
docs/source/index.rst
Normal file
@@ -0,0 +1,76 @@
|
||||
.. TexTeller documentation master file, created by
|
||||
sphinx-quickstart on Sun Apr 20 13:05:53 2025.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
TexTeller Documentation
|
||||
===========================================
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- **Image to LaTeX Conversion**: Convert images containing LaTeX formulas to LaTeX code
|
||||
- **LaTeX Detection**: Detect and locate LaTeX formulas in mixed text/formula images
|
||||
- **Paragraph to Markdown**: Convert mixed text and formula images to Markdown format
|
||||
|
||||
Installation
|
||||
-----------
|
||||
|
||||
You can install TexTeller using pip:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install texteller
|
||||
|
||||
Quick Start
|
||||
----------
|
||||
|
||||
Converting an image to LaTeX:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from texteller import load_model, load_tokenizer, img2latex
|
||||
|
||||
# Load models
|
||||
model = load_model(use_onnx=False)
|
||||
tokenizer = load_tokenizer()
|
||||
|
||||
# Convert image to LaTeX
|
||||
latex = img2latex(model, tokenizer, ["path/to/image.png"])[0]
|
||||
|
||||
Processing a mixed text/formula image:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from texteller import (
|
||||
load_model, load_tokenizer, load_latexdet_model,
|
||||
load_textdet_model, load_textrec_model, paragraph2md
|
||||
)
|
||||
|
||||
# Load all required models
|
||||
latex_model = load_model()
|
||||
tokenizer = load_tokenizer()
|
||||
latex_detector = load_latexdet_model()
|
||||
text_detector = load_textdet_model()
|
||||
text_recognizer = load_textrec_model()
|
||||
|
||||
# Convert to markdown
|
||||
markdown = paragraph2md(
|
||||
"path/to/mixed_image.png",
|
||||
latex_detector,
|
||||
text_detector,
|
||||
text_recognizer,
|
||||
latex_model,
|
||||
tokenizer
|
||||
)
|
||||
|
||||
API Documentation
|
||||
----------------
|
||||
|
||||
For detailed API documentation, please see :doc:`./api`.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:hidden:
|
||||
|
||||
api
|
||||
10
examples/client_demo.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import requests
|
||||
|
||||
server_url = "http://127.0.0.1:8000/predict"
|
||||
|
||||
img_path = "/path/to/your/image"
|
||||
with open(img_path, "rb") as img:
|
||||
files = {"img": img}
|
||||
response = requests.post(server_url, files=files)
|
||||
|
||||
print(response.text)
|
||||
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 8.7 KiB After Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 6.8 KiB After Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 5.2 KiB After Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 2.8 KiB After Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB After Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 3.7 KiB After Width: | Height: | Size: 3.7 KiB |
|
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 5.3 KiB After Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 4.9 KiB After Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 1.8 KiB After Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 5.7 KiB After Width: | Height: | Size: 5.7 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 4.8 KiB After Width: | Height: | Size: 4.8 KiB |
|
Before Width: | Height: | Size: 4.5 KiB After Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 5.2 KiB After Width: | Height: | Size: 5.2 KiB |
35
examples/train_texteller/dataset/train/metadata.jsonl
Normal file
@@ -0,0 +1,35 @@
|
||||
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
71
examples/train_texteller/train.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from functools import partial
|
||||
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from texteller import load_model, load_tokenizer
|
||||
from texteller.constants import MIN_HEIGHT, MIN_WIDTH
|
||||
|
||||
from examples.train_texteller.utils import (
|
||||
collate_fn,
|
||||
filter_fn,
|
||||
img_inf_transform,
|
||||
img_train_transform,
|
||||
tokenize_fn,
|
||||
)
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**training_config)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = load_dataset("imagefolder", data_dir="dataset")["train"]
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["image"].height > MIN_HEIGHT and x["image"].width > MIN_WIDTH
|
||||
)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = load_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
# tokenizer = load_tokenizer("/path/to/your/tokenizer")
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(
|
||||
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||
)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset["train"], split_dataset["test"]
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
model = load_model()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
# model = load_model("/path/to/your/model_checkpoint")
|
||||
|
||||
enable_train = True
|
||||
training_config = yaml.safe_load(open("train_config.yaml"))
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
32
examples/train_texteller/train_config.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
# For more information, please refer to the official documentation: https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
|
||||
|
||||
seed: 42 # Random seed for reproducibility
|
||||
use_cpu: false # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
learning_rate: 5.0e-5 # Learning rate
|
||||
num_train_epochs: 10 # Total number of training epochs
|
||||
per_device_train_batch_size: 4 # Batch size per GPU for training
|
||||
per_device_eval_batch_size: 8 # Batch size per GPU for evaluation
|
||||
output_dir: "train_result" # Output directory
|
||||
overwrite_output_dir: false # If the output directory exists, do not delete its content
|
||||
report_to:
|
||||
- tensorboard # Report logs to TensorBoard
|
||||
save_strategy: "steps" # Strategy to save checkpoints
|
||||
save_steps: 500 # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
save_total_limit: 5 # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
logging_strategy: "steps" # Log every certain number of steps
|
||||
logging_steps: 500 # Number of steps between each log
|
||||
logging_nan_inf_filter: false # Record logs for loss=nan or inf
|
||||
optim: "adamw_torch" # Optimizer
|
||||
lr_scheduler_type: "cosine" # Learning rate scheduler
|
||||
warmup_ratio: 0.1 # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
max_grad_norm: 1.0 # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
fp16: false # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
bf16: false # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
gradient_accumulation_steps: 1 # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
jit_mode_eval: false # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
torch_compile: false # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
dataloader_pin_memory: true # Can speed up data transfer between CPU and GPU
|
||||
dataloader_num_workers: 1 # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
evaluation_strategy: "steps" # Evaluation strategy, can be "steps" or "epoch"
|
||||
eval_steps: 500 # If evaluation_strategy="step"
|
||||
remove_unused_columns: false # Don't change this unless you really know what you are doing.
|
||||
17
examples/train_texteller/utils/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .functional import (
|
||||
collate_fn,
|
||||
filter_fn,
|
||||
tokenize_fn,
|
||||
)
|
||||
from .transforms import (
|
||||
img_train_transform,
|
||||
img_inf_transform,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"collate_fn",
|
||||
"filter_fn",
|
||||
"tokenize_fn",
|
||||
"img_train_transform",
|
||||
"img_inf_transform",
|
||||
]
|
||||
@@ -1,9 +1,36 @@
|
||||
from augraphy import *
|
||||
"""
|
||||
Custom augraphy pipeline for training
|
||||
|
||||
This file implements a custom augraphy data augmentation pipeline. We found that using augraphy's
|
||||
default pipeline can cause significant degradation to formula images, potentially losing semantic
|
||||
information. Therefore, we carefully selected several common augmentation effects,
|
||||
adjusting their parameters and combination methods to preserve the original semantic information
|
||||
of the images as much as possible.
|
||||
"""
|
||||
|
||||
from augraphy import (
|
||||
InkColorSwap,
|
||||
LinesDegradation,
|
||||
OneOf,
|
||||
Dithering,
|
||||
InkBleed,
|
||||
InkShifter,
|
||||
NoiseTexturize,
|
||||
BrightnessTexturize,
|
||||
ColorShift,
|
||||
DirtyDrum,
|
||||
LightingGradient,
|
||||
Brightness,
|
||||
Gamma,
|
||||
SubtleNoise,
|
||||
Jpeg,
|
||||
AugraphyPipeline,
|
||||
)
|
||||
import random
|
||||
|
||||
def ocr_augmentation_pipeline():
|
||||
pre_phase = [
|
||||
]
|
||||
|
||||
def get_custom_augraphy():
|
||||
pre_phase = []
|
||||
|
||||
ink_phase = [
|
||||
InkColorSwap(
|
||||
@@ -15,8 +42,7 @@ def ocr_augmentation_pipeline():
|
||||
ink_swap_max_height_range=(100, 120),
|
||||
ink_swap_min_area_range=(10, 20),
|
||||
ink_swap_max_area_range=(400, 500),
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
LinesDegradation(
|
||||
line_roi=(0.0, 0.0, 1.0, 1.0),
|
||||
@@ -28,10 +54,8 @@ def ocr_augmentation_pipeline():
|
||||
line_long_to_short_ratio=(5, 7),
|
||||
line_replacement_probability=(0.4, 0.5),
|
||||
line_replacement_thickness=(1, 3),
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
|
||||
# ============================
|
||||
OneOf(
|
||||
[
|
||||
@@ -45,11 +69,9 @@ def ocr_augmentation_pipeline():
|
||||
severity=(0.4, 0.6),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
# ============================
|
||||
|
||||
# ============================
|
||||
InkShifter(
|
||||
text_shift_scale_range=(18, 27),
|
||||
@@ -58,42 +80,32 @@ def ocr_augmentation_pipeline():
|
||||
blur_kernel_size=(5, 5),
|
||||
blur_sigma=0,
|
||||
noise_type="perlin",
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
# ============================
|
||||
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
NoiseTexturize( # tested
|
||||
NoiseTexturize(
|
||||
sigma_range=(3, 10),
|
||||
turbulence_range=(2, 5),
|
||||
texture_width_range=(300, 500),
|
||||
texture_height_range=(300, 500),
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
BrightnessTexturize( # tested
|
||||
texturize_range=(0.9, 0.99),
|
||||
deviation=0.03,
|
||||
# p=0.2
|
||||
p=0.4
|
||||
)
|
||||
BrightnessTexturize(texturize_range=(0.9, 0.99), deviation=0.03, p=0.2),
|
||||
]
|
||||
|
||||
post_phase = [
|
||||
ColorShift( # tested
|
||||
ColorShift(
|
||||
color_shift_offset_x_range=(3, 5),
|
||||
color_shift_offset_y_range=(3, 5),
|
||||
color_shift_iterations=(2, 3),
|
||||
color_shift_brightness_range=(0.9, 1.1),
|
||||
color_shift_gaussian_kernel_range=(3, 3),
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
|
||||
DirtyDrum( # tested
|
||||
DirtyDrum(
|
||||
line_width_range=(1, 6),
|
||||
line_concentration=random.uniform(0.05, 0.15),
|
||||
direction=random.randint(0, 2),
|
||||
@@ -101,10 +113,8 @@ def ocr_augmentation_pipeline():
|
||||
noise_value=(64, 224),
|
||||
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
||||
sigmaX=0,
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
@@ -126,11 +136,9 @@ def ocr_augmentation_pipeline():
|
||||
gamma_range=(0.9, 1.1),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
# =====================================
|
||||
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
@@ -141,8 +149,7 @@ def ocr_augmentation_pipeline():
|
||||
quality_range=(70, 95),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4
|
||||
p=0.2,
|
||||
),
|
||||
# =====================================
|
||||
]
|
||||
@@ -152,7 +159,7 @@ def ocr_augmentation_pipeline():
|
||||
paper_phase=paper_phase,
|
||||
post_phase=post_phase,
|
||||
pre_phase=pre_phase,
|
||||
log=False
|
||||
log=False,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
47
examples/train_texteller/utils/functional.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from texteller.constants import MAX_TOKEN_SIZE, MIN_HEIGHT, MIN_WIDTH
|
||||
|
||||
|
||||
def _left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, "x should be 2-dimensional"
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: dict[str, list[Any]], tokenizer=None) -> dict[str, list[Any]]:
|
||||
assert tokenizer is not None, "tokenizer should not be None"
|
||||
tokenized_formula = tokenizer(samples["latex_formula"], return_special_tokens_mask=True)
|
||||
tokenized_formula["pixel_values"] = samples["image"]
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: list[dict[str, Any]], tokenizer=None) -> dict[str, list[Any]]:
|
||||
assert tokenizer is not None, "tokenizer should not be None"
|
||||
pixel_values = [dic.pop("pixel_values") for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch["pixel_values"] = pixel_values
|
||||
batch["decoder_input_ids"] = batch.pop("input_ids")
|
||||
batch["decoder_attention_mask"] = batch.pop("attention_mask")
|
||||
|
||||
batch["labels"] = _left_move(batch["labels"], -100)
|
||||
|
||||
# convert list of Image to a tensor with (B, C, H, W)
|
||||
batch["pixel_values"] = torch.stack(batch["pixel_values"], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample["image"].height > MIN_HEIGHT
|
||||
and sample["image"].width > MIN_WIDTH
|
||||
and len(tokenizer(sample["latex_formula"])["input_ids"]) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
@@ -4,38 +4,19 @@ import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
from collections import Counter
|
||||
|
||||
from ...globals import (
|
||||
from texteller.constants import (
|
||||
IMG_CHANNELS,
|
||||
FIXED_IMG_SIZE,
|
||||
IMAGE_MEAN, IMAGE_STD,
|
||||
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||
MAX_RESIZE_RATIO,
|
||||
MIN_RESIZE_RATIO,
|
||||
)
|
||||
from .ocr_aug import ocr_augmentation_pipeline
|
||||
from texteller.utils import transform as inference_transform
|
||||
from .augraphy_pipe import get_custom_augraphy
|
||||
|
||||
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
||||
train_pipeline = ocr_augmentation_pipeline()
|
||||
|
||||
general_transform_pipeline = v2.Compose([
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(),
|
||||
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True
|
||||
),
|
||||
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
|
||||
# v2.ToPILImage()
|
||||
])
|
||||
augraphy_pipeline = get_custom_augraphy()
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
@@ -45,8 +26,7 @@ def trim_white_border(image: np.ndarray):
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
|
||||
tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
@@ -61,7 +41,7 @@ def trim_white_border(image: np.ndarray):
|
||||
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
trimmed_image = image[y:y+h, x:x+w]
|
||||
trimmed_image = image[y : y + h, x : x + w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
@@ -70,44 +50,42 @@ def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
pad_width_size = randi[0] + randi[2]
|
||||
if (pad_height_size + image.shape[0] < 30):
|
||||
if pad_height_size + image.shape[0] < 30:
|
||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||
randi[1] += compensate_height
|
||||
randi[3] += compensate_height
|
||||
if (pad_width_size + image.shape[1] < 30):
|
||||
if pad_width_size + image.shape[1] < 30:
|
||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||
randi[0] += compensate_width
|
||||
randi[2] += compensate_width
|
||||
return v2.functional.pad(
|
||||
torch.from_numpy(image).permute(2, 0, 1),
|
||||
padding=randi,
|
||||
padding_mode='constant',
|
||||
fill=(255, 255, 255)
|
||||
padding_mode="constant",
|
||||
fill=(255, 255, 255),
|
||||
)
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||
def padding(images: list[torch.Tensor], required_size: int) -> list[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img,
|
||||
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(
|
||||
images: List[np.ndarray],
|
||||
minr: float,
|
||||
maxr: float
|
||||
) -> List[np.ndarray]:
|
||||
def random_resize(images: list[np.ndarray], minr: float, maxr: float) -> list[np.ndarray]:
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
|
||||
# Anti-aliasing
|
||||
cv2.resize(
|
||||
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
@@ -133,7 +111,9 @@ def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
|
||||
rotated_image = cv2.warpAffine(
|
||||
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
||||
)
|
||||
|
||||
return rotated_image
|
||||
|
||||
@@ -142,14 +122,14 @@ def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
image = train_pipeline(image)
|
||||
image = augraphy_pipeline(image)
|
||||
return image
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
def train_transform(images: list[Image.Image]) -> list[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
images = [np.array(img.convert("RGB")) for img in images]
|
||||
# random resize first
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
images = [trim_white_border(image) for image in images]
|
||||
@@ -158,19 +138,17 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
images = inference_transform(images)
|
||||
return images
|
||||
|
||||
|
||||
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images]
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
def img_train_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
processed_img = train_transform(samples["pixel_values"])
|
||||
samples["pixel_values"] = processed_img
|
||||
return samples
|
||||
|
||||
return images
|
||||
|
||||
def img_inf_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
processed_img = inference_transform(samples["pixel_values"])
|
||||
samples["pixel_values"] = processed_img
|
||||
return samples
|
||||
87
pyproject.toml
Normal file
@@ -0,0 +1,87 @@
|
||||
[build-system]
|
||||
requires = ["hatchling", "hatch-vcs"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "texteller"
|
||||
authors = [
|
||||
{ name="OleehyO", email="leehy0357@gmail.com" }
|
||||
]
|
||||
dynamic = ["version"]
|
||||
description = "Texteller is a tool for converting rendered image to original latex code"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"click>=8.1.8",
|
||||
"colorama>=0.4.6",
|
||||
"opencv-python-headless>=4.11.0.86",
|
||||
"pyclipper>=1.3.0.post6",
|
||||
"shapely>=2.1.0",
|
||||
"streamlit>=1.44.1",
|
||||
"streamlit-paste-button>=0.1.2",
|
||||
"torch>=2.6.0",
|
||||
"torchvision>=0.21.0",
|
||||
"transformers==4.47",
|
||||
"wget>=3.2",
|
||||
"optimum[onnxruntime]>=1.24.0",
|
||||
"python-multipart>=0.0.20",
|
||||
"ray[serve]>=2.44.1",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "vcs"
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
|
||||
target-version = "py310"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.format]
|
||||
line-ending = "lf"
|
||||
quote-style = "double"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "W"]
|
||||
ignore = [
|
||||
"E999",
|
||||
"EXE001",
|
||||
"UP009",
|
||||
"F401",
|
||||
"TID252",
|
||||
"F403",
|
||||
"F841",
|
||||
"E501",
|
||||
"W291",
|
||||
"W293",
|
||||
"E741",
|
||||
"E712",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["texteller"]
|
||||
|
||||
[project.scripts]
|
||||
texteller = "texteller.cli:cli"
|
||||
|
||||
[project.optional-dependencies]
|
||||
onnxruntime-gpu = [
|
||||
"onnxruntime-gpu>=1.21.0",
|
||||
]
|
||||
test = [
|
||||
"pytest>=8.3.5",
|
||||
]
|
||||
train = [
|
||||
"accelerate>=1.6.0",
|
||||
"augraphy>=8.2.6",
|
||||
"datasets>=3.5.0",
|
||||
"tensorboardx>=2.6.2.2",
|
||||
]
|
||||
docs = [
|
||||
"myst-parser>=4.0.1",
|
||||
"nbsphinx>=0.9.7",
|
||||
"sphinx>=8.1.3",
|
||||
"sphinx-book-theme>=1.1.4",
|
||||
"sphinx-copybutton>=0.5.2",
|
||||
"sphinx-design>=0.6.1",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
transformers==4.45.2
|
||||
sentence-transformers==3.1.1
|
||||
datasets
|
||||
evaluate
|
||||
opencv-python
|
||||
ray[serve]
|
||||
accelerate
|
||||
tensorboardX
|
||||
nltk
|
||||
python-multipart
|
||||
|
||||
augraphy
|
||||
|
||||
streamlit==1.30
|
||||
streamlit-paste-button
|
||||
|
||||
shapely
|
||||
pyclipper
|
||||
onnxruntime-gpu
|
||||
42
setup.py
@@ -1,42 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
# Define the base dependencies
|
||||
install_requires = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"datasets",
|
||||
"evaluate",
|
||||
"opencv-python",
|
||||
"ray[serve]",
|
||||
"accelerate",
|
||||
"tensorboardX",
|
||||
"nltk",
|
||||
"python-multipart",
|
||||
"augraphy",
|
||||
"streamlit==1.30",
|
||||
"streamlit-paste-button",
|
||||
"shapely",
|
||||
"pyclipper",
|
||||
|
||||
"optimum[exporters]",
|
||||
]
|
||||
|
||||
setup(
|
||||
name="texteller",
|
||||
version="0.1.2",
|
||||
author="OleehyO",
|
||||
author_email="1258009915@qq.com",
|
||||
description="A meta-package for installing dependencies",
|
||||
long_description=open('README.md').read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/OleehyO/TexTeller",
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires='>=3.10',
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
import requests
|
||||
|
||||
rec_server_url = "http://127.0.0.1:8000/frec"
|
||||
det_server_url = "http://127.0.0.1:8000/fdet"
|
||||
|
||||
img_path = "/your/image/path/"
|
||||
with open(img_path, 'rb') as img:
|
||||
files = {'img': img}
|
||||
response = requests.post(rec_server_url, files=files)
|
||||
# response = requests.post(det_server_url, files=files)
|
||||
|
||||
print(response.text)
|
||||
@@ -1,85 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
import onnxruntime
|
||||
from pathlib import Path
|
||||
|
||||
from models.det_model.inference import PredictConfig, predict_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
|
||||
default="./models/det_model/model/infer_cfg.yml")
|
||||
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
|
||||
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
||||
parser.add_argument("--image_file", type=str)
|
||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
||||
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True)
|
||||
|
||||
|
||||
def get_test_images(infer_dir, infer_img):
|
||||
"""
|
||||
Get image path list in TEST mode
|
||||
"""
|
||||
assert infer_img is not None or infer_dir is not None, \
|
||||
"--image_file or --image_dir should be set"
|
||||
assert infer_img is None or os.path.isfile(infer_img), \
|
||||
"{} is not a file".format(infer_img)
|
||||
assert infer_dir is None or os.path.isdir(infer_dir), \
|
||||
"{} is not a directory".format(infer_dir)
|
||||
|
||||
# infer_img has a higher priority
|
||||
if infer_img and os.path.isfile(infer_img):
|
||||
return [infer_img]
|
||||
|
||||
images = set()
|
||||
infer_dir = os.path.abspath(infer_dir)
|
||||
assert os.path.isdir(infer_dir), \
|
||||
"infer_dir {} is not a directory".format(infer_dir)
|
||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
||||
exts += [ext.upper() for ext in exts]
|
||||
for ext in exts:
|
||||
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
||||
images = list(images)
|
||||
|
||||
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
||||
print("Found {} inference images in total.".format(len(images)))
|
||||
|
||||
return images
|
||||
|
||||
def download_file(url, filename):
|
||||
print(f"Downloading {filename}...")
|
||||
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
||||
print("Download complete.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
cur_path = os.getcwd()
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
if not os.path.exists(FLAGS.infer_cfg):
|
||||
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
||||
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
||||
|
||||
if not os.path.exists(FLAGS.onnx_file):
|
||||
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
||||
download_file(onnx_file_url, FLAGS.onnx_file)
|
||||
|
||||
# load image list
|
||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||
|
||||
if FLAGS.use_gpu:
|
||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider'])
|
||||
# load infer config
|
||||
infer_config = PredictConfig(FLAGS.infer_cfg)
|
||||
|
||||
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
||||
|
||||
os.chdir(cur_path)
|
||||
@@ -1,85 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import cv2 as cv
|
||||
|
||||
from pathlib import Path
|
||||
from onnxruntime import InferenceSession
|
||||
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
||||
from models.thrid_party.paddleocr.infer import utility
|
||||
|
||||
from models.utils import mix_inference
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
from models.ocr_model.utils.inference import inference as latex_inference
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.det_model.inference import PredictConfig
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-img',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to the input image'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inference-mode',
|
||||
type=str,
|
||||
default='cpu',
|
||||
help='Inference mode, select one of cpu, cuda, or mps'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num-beam',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of beam search for decoding'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-mix',
|
||||
action='store_true',
|
||||
help='use mix mode'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# You can use your own checkpoint and tokenizer path.
|
||||
print('Loading model and tokenizer...')
|
||||
latex_rec_model = TexTeller.from_pretrained()
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
print('Model and tokenizer loaded.')
|
||||
|
||||
img_path = args.img
|
||||
img = cv.imread(img_path)
|
||||
print('Inference...')
|
||||
if not args.mix:
|
||||
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
|
||||
res = to_katex(res[0])
|
||||
print(res)
|
||||
else:
|
||||
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
|
||||
use_gpu = args.inference_mode == 'cuda'
|
||||
SIZE_LIMIT = 20 * 1024 * 1024
|
||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||
det_use_gpu = False
|
||||
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||
|
||||
paddleocr_args = utility.parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.det_model_dir = det_model_dir
|
||||
paddleocr_args.rec_model_dir = rec_model_dir
|
||||
|
||||
paddleocr_args.use_gpu = det_use_gpu
|
||||
detector = predict_det.TextDetector(paddleocr_args)
|
||||
paddleocr_args.use_gpu = rec_use_gpu
|
||||
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||
|
||||
lang_ocr_models = [detector, recognizer]
|
||||
latex_rec_models = [latex_rec_model, tokenizer]
|
||||
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
||||
print(res)
|
||||
@@ -1,195 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
from .preprocess import Compose
|
||||
from .Bbox import Bbox
|
||||
|
||||
|
||||
# Global dictionary
|
||||
SUPPORT_MODELS = {
|
||||
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
|
||||
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
|
||||
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
|
||||
'DETR'
|
||||
}
|
||||
|
||||
|
||||
class PredictConfig(object):
|
||||
"""set config of preprocess, postprocess and visualize
|
||||
Args:
|
||||
infer_config (str): path of infer_cfg.yml
|
||||
"""
|
||||
|
||||
def __init__(self, infer_config):
|
||||
# parsing Yaml config for Preprocess
|
||||
with open(infer_config) as f:
|
||||
yml_conf = yaml.safe_load(f)
|
||||
self.check_model(yml_conf)
|
||||
self.arch = yml_conf['arch']
|
||||
self.preprocess_infos = yml_conf['Preprocess']
|
||||
self.min_subgraph_size = yml_conf['min_subgraph_size']
|
||||
self.label_list = yml_conf['label_list']
|
||||
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
|
||||
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
|
||||
self.mask = yml_conf.get("mask", False)
|
||||
self.tracker = yml_conf.get("tracker", None)
|
||||
self.nms = yml_conf.get("NMS", None)
|
||||
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
||||
|
||||
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
|
||||
|
||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
||||
print(
|
||||
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
|
||||
)
|
||||
self.print_config()
|
||||
|
||||
def check_model(self, yml_conf):
|
||||
"""
|
||||
Raises:
|
||||
ValueError: loaded model not in supported model type
|
||||
"""
|
||||
for support_model in SUPPORT_MODELS:
|
||||
if support_model in yml_conf['arch']:
|
||||
return True
|
||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
|
||||
'arch'], SUPPORT_MODELS))
|
||||
|
||||
def print_config(self):
|
||||
print('----------- Model Configuration -----------')
|
||||
print('%s: %s' % ('Model Arch', self.arch))
|
||||
print('%s: ' % ('Transform Order'))
|
||||
for op_info in self.preprocess_infos:
|
||||
print('--%s: %s' % ('transform op', op_info['type']))
|
||||
print('--------------------------------------------')
|
||||
|
||||
|
||||
def draw_bbox(image, outputs, infer_config):
|
||||
for output in outputs:
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
color = infer_config.colors[label]
|
||||
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
||||
cv2.putText(image, "{}: {:.2f}".format(label, score),
|
||||
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
return image
|
||||
|
||||
|
||||
def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
# load preprocess transforms
|
||||
transforms = Compose(infer_config.preprocess_infos)
|
||||
errImgList = []
|
||||
|
||||
# Check and create subimg_save_dir if not exist
|
||||
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
|
||||
os.makedirs(subimg_save_dir, exist_ok=True)
|
||||
|
||||
first_image_skipped = False
|
||||
total_time = 0
|
||||
num_images = 0
|
||||
# predict image
|
||||
for img_path in tqdm(img_list):
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f"Warning: Could not read image {img_path}. Skipping...")
|
||||
errImgList.append(img_path)
|
||||
continue
|
||||
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None, ] for k in inputs_name}
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)
|
||||
|
||||
# Stop timing
|
||||
end_time = time.time()
|
||||
inference_time = end_time - start_time
|
||||
if not first_image_skipped:
|
||||
first_image_skipped = True
|
||||
else:
|
||||
total_time += inference_time
|
||||
num_images += 1
|
||||
print(f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds")
|
||||
|
||||
print("ONNXRuntime predict: ")
|
||||
if infer_config.arch in ["HRNet"]:
|
||||
print(np.array(outputs[0]))
|
||||
else:
|
||||
bboxes = np.array(outputs[0])
|
||||
for bbox in bboxes:
|
||||
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
||||
print(f"{int(bbox[0])} {bbox[1]} "
|
||||
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
||||
|
||||
# Save the subimages (crop from the original image)
|
||||
subimg_counter = 1
|
||||
for output in np.array(outputs[0]):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
|
||||
if len(subimg) == 0:
|
||||
continue
|
||||
|
||||
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
|
||||
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
|
||||
cv2.imwrite(subimg_path, subimg)
|
||||
subimg_counter += 1
|
||||
|
||||
# Draw bounding boxes and save the image with bounding boxes
|
||||
img_with_mask = img.copy()
|
||||
for output in np.array(outputs[0]):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
cv2.rectangle(img_with_mask, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 255), -1) # 盖白
|
||||
|
||||
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
|
||||
|
||||
output_dir = imgsave_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
draw_box_dir = os.path.join(output_dir, 'draw_box')
|
||||
mask_white_dir = os.path.join(output_dir, 'mask_white')
|
||||
os.makedirs(draw_box_dir, exist_ok=True)
|
||||
os.makedirs(mask_white_dir, exist_ok=True)
|
||||
|
||||
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
|
||||
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
|
||||
cv2.imwrite(output_file_mask, img_with_mask)
|
||||
cv2.imwrite(output_file_bbox, img_with_bbox)
|
||||
|
||||
avg_time_per_image = total_time / num_images if num_images > 0 else 0
|
||||
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
|
||||
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
|
||||
print("ErrorImgs:")
|
||||
print(errImgList)
|
||||
|
||||
|
||||
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
|
||||
transforms = Compose(infer_config.preprocess_infos)
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None, ] for k in inputs_name}
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||
res = []
|
||||
for output in outputs:
|
||||
cls_name = infer_config.label_list[int(output[0])]
|
||||
score = output[1]
|
||||
xmin = int(max(output[2], 0))
|
||||
ymin = int(max(output[3], 0))
|
||||
xmax = int(output[4])
|
||||
ymax = int(output[5])
|
||||
if score > infer_config.draw_threshold:
|
||||
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
|
||||
|
||||
return res
|
||||
@@ -1,27 +0,0 @@
|
||||
mode: paddle
|
||||
draw_threshold: 0.5
|
||||
metric: COCO
|
||||
use_dynamic_shape: false
|
||||
arch: DETR
|
||||
min_subgraph_size: 3
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 1600
|
||||
- 1600
|
||||
type: Resize
|
||||
- mean:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
norm_type: none
|
||||
std:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
type: NormalizeImage
|
||||
- type: Permute
|
||||
label_list:
|
||||
- isolated
|
||||
- embedding
|
||||
@@ -1,499 +0,0 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
|
||||
|
||||
def decode_image(img_path):
|
||||
if isinstance(img_path, str):
|
||||
with open(img_path, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
else:
|
||||
assert isinstance(img_path, np.ndarray)
|
||||
data = img_path
|
||||
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
img_info = {
|
||||
"im_shape": np.array(
|
||||
im.shape[:2], dtype=np.float32),
|
||||
"scale_factor": np.array(
|
||||
[1., 1.], dtype=np.float32)
|
||||
}
|
||||
return im, img_info
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == 'mean_std':
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Permute(object):
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class PadStride(object):
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
class LetterBoxResize(object):
|
||||
def __init__(self, target_size):
|
||||
"""
|
||||
Resize image to target size, convert normalized xywh to pixel xyxy
|
||||
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
|
||||
Args:
|
||||
target_size (int|list): image target size.
|
||||
"""
|
||||
super(LetterBoxResize, self).__init__()
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
|
||||
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
|
||||
# letterbox: resize a rectangular image to a padded rectangular
|
||||
shape = img.shape[:2] # [height, width]
|
||||
ratio_h = float(height) / shape[0]
|
||||
ratio_w = float(width) / shape[1]
|
||||
ratio = min(ratio_h, ratio_w)
|
||||
new_shape = (round(shape[1] * ratio),
|
||||
round(shape[0] * ratio)) # [width, height]
|
||||
padw = (width - new_shape[0]) / 2
|
||||
padh = (height - new_shape[1]) / 2
|
||||
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
||||
left, right = round(padw - 0.1), round(padw + 0.1)
|
||||
|
||||
img = cv2.resize(
|
||||
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||
value=color) # padded rectangular
|
||||
return img, ratio, padw, padh
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
height, width = self.target_size
|
||||
h, w = im.shape[:2]
|
||||
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
|
||||
|
||||
new_shape = [round(h * ratio), round(w * ratio)]
|
||||
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Pad(object):
|
||||
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
|
||||
"""
|
||||
Pad image to a specified size.
|
||||
Args:
|
||||
size (list[int]): image target size
|
||||
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
|
||||
"""
|
||||
super(Pad, self).__init__()
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
im_h, im_w = im.shape[:2]
|
||||
h, w = self.size
|
||||
if h == im_h and w == im_w:
|
||||
im = im.astype(np.float32)
|
||||
return im, im_info
|
||||
|
||||
canvas = np.ones((h, w, 3), dtype=np.float32)
|
||||
canvas *= np.array(self.fill_value, dtype=np.float32)
|
||||
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
|
||||
im = canvas
|
||||
return im, im_info
|
||||
|
||||
|
||||
def rotate_point(pt, angle_rad):
|
||||
"""Rotate a point by an angle.
|
||||
|
||||
Args:
|
||||
pt (list[float]): 2 dimensional point to be rotated
|
||||
angle_rad (float): rotation angle by radian
|
||||
|
||||
Returns:
|
||||
list[float]: Rotated point.
|
||||
"""
|
||||
assert len(pt) == 2
|
||||
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
||||
new_x = pt[0] * cs - pt[1] * sn
|
||||
new_y = pt[0] * sn + pt[1] * cs
|
||||
rotated_pt = [new_x, new_y]
|
||||
|
||||
return rotated_pt
|
||||
|
||||
|
||||
def _get_3rd_point(a, b):
|
||||
"""To calculate the affine matrix, three pairs of points are required. This
|
||||
function is used to get the 3rd point, given 2D points a & b.
|
||||
|
||||
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
||||
anticlockwise, using b as the rotation center.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): point(x,y)
|
||||
b (np.ndarray): point(x,y)
|
||||
|
||||
Returns:
|
||||
np.ndarray: The 3rd point.
|
||||
"""
|
||||
assert len(a) == 2
|
||||
assert len(b) == 2
|
||||
direction = a - b
|
||||
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
|
||||
|
||||
return third_pt
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
input_size,
|
||||
rot,
|
||||
output_size,
|
||||
shift=(0., 0.),
|
||||
inv=False):
|
||||
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
||||
|
||||
Args:
|
||||
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
||||
scale (np.ndarray[2, ]): Scale of the bounding box
|
||||
wrt [width, height].
|
||||
rot (float): Rotation angle (degree).
|
||||
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
|
||||
shift (0-100%): Shift translation ratio wrt the width/height.
|
||||
Default (0., 0.).
|
||||
inv (bool): Option to inverse the affine transform direction.
|
||||
(inv=False: src->dst or inv=True: dst->src)
|
||||
|
||||
Returns:
|
||||
np.ndarray: The transform matrix.
|
||||
"""
|
||||
assert len(center) == 2
|
||||
assert len(output_size) == 2
|
||||
assert len(shift) == 2
|
||||
if not isinstance(input_size, (np.ndarray, list)):
|
||||
input_size = np.array([input_size, input_size], dtype=np.float32)
|
||||
scale_tmp = input_size
|
||||
|
||||
shift = np.array(shift)
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0., dst_w * -0.5])
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
class WarpAffine(object):
|
||||
"""Warp affine the image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
keep_res=False,
|
||||
pad=31,
|
||||
input_h=512,
|
||||
input_w=512,
|
||||
scale=0.4,
|
||||
shift=0.1):
|
||||
self.keep_res = keep_res
|
||||
self.pad = pad
|
||||
self.input_h = input_h
|
||||
self.input_w = input_w
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if self.keep_res:
|
||||
input_h = (h | self.pad) + 1
|
||||
input_w = (w | self.pad) + 1
|
||||
s = np.array([input_w, input_h], dtype=np.float32)
|
||||
c = np.array([w // 2, h // 2], dtype=np.float32)
|
||||
|
||||
else:
|
||||
s = max(h, w) * 1.0
|
||||
input_h, input_w = self.input_h, self.input_w
|
||||
c = np.array([w / 2., h / 2.], dtype=np.float32)
|
||||
|
||||
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
||||
img = cv2.resize(img, (w, h))
|
||||
inp = cv2.warpAffine(
|
||||
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
||||
return inp, im_info
|
||||
|
||||
|
||||
# keypoint preprocess
|
||||
def get_warp_matrix(theta, size_input, size_dst, size_target):
|
||||
"""This code is based on
|
||||
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
|
||||
|
||||
Calculate the transformation matrix under the constraint of unbiased.
|
||||
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
|
||||
Data Processing for Human Pose Estimation (CVPR 2020).
|
||||
|
||||
Args:
|
||||
theta (float): Rotation angle in degrees.
|
||||
size_input (np.ndarray): Size of input image [w, h].
|
||||
size_dst (np.ndarray): Size of output image [w, h].
|
||||
size_target (np.ndarray): Size of ROI in input plane [w, h].
|
||||
|
||||
Returns:
|
||||
matrix (np.ndarray): A matrix for transformation.
|
||||
"""
|
||||
theta = np.deg2rad(theta)
|
||||
matrix = np.zeros((2, 3), dtype=np.float32)
|
||||
scale_x = size_dst[0] / size_target[0]
|
||||
scale_y = size_dst[1] / size_target[1]
|
||||
matrix[0, 0] = np.cos(theta) * scale_x
|
||||
matrix[0, 1] = -np.sin(theta) * scale_x
|
||||
matrix[0, 2] = scale_x * (
|
||||
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
|
||||
np.sin(theta) + 0.5 * size_target[0])
|
||||
matrix[1, 0] = np.sin(theta) * scale_y
|
||||
matrix[1, 1] = np.cos(theta) * scale_y
|
||||
matrix[1, 2] = scale_y * (
|
||||
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
|
||||
np.cos(theta) + 0.5 * size_target[1])
|
||||
return matrix
|
||||
|
||||
|
||||
class TopDownEvalAffine(object):
|
||||
"""apply affine transform to image and coords
|
||||
|
||||
Args:
|
||||
trainsize (list): [w, h], the standard size used to train
|
||||
use_udp (bool): whether to use Unbiased Data Processing.
|
||||
records(dict): the dict contained the image and coords
|
||||
|
||||
Returns:
|
||||
records (dict): contain the image and coords after tranformed
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, trainsize, use_udp=False):
|
||||
self.trainsize = trainsize
|
||||
self.use_udp = use_udp
|
||||
|
||||
def __call__(self, image, im_info):
|
||||
rot = 0
|
||||
imshape = im_info['im_shape'][::-1]
|
||||
center = im_info['center'] if 'center' in im_info else imshape / 2.
|
||||
scale = im_info['scale'] if 'scale' in im_info else imshape
|
||||
if self.use_udp:
|
||||
trans = get_warp_matrix(
|
||||
rot, center * 2.0,
|
||||
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
else:
|
||||
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
|
||||
return image, im_info
|
||||
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = []
|
||||
for op_info in transforms:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
self.transforms.append(eval(op_type)(**new_op_info))
|
||||
|
||||
def __call__(self, img_path):
|
||||
img, im_info = decode_image(img_path)
|
||||
for t in self.transforms:
|
||||
img, im_info = t(img, im_info)
|
||||
inputs = copy.deepcopy(im_info)
|
||||
inputs['image'] = img
|
||||
return inputs
|
||||
@@ -1,45 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ...globals import (
|
||||
VOCAB_SIZE,
|
||||
FIXED_IMG_SIZE,
|
||||
IMG_CHANNELS,
|
||||
MAX_TOKEN_SIZE
|
||||
)
|
||||
|
||||
from transformers import (
|
||||
RobertaTokenizerFast,
|
||||
VisionEncoderDecoderModel,
|
||||
VisionEncoderDecoderConfig
|
||||
)
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
REPO_NAME = 'OleehyO/TexTeller'
|
||||
def __init__(self):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||
|
||||
super().__init__(config=config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
||||
if model_path is None or model_path == 'default':
|
||||
if not use_onnx:
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
else:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
use_gpu = True if onnx_provider == 'cuda' else False
|
||||
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_path is None or tokenizer_path == 'default':
|
||||
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
@@ -1,168 +0,0 @@
|
||||
{
|
||||
"_name_or_path": "OleehyO/TexTeller",
|
||||
"architectures": [
|
||||
"VisionEncoderDecoderModel"
|
||||
],
|
||||
"decoder": {
|
||||
"_name_or_path": "",
|
||||
"activation_dropout": 0.0,
|
||||
"activation_function": "gelu",
|
||||
"add_cross_attention": true,
|
||||
"architectures": null,
|
||||
"attention_dropout": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"begin_suppress_tokens": null,
|
||||
"bos_token_id": 0,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"classifier_dropout": 0.0,
|
||||
"cross_attention_hidden_size": 768,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_ffn_dim": 4096,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 12,
|
||||
"decoder_start_token_id": 2,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"dropout": 0.1,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"eos_token_id": 2,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_decoder": true,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layernorm_embedding": true,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"max_position_embeddings": 1024,
|
||||
"min_length": 0,
|
||||
"model_type": "trocr",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": 1,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"scale_embedding": false,
|
||||
"sep_token_id": null,
|
||||
"suppress_tokens": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tf_legacy_loss": false,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false,
|
||||
"use_cache": false,
|
||||
"use_learned_position_embeddings": true,
|
||||
"vocab_size": 15000
|
||||
},
|
||||
"encoder": {
|
||||
"_name_or_path": "",
|
||||
"add_cross_attention": false,
|
||||
"architectures": null,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"begin_suppress_tokens": null,
|
||||
"bos_token_id": null,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"cross_attention_hidden_size": null,
|
||||
"decoder_start_token_id": null,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_stride": 16,
|
||||
"eos_token_id": null,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"image_size": 448,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"is_decoder": false,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"min_length": 0,
|
||||
"model_type": "vit",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_attention_heads": 12,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_channels": 1,
|
||||
"num_hidden_layers": 12,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": null,
|
||||
"patch_size": 16,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"qkv_bias": false,
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"sep_token_id": null,
|
||||
"suppress_tokens": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tf_legacy_loss": false,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false
|
||||
},
|
||||
"is_encoder_decoder": true,
|
||||
"model_type": "vision-encoder-decoder",
|
||||
"tie_word_embeddings": false,
|
||||
"transformers_version": "4.41.2",
|
||||
"use_cache": true
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
@@ -1,50 +0,0 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import datasets
|
||||
import json
|
||||
|
||||
DIR_URL = Path('absolute/path/to/dataset/directory')
|
||||
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
|
||||
|
||||
|
||||
class LatexFormulas(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = []
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
features=datasets.Features({
|
||||
"image": datasets.Image(),
|
||||
"latex_formula": datasets.Value("string")
|
||||
})
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
dir_path = Path(dl_manager.download(str(DIR_URL)))
|
||||
assert dir_path.is_dir()
|
||||
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
'dir_path': dir_path,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, dir_path: Path):
|
||||
images_path = dir_path / 'images'
|
||||
formulas_path = dir_path / 'formulas.jsonl'
|
||||
|
||||
img2formula = {}
|
||||
with formulas_path.open('r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
single_json = json.loads(line)
|
||||
img2formula[single_json['img_name']] = single_json['formula']
|
||||
|
||||
for img_path in images_path.iterdir():
|
||||
if img_path.suffix not in ['.jpg', '.png']:
|
||||
continue
|
||||
yield str(img_path), {
|
||||
"image": Image.open(img_path),
|
||||
"latex_formula": img2formula[img_path.name]
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig
|
||||
)
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eval_config = CONFIG.copy()
|
||||
eval_config['predict_with_generate'] = True
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
eval_config['generation_config'] = generate_config
|
||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model,
|
||||
seq2seq_config,
|
||||
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||
)
|
||||
|
||||
eval_res = trainer.evaluate()
|
||||
print(eval_res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
dataset = dataset.filter(
|
||||
filter_fn_with_tokenizer,
|
||||
num_proc=8
|
||||
)
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
model = TexTeller()
|
||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
#+e.g.
|
||||
#+model = TexTeller.from_pretrained(
|
||||
#+ '/path/to/your/model_checkpoint'
|
||||
#+)
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate and len(eval_dataset) > 0:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
@@ -1,38 +0,0 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import torch
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from .transforms import train_transform, inference_transform
|
||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||
tokenized_formula['pixel_values'] = samples['image']
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch['pixel_values'] = pixel_values
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = inference_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
processed_images = []
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
if image.dtype == np.uint16:
|
||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0))
|
||||
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
@@ -1,49 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List, Union
|
||||
|
||||
from .transforms import inference_transform
|
||||
from .helpers import convert2rgb
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ...globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs: Union[List[str], List[np.ndarray]],
|
||||
accelerator: str = 'cpu',
|
||||
num_beams: int = 1,
|
||||
max_tokens = None
|
||||
) -> List[str]:
|
||||
if imgs == []:
|
||||
return []
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, turn model.eval()
|
||||
model.eval()
|
||||
if isinstance(imgs[0], str):
|
||||
imgs = convert2rgb(imgs)
|
||||
else: # already numpy array(rgb format)
|
||||
assert isinstance(imgs[0], np.ndarray)
|
||||
imgs = imgs
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, move weights to device
|
||||
model = model.to(accelerator)
|
||||
pixel_values = pixel_values.to(accelerator)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
@@ -1,23 +0,0 @@
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from transformers import EvalPrediction, RobertaTokenizer
|
||||
|
||||
|
||||
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
|
||||
cur_dir = Path(os.getcwd())
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded
|
||||
os.chdir(cur_dir)
|
||||
|
||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
preds = logits
|
||||
|
||||
labels = np.where(labels == -100, 1, labels)
|
||||
|
||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
@@ -1,180 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
result = ""
|
||||
i = 0
|
||||
n = len(input_str)
|
||||
|
||||
while i < n:
|
||||
if input_str[i:i+len(old_inst)] == old_inst:
|
||||
# check if the old_inst is followed by old_surr_l
|
||||
start = i + len(old_inst)
|
||||
else:
|
||||
result += input_str[i]
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if start < n and input_str[start] == old_surr_l:
|
||||
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
|
||||
count = 1
|
||||
j = start + 1
|
||||
escaped = False
|
||||
while j < n and count > 0:
|
||||
if input_str[j] == '\\' and not escaped:
|
||||
escaped = True
|
||||
j += 1
|
||||
continue
|
||||
if input_str[j] == old_surr_r and not escaped:
|
||||
count -= 1
|
||||
if count == 0:
|
||||
break
|
||||
elif input_str[j] == old_surr_l and not escaped:
|
||||
count += 1
|
||||
escaped = False
|
||||
j += 1
|
||||
|
||||
if count == 0:
|
||||
assert j < n
|
||||
assert input_str[start] == old_surr_l
|
||||
assert input_str[j] == old_surr_r
|
||||
inner_content = input_str[start + 1:j]
|
||||
# Replace the content with new pattern
|
||||
result += new_inst + new_surr_l + inner_content + new_surr_r
|
||||
i = j + 1
|
||||
continue
|
||||
else:
|
||||
assert count >= 1
|
||||
assert j == n
|
||||
print("Warning: unbalanced surrogate pair in input string")
|
||||
result += new_inst + new_surr_l
|
||||
i = start + 1
|
||||
continue
|
||||
else:
|
||||
result += input_str[i:start]
|
||||
i = start
|
||||
|
||||
if old_inst != new_inst and (old_inst + old_surr_l) in result:
|
||||
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def find_substring_positions(string, substring):
|
||||
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
|
||||
return positions
|
||||
|
||||
|
||||
def rm_dollar_surr(content):
|
||||
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
|
||||
matches = pattern.findall(content)
|
||||
|
||||
for match in matches:
|
||||
if not re.match(r'\\[a-zA-Z]+', match):
|
||||
new_match = match.strip('$')
|
||||
content = content.replace(match, ' ' + new_match + ' ')
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
pos = find_substring_positions(input_str, old_inst + old_surr_l)
|
||||
res = list(input_str)
|
||||
for p in pos[::-1]:
|
||||
res[p:] = list(change(''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r))
|
||||
res = ''.join(res)
|
||||
return res
|
||||
|
||||
|
||||
def to_katex(formula: str) -> str:
|
||||
res = formula
|
||||
# remove mbox surrounding
|
||||
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
|
||||
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
|
||||
# remove hbox surrounding
|
||||
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
|
||||
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove raise surrounding
|
||||
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
|
||||
# remove makebox
|
||||
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
|
||||
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove vbox surrounding, scalebox surrounding
|
||||
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
|
||||
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
|
||||
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
|
||||
origin_instructions = [
|
||||
r'\Huge',
|
||||
r'\huge',
|
||||
r'\LARGE',
|
||||
r'\Large',
|
||||
r'\large',
|
||||
r'\normalsize',
|
||||
r'\small',
|
||||
r'\footnotesize',
|
||||
r'\tiny'
|
||||
]
|
||||
for (old_ins, new_ins) in zip(origin_instructions, origin_instructions):
|
||||
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
|
||||
|
||||
origin_instructions = [
|
||||
r'\left',
|
||||
r'\middle',
|
||||
r'\right',
|
||||
r'\big',
|
||||
r'\Big',
|
||||
r'\bigg',
|
||||
r'\Bigg',
|
||||
r'\bigl',
|
||||
r'\Bigl',
|
||||
r'\biggl',
|
||||
r'\Biggl',
|
||||
r'\bigm',
|
||||
r'\Bigm',
|
||||
r'\biggm',
|
||||
r'\Biggm',
|
||||
r'\bigr',
|
||||
r'\Bigr',
|
||||
r'\biggr',
|
||||
r'\Biggr'
|
||||
]
|
||||
for origin_ins in origin_instructions:
|
||||
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
|
||||
|
||||
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
|
||||
|
||||
if res.endswith(r'\newline'):
|
||||
res = res[:-8]
|
||||
|
||||
# remove multiple spaces
|
||||
res = re.sub(r'(\\,){1,}', ' ', res)
|
||||
res = re.sub(r'(\\!){1,}', ' ', res)
|
||||
res = re.sub(r'(\\;){1,}', ' ', res)
|
||||
res = re.sub(r'(\\:){1,}', ' ', res)
|
||||
res = re.sub(r'\\vspace\{.*?}', '', res)
|
||||
|
||||
# merge consecutive text
|
||||
def merge_texts(match):
|
||||
texts = match.group(0)
|
||||
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
|
||||
return f'\\text{{{merged_content}}}'
|
||||
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
|
||||
|
||||
res = res.replace(r'\bf ', '')
|
||||
res = rm_dollar_surr(res)
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
res = re.sub(r' +', ' ', res)
|
||||
|
||||
return res.strip()
|
||||
@@ -1,25 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
from ..ocr_model.model.TexTeller import TexTeller
|
||||
from ..globals import VOCAB_SIZE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
|
||||
# Don't forget to config your dataset path in loader.py
|
||||
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
|
||||
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
text_iterator=dataset['latex_formula'],
|
||||
|
||||
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
|
||||
vocab_size=VOCAB_SIZE
|
||||
)
|
||||
|
||||
# Save the new tokenizer for later training and inference
|
||||
new_tokenizer.save_pretrained('./your_dir_name')
|
||||
@@ -1 +0,0 @@
|
||||
from .mix_inference import mix_inference
|
||||
@@ -1,264 +0,0 @@
|
||||
import re
|
||||
import heapq
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
from ..det_model.inference import predict as latex_det_predict
|
||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
||||
|
||||
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
||||
from ..ocr_model.utils.to_katex import to_katex, change_all
|
||||
|
||||
MAXV = 999999999
|
||||
|
||||
|
||||
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
|
||||
mask_img = img.copy()
|
||||
for bbox in bboxes:
|
||||
mask_img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w] = bg_color
|
||||
return mask_img
|
||||
|
||||
|
||||
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
|
||||
if (len(sorted_bboxes) == 0):
|
||||
return []
|
||||
bboxes = sorted_bboxes.copy()
|
||||
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
|
||||
bboxes.append(guard)
|
||||
res = []
|
||||
prev = bboxes[0]
|
||||
for curr in bboxes:
|
||||
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
|
||||
res.append(prev)
|
||||
prev = curr
|
||||
else:
|
||||
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
|
||||
return res
|
||||
|
||||
|
||||
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
|
||||
if latex_bboxes == []:
|
||||
return ocr_bboxes
|
||||
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
|
||||
return ocr_bboxes
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(bboxes):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
|
||||
|
||||
assert len(bboxes) > 1
|
||||
|
||||
heapq.heapify(bboxes)
|
||||
res = []
|
||||
candidate = heapq.heappop(bboxes)
|
||||
curr = heapq.heappop(bboxes)
|
||||
idx = 0
|
||||
while (len(bboxes) > 0):
|
||||
idx += 1
|
||||
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
||||
|
||||
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.ur_point.x < curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
if candidate.label == "text" and curr.label == "text":
|
||||
candidate.w = curr.ur_point.x - candidate.p.x
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.label != curr.label:
|
||||
if candidate.label == "text":
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
curr.w = curr.ur_point.x - candidate.ur_point.x
|
||||
curr.p.x = candidate.ur_point.x
|
||||
heapq.heappush(bboxes, curr)
|
||||
curr = heapq.heappop(bboxes)
|
||||
|
||||
elif candidate.ur_point.x >= curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
|
||||
if candidate.label == "text":
|
||||
assert curr.label != "text"
|
||||
heapq.heappush(
|
||||
bboxes,
|
||||
Bbox(
|
||||
curr.ur_point.x,
|
||||
candidate.p.y,
|
||||
candidate.h,
|
||||
candidate.ur_point.x - curr.ur_point.x,
|
||||
label="text",
|
||||
confidence=candidate.confidence,
|
||||
content=None
|
||||
)
|
||||
)
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert curr.label == "text"
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert False
|
||||
res.append(candidate)
|
||||
res.append(curr)
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(res):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
|
||||
sliced_imgs = []
|
||||
for bbox in ocr_bboxes:
|
||||
x, y = int(bbox.p.x), int(bbox.p.y)
|
||||
w, h = int(bbox.w), int(bbox.h)
|
||||
sliced_img = img[y:y+h, x:x+w]
|
||||
sliced_imgs.append(sliced_img)
|
||||
return sliced_imgs
|
||||
|
||||
|
||||
def mix_inference(
|
||||
img_path: str,
|
||||
infer_config,
|
||||
latex_det_model,
|
||||
|
||||
lang_ocr_models,
|
||||
|
||||
latex_rec_models,
|
||||
accelerator="cpu",
|
||||
num_beams=1
|
||||
) -> str:
|
||||
'''
|
||||
Input a mixed image of formula text and output str (in markdown syntax)
|
||||
'''
|
||||
global img
|
||||
img = cv2.imread(img_path)
|
||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
|
||||
tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||
|
||||
start_time = time.time()
|
||||
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
||||
end_time = time.time()
|
||||
print(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||
latex_bboxes = sorted(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
||||
latex_bboxes = bbox_merge(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
|
||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||
|
||||
det_model, rec_model = lang_ocr_models
|
||||
start_time = time.time()
|
||||
det_prediction, _ = det_model(masked_img)
|
||||
end_time = time.time()
|
||||
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||
ocr_bboxes = [
|
||||
Bbox(
|
||||
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
|
||||
label="text",
|
||||
confidence=None,
|
||||
content=None
|
||||
)
|
||||
for p in det_prediction
|
||||
]
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
|
||||
|
||||
ocr_bboxes = sorted(ocr_bboxes)
|
||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
|
||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||
|
||||
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||
start_time = time.time()
|
||||
rec_predictions, _ = rec_model(sliced_imgs)
|
||||
end_time = time.time()
|
||||
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
assert len(rec_predictions) == len(ocr_bboxes)
|
||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||
bbox.content = content[0]
|
||||
|
||||
latex_imgs =[]
|
||||
for bbox in latex_bboxes:
|
||||
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
|
||||
start_time = time.time()
|
||||
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800)
|
||||
end_time = time.time()
|
||||
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||
bbox.content = to_katex(content)
|
||||
if bbox.label == "embedding":
|
||||
bbox.content = " $" + bbox.content + "$ "
|
||||
elif bbox.label == "isolated":
|
||||
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
|
||||
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
if bboxes == []:
|
||||
return ""
|
||||
|
||||
md = ""
|
||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||
for curr in bboxes:
|
||||
# Add the formula number back to the isolated formula
|
||||
if (
|
||||
prev.label == "isolated"
|
||||
and curr.label == "text"
|
||||
and prev.same_row(curr)
|
||||
):
|
||||
curr.content = curr.content.strip()
|
||||
if curr.content.startswith('(') and curr.content.endswith(')'):
|
||||
curr.content = curr.content[1:-1]
|
||||
|
||||
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
|
||||
# in case of multiple tag
|
||||
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
|
||||
else:
|
||||
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
|
||||
continue
|
||||
|
||||
if not prev.same_row(curr):
|
||||
md += " "
|
||||
|
||||
if curr.label == "embedding":
|
||||
# remove the bold effect from inline formulas
|
||||
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
# change split environment into aligned
|
||||
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
|
||||
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
curr.content = re.sub(r' +', ' ', curr.content)
|
||||
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
|
||||
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
|
||||
md += curr.content
|
||||
prev = curr
|
||||
return md.strip()
|
||||
@@ -1,65 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import cv2 as cv
|
||||
from pathlib import Path
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
from models.ocr_model.utils.inference import inference as latex_inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-img_dir',
|
||||
type=str,
|
||||
help='path to the input image',
|
||||
default='./detect_results/subimages'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-output_dir',
|
||||
type=str,
|
||||
help='path to the output dir',
|
||||
default='./rec_results'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inference-mode',
|
||||
type=str,
|
||||
default='cpu',
|
||||
help='Inference mode, select one of cpu, cuda, or mps'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num-beam',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of beam search for decoding'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print('Loading model and tokenizer...')
|
||||
latex_rec_model = TexTeller.from_pretrained()
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
print('Model and tokenizer loaded.')
|
||||
|
||||
# Create the output directory if it doesn't exist
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Loop through all images in the input directory
|
||||
for filename in os.listdir(args.img_dir):
|
||||
img_path = os.path.join(args.img_dir, filename)
|
||||
img = cv.imread(img_path)
|
||||
|
||||
if img is not None:
|
||||
print(f'Inference for {filename}...')
|
||||
res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
|
||||
res = to_katex(res[0])
|
||||
|
||||
# Save the recognition result to a text file
|
||||
output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt')
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(res)
|
||||
|
||||
print(f'Result saved to {output_file}')
|
||||
else:
|
||||
print(f"Warning: Could not read image {img_path}. Skipping...")
|
||||
157
src/server.py
@@ -1,157 +0,0 @@
|
||||
import sys
|
||||
import argparse
|
||||
import tempfile
|
||||
import time
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from pathlib import Path
|
||||
from starlette.requests import Request
|
||||
from ray import serve
|
||||
from ray.serve.handle import DeploymentHandle
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from models.ocr_model.utils.inference import inference as rec_inference
|
||||
from models.det_model.inference import predict as det_inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.det_model.inference import PredictConfig
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
|
||||
|
||||
PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
|
||||
LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages'
|
||||
CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib'
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-ckpt', '--checkpoint_dir', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
'-tknz', '--tokenizer_dir', type=str
|
||||
)
|
||||
parser.add_argument('-port', '--server_port', type=int, default=8000)
|
||||
parser.add_argument('--num_replicas', type=int, default=1)
|
||||
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
|
||||
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
|
||||
|
||||
parser.add_argument('--inference-mode', type=str, default='cpu')
|
||||
parser.add_argument('--num_beams', type=int, default=1)
|
||||
parser.add_argument('-onnx', action='store_true', help='using onnx runtime')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
|
||||
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
|
||||
|
||||
|
||||
@serve.deployment(
|
||||
num_replicas=args.num_replicas,
|
||||
ray_actor_options={
|
||||
"num_cpus": args.ncpu_per_replica,
|
||||
"num_gpus": args.ngpu_per_replica * 1.0 / 2
|
||||
}
|
||||
)
|
||||
class TexTellerRecServer:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
tokenizer_path: str,
|
||||
inf_mode: str = 'cpu',
|
||||
use_onnx: bool = False,
|
||||
num_beams: int = 1
|
||||
) -> None:
|
||||
self.model = TexTeller.from_pretrained(checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode)
|
||||
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
|
||||
self.inf_mode = inf_mode
|
||||
self.num_beams = num_beams
|
||||
|
||||
if not use_onnx:
|
||||
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
||||
|
||||
def predict(self, image_nparray) -> str:
|
||||
return to_katex(rec_inference(
|
||||
self.model, self.tokenizer, [image_nparray],
|
||||
accelerator=self.inf_mode, num_beams=self.num_beams
|
||||
)[0])
|
||||
|
||||
@serve.deployment(
|
||||
num_replicas=args.num_replicas,
|
||||
ray_actor_options={
|
||||
"num_cpus": args.ncpu_per_replica,
|
||||
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
|
||||
"runtime_env": {
|
||||
"env_vars": {
|
||||
"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
class TexTellerDetServer:
|
||||
def __init__(
|
||||
self,
|
||||
inf_mode='cpu'
|
||||
):
|
||||
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||
self.latex_det_model = InferenceSession(
|
||||
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
||||
providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
async def predict(self, image_nparray) -> str:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
img_path = f"{temp_dir}/temp_image.jpg"
|
||||
cv2.imwrite(img_path, image_nparray)
|
||||
|
||||
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
|
||||
return latex_bboxes
|
||||
|
||||
|
||||
@serve.deployment()
|
||||
class Ingress:
|
||||
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
|
||||
self.det_server = det_server
|
||||
self.texteller_server = rec_server
|
||||
|
||||
async def __call__(self, request: Request) -> str:
|
||||
request_path = request.url.path
|
||||
form = await request.form()
|
||||
img_rb = await form['img'].read()
|
||||
|
||||
img_nparray = np.frombuffer(img_rb, np.uint8)
|
||||
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
||||
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if request_path.startswith("/fdet"):
|
||||
if self.det_server == None:
|
||||
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
|
||||
pred = await self.det_server.predict.remote(img_nparray)
|
||||
return pred
|
||||
|
||||
elif request_path.startswith("/frec"):
|
||||
pred = await self.texteller_server.predict.remote(img_nparray)
|
||||
return pred
|
||||
|
||||
else:
|
||||
return "[ERROR] Invalid request path"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ckpt_dir = args.checkpoint_dir
|
||||
tknz_dir = args.tokenizer_dir
|
||||
|
||||
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
|
||||
rec_server = TexTellerRecServer.bind(
|
||||
ckpt_dir, tknz_dir,
|
||||
inf_mode=args.inference_mode,
|
||||
use_onnx=args.onnx,
|
||||
num_beams=args.num_beams
|
||||
)
|
||||
det_server = None
|
||||
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
|
||||
det_server = TexTellerDetServer.bind(args.inference_mode)
|
||||
ingress = Ingress.bind(det_server, rec_server)
|
||||
|
||||
# ingress_handle = serve.run(ingress, route_prefix="/predict")
|
||||
ingress_handle = serve.run(ingress, route_prefix="/")
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
@@ -1,9 +0,0 @@
|
||||
@echo off
|
||||
SETLOCAL ENABLEEXTENSIONS
|
||||
|
||||
set CHECKPOINT_DIR=default
|
||||
set TOKENIZER_DIR=default
|
||||
|
||||
streamlit run web.py
|
||||
|
||||
ENDLOCAL
|
||||
@@ -1,7 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -exu
|
||||
|
||||
export CHECKPOINT_DIR="default"
|
||||
export TOKENIZER_DIR="default"
|
||||
|
||||
streamlit run web.py
|
||||
@@ -1,14 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
gpu_ids: all
|
||||
num_processes: 1
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
272
src/web.py
@@ -1,272 +0,0 @@
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import base64
|
||||
import tempfile
|
||||
import shutil
|
||||
import streamlit as st
|
||||
|
||||
from PIL import Image
|
||||
from streamlit_paste_button import paste_image_button as pbutton
|
||||
from onnxruntime import InferenceSession
|
||||
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
||||
from models.thrid_party.paddleocr.infer import utility
|
||||
|
||||
from models.utils import mix_inference
|
||||
from models.det_model.inference import PredictConfig
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.inference import inference as latex_recognition
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
|
||||
|
||||
st.set_page_config(
|
||||
page_title="TexTeller",
|
||||
page_icon="🧮"
|
||||
)
|
||||
|
||||
html_string = '''
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
|
||||
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||
</h1>
|
||||
'''
|
||||
|
||||
suc_gif_html = '''
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||
</h1>
|
||||
'''
|
||||
|
||||
fail_gif_html = '''
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
|
||||
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
|
||||
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
|
||||
</h1>
|
||||
'''
|
||||
|
||||
@st.cache_resource
|
||||
def get_texteller(use_onnx, accelerator):
|
||||
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'], use_onnx=use_onnx, onnx_provider=accelerator)
|
||||
|
||||
@st.cache_resource
|
||||
def get_tokenizer():
|
||||
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
||||
|
||||
@st.cache_resource
|
||||
def get_det_models(accelerator):
|
||||
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||
latex_det_model = InferenceSession(
|
||||
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
||||
providers=['CUDAExecutionProvider'] if accelerator == 'cuda' else ['CPUExecutionProvider']
|
||||
)
|
||||
return infer_config, latex_det_model
|
||||
|
||||
@st.cache_resource()
|
||||
def get_ocr_models(accelerator):
|
||||
use_gpu = accelerator == 'cuda'
|
||||
|
||||
SIZE_LIMIT = 20 * 1024 * 1024
|
||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||
det_use_gpu = False
|
||||
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||
|
||||
paddleocr_args = utility.parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.det_model_dir = det_model_dir
|
||||
paddleocr_args.rec_model_dir = rec_model_dir
|
||||
|
||||
paddleocr_args.use_gpu = det_use_gpu
|
||||
detector = predict_det.TextDetector(paddleocr_args)
|
||||
paddleocr_args.use_gpu = rec_use_gpu
|
||||
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||
return [detector, recognizer]
|
||||
|
||||
|
||||
def get_image_base64(img_file):
|
||||
buffered = io.BytesIO()
|
||||
img_file.seek(0)
|
||||
img = Image.open(img_file)
|
||||
img.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
def on_file_upload():
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = True
|
||||
|
||||
def change_side_bar():
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
|
||||
|
||||
if "start" not in st.session_state:
|
||||
st.session_state["start"] = 1
|
||||
st.toast('Hooray!', icon='🎉')
|
||||
|
||||
if "UPLOADED_FILE_CHANGED" not in st.session_state:
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||
|
||||
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
|
||||
if "INF_MODE" not in st.session_state:
|
||||
st.session_state["INF_MODE"] = "Formula recognition"
|
||||
|
||||
|
||||
############################## <sidebar> ##############################
|
||||
|
||||
with st.sidebar:
|
||||
num_beams = 1
|
||||
|
||||
st.markdown("# 🔨️ Config")
|
||||
st.markdown("")
|
||||
|
||||
inf_mode = st.selectbox(
|
||||
"Inference mode",
|
||||
("Formula recognition", "Paragraph recognition"),
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
num_beams = st.number_input(
|
||||
'Number of beams',
|
||||
min_value=1,
|
||||
max_value=20,
|
||||
step=1,
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
accelerator = st.radio(
|
||||
"Accelerator",
|
||||
("cpu", "cuda", "mps"),
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
st.markdown("## Seedup")
|
||||
use_onnx = st.toggle("ONNX Runtime ")
|
||||
|
||||
|
||||
|
||||
############################## </sidebar> ##############################
|
||||
|
||||
|
||||
################################ <page> ################################
|
||||
|
||||
texteller = get_texteller(use_onnx, accelerator)
|
||||
tokenizer = get_tokenizer()
|
||||
latex_rec_models = [texteller, tokenizer]
|
||||
|
||||
if inf_mode == "Paragraph recognition":
|
||||
infer_config, latex_det_model = get_det_models(accelerator)
|
||||
lang_ocr_models = get_ocr_models(accelerator)
|
||||
|
||||
st.markdown(html_string, unsafe_allow_html=True)
|
||||
|
||||
uploaded_file = st.file_uploader(
|
||||
" ",
|
||||
type=['jpg', 'png'],
|
||||
on_change=on_file_upload
|
||||
)
|
||||
|
||||
paste_result = pbutton(
|
||||
label="📋 Paste an image",
|
||||
background_color="#5BBCFF",
|
||||
hover_background_color="#3498db",
|
||||
)
|
||||
st.write("")
|
||||
|
||||
if st.session_state["CHANGE_SIDEBAR_FLAG"] == True:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
elif uploaded_file or paste_result.image_data is not None:
|
||||
if st.session_state["UPLOADED_FILE_CHANGED"] == False and paste_result.image_data is not None:
|
||||
uploaded_file = io.BytesIO()
|
||||
paste_result.image_data.save(uploaded_file, format='PNG')
|
||||
uploaded_file.seek(0)
|
||||
|
||||
if st.session_state["UPLOADED_FILE_CHANGED"] == True:
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||
|
||||
img = Image.open(uploaded_file)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
png_file_path = os.path.join(temp_dir, 'image.png')
|
||||
img.save(png_file_path, 'PNG')
|
||||
|
||||
with st.container(height=300):
|
||||
img_base64 = get_image_base64(uploaded_file)
|
||||
|
||||
st.markdown(f"""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
.centered-image {{
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-height: 350px;
|
||||
max-width: 100%;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
st.markdown(f"""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<p style="color:gray;">Input image ({img.height}✖️{img.width})</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
st.write("")
|
||||
|
||||
with st.spinner("Predicting..."):
|
||||
if inf_mode == "Formula recognition":
|
||||
TexTeller_result = latex_recognition(
|
||||
texteller,
|
||||
tokenizer,
|
||||
[png_file_path],
|
||||
accelerator=accelerator,
|
||||
num_beams=num_beams
|
||||
)[0]
|
||||
katex_res = to_katex(TexTeller_result)
|
||||
else:
|
||||
katex_res = mix_inference(png_file_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
|
||||
|
||||
st.success('Completed!', icon="✅")
|
||||
st.markdown(suc_gif_html, unsafe_allow_html=True)
|
||||
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
|
||||
|
||||
if inf_mode == "Formula recognition":
|
||||
st.latex(katex_res)
|
||||
elif inf_mode == "Paragraph recognition":
|
||||
mixed_res = re.split(r'(\$\$.*?\$\$)', katex_res)
|
||||
for text in mixed_res:
|
||||
if text.startswith('$$') and text.endswith('$$'):
|
||||
st.latex(text[2:-2])
|
||||
else:
|
||||
st.markdown(text)
|
||||
|
||||
st.write("")
|
||||
st.write("")
|
||||
|
||||
with st.expander(":star2: :gray[Tips for better results]"):
|
||||
st.markdown('''
|
||||
* :mag_right: Use a clear and high-resolution image.
|
||||
* :scissors: Crop images as accurately as possible.
|
||||
* :jigsaw: Split large multi line formulas into smaller ones.
|
||||
* :page_facing_up: Use images with **white background and black text** as much as possible.
|
||||
* :book: Use a font with good readability.
|
||||
''')
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
paste_result.image_data = None
|
||||
|
||||
################################ </page> ################################
|
||||
67
tests/test_globals.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
from texteller.globals import Globals
|
||||
|
||||
|
||||
def test_singleton_pattern():
|
||||
"""Test that Globals uses the singleton pattern correctly."""
|
||||
# Create two instances
|
||||
globals1 = Globals()
|
||||
globals2 = Globals()
|
||||
|
||||
# Both variables should reference the same object
|
||||
assert globals1 is globals2
|
||||
|
||||
# Modifying one should affect the other
|
||||
globals1.test_attr = "test_value"
|
||||
assert globals2.test_attr == "test_value"
|
||||
|
||||
# Clean up after test
|
||||
delattr(globals1, "test_attr")
|
||||
|
||||
|
||||
def test_predefined_attributes():
|
||||
"""Test predefined attributes have correct default values."""
|
||||
globals_instance = Globals()
|
||||
assert globals_instance.repo_name == "OleehyO/TexTeller"
|
||||
assert globals_instance.logging_level == logging.INFO
|
||||
|
||||
|
||||
def test_attribute_modification():
|
||||
"""Test that attributes can be modified."""
|
||||
globals_instance = Globals()
|
||||
|
||||
# Modify existing attribute
|
||||
original_repo_name = globals_instance.repo_name
|
||||
globals_instance.repo_name = "NewRepo/NewName"
|
||||
assert globals_instance.repo_name == "NewRepo/NewName"
|
||||
|
||||
assert Globals().logging_level == logging.INFO
|
||||
Globals().logging_level = logging.DEBUG
|
||||
assert Globals().logging_level == logging.DEBUG
|
||||
|
||||
# Reset for other tests
|
||||
globals_instance.repo_name = original_repo_name
|
||||
globals_instance.logging_level = logging.INFO
|
||||
|
||||
|
||||
def test_dynamic_attributes():
|
||||
"""Test that new attributes can be added dynamically."""
|
||||
globals_instance = Globals()
|
||||
|
||||
# Add new attribute
|
||||
globals_instance.new_attribute = "new_value"
|
||||
assert globals_instance.new_attribute == "new_value"
|
||||
|
||||
# Clean up after test
|
||||
delattr(globals_instance, "new_attribute")
|
||||
|
||||
|
||||
def test_representation():
|
||||
"""Test the string representation of Globals."""
|
||||
globals_instance = Globals()
|
||||
repr_string = repr(globals_instance)
|
||||
|
||||
# Check that repr contains class name and is formatted as expected
|
||||
assert repr_string.startswith("<Globals:")
|
||||
assert "repo_name" in repr_string
|
||||
assert "logging_level" in repr_string
|
||||
8
tests/test_to_katex.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from texteller import to_katex
|
||||
|
||||
|
||||
def test_to_katex():
|
||||
# Test complex mathematical equations with vectors and symbols
|
||||
complex_input = "\\[\\begin{split}&\\mathbb{E}_{\\bm{\\omega}}[h(\\mathbf{x})h(\\mathbf{y})^{ *}]\\\\ &=\\mathbb{E}_{\\bm{\\omega}}[\\exp(i\\bm{\\omega}^{\\top}\\mathbf{x})\\exp (-i\\bm{\\omega}^{\\top}\\mathbf{y}))]\\\\ &=\\mathbb{E}_{\\bm{\\omega}}[\\exp(i\\bm{\\omega}^{\\top}\\bm{\\delta})] \\\\ &=\\int_{\\mathbb{R}^{D}}p(\\bm{\\omega})\\exp(i\\bm{\\omega}^{\\top}\\bm{ \\delta})\\mathrm{d}\\bm{\\omega}\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\bm{ \\omega}^{\\top}\\bm{\\omega}\\big{)}\\exp(i\\bm{\\omega}^{\\top}\\bm{\\delta})\\mathrm{d} \\bm{\\omega}\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\bm{ \\omega}^{\\top}\\bm{\\omega}-i\\bm{\\omega}^{\\top}\\bm{\\delta}\\big{)}\\mathrm{d}\\bm{ \\omega}\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\big{(} \\bm{\\omega}^{\\top}\\bm{\\omega}-2i\\bm{\\omega}^{\\top}\\bm{\\delta}-\\bm{\\delta}^{ \\top}\\bm{\\delta}\\big{)}-\\frac{1}{2}\\bm{\\delta}^{\\top}\\bm{\\delta}\\big{)} \\mathrm{d}\\bm{\\omega}\\\\ &=(2\\pi)^{-D/2}\\exp\\!\\big{(}-\\frac{1}{2}\\bm{\\delta}^{\\top}\\bm{ \\delta}\\big{)}\\!\\underbrace{\\int_{\\mathbb{R}^{D}}\\exp\\!\\big{(}-\\frac{1}{2}\\big{(} \\bm{\\omega}-i\\bm{\\delta}\\big{)}^{\\top}\\big{(}\\bm{\\omega}-i\\bm{\\delta}\\big{)} \\big{)}\\mathrm{d}\\bm{\\omega}}_{(2\\pi)^{D/2}}\\\\ &=\\exp\\!\\big{(}-\\frac{1}{2}\\bm{\\delta}^{\\top}\\bm{\\delta}\\big{)} \\\\ &=k(\\bm{\\delta}).\\end{split}\\]"
|
||||
expected_output = "\\begin{split}&\\mathbb{E}_{ \\omega}[h( x)h( y)^{ *}]\\\\ &=\\mathbb{E}_{ \\omega}[\\exp(i \\omega^{\\top} x)\\exp (-i \\omega^{\\top} y))]\\\\ &=\\mathbb{E}_{ \\omega}[\\exp(i \\omega^{\\top} \\delta)] \\\\ &=\\int_{\\mathbb{R}^{D}}p( \\omega)\\exp(i \\omega^{\\top} \\delta)\\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2} \\omega^{\\top} \\omega\\big)\\exp(i \\omega^{\\top} \\delta)\\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2} \\omega^{\\top} \\omega-i \\omega^{\\top} \\delta\\big)\\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2}\\big( \\omega^{\\top} \\omega-2i \\omega^{\\top} \\delta- \\delta^{ \\top} \\delta\\big)-\\frac{1}{2} \\delta^{\\top} \\delta\\big) \\mathrm{d} \\omega\\\\ &=(2\\pi)^{-D/2}\\exp \\big(-\\frac{1}{2} \\delta^{\\top} \\delta\\big) \\underbrace{\\int_{\\mathbb{R}^{D}}\\exp \\big(-\\frac{1}{2}\\big( \\omega-i \\delta\\big)^{\\top}\\big( \\omega-i \\delta\\big) \\big)\\mathrm{d} \\omega}_{(2\\pi)^{D/2}}\\\\ &=\\exp \\big(-\\frac{1}{2} \\delta^{\\top} \\delta\\big) \\\\ &=k( \\delta).\n\\end{split}\n"
|
||||
assert to_katex(complex_input).strip() == expected_output.strip()
|
||||
5
texteller/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from importlib.metadata import version
|
||||
|
||||
from texteller.api import *
|
||||
|
||||
__version__ = version("texteller")
|
||||
24
texteller/api/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .detection import latex_detect
|
||||
from .format import format_latex
|
||||
from .inference import img2latex, paragraph2md
|
||||
from .katex import to_katex
|
||||
from .load import (
|
||||
load_latexdet_model,
|
||||
load_model,
|
||||
load_textdet_model,
|
||||
load_textrec_model,
|
||||
load_tokenizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"to_katex",
|
||||
"format_latex",
|
||||
"img2latex",
|
||||
"paragraph2md",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"load_latexdet_model",
|
||||
"load_textrec_model",
|
||||
"load_textdet_model",
|
||||
"latex_detect",
|
||||
]
|
||||
4
texteller/api/criterias/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .ngram import DetectRepeatingNgramCriteria
|
||||
|
||||
|
||||
__all__ = ["DetectRepeatingNgramCriteria"]
|
||||
63
texteller/api/criterias/ngram.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
class DetectRepeatingNgramCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops generation efficiently if any n-gram repeats.
|
||||
|
||||
This criteria maintains a set of encountered n-grams.
|
||||
At each step, it checks if the *latest* n-gram is already in the set.
|
||||
If yes, it stops generation. If no, it adds the n-gram to the set.
|
||||
"""
|
||||
|
||||
def __init__(self, n: int):
|
||||
"""
|
||||
Args:
|
||||
n (int): The size of the n-gram to check for repetition.
|
||||
"""
|
||||
if n <= 0:
|
||||
raise ValueError("n-gram size 'n' must be positive.")
|
||||
self.n = n
|
||||
# Stores tuples of token IDs representing seen n-grams
|
||||
self.seen_ngrams = set()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores.
|
||||
|
||||
Return:
|
||||
`bool`: `True` if generation should stop, `False` otherwise.
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
# Need at least n tokens to form the first n-gram
|
||||
if seq_length < self.n:
|
||||
return False
|
||||
|
||||
# --- Efficient Check ---
|
||||
# Consider only the first sequence in the batch for simplicity
|
||||
if batch_size > 1:
|
||||
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
|
||||
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
|
||||
# For now, we'll focus on the first sequence.
|
||||
pass # No warning needed every step, maybe once in __init__ if needed.
|
||||
|
||||
sequence = input_ids[0] # Get the first sequence
|
||||
|
||||
# Get the latest n-gram (the one ending at the last token)
|
||||
last_ngram_tensor = sequence[-self.n :]
|
||||
# Convert to a hashable tuple for set storage and lookup
|
||||
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
|
||||
|
||||
# Check if this n-gram has been seen before *at any prior step*
|
||||
if last_ngram_tuple in self.seen_ngrams:
|
||||
return True # Stop generation
|
||||
else:
|
||||
# It's a new n-gram, add it to the set and continue
|
||||
self.seen_ngrams.add(last_ngram_tuple)
|
||||
return False # Continue generation
|
||||
3
texteller/api/detection/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .detect import latex_detect
|
||||
|
||||
__all__ = ["latex_detect"]
|
||||
69
texteller/api/detection/detect.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import List
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from texteller.types import Bbox
|
||||
|
||||
from .preprocess import Compose
|
||||
|
||||
_config = {
|
||||
"mode": "paddle",
|
||||
"draw_threshold": 0.5,
|
||||
"metric": "COCO",
|
||||
"use_dynamic_shape": False,
|
||||
"arch": "DETR",
|
||||
"min_subgraph_size": 3,
|
||||
"preprocess": [
|
||||
{"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"},
|
||||
{
|
||||
"mean": [0.0, 0.0, 0.0],
|
||||
"norm_type": "none",
|
||||
"std": [1.0, 1.0, 1.0],
|
||||
"type": "NormalizeImage",
|
||||
},
|
||||
{"type": "Permute"},
|
||||
],
|
||||
"label_list": ["isolated", "embedding"],
|
||||
}
|
||||
|
||||
|
||||
def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]:
|
||||
"""
|
||||
Detect LaTeX formulas in an image and classify them as isolated or embedded.
|
||||
|
||||
This function uses an ONNX model to detect LaTeX formulas in images. The model
|
||||
identifies two types of LaTeX formulas:
|
||||
- 'isolated': Standalone LaTeX formulas (typically displayed equations)
|
||||
- 'embedding': Inline LaTeX formulas embedded within text
|
||||
|
||||
Args:
|
||||
img_path: Path to the input image file
|
||||
predictor: ONNX InferenceSession model for LaTeX detection
|
||||
|
||||
Returns:
|
||||
List of Bbox objects representing the detected LaTeX formulas with their
|
||||
positions, classifications, and confidence scores
|
||||
|
||||
Example:
|
||||
>>> from texteller.api import load_latexdet_model, latex_detect
|
||||
>>> model = load_latexdet_model()
|
||||
>>> bboxes = latex_detect("path/to/image.png", model)
|
||||
"""
|
||||
transforms = Compose(_config["preprocess"])
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||
res = []
|
||||
for output in outputs:
|
||||
cls_name = _config["label_list"][int(output[0])]
|
||||
score = output[1]
|
||||
xmin = int(max(output[2], 0))
|
||||
ymin = int(max(output[3], 0))
|
||||
xmax = int(output[4])
|
||||
ymax = int(output[5])
|
||||
if score > 0.5:
|
||||
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
|
||||
|
||||
return res
|
||||