diff --git a/.gitignore b/.gitignore index f44176c..b35b345 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,14 @@ +**/.DS_Store **/__pycache__ **/.vscode + **/train_result +**/ckpt +**/*cache +**/.cache +**/data **/logs -**/.cache **/tmp* **/data **/*cache diff --git a/README.md b/README.md new file mode 100644 index 0000000..0a79ce2 --- /dev/null +++ b/README.md @@ -0,0 +1,206 @@ +📄 English | 中文 + +
+

+ + 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 + +

+

+ 🤗 Hugging Face +

+ +
+ +https://github.com/OleehyO/TexTeller/assets/56267907/b23b2b2e-a663-4abb-b013-bd47238d513b + +TexTeller is an end-to-end formula recognition model based on ViT, capable of converting images into corresponding LaTeX formulas. + +TexTeller was trained with ~~550K~~7.5M image-formula pairs (dataset available [here](https://huggingface.co/datasets/OleehyO/latex-formulas)), compared to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which used a 100K dataset, TexTeller has **stronger generalization abilities** and **higher accuracy**, covering most use cases (**except for scanned images and handwritten formulas**). + +> ~~We will soon release a TexTeller checkpoint trained on a 7.5M dataset~~ + +## 🔄 Change Log + +* 📮[2024-03-25] TexTeller 2.0 released! The training data for TexTeller 2.0 has been increased to 7.5M (about **15 times more** than TexTeller 1.0 and also improved in data quality). The trained TexTeller 2.0 demonstrated **superior performance** in the test set, especially in recognizing rare symbols, complex multi-line formulas, and matrices. + > [There](./assets/test.pdf) are more test images here and a horizontal comparison of recognition models from different companies. + +* 📮[2024-04-12] Trained a **formula detection model**, thereby enhancing the capability to detect and recognize formulas in entire documents (whole-image inference)! + + +## 🔑 Prerequisites + +python=3.10 + +[pytorch](https://pytorch.org/get-started/locally/) + +> [!WARNING] +> Only CUDA versions >= 12.0 have been fully tested, so it is recommended to use CUDA version >= 12.0 + +## 🚀 Getting Started + +1. Clone the repository: + + ```bash + git clone https://github.com/OleehyO/TexTeller + ``` + +2. [Installing pytorch](https://pytorch.org/get-started/locally/#start-locally) + +3. Install the project's dependencies: + + ```bash + pip install -r requirements.txt + ``` + +4. Enter the `TexTeller/src` directory and run the following command in the terminal to start inference: + + ```bash + python inference.py -img "/path/to/image.{jpg,png}" + # use --inference-mode option to enable GPU(cuda or mps) inference + #+e.g. python inference.py -img "./img.jpg" --inference-mode cuda + ``` + +> [!NOTE] +> The first time you run it, the required checkpoints will be downloaded from Hugging Face + +## 🌐 Web Demo + +Go to the `TexTeller/src` directory and run the following command: + +```bash +./start_web.sh +``` + +Enter `http://localhost:8501` in a browser to view the web demo. + +> [!NOTE] +> If you are Windows user, please run the `start_web.bat` file instead. + +## 🧠 Full Image Inference + +TexTeller also supports **formula detection and recognition** on full images, allowing for the detection of formulas throughout the image, followed by batch recognition of the formulas. + +### Download Weights + +English documentation formula detection [[link](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx?download=true)]: Trained on 8272 images from the [IBEM dataset](https://zenodo.org/records/4757865). + +Chinese documentation formula detection [[link](https://huggingface.co/TonyLee1256/texteller_det/blob/main/rtdetr_r50vd_6x_coco_trained_on_cn_textbook.onnx)]: Trained on 2560 Chinese textbook images (100+ layouts). + +### Formula Detection + +Run the following command in the `TexTeller/src` directory: + +```bash +python infer_det.py +``` + +Detects all formulas in the full image, and the results are saved in `TexTeller/src/subimages`. + +
+ +
+ +### Batch Formula Recognition + +After **formula detection**, run the following command in the `TexTeller/src` directory: + +```shell +python rec_infer_from_crop_imgs.py +``` + +This will use the results of the previous formula detection to perform batch recognition on all cropped formulas, saving the recognition results as txt files in `TexTeller/src/results`. + +## 📡 API Usage + +We use [ray serve](https://github.com/ray-project/ray) to provide an API interface for TexTeller, allowing you to integrate TexTeller into your own projects. To start the server, you first need to enter the `TexTeller/src` directory and then run the following command: + +```bash +python server.py # default settings +``` + +| Parameter | Description | +| --- | --- | +| `-ckpt` | The path to the weights file, *default is TexTeller's pretrained weights*.| +| `-tknz` | The path to the tokenizer, *default is TexTeller's tokenizer*.| +| `-port` | The server's service port, *default is 8000*. | +| `--inference-mode` | Whether to use GPU(cuda or mps) for inference, *default is CPU*. | +| `--num_beams` | The number of beams for beam search, *default is 1*. | +| `--num_replicas` | The number of service replicas to run on the server, *default is 1 replica*. You can use more replicas to achieve greater throughput.| +| `--ncpu_per_replica` | The number of CPU cores used per service replica, *default is 1*. | +| `--ngpu_per_replica` | The number of GPUs used per service replica, *default is 1*. You can set this value between 0 and 1 to run multiple service replicas on one GPU to share the GPU, thereby improving GPU utilization. (Note, if --num_replicas is 2, --ngpu_per_replica is 0.7, then 2 GPUs must be available) | + +> [!NOTE] +> A client demo can be found at `TexTeller/client/demo.py`, you can refer to `demo.py` to send requests to the server + +## 🏋️‍♂️ Training + +### Dataset + +We provide an example dataset in the `TexTeller/src/models/ocr_model/train/dataset` directory, you can place your own images in the `images` directory and annotate each image with its corresponding formula in `formulas.jsonl`. + +After preparing your dataset, you need to **change the `DIR_URL` variable to your own dataset's path** in `.../dataset/loader.py` + +### Retraining the Tokenizer + +If you are using a different dataset, you might need to retrain the tokenizer to obtain a different dictionary. After configuring your dataset, you can train your own tokenizer with the following command: + +1. In `TexTeller/src/models/tokenizer/train.py`, change `new_tokenizer.save_pretrained('./your_dir_name')` to your custom output directory + > If you want to use a different dictionary size (default is 10k tokens), you need to change the `VOCAB_SIZE` variable in `TexTeller/src/models/globals.py` + +2. **In the `TexTeller/src` directory**, run the following command: + + ```bash + python -m models.tokenizer.train + ``` + +### Training the Model + +To train the model, you need to run the following command in the `TexTeller/src` directory: + +```bash +python -m models.ocr_model.train.train +``` + +You can set your own tokenizer and checkpoint paths in `TexTeller/src/models/ocr_model/train/train.py` (refer to `train.py` for more information). If you are using the same architecture and dictionary as TexTeller, you can also fine-tune TexTeller's default weights with your own dataset. + +In `TexTeller/src/globals.py` and `TexTeller/src/models/ocr_model/train/train_args.py`, you can change the model's architecture and training hyperparameters. + +> [!NOTE] +> Our training scripts use the [Hugging Face Transformers](https://github.com/huggingface/transformers) library, so you can refer to their [documentation](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments) for more details and configurations on training parameters. + +## 🚧 Limitations + +* Does not support scanned images and PDF document recognition + +* Does not support handwritten formulas + +## 📅 Plans + +- [x] ~~Train the model with a larger dataset (7.5M samples, coming soon)~~ + +- [ ] Recognition of scanned images + +- [ ] PDF document recognition + Support for English and Chinese scenarios + +- [ ] Inference acceleration + +- [ ] ... + +## ⭐️ Stargazers over time + +[![Stargazers over time](https://starchart.cc/OleehyO/TexTeller.svg?variant=adaptive)](https://starchart.cc/OleehyO/TexTeller) + +## 💖 Acknowledgments + +Thanks to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which has brought me a lot of inspiration, and [im2latex-100K](https://zenodo.org/records/56198#.V2px0jXT6eA) which enriches our dataset. + +## 👥 Contributors + + + + + + diff --git a/assets/README_zh.md b/assets/README_zh.md new file mode 100644 index 0000000..b946d49 --- /dev/null +++ b/assets/README_zh.md @@ -0,0 +1,233 @@ +📄 English | 中文 + +
+

+ + 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 + +

+

+ 🤗 Hugging Face +

+ +
+ +https://github.com/OleehyO/TexTeller/assets/56267907/fb17af43-f2a5-47ce-ad1d-101db5fd7fbb + +TexTeller是一个基于ViT的端到端公式识别模型,可以把图片转换为对应的latex公式 + +TexTeller用了~~550K~~7.5M的图片-公式对进行训练(数据集可以在[这里](https://huggingface.co/datasets/OleehyO/latex-formulas)获取),相比于[LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)(使用了一个100K的数据集),TexTeller具有**更强的泛化能力**以及**更高的准确率**,可以覆盖大部分的使用场景(**扫描图片,手写公式除外**)。 + +> ~~我们马上就会发布一个使用7.5M数据集进行训练的TexTeller checkpoint~~ + +## 🔄 变更信息 + +* 📮[2024-03-25] TexTeller2.0发布!TexTeller2.0的训练数据增大到了7.5M(相较于TexTeller1.0**增加了~15倍**并且数据质量也有所改善)。训练后的TexTeller2.0在测试集中展现出了**更加优越的性能**,尤其在生僻符号、复杂多行、矩阵的识别场景中。 + + > 在[这里](./test.pdf)有更多的测试图片以及各家识别模型的横向对比。 + > +* 📮[2024-04-12] 训练了**公式检测模型**,从而增加了对整个文档进行公式检测+公式识别(整图推理)的功能! + +## 🔑 前置条件 + +python=3.10 + +[pytorch](https://pytorch.org/get-started/locally/) + +> [!WARNING] +> 只有CUDA版本>= 12.0被完全测试过,所以最好使用>= 12.0的CUDA版本 + +## 🚀 开搞 + +1. 克隆本仓库: + + ```bash + git clone https://github.com/OleehyO/TexTeller + ``` + +2. [安装pytorch](https://pytorch.org/get-started/locally/#start-locally) +3. 安装本项目的依赖包: + + ```bash + pip install -r requirements.txt + ``` + +4. 进入 `TexTeller/src`目录,在终端运行以下命令进行推理: + + ```bash + python inference.py -img "/path/to/image.{jpg,png}" + # use --inference-mode option to enable GPU(cuda or mps) inference + #+e.g. python inference.py -img "./img.jpg" --inference-mode cuda + ``` + +> [!NOTE] +> 第一次运行时会在hugging face上下载所需要的checkpoints + +## ❓ 常见问题:无法连接到Hugging Face + +默认情况下,会在Hugging Face中下载模型权重,**如果你的远端服务器无法连接到Hugging Face**,你可以通过以下命令进行加载: + +1. 安装huggingface hub包 + + ```bash + pip install -U "huggingface_hub[cli]" + ``` + +2. 在能连接Hugging Face的机器上下载模型权重: + + ```bash + huggingface-cli download OleehyO/TexTeller --include "*.json" "*.bin" "*.txt" --repo-type model --local-dir "your/dir/path" + ``` + +3. 把包含权重的目录上传远端服务器,然后把 `TexTeller/src/models/ocr_model/model/TexTeller.py`中的 `REPO_NAME = 'OleehyO/TexTeller'`修改为 `REPO_NAME = 'your/dir/path'` + +如果你还想在训练模型时开启evaluate,你需要提前下载metric脚本并上传远端服务器: + +1. 在能连接Hugging Face的机器上下载metric脚本 + + ```bash + huggingface-cli download evaluate-metric/google_bleu --repo-type space --local-dir "your/dir/path" + ``` + +2. 把这个目录上传远端服务器,并在 `TexTeller/src/models/ocr_model/utils/metrics.py`中把 `evaluate.load('google_bleu')`改为 `evaluate.load('your/dir/path/google_bleu.py')` + +## 🌐 网页演示 + +进入 `TexTeller/src` 目录,运行以下命令 + +```bash +./start_web.sh +``` + +在浏览器里输入 `http://localhost:8501`就可以看到web demo + +> [!NOTE] +> 对于Windows用户, 请运行 `start_web.bat`文件. + +## 🧠 整图推理 + +TexTeller还支持对整张图片进行**公式检测+公式识别**,从而对整图公式进行检测,然后进行批公式识别。 + +### 下载权重 + +英文文档公式检测 [[link](https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx?download=true)]:在8272张[IBEM数据集](https://zenodo.org/records/4757865)上训练得到 + +中文文档公式检测 [[link](https://huggingface.co/TonyLee1256/texteller_det/blob/main/rtdetr_r50vd_6x_coco_trained_on_cn_textbook.onnx)]:在2560张中文教材数据(100+版式)上训练得到 + +### 公式检测 + + `TexTeller/src`目录下运行以下命令 + +```bash +python infer_det.py +``` + +对整张图中的所有公式进行检测,结果保存在 `TexTeller/src/subimages` + +
+ +
+ +### 公式批识别 + +在进行**公式检测后**, `TexTeller/src`目录下运行以下命令 + +```shell +python rec_infer_from_crop_imgs.py +``` + +会基于上一步公式检测的结果,对裁剪出的所有公式进行批量识别,将识别结果在 `TexTeller/src/results`中保存为txt文件。 + +## 📡 API调用 + +我们使用[ray serve](https://github.com/ray-project/ray)来对外提供一个TexTeller的API接口,通过使用这个接口,你可以把TexTeller整合到自己的项目里。要想启动server,你需要先进入 `TexTeller/src`目录然后运行以下命令: + +```bash +python server.py +``` + +| 参数 | 描述 | +| - | - | +| `-ckpt` | 权重文件的路径,*默认为TexTeller的预训练权重*。 | +| `-tknz` | 分词器的路径,*默认为TexTeller的分词器*。 | +| `-port` | 服务器的服务端口,*默认是8000*。 | +| `--inference-mode`| 是否使用GPU(cuda或mps)推理,*默认为CPU*。 | +| `--num_beams` | beam search的beam数量,*默认是1*。 | +| `--num_replicas`| 在服务器上运行的服务副本数量,*默认1个副本*。你可以使用更多的副本来获取更大的吞吐量。 | +| `--ncpu_per_replica` | 每个服务副本所用的CPU核心数,*默认为1*。 | +| `--ngpu_per_replica` | 每个服务副本所用的GPU数量,*默认为1*。你可以把这个值设置成 0~1之间的数,这样会在一个GPU上运行多个服务副本来共享GPU,从而提高GPU的利用率。(注意,如果 --num_replicas 2, --ngpu_per_replica 0.7, 那么就必须要有2个GPU可用) | + +> [!NOTE] +> 一个客户端demo可以在 `TexTeller/client/demo.py`找到,你可以参考 `demo.py`来给server发送请求 + +## 🏋️‍♂️ 训练 + +### 数据集 + +我们在 `TexTeller/src/models/ocr_model/train/dataset`目录中提供了一个数据集的例子,你可以把自己的图片放在 `images`目录然后在 `formulas.jsonl`中为每张图片标注对应的公式。 + +准备好数据集后,你需要在 `.../dataset/loader.py`中把 **`DIR_URL`变量改成你自己数据集的路径** + +### 重新训练分词器 + +如果你使用了不一样的数据集,你可能需要重新训练tokenizer来得到一个不一样的字典。配置好数据集后,可以通过以下命令来训练自己的tokenizer: + +1. 在 `TexTeller/src/models/tokenizer/train.py`中,修改 `new_tokenizer.save_pretrained('./your_dir_name')`为你自定义的输出目录 + + > 注意:如果要用一个不一样大小的字典(默认1W个token),你需要在 `TexTeller/src/models/globals.py`中修改 `VOCAB_SIZE`变量 + > +2. **在 `TexTeller/src` 目录下**运行以下命令: + + ```bash + python -m models.tokenizer.train + ``` + +### 训练模型 + +要想训练模型, 你需要在 `TexTeller/src`目录下运行以下命令: + +```bash +python -m models.ocr_model.train.train +``` + +你可以在 `TexTeller/src/models/ocr_model/train/train.py`中设置自己的tokenizer和checkpoint路径(请参考 `train.py`)。如果你使用了与TexTeller一样的架构和相同的字典,你还可以用自己的数据集来微调TexTeller的默认权重。 + +在 `TexTeller/src/globals.py`和 `TexTeller/src/models/ocr_model/train/train_args.py`中,你可以改变模型的架构以及训练的超参数。 + +> [!NOTE] +> 我们的训练脚本使用了[Hugging Face Transformers](https://github.com/huggingface/transformers)库, 所以你可以参考他们提供的[文档](https://huggingface.co/docs/transformers/v4.32.1/main_classes/trainer#transformers.TrainingArguments)来获取更多训练参数的细节以及配置。 + +## 🚧 不足 + +* 不支持扫描图片以及PDF文档识别 +* 不支持手写体公式 + +## 📅 计划 + +- [X] ~~使用更大的数据集来训练模型(7.5M样本,即将发布)~~ + +- [ ] 扫描图片识别 + +- [ ] PDF文档识别 + 中英文场景支持 + +- [ ] 推理加速 + +- [ ] ... + +## ⭐️ 观星曲线 + +[![Stargazers over time](https://starchart.cc/OleehyO/TexTeller.svg?variant=adaptive)](https://starchart.cc/OleehyO/TexTeller) + +## 💖 感谢 + +Thanks to [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) which has brought me a lot of inspiration, and [im2latex-100K](https://zenodo.org/records/56198#.V2px0jXT6eA) which enriches our dataset. + +## 👥 贡献者 + + + + + + diff --git a/assets/css.css b/assets/css.css deleted file mode 100644 index 00122f6..0000000 --- a/assets/css.css +++ /dev/null @@ -1,157 +0,0 @@ -html { - font-family: Inter; - font-size: 16px; - font-weight: 400; - line-height: 1.5; - -webkit-text-size-adjust: 100%; - background: #fff; - color: #323232; - -webkit-font-smoothing: antialiased; - -moz-osx-font-smoothing: grayscale; - text-rendering: optimizeLegibility; -} - -:root { - --space: 1; - --vspace: calc(var(--space) * 1rem); - --vspace-0: calc(3 * var(--space) * 1rem); - --vspace-1: calc(2 * var(--space) * 1rem); - --vspace-2: calc(1.5 * var(--space) * 1rem); - --vspace-3: calc(0.5 * var(--space) * 1rem); -} - -.app { - max-width: 748px !important; -} - -.prose p { - margin: var(--vspace) 0; - line-height: var(--vspace * 2); - font-size: 1rem; -} - -code { - font-family: "inconsolata", sans-serif; - font-size: 16px; -} - -h1, -h1 code { - font-weight: 400; - line-height: calc(2.5 / var(--space) * var(--vspace)); -} - -h1 code { - background: none; - border: none; - letter-spacing: 0.05em; - padding-bottom: 5px; - position: relative; - padding: 0; -} - -h2 { - margin: var(--vspace-1) 0 var(--vspace-2) 0; - line-height: 1em; -} - -h3, -h3 code { - margin: var(--vspace-1) 0 var(--vspace-2) 0; - line-height: 1em; -} - -h4, -h5, -h6 { - margin: var(--vspace-3) 0 var(--vspace-3) 0; - line-height: var(--vspace); -} - -.bigtitle, -h1, -h1 code { - font-size: calc(8px * 4.5); - word-break: break-word; -} - -.title, -h2, -h2 code { - font-size: calc(8px * 3.375); - font-weight: lighter; - word-break: break-word; - border: none; - background: none; -} - -.subheading1, -h3, -h3 code { - font-size: calc(8px * 1.8); - font-weight: 600; - border: none; - background: none; - letter-spacing: 0.1em; - text-transform: uppercase; -} - -h2 code { - padding: 0; - position: relative; - letter-spacing: 0.05em; -} - -blockquote { - font-size: calc(8px * 1.1667); - font-style: italic; - line-height: calc(1.1667 * var(--vspace)); - margin: var(--vspace-2) var(--vspace-2); -} - -.subheading2, -h4 { - font-size: calc(8px * 1.4292); - text-transform: uppercase; - font-weight: 600; -} - -.subheading3, -h5 { - font-size: calc(8px * 1.2917); - line-height: calc(1.2917 * var(--vspace)); - - font-weight: lighter; - text-transform: uppercase; - letter-spacing: 0.15em; -} - -h6 { - font-size: calc(8px * 1.1667); - font-size: 1.1667em; - font-weight: normal; - font-style: italic; - font-family: "le-monde-livre-classic-byol", serif !important; - letter-spacing: 0px !important; -} - -#start .md > *:first-child { - margin-top: 0; -} - -h2 + h3 { - margin-top: 0; -} - -.md hr { - border: none; - border-top: 1px solid var(--block-border-color); - margin: var(--vspace-2) 0 var(--vspace-2) 0; -} -.prose ul { - margin: var(--vspace-2) 0 var(--vspace-1) 0; -} - -.gap { - gap: 0; -} \ No newline at end of file diff --git a/assets/det_rec.png b/assets/det_rec.png new file mode 100644 index 0000000..dbd1ffc Binary files /dev/null and b/assets/det_rec.png differ diff --git a/assets/fire.svg b/assets/fire.svg new file mode 100644 index 0000000..8f9f7eb --- /dev/null +++ b/assets/fire.svg @@ -0,0 +1,460 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/test.pdf b/assets/test.pdf new file mode 100644 index 0000000..f587024 Binary files /dev/null and b/assets/test.pdf differ diff --git a/assets/web_demo.gif b/assets/web_demo.gif new file mode 100644 index 0000000..0403d86 Binary files /dev/null and b/assets/web_demo.gif differ diff --git a/requirements.txt b/requirements.txt index e677590..c5f254c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,6 @@ transformers datasets evaluate -streamlit - -gradio - opencv-python ray[serve] accelerate @@ -12,4 +8,8 @@ tensorboardX nltk python-multipart -augraphy \ No newline at end of file +augraphy +onnxruntime + +streamlit==1.30 +streamlit-paste-button diff --git a/src/infer_det.py b/src/infer_det.py new file mode 100644 index 0000000..cc5da44 --- /dev/null +++ b/src/infer_det.py @@ -0,0 +1,197 @@ +import os +import yaml +import argparse +import numpy as np +import glob +from onnxruntime import InferenceSession +from tqdm import tqdm + +from models.det_model.preprocess import Compose +import cv2 + +# 注意:文件名要标准,最好都用下划线 + +# Global dictionary +SUPPORT_MODELS = { + 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', + 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', + 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet', + 'DETR' +} + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml", + default="./models/det_model/model/infer_cfg.yml" + ) +parser.add_argument('--onnx_file', type=str, help="onnx model file path", + default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx" + ) +parser.add_argument("--image_dir", type=str) +parser.add_argument("--image_file", type=str, default='/data/ljm/TexTeller/src/Tr00_0001015-page02.jpg') +parser.add_argument("--imgsave_dir", type=str, + default="." + ) + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--image_file or --image_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + print("Found {} inference images in total.".format(len(images))) + + return images + + +class PredictConfig(object): + """set config of preprocess, postprocess and visualize + Args: + infer_config (str): path of infer_cfg.yml + """ + + def __init__(self, infer_config): + # parsing Yaml config for Preprocess + with open(infer_config) as f: + yml_conf = yaml.safe_load(f) + self.check_model(yml_conf) + self.arch = yml_conf['arch'] + self.preprocess_infos = yml_conf['Preprocess'] + self.min_subgraph_size = yml_conf['min_subgraph_size'] + self.label_list = yml_conf['label_list'] + self.use_dynamic_shape = yml_conf['use_dynamic_shape'] + self.draw_threshold = yml_conf.get("draw_threshold", 0.5) + self.mask = yml_conf.get("mask", False) + self.tracker = yml_conf.get("tracker", None) + self.nms = yml_conf.get("NMS", None) + self.fpn_stride = yml_conf.get("fpn_stride", None) + + # 预定义颜色池 + color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] + # 根据label_list动态生成颜色映射 + self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)} + + if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): + print( + 'The RCNN export model is used for ONNX and it only supports batch_size = 1' + ) + self.print_config() + + def check_model(self, yml_conf): + """ + Raises: + ValueError: loaded model not in supported model type + """ + for support_model in SUPPORT_MODELS: + if support_model in yml_conf['arch']: + return True + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], SUPPORT_MODELS)) + + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + + +def draw_bbox(image, outputs, infer_config): + for output in outputs: + cls_id, score, xmin, ymin, xmax, ymax = output + if score > infer_config.draw_threshold: + # 获取类别名 + label = infer_config.label_list[int(cls_id)] + # 根据类别名获取颜色 + color = infer_config.colors[label] + cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) + cv2.putText(image, "{}: {:.2f}".format(label, score), + (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + return image + + +def predict_image(infer_config, predictor, img_list): + # load preprocess transforms + transforms = Compose(infer_config.preprocess_infos) + errImgList = [] + + # Check and create subimg_save_dir if not exist + subimg_save_dir = os.path.join(FLAGS.imgsave_dir, 'subimages') + os.makedirs(subimg_save_dir, exist_ok=True) + + # predict image + for img_path in tqdm(img_list): + img = cv2.imread(img_path) + if img is None: + print(f"Warning: Could not read image {img_path}. Skipping...") + errImgList.append(img_path) + continue + + inputs = transforms(img_path) + inputs_name = [var.name for var in predictor.get_inputs()] + inputs = {k: inputs[k][None, ] for k in inputs_name} + + outputs = predictor.run(output_names=None, input_feed=inputs) + + print("ONNXRuntime predict: ") + if infer_config.arch in ["HRNet"]: + print(np.array(outputs[0])) + else: + bboxes = np.array(outputs[0]) + for bbox in bboxes: + if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: + print(f"{int(bbox[0])} {bbox[1]} " + f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") + + # Save the subimages (crop from the original image) + subimg_counter = 1 + for output in np.array(outputs[0]): + cls_id, score, xmin, ymin, xmax, ymax = output + if score > infer_config.draw_threshold: + label = infer_config.label_list[int(cls_id)] + subimg = img[int(ymin):int(ymax), int(xmin):int(xmax)] + subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg" + subimg_path = os.path.join(subimg_save_dir, subimg_filename) + cv2.imwrite(subimg_path, subimg) + subimg_counter += 1 + + # Draw bounding boxes and save the image with bounding boxes + img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config) + output_dir = FLAGS.imgsave_dir + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path)) + cv2.imwrite(output_file, img_with_bbox) + + print("ErrorImgs:") + print(errImgList) + +if __name__ == '__main__': + FLAGS = parser.parse_args() + # load image list + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + # load predictor + predictor = InferenceSession(FLAGS.onnx_file) + # load infer config + infer_config = PredictConfig(FLAGS.infer_cfg) + + predict_image(infer_config, predictor, img_list) diff --git a/src/inference.py b/src/inference.py index c0e263c..bc7c006 100644 --- a/src/inference.py +++ b/src/inference.py @@ -19,10 +19,16 @@ if __name__ == '__main__': help='path to the input image' ) parser.add_argument( - '-cuda', - default=False, - action='store_true', - help='use cuda or not' + '--inference-mode', + type=str, + default='cpu', + help='Inference mode, select one of cpu, cuda, or mps' + ) + parser.add_argument( + '--num-beam', + type=int, + default=1, + help='number of beam search for decoding' ) # ================= new feature ================== parser.add_argument( @@ -37,6 +43,7 @@ if __name__ == '__main__': # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') latex_rec_model = TexTeller.from_pretrained() + latex_rec_model = TexTeller.from_pretrained() tokenizer = TexTeller.get_tokenizer() print('Model and tokenizer loaded.') @@ -44,7 +51,7 @@ if __name__ == '__main__': img = cv.imread(args.img) print('Inference...') if not args.mix: - res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda) + res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam) res = to_katex(res[0]) print(res) else: diff --git a/src/models/det_model/model/infer_cfg.yml b/src/models/det_model/model/infer_cfg.yml new file mode 100644 index 0000000..0c156fc --- /dev/null +++ b/src/models/det_model/model/infer_cfg.yml @@ -0,0 +1,27 @@ +mode: paddle +draw_threshold: 0.5 +metric: COCO +use_dynamic_shape: false +arch: DETR +min_subgraph_size: 3 +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 640 + - 640 + type: Resize +- mean: + - 0.0 + - 0.0 + - 0.0 + norm_type: none + std: + - 1.0 + - 1.0 + - 1.0 + type: NormalizeImage +- type: Permute +label_list: +- isolated +- embedding diff --git a/src/models/det_model/preprocess.py b/src/models/det_model/preprocess.py new file mode 100644 index 0000000..3554b7f --- /dev/null +++ b/src/models/det_model/preprocess.py @@ -0,0 +1,494 @@ +import numpy as np +import cv2 +import copy + + +def decode_image(img_path): + with open(img_path, 'rb') as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + img_info = { + "im_shape": np.array( + im.shape[:2], dtype=np.float32), + "scale_factor": np.array( + [1., 1.], dtype=np.float32) + } + return im, img_info + + +class Resize(object): + """resize image by target_size and max_size + Args: + target_size (int): the target size of image + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): method of resize + """ + + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + self.keep_ratio = keep_ratio + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + im_channel = im.shape[2] + im_scale_y, im_scale_x = self.generate_scale(im) + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') + im_info['scale_factor'] = np.array( + [im_scale_y, im_scale_x]).astype('float32') + return im, im_info + + def generate_scale(self, im): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +class NormalizeImage(object): + """normalize image + Args: + mean (list): im - mean + std (list): im / std + is_scale (bool): whether need im / 255 + norm_type (str): type in ['mean_std', 'none'] + """ + + def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): + self.mean = mean + self.std = std + self.is_scale = is_scale + self.norm_type = norm_type + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.astype(np.float32, copy=False) + if self.is_scale: + scale = 1.0 / 255.0 + im *= scale + + if self.norm_type == 'mean_std': + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im -= mean + im /= std + return im, im_info + + +class Permute(object): + """permute image + Args: + to_bgr (bool): whether convert RGB to BGR + channel_first (bool): whether convert HWC to CHW + """ + + def __init__(self, ): + super(Permute, self).__init__() + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.transpose((2, 0, 1)).copy() + return im, im_info + + +class PadStride(object): + """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config + Args: + stride (bool): model with FPN need image shape % stride == 0 + """ + + def __init__(self, stride=0): + self.coarsest_stride = stride + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + coarsest_stride = self.coarsest_stride + if coarsest_stride <= 0: + return im, im_info + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return padding_im, im_info + + +class LetterBoxResize(object): + def __init__(self, target_size): + """ + Resize image to target size, convert normalized xywh to pixel xyxy + format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). + Args: + target_size (int|list): image target size. + """ + super(LetterBoxResize, self).__init__() + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + + def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): + # letterbox: resize a rectangular image to a padded rectangular + shape = img.shape[:2] # [height, width] + ratio_h = float(height) / shape[0] + ratio_w = float(width) / shape[1] + ratio = min(ratio_h, ratio_w) + new_shape = (round(shape[1] * ratio), + round(shape[0] * ratio)) # [width, height] + padw = (width - new_shape[0]) / 2 + padh = (height - new_shape[1]) / 2 + top, bottom = round(padh - 0.1), round(padh + 0.1) + left, right = round(padw - 0.1), round(padw + 0.1) + + img = cv2.resize( + img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, + value=color) # padded rectangular + return img, ratio, padw, padh + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + height, width = self.target_size + h, w = im.shape[:2] + im, ratio, padw, padh = self.letterbox(im, height=height, width=width) + + new_shape = [round(h * ratio), round(w * ratio)] + im_info['im_shape'] = np.array(new_shape, dtype=np.float32) + im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) + return im, im_info + + +class Pad(object): + def __init__(self, size, fill_value=[114.0, 114.0, 114.0]): + """ + Pad image to a specified size. + Args: + size (list[int]): image target size + fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0) + """ + super(Pad, self).__init__() + if isinstance(size, int): + size = [size, size] + self.size = size + self.fill_value = fill_value + + def __call__(self, im, im_info): + im_h, im_w = im.shape[:2] + h, w = self.size + if h == im_h and w == im_w: + im = im.astype(np.float32) + return im, im_info + + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array(self.fill_value, dtype=np.float32) + canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) + im = canvas + return im, im_info + + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + + Returns: + np.ndarray: The 3rd point. + """ + assert len(a) == 2 + assert len(b) == 2 + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + + +def get_affine_transform(center, + input_size, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ]): Size of the destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + if not isinstance(input_size, (np.ndarray, list)): + input_size = np.array([input_size, input_size], dtype=np.float32) + scale_tmp = input_size + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +class WarpAffine(object): + """Warp affine the image + """ + + def __init__(self, + keep_res=False, + pad=31, + input_h=512, + input_w=512, + scale=0.4, + shift=0.1): + self.keep_res = keep_res + self.pad = pad + self.input_h = input_h + self.input_w = input_w + self.scale = scale + self.shift = shift + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + + h, w = img.shape[:2] + + if self.keep_res: + input_h = (h | self.pad) + 1 + input_w = (w | self.pad) + 1 + s = np.array([input_w, input_h], dtype=np.float32) + c = np.array([w // 2, h // 2], dtype=np.float32) + + else: + s = max(h, w) * 1.0 + input_h, input_w = self.input_h, self.input_w + c = np.array([w / 2., h / 2.], dtype=np.float32) + + trans_input = get_affine_transform(c, s, 0, [input_w, input_h]) + img = cv2.resize(img, (w, h)) + inp = cv2.warpAffine( + img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) + return inp, im_info + + +# keypoint preprocess +def get_warp_matrix(theta, size_input, size_dst, size_target): + """This code is based on + https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py + + Calculate the transformation matrix under the constraint of unbiased. + Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased + Data Processing for Human Pose Estimation (CVPR 2020). + + Args: + theta (float): Rotation angle in degrees. + size_input (np.ndarray): Size of input image [w, h]. + size_dst (np.ndarray): Size of output image [w, h]. + size_target (np.ndarray): Size of ROI in input plane [w, h]. + + Returns: + matrix (np.ndarray): A matrix for transformation. + """ + theta = np.deg2rad(theta) + matrix = np.zeros((2, 3), dtype=np.float32) + scale_x = size_dst[0] / size_target[0] + scale_y = size_dst[1] / size_target[1] + matrix[0, 0] = np.cos(theta) * scale_x + matrix[0, 1] = -np.sin(theta) * scale_x + matrix[0, 2] = scale_x * ( + -0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] * + np.sin(theta) + 0.5 * size_target[0]) + matrix[1, 0] = np.sin(theta) * scale_y + matrix[1, 1] = np.cos(theta) * scale_y + matrix[1, 2] = scale_y * ( + -0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] * + np.cos(theta) + 0.5 * size_target[1]) + return matrix + + +class TopDownEvalAffine(object): + """apply affine transform to image and coords + + Args: + trainsize (list): [w, h], the standard size used to train + use_udp (bool): whether to use Unbiased Data Processing. + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the image and coords after tranformed + + """ + + def __init__(self, trainsize, use_udp=False): + self.trainsize = trainsize + self.use_udp = use_udp + + def __call__(self, image, im_info): + rot = 0 + imshape = im_info['im_shape'][::-1] + center = im_info['center'] if 'center' in im_info else imshape / 2. + scale = im_info['scale'] if 'scale' in im_info else imshape + if self.use_udp: + trans = get_warp_matrix( + rot, center * 2.0, + [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + else: + trans = get_affine_transform(center, scale, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + + return image, im_info + + +class Compose: + def __init__(self, transforms): + self.transforms = [] + for op_info in transforms: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + self.transforms.append(eval(op_type)(**new_op_info)) + + def __call__(self, img_path): + img, im_info = decode_image(img_path) + for t in self.transforms: + img, im_info = t(img, im_info) + inputs = copy.deepcopy(im_info) + inputs['image'] = img + return inputs diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py index fcff742..63273a8 100644 --- a/src/models/ocr_model/utils/inference.py +++ b/src/models/ocr_model/utils/inference.py @@ -13,8 +13,8 @@ from models.globals import MAX_TOKEN_SIZE def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, - imgs: Union[List[str], List[np.ndarray]], - use_cuda: bool, + imgs_path: Union[List[str], List[np.ndarray]], + inf_mode: str = 'cpu', num_beams: int = 1, ) -> List[str]: model.eval() @@ -26,9 +26,8 @@ def inference( imgs = inference_transform(imgs) pixel_values = torch.stack(imgs) - if use_cuda: - model = model.to('cuda') - pixel_values = pixel_values.to('cuda') + model = model.to(inf_mode) + pixel_values = pixel_values.to(inf_mode) generate_config = GenerationConfig( max_new_tokens=MAX_TOKEN_SIZE, diff --git a/src/rec_infer_from_crop_imgs.py b/src/rec_infer_from_crop_imgs.py new file mode 100644 index 0000000..1cacbc6 --- /dev/null +++ b/src/rec_infer_from_crop_imgs.py @@ -0,0 +1,59 @@ +import os +import argparse +import cv2 as cv +from pathlib import Path +from utils import to_katex +from models.ocr_model.utils.inference import inference as latex_inference +from models.ocr_model.model.TexTeller import TexTeller + + +if __name__ == '__main__': + os.chdir(Path(__file__).resolve().parent) + parser = argparse.ArgumentParser() + parser.add_argument( + '-img', + type=str, + required=True, + help='path to the input image' + ) + parser.add_argument( + '--inference-mode', + type=str, + default='cpu', + help='Inference mode, select one of cpu, cuda, or mps' + ) + parser.add_argument( + '--num-beam', + type=int, + default=1, + help='number of beam search for decoding' + ) + + args = parser.parse_args() + + print('Loading model and tokenizer...') + latex_rec_model = TexTeller.from_pretrained() + tokenizer = TexTeller.get_tokenizer() + print('Model and tokenizer loaded.') + + # Create the output directory if it doesn't exist + os.makedirs(args.output_dir, exist_ok=True) + + # Loop through all images in the input directory + for filename in os.listdir(args.img_dir): + img_path = os.path.join(args.img_dir, filename) + img = cv.imread(img_path) + + if img is not None: + print(f'Inference for {filename}...') + res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam) + res = to_katex(res[0]) + + # Save the recognition result to a text file + output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt') + with open(output_file, 'w') as f: + f.write(res) + + print(f'Result saved to {output_file}') + else: + print(f"Warning: Could not read image {img_path}. Skipping...") diff --git a/src/server.py b/src/server.py index 11adfa3..520d908 100644 --- a/src/server.py +++ b/src/server.py @@ -23,8 +23,8 @@ parser.add_argument('--num_replicas', type=int, default=1) parser.add_argument('--ncpu_per_replica', type=float, default=1.0) parser.add_argument('--ngpu_per_replica', type=float, default=0.0) -parser.add_argument('--use_cuda', action='store_true', default=False) -parser.add_argument('--num_beam', type=int, default=1) +parser.add_argument('--inference-mode', type=str, default='cpu') +parser.add_argument('--num_beams', type=int, default=1) args = parser.parse_args() if args.ngpu_per_replica > 0 and not args.use_cuda: @@ -43,18 +43,21 @@ class TexTellerServer: self, checkpoint_path: str, tokenizer_path: str, - use_cuda: bool = False, - num_beam: int = 1 + inf_mode: str = 'cpu', + num_beams: int = 1 ) -> None: self.model = TexTeller.from_pretrained(checkpoint_path) self.tokenizer = TexTeller.get_tokenizer(tokenizer_path) - self.use_cuda = use_cuda - self.num_beam = num_beam + self.inf_mode = inf_mode + self.num_beams = num_beams - self.model = self.model.to('cuda') if use_cuda else self.model + self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model def predict(self, image_nparray) -> str: - return inference(self.model, self.tokenizer, [image_nparray], self.use_cuda, self.num_beam)[0] + return inference( + self.model, self.tokenizer, [image_nparray], + inf_mode=self.inf_mode, num_beams=self.num_beams + )[0] @serve.deployment() @@ -78,7 +81,11 @@ if __name__ == '__main__': tknz_dir = args.tokenizer_dir serve.start(http_options={"port": args.server_port}) - texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam) + texteller_server = TexTellerServer.bind( + ckpt_dir, tknz_dir, + inf_mode=args.inference_mode, + num_beams=args.num_beams + ) ingress = Ingress.bind(texteller_server) ingress_handle = serve.run(ingress, route_prefix="/predict") diff --git a/src/start_web.bat b/src/start_web.bat new file mode 100644 index 0000000..e235cca --- /dev/null +++ b/src/start_web.bat @@ -0,0 +1,9 @@ +@echo off +SETLOCAL ENABLEEXTENSIONS + +set CHECKPOINT_DIR=default +set TOKENIZER_DIR=default + +streamlit run web.py + +ENDLOCAL diff --git a/src/start_web.sh b/src/start_web.sh index 7dab5d2..6ec8f7b 100755 --- a/src/start_web.sh +++ b/src/start_web.sh @@ -1,10 +1,7 @@ #!/usr/bin/env bash set -exu -export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-460000" -# export CHECKPOINT_DIR="default" -export TOKENIZER_DIR="/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas" -export USE_CUDA=True # True or False (case-sensitive) -export NUM_BEAM=3 +export CHECKPOINT_DIR="default" +export TOKENIZER_DIR="default" streamlit run web.py diff --git a/src/web.py b/src/web.py index 5eb2413..379a609 100644 --- a/src/web.py +++ b/src/web.py @@ -6,16 +6,22 @@ import shutil import streamlit as st from PIL import Image +from streamlit_paste_button import paste_image_button as pbutton from models.ocr_model.utils.inference import inference from models.ocr_model.model.TexTeller import TexTeller from utils import to_katex +st.set_page_config( + page_title="TexTeller", + page_icon="🧮" +) + html_string = '''

- - TexTeller - + + 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 +

''' @@ -35,8 +41,6 @@ fail_gif_html = ''' ''' - - @st.cache_resource def get_model(): return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR']) @@ -52,6 +56,12 @@ def get_image_base64(img_file): img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() +def on_file_upload(): + st.session_state["UPLOADED_FILE_CHANGED"] = True + +def change_side_bar(): + st.session_state["CHANGE_SIDEBAR_FLAG"] = True + model = get_model() tokenizer = get_tokenizer() @@ -59,37 +69,106 @@ if "start" not in st.session_state: st.session_state["start"] = 1 st.toast('Hooray!', icon='🎉') +if "UPLOADED_FILE_CHANGED" not in st.session_state: + st.session_state["UPLOADED_FILE_CHANGED"] = False -# ============================ pages =============================== # +if "CHANGE_SIDEBAR_FLAG" not in st.session_state: + st.session_state["CHANGE_SIDEBAR_FLAG"] = False + +# ============================ begin sidebar =============================== # + +with st.sidebar: + num_beams = 1 + inf_mode = 'cpu' + + st.markdown("# 🔨️ Config") + st.markdown("") + + model_type = st.selectbox( + "Model type", + ("TexTeller", "None"), + on_change=change_side_bar + ) + if model_type == "TexTeller": + num_beams = st.number_input( + 'Number of beams', + min_value=1, + max_value=20, + step=1, + on_change=change_side_bar + ) + + inf_mode = st.radio( + "Inference mode", + ("cpu", "cuda", "mps"), + on_change=change_side_bar + ) + +# ============================ end sidebar =============================== # + + + +# ============================ begin pages =============================== # st.markdown(html_string, unsafe_allow_html=True) -uploaded_file = st.file_uploader("",type=['jpg', 'png', 'pdf']) +uploaded_file = st.file_uploader( + " ", + type=['jpg', 'png'], + on_change=on_file_upload +) + +paste_result = pbutton( + label="📋 Paste an image", + background_color="#5BBCFF", + hover_background_color="#3498db", +) +st.write("") + +if st.session_state["CHANGE_SIDEBAR_FLAG"] == True: + st.session_state["CHANGE_SIDEBAR_FLAG"] = False +elif uploaded_file or paste_result.image_data is not None: + if st.session_state["UPLOADED_FILE_CHANGED"] == False and paste_result.image_data is not None: + uploaded_file = io.BytesIO() + paste_result.image_data.save(uploaded_file, format='PNG') + uploaded_file.seek(0) + + if st.session_state["UPLOADED_FILE_CHANGED"] == True: + st.session_state["UPLOADED_FILE_CHANGED"] = False -if uploaded_file: img = Image.open(uploaded_file) temp_dir = tempfile.mkdtemp() png_file_path = os.path.join(temp_dir, 'image.png') img.save(png_file_path, 'PNG') - img_base64 = get_image_base64(uploaded_file) + with st.container(height=300): + img_base64 = get_image_base64(uploaded_file) + st.markdown(f""" + +
+ Input image +
+ """, unsafe_allow_html=True) st.markdown(f"""
- Input image

Input image ({img.height}✖️{img.width})

""", unsafe_allow_html=True) @@ -102,15 +181,28 @@ if uploaded_file: model, tokenizer, [png_file_path], - True if os.environ['USE_CUDA'] == 'True' else False, - int(os.environ['NUM_BEAM']) + inf_mode=inf_mode, + num_beams=num_beams )[0] st.success('Completed!', icon="✅") st.markdown(suc_gif_html, unsafe_allow_html=True) katex_res = to_katex(TexTeller_result) - st.text_area(":red[Predicted formula]", katex_res, height=150) + st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150) st.latex(katex_res) + st.write("") + st.write("") + + with st.expander(":star2: :gray[Tips for better results]"): + st.markdown(''' + * :mag_right: Use a clear and high-resolution image. + * :scissors: Crop images as accurately as possible. + * :jigsaw: Split large multi line formulas into smaller ones. + * :page_facing_up: Use images with **white background and black text** as much as possible. + * :book: Use a font with good readability. + ''') shutil.rmtree(temp_dir) -# ============================ pages =============================== # + paste_result.image_data = None + +# ============================ end pages =============================== #