diff --git a/.gitignore b/.gitignore
index f44176c..b35b345 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,9 +1,14 @@
+**/.DS_Store
**/__pycache__
**/.vscode
+
**/train_result
+**/ckpt
+**/*cache
+**/.cache
+**/data
**/logs
-**/.cache
**/tmp*
**/data
**/*cache
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..0a79ce2
--- /dev/null
+++ b/README.md
@@ -0,0 +1,206 @@
+📄 English | 中文
+
+
+
+https://github.com/OleehyO/TexTeller/assets/56267907/b23b2b2e-a663-4abb-b013-bd47238d513b
+
+TexTeller is an end-to-end formula recognition model based on ViT, capable of converting images into corresponding LaTeX formulas.
+
+TexTeller was trained with ~~550K~~7.5M image-formula pairs (dataset available [here](https://huggingface.co/datasets/OleehyO/latex-formulas)), compared to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which used a 100K dataset, TexTeller has **stronger generalization abilities** and **higher accuracy**, covering most use cases (**except for scanned images and handwritten formulas**).
+
+> ~~We will soon release a TexTeller checkpoint trained on a 7.5M dataset~~
+
+## 🔄 Change Log
+
+* 📮[2024-03-25] TexTeller 2.0 released! The training data for TexTeller 2.0 has been increased to 7.5M (about **15 times more** than TexTeller 1.0 and also improved in data quality). The trained TexTeller 2.0 demonstrated **superior performance** in the test set, especially in recognizing rare symbols, complex multi-line formulas, and matrices.
+ > [There](./assets/test.pdf) are more test images here and a horizontal comparison of recognition models from different companies.
+
+* 📮[2024-04-12] Trained a **formula detection model**, thereby enhancing the capability to detect and recognize formulas in entire documents (whole-image inference)!
+
+
+## 🔑 Prerequisites
+
+python=3.10
+
+[pytorch](https://pytorch.org/get-started/locally/)
+
+> [!WARNING]
+> Only CUDA versions >= 12.0 have been fully tested, so it is recommended to use CUDA version >= 12.0
+
+## 🚀 Getting Started
+
+1. Clone the repository:
+
+ ```bash
+ git clone https://github.com/OleehyO/TexTeller
+ ```
+
+2. [Installing pytorch](https://pytorch.org/get-started/locally/#start-locally)
+
+3. Install the project's dependencies:
+
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+4. Enter the `TexTeller/src` directory and run the following command in the terminal to start inference:
+
+ ```bash
+ python inference.py -img "/path/to/image.{jpg,png}"
+ # use --inference-mode option to enable GPU(cuda or mps) inference
+ #+e.g. python inference.py -img "./img.jpg" --inference-mode cuda
+ ```
+
+> [!NOTE]
+> The first time you run it, the required checkpoints will be downloaded from Hugging Face
+
+## 🌐 Web Demo
+
+Go to the `TexTeller/src` directory and run the following command:
+
+```bash
+./start_web.sh
+```
+
+Enter `http://localhost:8501` in a browser to view the web demo.
+
+> [!NOTE]
+> If you are Windows user, please run the `start_web.bat` file instead.
+
+## 🧠 Full Image Inference
+
+TexTeller also supports **formula detection and recognition** on full images, allowing for the detection of formulas throughout the image, followed by batch recognition of the formulas.
+
+### Download Weights
+
+English documentation formula detection [[link](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx?download=true)]: Trained on 8272 images from the [IBEM dataset](https://zenodo.org/records/4757865).
+
+Chinese documentation formula detection [[link](https://huggingface.co/TonyLee1256/texteller_det/blob/main/rtdetr_r50vd_6x_coco_trained_on_cn_textbook.onnx)]: Trained on 2560 Chinese textbook images (100+ layouts).
+
+### Formula Detection
+
+Run the following command in the `TexTeller/src` directory:
+
+```bash
+python infer_det.py
+```
+
+Detects all formulas in the full image, and the results are saved in `TexTeller/src/subimages`.
+
+
+
+
+
+### Batch Formula Recognition
+
+After **formula detection**, run the following command in the `TexTeller/src` directory:
+
+```shell
+python rec_infer_from_crop_imgs.py
+```
+
+This will use the results of the previous formula detection to perform batch recognition on all cropped formulas, saving the recognition results as txt files in `TexTeller/src/results`.
+
+## 📡 API Usage
+
+We use [ray serve](https://github.com/ray-project/ray) to provide an API interface for TexTeller, allowing you to integrate TexTeller into your own projects. To start the server, you first need to enter the `TexTeller/src` directory and then run the following command:
+
+```bash
+python server.py # default settings
+```
+
+| Parameter | Description |
+| --- | --- |
+| `-ckpt` | The path to the weights file, *default is TexTeller's pretrained weights*.|
+| `-tknz` | The path to the tokenizer, *default is TexTeller's tokenizer*.|
+| `-port` | The server's service port, *default is 8000*. |
+| `--inference-mode` | Whether to use GPU(cuda or mps) for inference, *default is CPU*. |
+| `--num_beams` | The number of beams for beam search, *default is 1*. |
+| `--num_replicas` | The number of service replicas to run on the server, *default is 1 replica*. You can use more replicas to achieve greater throughput.|
+| `--ncpu_per_replica` | The number of CPU cores used per service replica, *default is 1*. |
+| `--ngpu_per_replica` | The number of GPUs used per service replica, *default is 1*. You can set this value between 0 and 1 to run multiple service replicas on one GPU to share the GPU, thereby improving GPU utilization. (Note, if --num_replicas is 2, --ngpu_per_replica is 0.7, then 2 GPUs must be available) |
+
+> [!NOTE]
+> A client demo can be found at `TexTeller/client/demo.py`, you can refer to `demo.py` to send requests to the server
+
+## 🏋️♂️ Training
+
+### Dataset
+
+We provide an example dataset in the `TexTeller/src/models/ocr_model/train/dataset` directory, you can place your own images in the `images` directory and annotate each image with its corresponding formula in `formulas.jsonl`.
+
+After preparing your dataset, you need to **change the `DIR_URL` variable to your own dataset's path** in `.../dataset/loader.py`
+
+### Retraining the Tokenizer
+
+If you are using a different dataset, you might need to retrain the tokenizer to obtain a different dictionary. After configuring your dataset, you can train your own tokenizer with the following command:
+
+1. In `TexTeller/src/models/tokenizer/train.py`, change `new_tokenizer.save_pretrained('./your_dir_name')` to your custom output directory
+ > If you want to use a different dictionary size (default is 10k tokens), you need to change the `VOCAB_SIZE` variable in `TexTeller/src/models/globals.py`
+
+2. **In the `TexTeller/src` directory**, run the following command:
+
+ ```bash
+ python -m models.tokenizer.train
+ ```
+
+### Training the Model
+
+To train the model, you need to run the following command in the `TexTeller/src` directory:
+
+```bash
+python -m models.ocr_model.train.train
+```
+
+You can set your own tokenizer and checkpoint paths in `TexTeller/src/models/ocr_model/train/train.py` (refer to `train.py` for more information). If you are using the same architecture and dictionary as TexTeller, you can also fine-tune TexTeller's default weights with your own dataset.
+
+In `TexTeller/src/globals.py` and `TexTeller/src/models/ocr_model/train/train_args.py`, you can change the model's architecture and training hyperparameters.
+
+> [!NOTE]
+> Our training scripts use the [Hugging Face Transformers](https://github.com/huggingface/transformers) library, so you can refer to their [documentation](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments) for more details and configurations on training parameters.
+
+## 🚧 Limitations
+
+* Does not support scanned images and PDF document recognition
+
+* Does not support handwritten formulas
+
+## 📅 Plans
+
+- [x] ~~Train the model with a larger dataset (7.5M samples, coming soon)~~
+
+- [ ] Recognition of scanned images
+
+- [ ] PDF document recognition + Support for English and Chinese scenarios
+
+- [ ] Inference acceleration
+
+- [ ] ...
+
+## ⭐️ Stargazers over time
+
+[](https://starchart.cc/OleehyO/TexTeller)
+
+## 💖 Acknowledgments
+
+Thanks to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which has brought me a lot of inspiration, and [im2latex-100K](https://zenodo.org/records/56198#.V2px0jXT6eA) which enriches our dataset.
+
+## 👥 Contributors
+
+
+
+
+
+
diff --git a/assets/README_zh.md b/assets/README_zh.md
new file mode 100644
index 0000000..b946d49
--- /dev/null
+++ b/assets/README_zh.md
@@ -0,0 +1,233 @@
+📄 English | 中文
+
+
+
+https://github.com/OleehyO/TexTeller/assets/56267907/fb17af43-f2a5-47ce-ad1d-101db5fd7fbb
+
+TexTeller是一个基于ViT的端到端公式识别模型,可以把图片转换为对应的latex公式
+
+TexTeller用了~~550K~~7.5M的图片-公式对进行训练(数据集可以在[这里](https://huggingface.co/datasets/OleehyO/latex-formulas)获取),相比于[LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)(使用了一个100K的数据集),TexTeller具有**更强的泛化能力**以及**更高的准确率**,可以覆盖大部分的使用场景(**扫描图片,手写公式除外**)。
+
+> ~~我们马上就会发布一个使用7.5M数据集进行训练的TexTeller checkpoint~~
+
+## 🔄 变更信息
+
+* 📮[2024-03-25] TexTeller2.0发布!TexTeller2.0的训练数据增大到了7.5M(相较于TexTeller1.0**增加了~15倍**并且数据质量也有所改善)。训练后的TexTeller2.0在测试集中展现出了**更加优越的性能**,尤其在生僻符号、复杂多行、矩阵的识别场景中。
+
+ > 在[这里](./test.pdf)有更多的测试图片以及各家识别模型的横向对比。
+ >
+* 📮[2024-04-12] 训练了**公式检测模型**,从而增加了对整个文档进行公式检测+公式识别(整图推理)的功能!
+
+## 🔑 前置条件
+
+python=3.10
+
+[pytorch](https://pytorch.org/get-started/locally/)
+
+> [!WARNING]
+> 只有CUDA版本>= 12.0被完全测试过,所以最好使用>= 12.0的CUDA版本
+
+## 🚀 开搞
+
+1. 克隆本仓库:
+
+ ```bash
+ git clone https://github.com/OleehyO/TexTeller
+ ```
+
+2. [安装pytorch](https://pytorch.org/get-started/locally/#start-locally)
+3. 安装本项目的依赖包:
+
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+4. 进入 `TexTeller/src`目录,在终端运行以下命令进行推理:
+
+ ```bash
+ python inference.py -img "/path/to/image.{jpg,png}"
+ # use --inference-mode option to enable GPU(cuda or mps) inference
+ #+e.g. python inference.py -img "./img.jpg" --inference-mode cuda
+ ```
+
+> [!NOTE]
+> 第一次运行时会在hugging face上下载所需要的checkpoints
+
+## ❓ 常见问题:无法连接到Hugging Face
+
+默认情况下,会在Hugging Face中下载模型权重,**如果你的远端服务器无法连接到Hugging Face**,你可以通过以下命令进行加载:
+
+1. 安装huggingface hub包
+
+ ```bash
+ pip install -U "huggingface_hub[cli]"
+ ```
+
+2. 在能连接Hugging Face的机器上下载模型权重:
+
+ ```bash
+ huggingface-cli download OleehyO/TexTeller --include "*.json" "*.bin" "*.txt" --repo-type model --local-dir "your/dir/path"
+ ```
+
+3. 把包含权重的目录上传远端服务器,然后把 `TexTeller/src/models/ocr_model/model/TexTeller.py`中的 `REPO_NAME = 'OleehyO/TexTeller'`修改为 `REPO_NAME = 'your/dir/path'`
+
+如果你还想在训练模型时开启evaluate,你需要提前下载metric脚本并上传远端服务器:
+
+1. 在能连接Hugging Face的机器上下载metric脚本
+
+ ```bash
+ huggingface-cli download evaluate-metric/google_bleu --repo-type space --local-dir "your/dir/path"
+ ```
+
+2. 把这个目录上传远端服务器,并在 `TexTeller/src/models/ocr_model/utils/metrics.py`中把 `evaluate.load('google_bleu')`改为 `evaluate.load('your/dir/path/google_bleu.py')`
+
+## 🌐 网页演示
+
+进入 `TexTeller/src` 目录,运行以下命令
+
+```bash
+./start_web.sh
+```
+
+在浏览器里输入 `http://localhost:8501`就可以看到web demo
+
+> [!NOTE]
+> 对于Windows用户, 请运行 `start_web.bat`文件.
+
+## 🧠 整图推理
+
+TexTeller还支持对整张图片进行**公式检测+公式识别**,从而对整图公式进行检测,然后进行批公式识别。
+
+### 下载权重
+
+英文文档公式检测 [[link](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx?download=true)]:在8272张[IBEM数据集](https://zenodo.org/records/4757865)上训练得到
+
+中文文档公式检测 [[link](https://huggingface.co/TonyLee1256/texteller_det/blob/main/rtdetr_r50vd_6x_coco_trained_on_cn_textbook.onnx)]:在2560张中文教材数据(100+版式)上训练得到
+
+### 公式检测
+
+ `TexTeller/src`目录下运行以下命令
+
+```bash
+python infer_det.py
+```
+
+对整张图中的所有公式进行检测,结果保存在 `TexTeller/src/subimages`
+
+
+
+
+
+### 公式批识别
+
+在进行**公式检测后**, `TexTeller/src`目录下运行以下命令
+
+```shell
+python rec_infer_from_crop_imgs.py
+```
+
+会基于上一步公式检测的结果,对裁剪出的所有公式进行批量识别,将识别结果在 `TexTeller/src/results`中保存为txt文件。
+
+## 📡 API调用
+
+我们使用[ray serve](https://github.com/ray-project/ray)来对外提供一个TexTeller的API接口,通过使用这个接口,你可以把TexTeller整合到自己的项目里。要想启动server,你需要先进入 `TexTeller/src`目录然后运行以下命令:
+
+```bash
+python server.py
+```
+
+| 参数 | 描述 |
+| - | - |
+| `-ckpt` | 权重文件的路径,*默认为TexTeller的预训练权重*。 |
+| `-tknz` | 分词器的路径,*默认为TexTeller的分词器*。 |
+| `-port` | 服务器的服务端口,*默认是8000*。 |
+| `--inference-mode`| 是否使用GPU(cuda或mps)推理,*默认为CPU*。 |
+| `--num_beams` | beam search的beam数量,*默认是1*。 |
+| `--num_replicas`| 在服务器上运行的服务副本数量,*默认1个副本*。你可以使用更多的副本来获取更大的吞吐量。 |
+| `--ncpu_per_replica` | 每个服务副本所用的CPU核心数,*默认为1*。 |
+| `--ngpu_per_replica` | 每个服务副本所用的GPU数量,*默认为1*。你可以把这个值设置成 0~1之间的数,这样会在一个GPU上运行多个服务副本来共享GPU,从而提高GPU的利用率。(注意,如果 --num_replicas 2, --ngpu_per_replica 0.7, 那么就必须要有2个GPU可用) |
+
+> [!NOTE]
+> 一个客户端demo可以在 `TexTeller/client/demo.py`找到,你可以参考 `demo.py`来给server发送请求
+
+## 🏋️♂️ 训练
+
+### 数据集
+
+我们在 `TexTeller/src/models/ocr_model/train/dataset`目录中提供了一个数据集的例子,你可以把自己的图片放在 `images`目录然后在 `formulas.jsonl`中为每张图片标注对应的公式。
+
+准备好数据集后,你需要在 `.../dataset/loader.py`中把 **`DIR_URL`变量改成你自己数据集的路径**
+
+### 重新训练分词器
+
+如果你使用了不一样的数据集,你可能需要重新训练tokenizer来得到一个不一样的字典。配置好数据集后,可以通过以下命令来训练自己的tokenizer:
+
+1. 在 `TexTeller/src/models/tokenizer/train.py`中,修改 `new_tokenizer.save_pretrained('./your_dir_name')`为你自定义的输出目录
+
+ > 注意:如果要用一个不一样大小的字典(默认1W个token),你需要在 `TexTeller/src/models/globals.py`中修改 `VOCAB_SIZE`变量
+ >
+2. **在 `TexTeller/src` 目录下**运行以下命令:
+
+ ```bash
+ python -m models.tokenizer.train
+ ```
+
+### 训练模型
+
+要想训练模型, 你需要在 `TexTeller/src`目录下运行以下命令:
+
+```bash
+python -m models.ocr_model.train.train
+```
+
+你可以在 `TexTeller/src/models/ocr_model/train/train.py`中设置自己的tokenizer和checkpoint路径(请参考 `train.py`)。如果你使用了与TexTeller一样的架构和相同的字典,你还可以用自己的数据集来微调TexTeller的默认权重。
+
+在 `TexTeller/src/globals.py`和 `TexTeller/src/models/ocr_model/train/train_args.py`中,你可以改变模型的架构以及训练的超参数。
+
+> [!NOTE]
+> 我们的训练脚本使用了[Hugging Face Transformers](https://github.com/huggingface/transformers)库, 所以你可以参考他们提供的[文档](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments)来获取更多训练参数的细节以及配置。
+
+## 🚧 不足
+
+* 不支持扫描图片以及PDF文档识别
+* 不支持手写体公式
+
+## 📅 计划
+
+- [X] ~~使用更大的数据集来训练模型(7.5M样本,即将发布)~~
+
+- [ ] 扫描图片识别
+
+- [ ] PDF文档识别 + 中英文场景支持
+
+- [ ] 推理加速
+
+- [ ] ...
+
+## ⭐️ 观星曲线
+
+[](https://starchart.cc/OleehyO/TexTeller)
+
+## 💖 感谢
+
+Thanks to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which has brought me a lot of inspiration, and [im2latex-100K](https://zenodo.org/records/56198#.V2px0jXT6eA) which enriches our dataset.
+
+## 👥 贡献者
+
+
+
+
+
+
diff --git a/assets/css.css b/assets/css.css
deleted file mode 100644
index 00122f6..0000000
--- a/assets/css.css
+++ /dev/null
@@ -1,157 +0,0 @@
-html {
- font-family: Inter;
- font-size: 16px;
- font-weight: 400;
- line-height: 1.5;
- -webkit-text-size-adjust: 100%;
- background: #fff;
- color: #323232;
- -webkit-font-smoothing: antialiased;
- -moz-osx-font-smoothing: grayscale;
- text-rendering: optimizeLegibility;
-}
-
-:root {
- --space: 1;
- --vspace: calc(var(--space) * 1rem);
- --vspace-0: calc(3 * var(--space) * 1rem);
- --vspace-1: calc(2 * var(--space) * 1rem);
- --vspace-2: calc(1.5 * var(--space) * 1rem);
- --vspace-3: calc(0.5 * var(--space) * 1rem);
-}
-
-.app {
- max-width: 748px !important;
-}
-
-.prose p {
- margin: var(--vspace) 0;
- line-height: var(--vspace * 2);
- font-size: 1rem;
-}
-
-code {
- font-family: "inconsolata", sans-serif;
- font-size: 16px;
-}
-
-h1,
-h1 code {
- font-weight: 400;
- line-height: calc(2.5 / var(--space) * var(--vspace));
-}
-
-h1 code {
- background: none;
- border: none;
- letter-spacing: 0.05em;
- padding-bottom: 5px;
- position: relative;
- padding: 0;
-}
-
-h2 {
- margin: var(--vspace-1) 0 var(--vspace-2) 0;
- line-height: 1em;
-}
-
-h3,
-h3 code {
- margin: var(--vspace-1) 0 var(--vspace-2) 0;
- line-height: 1em;
-}
-
-h4,
-h5,
-h6 {
- margin: var(--vspace-3) 0 var(--vspace-3) 0;
- line-height: var(--vspace);
-}
-
-.bigtitle,
-h1,
-h1 code {
- font-size: calc(8px * 4.5);
- word-break: break-word;
-}
-
-.title,
-h2,
-h2 code {
- font-size: calc(8px * 3.375);
- font-weight: lighter;
- word-break: break-word;
- border: none;
- background: none;
-}
-
-.subheading1,
-h3,
-h3 code {
- font-size: calc(8px * 1.8);
- font-weight: 600;
- border: none;
- background: none;
- letter-spacing: 0.1em;
- text-transform: uppercase;
-}
-
-h2 code {
- padding: 0;
- position: relative;
- letter-spacing: 0.05em;
-}
-
-blockquote {
- font-size: calc(8px * 1.1667);
- font-style: italic;
- line-height: calc(1.1667 * var(--vspace));
- margin: var(--vspace-2) var(--vspace-2);
-}
-
-.subheading2,
-h4 {
- font-size: calc(8px * 1.4292);
- text-transform: uppercase;
- font-weight: 600;
-}
-
-.subheading3,
-h5 {
- font-size: calc(8px * 1.2917);
- line-height: calc(1.2917 * var(--vspace));
-
- font-weight: lighter;
- text-transform: uppercase;
- letter-spacing: 0.15em;
-}
-
-h6 {
- font-size: calc(8px * 1.1667);
- font-size: 1.1667em;
- font-weight: normal;
- font-style: italic;
- font-family: "le-monde-livre-classic-byol", serif !important;
- letter-spacing: 0px !important;
-}
-
-#start .md > *:first-child {
- margin-top: 0;
-}
-
-h2 + h3 {
- margin-top: 0;
-}
-
-.md hr {
- border: none;
- border-top: 1px solid var(--block-border-color);
- margin: var(--vspace-2) 0 var(--vspace-2) 0;
-}
-.prose ul {
- margin: var(--vspace-2) 0 var(--vspace-1) 0;
-}
-
-.gap {
- gap: 0;
-}
\ No newline at end of file
diff --git a/assets/det_rec.png b/assets/det_rec.png
new file mode 100644
index 0000000..dbd1ffc
Binary files /dev/null and b/assets/det_rec.png differ
diff --git a/assets/fire.svg b/assets/fire.svg
new file mode 100644
index 0000000..8f9f7eb
--- /dev/null
+++ b/assets/fire.svg
@@ -0,0 +1,460 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/assets/test.pdf b/assets/test.pdf
new file mode 100644
index 0000000..f587024
Binary files /dev/null and b/assets/test.pdf differ
diff --git a/assets/web_demo.gif b/assets/web_demo.gif
new file mode 100644
index 0000000..0403d86
Binary files /dev/null and b/assets/web_demo.gif differ
diff --git a/requirements.txt b/requirements.txt
index e677590..c5f254c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,6 @@
transformers
datasets
evaluate
-streamlit
-
-gradio
-
opencv-python
ray[serve]
accelerate
@@ -12,4 +8,8 @@ tensorboardX
nltk
python-multipart
-augraphy
\ No newline at end of file
+augraphy
+onnxruntime
+
+streamlit==1.30
+streamlit-paste-button
diff --git a/src/infer_det.py b/src/infer_det.py
new file mode 100644
index 0000000..cc5da44
--- /dev/null
+++ b/src/infer_det.py
@@ -0,0 +1,197 @@
+import os
+import yaml
+import argparse
+import numpy as np
+import glob
+from onnxruntime import InferenceSession
+from tqdm import tqdm
+
+from models.det_model.preprocess import Compose
+import cv2
+
+# 注意:文件名要标准,最好都用下划线
+
+# Global dictionary
+SUPPORT_MODELS = {
+ 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
+ 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
+ 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
+ 'DETR'
+}
+
+parser = argparse.ArgumentParser(description=__doc__)
+parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
+ default="./models/det_model/model/infer_cfg.yml"
+ )
+parser.add_argument('--onnx_file', type=str, help="onnx model file path",
+ default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx"
+ )
+parser.add_argument("--image_dir", type=str)
+parser.add_argument("--image_file", type=str, default='/data/ljm/TexTeller/src/Tr00_0001015-page02.jpg')
+parser.add_argument("--imgsave_dir", type=str,
+ default="."
+ )
+
+def get_test_images(infer_dir, infer_img):
+ """
+ Get image path list in TEST mode
+ """
+ assert infer_img is not None or infer_dir is not None, \
+ "--image_file or --image_dir should be set"
+ assert infer_img is None or os.path.isfile(infer_img), \
+ "{} is not a file".format(infer_img)
+ assert infer_dir is None or os.path.isdir(infer_dir), \
+ "{} is not a directory".format(infer_dir)
+
+ # infer_img has a higher priority
+ if infer_img and os.path.isfile(infer_img):
+ return [infer_img]
+
+ images = set()
+ infer_dir = os.path.abspath(infer_dir)
+ assert os.path.isdir(infer_dir), \
+ "infer_dir {} is not a directory".format(infer_dir)
+ exts = ['jpg', 'jpeg', 'png', 'bmp']
+ exts += [ext.upper() for ext in exts]
+ for ext in exts:
+ images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
+ images = list(images)
+
+ assert len(images) > 0, "no image found in {}".format(infer_dir)
+ print("Found {} inference images in total.".format(len(images)))
+
+ return images
+
+
+class PredictConfig(object):
+ """set config of preprocess, postprocess and visualize
+ Args:
+ infer_config (str): path of infer_cfg.yml
+ """
+
+ def __init__(self, infer_config):
+ # parsing Yaml config for Preprocess
+ with open(infer_config) as f:
+ yml_conf = yaml.safe_load(f)
+ self.check_model(yml_conf)
+ self.arch = yml_conf['arch']
+ self.preprocess_infos = yml_conf['Preprocess']
+ self.min_subgraph_size = yml_conf['min_subgraph_size']
+ self.label_list = yml_conf['label_list']
+ self.use_dynamic_shape = yml_conf['use_dynamic_shape']
+ self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
+ self.mask = yml_conf.get("mask", False)
+ self.tracker = yml_conf.get("tracker", None)
+ self.nms = yml_conf.get("NMS", None)
+ self.fpn_stride = yml_conf.get("fpn_stride", None)
+
+ # 预定义颜色池
+ color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
+ # 根据label_list动态生成颜色映射
+ self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
+
+ if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
+ print(
+ 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
+ )
+ self.print_config()
+
+ def check_model(self, yml_conf):
+ """
+ Raises:
+ ValueError: loaded model not in supported model type
+ """
+ for support_model in SUPPORT_MODELS:
+ if support_model in yml_conf['arch']:
+ return True
+ raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
+ 'arch'], SUPPORT_MODELS))
+
+ def print_config(self):
+ print('----------- Model Configuration -----------')
+ print('%s: %s' % ('Model Arch', self.arch))
+ print('%s: ' % ('Transform Order'))
+ for op_info in self.preprocess_infos:
+ print('--%s: %s' % ('transform op', op_info['type']))
+ print('--------------------------------------------')
+
+
+def draw_bbox(image, outputs, infer_config):
+ for output in outputs:
+ cls_id, score, xmin, ymin, xmax, ymax = output
+ if score > infer_config.draw_threshold:
+ # 获取类别名
+ label = infer_config.label_list[int(cls_id)]
+ # 根据类别名获取颜色
+ color = infer_config.colors[label]
+ cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
+ cv2.putText(image, "{}: {:.2f}".format(label, score),
+ (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
+ return image
+
+
+def predict_image(infer_config, predictor, img_list):
+ # load preprocess transforms
+ transforms = Compose(infer_config.preprocess_infos)
+ errImgList = []
+
+ # Check and create subimg_save_dir if not exist
+ subimg_save_dir = os.path.join(FLAGS.imgsave_dir, 'subimages')
+ os.makedirs(subimg_save_dir, exist_ok=True)
+
+ # predict image
+ for img_path in tqdm(img_list):
+ img = cv2.imread(img_path)
+ if img is None:
+ print(f"Warning: Could not read image {img_path}. Skipping...")
+ errImgList.append(img_path)
+ continue
+
+ inputs = transforms(img_path)
+ inputs_name = [var.name for var in predictor.get_inputs()]
+ inputs = {k: inputs[k][None, ] for k in inputs_name}
+
+ outputs = predictor.run(output_names=None, input_feed=inputs)
+
+ print("ONNXRuntime predict: ")
+ if infer_config.arch in ["HRNet"]:
+ print(np.array(outputs[0]))
+ else:
+ bboxes = np.array(outputs[0])
+ for bbox in bboxes:
+ if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
+ print(f"{int(bbox[0])} {bbox[1]} "
+ f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
+
+ # Save the subimages (crop from the original image)
+ subimg_counter = 1
+ for output in np.array(outputs[0]):
+ cls_id, score, xmin, ymin, xmax, ymax = output
+ if score > infer_config.draw_threshold:
+ label = infer_config.label_list[int(cls_id)]
+ subimg = img[int(ymin):int(ymax), int(xmin):int(xmax)]
+ subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
+ subimg_path = os.path.join(subimg_save_dir, subimg_filename)
+ cv2.imwrite(subimg_path, subimg)
+ subimg_counter += 1
+
+ # Draw bounding boxes and save the image with bounding boxes
+ img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
+ output_dir = FLAGS.imgsave_dir
+ os.makedirs(output_dir, exist_ok=True)
+ output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
+ cv2.imwrite(output_file, img_with_bbox)
+
+ print("ErrorImgs:")
+ print(errImgList)
+
+if __name__ == '__main__':
+ FLAGS = parser.parse_args()
+ # load image list
+ img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
+ # load predictor
+ predictor = InferenceSession(FLAGS.onnx_file)
+ # load infer config
+ infer_config = PredictConfig(FLAGS.infer_cfg)
+
+ predict_image(infer_config, predictor, img_list)
diff --git a/src/inference.py b/src/inference.py
index c0e263c..bc7c006 100644
--- a/src/inference.py
+++ b/src/inference.py
@@ -19,10 +19,16 @@ if __name__ == '__main__':
help='path to the input image'
)
parser.add_argument(
- '-cuda',
- default=False,
- action='store_true',
- help='use cuda or not'
+ '--inference-mode',
+ type=str,
+ default='cpu',
+ help='Inference mode, select one of cpu, cuda, or mps'
+ )
+ parser.add_argument(
+ '--num-beam',
+ type=int,
+ default=1,
+ help='number of beam search for decoding'
)
# ================= new feature ==================
parser.add_argument(
@@ -37,6 +43,7 @@ if __name__ == '__main__':
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
+ latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
@@ -44,7 +51,7 @@ if __name__ == '__main__':
img = cv.imread(args.img)
print('Inference...')
if not args.mix:
- res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
+ res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
res = to_katex(res[0])
print(res)
else:
diff --git a/src/models/det_model/model/infer_cfg.yml b/src/models/det_model/model/infer_cfg.yml
new file mode 100644
index 0000000..0c156fc
--- /dev/null
+++ b/src/models/det_model/model/infer_cfg.yml
@@ -0,0 +1,27 @@
+mode: paddle
+draw_threshold: 0.5
+metric: COCO
+use_dynamic_shape: false
+arch: DETR
+min_subgraph_size: 3
+Preprocess:
+- interp: 2
+ keep_ratio: false
+ target_size:
+ - 640
+ - 640
+ type: Resize
+- mean:
+ - 0.0
+ - 0.0
+ - 0.0
+ norm_type: none
+ std:
+ - 1.0
+ - 1.0
+ - 1.0
+ type: NormalizeImage
+- type: Permute
+label_list:
+- isolated
+- embedding
diff --git a/src/models/det_model/preprocess.py b/src/models/det_model/preprocess.py
new file mode 100644
index 0000000..3554b7f
--- /dev/null
+++ b/src/models/det_model/preprocess.py
@@ -0,0 +1,494 @@
+import numpy as np
+import cv2
+import copy
+
+
+def decode_image(img_path):
+ with open(img_path, 'rb') as f:
+ im_read = f.read()
+ data = np.frombuffer(im_read, dtype='uint8')
+ im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ img_info = {
+ "im_shape": np.array(
+ im.shape[:2], dtype=np.float32),
+ "scale_factor": np.array(
+ [1., 1.], dtype=np.float32)
+ }
+ return im, img_info
+
+
+class Resize(object):
+ """resize image by target_size and max_size
+ Args:
+ target_size (int): the target size of image
+ keep_ratio (bool): whether keep_ratio or not, default true
+ interp (int): method of resize
+ """
+
+ def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
+ if isinstance(target_size, int):
+ target_size = [target_size, target_size]
+ self.target_size = target_size
+ self.keep_ratio = keep_ratio
+ self.interp = interp
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ assert len(self.target_size) == 2
+ assert self.target_size[0] > 0 and self.target_size[1] > 0
+ im_channel = im.shape[2]
+ im_scale_y, im_scale_x = self.generate_scale(im)
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=self.interp)
+ im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
+ im_info['scale_factor'] = np.array(
+ [im_scale_y, im_scale_x]).astype('float32')
+ return im, im_info
+
+ def generate_scale(self, im):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ Returns:
+ im_scale_x: the resize ratio of X
+ im_scale_y: the resize ratio of Y
+ """
+ origin_shape = im.shape[:2]
+ im_c = im.shape[2]
+ if self.keep_ratio:
+ im_size_min = np.min(origin_shape)
+ im_size_max = np.max(origin_shape)
+ target_size_min = np.min(self.target_size)
+ target_size_max = np.max(self.target_size)
+ im_scale = float(target_size_min) / float(im_size_min)
+ if np.round(im_scale * im_size_max) > target_size_max:
+ im_scale = float(target_size_max) / float(im_size_max)
+ im_scale_x = im_scale
+ im_scale_y = im_scale
+ else:
+ resize_h, resize_w = self.target_size
+ im_scale_y = resize_h / float(origin_shape[0])
+ im_scale_x = resize_w / float(origin_shape[1])
+ return im_scale_y, im_scale_x
+
+
+class NormalizeImage(object):
+ """normalize image
+ Args:
+ mean (list): im - mean
+ std (list): im / std
+ is_scale (bool): whether need im / 255
+ norm_type (str): type in ['mean_std', 'none']
+ """
+
+ def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
+ self.mean = mean
+ self.std = std
+ self.is_scale = is_scale
+ self.norm_type = norm_type
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.astype(np.float32, copy=False)
+ if self.is_scale:
+ scale = 1.0 / 255.0
+ im *= scale
+
+ if self.norm_type == 'mean_std':
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+ im -= mean
+ im /= std
+ return im, im_info
+
+
+class Permute(object):
+ """permute image
+ Args:
+ to_bgr (bool): whether convert RGB to BGR
+ channel_first (bool): whether convert HWC to CHW
+ """
+
+ def __init__(self, ):
+ super(Permute, self).__init__()
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.transpose((2, 0, 1)).copy()
+ return im, im_info
+
+
+class PadStride(object):
+ """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
+ Args:
+ stride (bool): model with FPN need image shape % stride == 0
+ """
+
+ def __init__(self, stride=0):
+ self.coarsest_stride = stride
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ coarsest_stride = self.coarsest_stride
+ if coarsest_stride <= 0:
+ return im, im_info
+ im_c, im_h, im_w = im.shape
+ pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
+ pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
+ padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = im
+ return padding_im, im_info
+
+
+class LetterBoxResize(object):
+ def __init__(self, target_size):
+ """
+ Resize image to target size, convert normalized xywh to pixel xyxy
+ format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
+ Args:
+ target_size (int|list): image target size.
+ """
+ super(LetterBoxResize, self).__init__()
+ if isinstance(target_size, int):
+ target_size = [target_size, target_size]
+ self.target_size = target_size
+
+ def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
+ # letterbox: resize a rectangular image to a padded rectangular
+ shape = img.shape[:2] # [height, width]
+ ratio_h = float(height) / shape[0]
+ ratio_w = float(width) / shape[1]
+ ratio = min(ratio_h, ratio_w)
+ new_shape = (round(shape[1] * ratio),
+ round(shape[0] * ratio)) # [width, height]
+ padw = (width - new_shape[0]) / 2
+ padh = (height - new_shape[1]) / 2
+ top, bottom = round(padh - 0.1), round(padh + 0.1)
+ left, right = round(padw - 0.1), round(padw + 0.1)
+
+ img = cv2.resize(
+ img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
+ img = cv2.copyMakeBorder(
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT,
+ value=color) # padded rectangular
+ return img, ratio, padw, padh
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ assert len(self.target_size) == 2
+ assert self.target_size[0] > 0 and self.target_size[1] > 0
+ height, width = self.target_size
+ h, w = im.shape[:2]
+ im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
+
+ new_shape = [round(h * ratio), round(w * ratio)]
+ im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
+ im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
+ return im, im_info
+
+
+class Pad(object):
+ def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
+ """
+ Pad image to a specified size.
+ Args:
+ size (list[int]): image target size
+ fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
+ """
+ super(Pad, self).__init__()
+ if isinstance(size, int):
+ size = [size, size]
+ self.size = size
+ self.fill_value = fill_value
+
+ def __call__(self, im, im_info):
+ im_h, im_w = im.shape[:2]
+ h, w = self.size
+ if h == im_h and w == im_w:
+ im = im.astype(np.float32)
+ return im, im_info
+
+ canvas = np.ones((h, w, 3), dtype=np.float32)
+ canvas *= np.array(self.fill_value, dtype=np.float32)
+ canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
+ im = canvas
+ return im, im_info
+
+
+def rotate_point(pt, angle_rad):
+ """Rotate a point by an angle.
+
+ Args:
+ pt (list[float]): 2 dimensional point to be rotated
+ angle_rad (float): rotation angle by radian
+
+ Returns:
+ list[float]: Rotated point.
+ """
+ assert len(pt) == 2
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ new_x = pt[0] * cs - pt[1] * sn
+ new_y = pt[0] * sn + pt[1] * cs
+ rotated_pt = [new_x, new_y]
+
+ return rotated_pt
+
+
+def _get_3rd_point(a, b):
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): point(x,y)
+ b (np.ndarray): point(x,y)
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ assert len(a) == 2
+ assert len(b) == 2
+ direction = a - b
+ third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
+
+ return third_pt
+
+
+def get_affine_transform(center,
+ input_size,
+ rot,
+ output_size,
+ shift=(0., 0.),
+ inv=False):
+ """Get the affine transform matrix, given the center/scale/rot/output_size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ]): Size of the destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: The transform matrix.
+ """
+ assert len(center) == 2
+ assert len(output_size) == 2
+ assert len(shift) == 2
+ if not isinstance(input_size, (np.ndarray, list)):
+ input_size = np.array([input_size, input_size], dtype=np.float32)
+ scale_tmp = input_size
+
+ shift = np.array(shift)
+ src_w = scale_tmp[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = rotate_point([0., src_w * -0.5], rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift
+ src[1, :] = center + src_dir + scale_tmp * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+class WarpAffine(object):
+ """Warp affine the image
+ """
+
+ def __init__(self,
+ keep_res=False,
+ pad=31,
+ input_h=512,
+ input_w=512,
+ scale=0.4,
+ shift=0.1):
+ self.keep_res = keep_res
+ self.pad = pad
+ self.input_h = input_h
+ self.input_w = input_w
+ self.scale = scale
+ self.shift = shift
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+
+ h, w = img.shape[:2]
+
+ if self.keep_res:
+ input_h = (h | self.pad) + 1
+ input_w = (w | self.pad) + 1
+ s = np.array([input_w, input_h], dtype=np.float32)
+ c = np.array([w // 2, h // 2], dtype=np.float32)
+
+ else:
+ s = max(h, w) * 1.0
+ input_h, input_w = self.input_h, self.input_w
+ c = np.array([w / 2., h / 2.], dtype=np.float32)
+
+ trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
+ img = cv2.resize(img, (w, h))
+ inp = cv2.warpAffine(
+ img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
+ return inp, im_info
+
+
+# keypoint preprocess
+def get_warp_matrix(theta, size_input, size_dst, size_target):
+ """This code is based on
+ https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
+
+ Calculate the transformation matrix under the constraint of unbiased.
+ Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
+ Data Processing for Human Pose Estimation (CVPR 2020).
+
+ Args:
+ theta (float): Rotation angle in degrees.
+ size_input (np.ndarray): Size of input image [w, h].
+ size_dst (np.ndarray): Size of output image [w, h].
+ size_target (np.ndarray): Size of ROI in input plane [w, h].
+
+ Returns:
+ matrix (np.ndarray): A matrix for transformation.
+ """
+ theta = np.deg2rad(theta)
+ matrix = np.zeros((2, 3), dtype=np.float32)
+ scale_x = size_dst[0] / size_target[0]
+ scale_y = size_dst[1] / size_target[1]
+ matrix[0, 0] = np.cos(theta) * scale_x
+ matrix[0, 1] = -np.sin(theta) * scale_x
+ matrix[0, 2] = scale_x * (
+ -0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
+ np.sin(theta) + 0.5 * size_target[0])
+ matrix[1, 0] = np.sin(theta) * scale_y
+ matrix[1, 1] = np.cos(theta) * scale_y
+ matrix[1, 2] = scale_y * (
+ -0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
+ np.cos(theta) + 0.5 * size_target[1])
+ return matrix
+
+
+class TopDownEvalAffine(object):
+ """apply affine transform to image and coords
+
+ Args:
+ trainsize (list): [w, h], the standard size used to train
+ use_udp (bool): whether to use Unbiased Data Processing.
+ records(dict): the dict contained the image and coords
+
+ Returns:
+ records (dict): contain the image and coords after tranformed
+
+ """
+
+ def __init__(self, trainsize, use_udp=False):
+ self.trainsize = trainsize
+ self.use_udp = use_udp
+
+ def __call__(self, image, im_info):
+ rot = 0
+ imshape = im_info['im_shape'][::-1]
+ center = im_info['center'] if 'center' in im_info else imshape / 2.
+ scale = im_info['scale'] if 'scale' in im_info else imshape
+ if self.use_udp:
+ trans = get_warp_matrix(
+ rot, center * 2.0,
+ [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
+ image = cv2.warpAffine(
+ image,
+ trans, (int(self.trainsize[0]), int(self.trainsize[1])),
+ flags=cv2.INTER_LINEAR)
+ else:
+ trans = get_affine_transform(center, scale, rot, self.trainsize)
+ image = cv2.warpAffine(
+ image,
+ trans, (int(self.trainsize[0]), int(self.trainsize[1])),
+ flags=cv2.INTER_LINEAR)
+
+ return image, im_info
+
+
+class Compose:
+ def __init__(self, transforms):
+ self.transforms = []
+ for op_info in transforms:
+ new_op_info = op_info.copy()
+ op_type = new_op_info.pop('type')
+ self.transforms.append(eval(op_type)(**new_op_info))
+
+ def __call__(self, img_path):
+ img, im_info = decode_image(img_path)
+ for t in self.transforms:
+ img, im_info = t(img, im_info)
+ inputs = copy.deepcopy(im_info)
+ inputs['image'] = img
+ return inputs
diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py
index fcff742..63273a8 100644
--- a/src/models/ocr_model/utils/inference.py
+++ b/src/models/ocr_model/utils/inference.py
@@ -13,8 +13,8 @@ from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
- imgs: Union[List[str], List[np.ndarray]],
- use_cuda: bool,
+ imgs_path: Union[List[str], List[np.ndarray]],
+ inf_mode: str = 'cpu',
num_beams: int = 1,
) -> List[str]:
model.eval()
@@ -26,9 +26,8 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
- if use_cuda:
- model = model.to('cuda')
- pixel_values = pixel_values.to('cuda')
+ model = model.to(inf_mode)
+ pixel_values = pixel_values.to(inf_mode)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
diff --git a/src/rec_infer_from_crop_imgs.py b/src/rec_infer_from_crop_imgs.py
new file mode 100644
index 0000000..1cacbc6
--- /dev/null
+++ b/src/rec_infer_from_crop_imgs.py
@@ -0,0 +1,59 @@
+import os
+import argparse
+import cv2 as cv
+from pathlib import Path
+from utils import to_katex
+from models.ocr_model.utils.inference import inference as latex_inference
+from models.ocr_model.model.TexTeller import TexTeller
+
+
+if __name__ == '__main__':
+ os.chdir(Path(__file__).resolve().parent)
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-img',
+ type=str,
+ required=True,
+ help='path to the input image'
+ )
+ parser.add_argument(
+ '--inference-mode',
+ type=str,
+ default='cpu',
+ help='Inference mode, select one of cpu, cuda, or mps'
+ )
+ parser.add_argument(
+ '--num-beam',
+ type=int,
+ default=1,
+ help='number of beam search for decoding'
+ )
+
+ args = parser.parse_args()
+
+ print('Loading model and tokenizer...')
+ latex_rec_model = TexTeller.from_pretrained()
+ tokenizer = TexTeller.get_tokenizer()
+ print('Model and tokenizer loaded.')
+
+ # Create the output directory if it doesn't exist
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Loop through all images in the input directory
+ for filename in os.listdir(args.img_dir):
+ img_path = os.path.join(args.img_dir, filename)
+ img = cv.imread(img_path)
+
+ if img is not None:
+ print(f'Inference for {filename}...')
+ res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam)
+ res = to_katex(res[0])
+
+ # Save the recognition result to a text file
+ output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt')
+ with open(output_file, 'w') as f:
+ f.write(res)
+
+ print(f'Result saved to {output_file}')
+ else:
+ print(f"Warning: Could not read image {img_path}. Skipping...")
diff --git a/src/server.py b/src/server.py
index 11adfa3..520d908 100644
--- a/src/server.py
+++ b/src/server.py
@@ -23,8 +23,8 @@ parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
-parser.add_argument('--use_cuda', action='store_true', default=False)
-parser.add_argument('--num_beam', type=int, default=1)
+parser.add_argument('--inference-mode', type=str, default='cpu')
+parser.add_argument('--num_beams', type=int, default=1)
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.use_cuda:
@@ -43,18 +43,21 @@ class TexTellerServer:
self,
checkpoint_path: str,
tokenizer_path: str,
- use_cuda: bool = False,
- num_beam: int = 1
+ inf_mode: str = 'cpu',
+ num_beams: int = 1
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
- self.use_cuda = use_cuda
- self.num_beam = num_beam
+ self.inf_mode = inf_mode
+ self.num_beams = num_beams
- self.model = self.model.to('cuda') if use_cuda else self.model
+ self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
- return inference(self.model, self.tokenizer, [image_nparray], self.use_cuda, self.num_beam)[0]
+ return inference(
+ self.model, self.tokenizer, [image_nparray],
+ inf_mode=self.inf_mode, num_beams=self.num_beams
+ )[0]
@serve.deployment()
@@ -78,7 +81,11 @@ if __name__ == '__main__':
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port})
- texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
+ texteller_server = TexTellerServer.bind(
+ ckpt_dir, tknz_dir,
+ inf_mode=args.inference_mode,
+ num_beams=args.num_beams
+ )
ingress = Ingress.bind(texteller_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
diff --git a/src/start_web.bat b/src/start_web.bat
new file mode 100644
index 0000000..e235cca
--- /dev/null
+++ b/src/start_web.bat
@@ -0,0 +1,9 @@
+@echo off
+SETLOCAL ENABLEEXTENSIONS
+
+set CHECKPOINT_DIR=default
+set TOKENIZER_DIR=default
+
+streamlit run web.py
+
+ENDLOCAL
diff --git a/src/start_web.sh b/src/start_web.sh
index 7dab5d2..6ec8f7b 100755
--- a/src/start_web.sh
+++ b/src/start_web.sh
@@ -1,10 +1,7 @@
#!/usr/bin/env bash
set -exu
-export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-460000"
-# export CHECKPOINT_DIR="default"
-export TOKENIZER_DIR="/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas"
-export USE_CUDA=True # True or False (case-sensitive)
-export NUM_BEAM=3
+export CHECKPOINT_DIR="default"
+export TOKENIZER_DIR="default"
streamlit run web.py
diff --git a/src/web.py b/src/web.py
index 5eb2413..379a609 100644
--- a/src/web.py
+++ b/src/web.py
@@ -6,16 +6,22 @@ import shutil
import streamlit as st
from PIL import Image
+from streamlit_paste_button import paste_image_button as pbutton
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
+st.set_page_config(
+ page_title="TexTeller",
+ page_icon="🧮"
+)
+
html_string = '''
-
- TexTeller
-
+
+ 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
+
'''
@@ -35,8 +41,6 @@ fail_gif_html = '''
'''
-
-
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@@ -52,6 +56,12 @@ def get_image_base64(img_file):
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
+def on_file_upload():
+ st.session_state["UPLOADED_FILE_CHANGED"] = True
+
+def change_side_bar():
+ st.session_state["CHANGE_SIDEBAR_FLAG"] = True
+
model = get_model()
tokenizer = get_tokenizer()
@@ -59,37 +69,106 @@ if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
+if "UPLOADED_FILE_CHANGED" not in st.session_state:
+ st.session_state["UPLOADED_FILE_CHANGED"] = False
-# ============================ pages =============================== #
+if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
+ st.session_state["CHANGE_SIDEBAR_FLAG"] = False
+
+# ============================ begin sidebar =============================== #
+
+with st.sidebar:
+ num_beams = 1
+ inf_mode = 'cpu'
+
+ st.markdown("# 🔨️ Config")
+ st.markdown("")
+
+ model_type = st.selectbox(
+ "Model type",
+ ("TexTeller", "None"),
+ on_change=change_side_bar
+ )
+ if model_type == "TexTeller":
+ num_beams = st.number_input(
+ 'Number of beams',
+ min_value=1,
+ max_value=20,
+ step=1,
+ on_change=change_side_bar
+ )
+
+ inf_mode = st.radio(
+ "Inference mode",
+ ("cpu", "cuda", "mps"),
+ on_change=change_side_bar
+ )
+
+# ============================ end sidebar =============================== #
+
+
+
+# ============================ begin pages =============================== #
st.markdown(html_string, unsafe_allow_html=True)
-uploaded_file = st.file_uploader("",type=['jpg', 'png', 'pdf'])
+uploaded_file = st.file_uploader(
+ " ",
+ type=['jpg', 'png'],
+ on_change=on_file_upload
+)
+
+paste_result = pbutton(
+ label="📋 Paste an image",
+ background_color="#5BBCFF",
+ hover_background_color="#3498db",
+)
+st.write("")
+
+if st.session_state["CHANGE_SIDEBAR_FLAG"] == True:
+ st.session_state["CHANGE_SIDEBAR_FLAG"] = False
+elif uploaded_file or paste_result.image_data is not None:
+ if st.session_state["UPLOADED_FILE_CHANGED"] == False and paste_result.image_data is not None:
+ uploaded_file = io.BytesIO()
+ paste_result.image_data.save(uploaded_file, format='PNG')
+ uploaded_file.seek(0)
+
+ if st.session_state["UPLOADED_FILE_CHANGED"] == True:
+ st.session_state["UPLOADED_FILE_CHANGED"] = False
-if uploaded_file:
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
- img_base64 = get_image_base64(uploaded_file)
+ with st.container(height=300):
+ img_base64 = get_image_base64(uploaded_file)
+ st.markdown(f"""
+
+
+
+
+ """, unsafe_allow_html=True)
st.markdown(f"""
-
Input image ({img.height}✖️{img.width})
""", unsafe_allow_html=True)
@@ -102,15 +181,28 @@ if uploaded_file:
model,
tokenizer,
[png_file_path],
- True if os.environ['USE_CUDA'] == 'True' else False,
- int(os.environ['NUM_BEAM'])
+ inf_mode=inf_mode,
+ num_beams=num_beams
)[0]
st.success('Completed!', icon="✅")
st.markdown(suc_gif_html, unsafe_allow_html=True)
katex_res = to_katex(TexTeller_result)
- st.text_area(":red[Predicted formula]", katex_res, height=150)
+ st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
st.latex(katex_res)
+ st.write("")
+ st.write("")
+
+ with st.expander(":star2: :gray[Tips for better results]"):
+ st.markdown('''
+ * :mag_right: Use a clear and high-resolution image.
+ * :scissors: Crop images as accurately as possible.
+ * :jigsaw: Split large multi line formulas into smaller ones.
+ * :page_facing_up: Use images with **white background and black text** as much as possible.
+ * :book: Use a font with good readability.
+ ''')
shutil.rmtree(temp_dir)
-# ============================ pages =============================== #
+ paste_result.image_data = None
+
+# ============================ end pages =============================== #