完成了web,ray server,重构了代码
This commit is contained in:
16
src/client_demo.py
Normal file
16
src/client_demo.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import requests
|
||||
|
||||
# 服务的 URL
|
||||
url = "http://127.0.0.1:9900/predict"
|
||||
|
||||
# 替换成你要预测的图像的路径
|
||||
img_path = "/home/lhy/code/TeXify/src/7.png"
|
||||
|
||||
# 构造请求数据
|
||||
data = {"img_path": img_path}
|
||||
|
||||
# 发送 POST 请求
|
||||
response = requests.post(url, json=data)
|
||||
|
||||
# 打印响应
|
||||
print(response.text)
|
||||
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-img',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to the input image'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-cuda',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='use cuda or not'
|
||||
)
|
||||
|
||||
args = parser.parse_args([
|
||||
'-img', './models/ocr_model/test_img/1.png',
|
||||
'-cuda'
|
||||
])
|
||||
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
model = TexTeller.from_pretrained('./models/ocr_model/model_checkpoint')
|
||||
tokenizer = TexTeller.get_tokenizer('./models/tokenizer/roberta-tokenizer-550K')
|
||||
|
||||
# base = '/home/lhy/code/TeXify/src/models/ocr_model/test_img'
|
||||
# img_path = [base + f'/{i}.png' for i in range(7, 12)]
|
||||
img_path = [args.img]
|
||||
|
||||
res = inference(model, tokenizer, img_path, args.cuda)
|
||||
print(res[0])
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
|
||||
from .model.TexTeller import TexTeller
|
||||
from .utils.transforms import inference_transform
|
||||
from .utils.helpers import convert2rgb
|
||||
from ...globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=3,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference()
|
||||
@@ -1,6 +1,7 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from ....globals import (
|
||||
from models.globals import (
|
||||
VOCAB_SIZE,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_IMG_CHANNELS,
|
||||
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
return VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
|
||||
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# texteller = TexTeller()
|
||||
from ..inference import inference
|
||||
from ..utils.inference import inference
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
@@ -82,10 +82,10 @@ if __name__ == '__main__':
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
enable_train = False
|
||||
enable_evaluate = True
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate:
|
||||
|
||||
@@ -4,12 +4,14 @@ from typing import List
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
# 输出的np.ndarray的格式为:[H, W, C](通道在第三维),通道的排列顺序为RGB
|
||||
processed_images = []
|
||||
|
||||
for path in image_paths:
|
||||
# 读取图片
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
# 如果是 BGR (3通道), 转换为 RGB
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(Image.fromarray(image))
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
39
src/models/ocr_model/utils/inference.py
Normal file
39
src/models/ocr_model/utils/inference.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.transforms import inference_transform
|
||||
from models.ocr_model.utils.helpers import convert2rgb
|
||||
from models.globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs_path: List[str],
|
||||
use_cuda: bool,
|
||||
num_beams: int = 1,
|
||||
) -> List[str]:
|
||||
model.eval()
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if use_cuda:
|
||||
model = model.to('cuda')
|
||||
pixel_values = pixel_values.to('cuda')
|
||||
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
@@ -7,7 +7,7 @@ from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
|
||||
from ....globals import (
|
||||
from ...globals import (
|
||||
OCR_IMG_CHANNELS,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_FIX_SIZE,
|
||||
|
||||
@@ -10,7 +10,7 @@ from transformers import (
|
||||
|
||||
from ..utils import preprocess_fn
|
||||
from ..model.Resizer import Resizer
|
||||
from ....globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
|
||||
from ...globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
|
||||
|
||||
|
||||
def train():
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from PIL import Image, ImageChops
|
||||
from ....globals import (
|
||||
from ...globals import (
|
||||
IMAGE_MEAN, IMAGE_STD,
|
||||
LABEL_RATIO,
|
||||
RESIZER_IMG_SIZE,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datasets import load_dataset
|
||||
from ...ocr_model.model.TexTeller import TexTeller
|
||||
from ....globals import VOCAB_SIZE
|
||||
from ...globals import VOCAB_SIZE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
91
src/server.py
Normal file
91
src/server.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from starlette.requests import Request
|
||||
from ray import serve
|
||||
from ray.serve.handle import DeploymentHandle
|
||||
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-ckpt', '--checkpoint_dir', type=str, required=True)
|
||||
parser.add_argument('-tknz', '--tokenizer_dir', type=str, required=True)
|
||||
|
||||
parser.add_argument('-port', '--server_port', type=int, default=8000)
|
||||
parser.add_argument('--num_replicas', type=int, default=1)
|
||||
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
|
||||
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
|
||||
|
||||
parser.add_argument('--use_cuda', action='store_true', default=False)
|
||||
parser.add_argument('--num_beam', type=int, default=1)
|
||||
|
||||
# args = parser.parse_args()
|
||||
args = parser.parse_args([
|
||||
'--checkpoint_dir', '/home/lhy/code/TeXify/src/models/ocr_model/model_checkpoint',
|
||||
'--tokenizer_dir', '/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550K',
|
||||
|
||||
'--server_port', '9900',
|
||||
'--num_replicas', '1',
|
||||
'--ncpu_per_replica', '1.0',
|
||||
'--ngpu_per_replica', '0.0',
|
||||
|
||||
# '--use_cuda',
|
||||
'--num_beam', '1'
|
||||
])
|
||||
|
||||
if args.ngpu_per_replica > 0 and not args.use_cuda:
|
||||
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
|
||||
|
||||
|
||||
@serve.deployment(
|
||||
num_replicas=args.num_replicas,
|
||||
ray_actor_options={
|
||||
"num_cpus": args.ncpu_per_replica,
|
||||
"num_gpus": args.ngpu_per_replica
|
||||
}
|
||||
)
|
||||
class TexTellerServer:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
tokenizer_path: str,
|
||||
use_cuda: bool = False,
|
||||
num_beam: int = 1
|
||||
) -> None:
|
||||
self.model = TexTeller.from_pretrained(checkpoint_path)
|
||||
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
|
||||
self.use_cuda = use_cuda
|
||||
self.num_beam = num_beam
|
||||
|
||||
self.model = self.model.to('cuda') if use_cuda else self.model
|
||||
|
||||
def predict(self, image_path: str) -> str:
|
||||
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0]
|
||||
|
||||
|
||||
@serve.deployment()
|
||||
class Ingress:
|
||||
def __init__(self, texteller_server: DeploymentHandle) -> None:
|
||||
self.texteller_server = texteller_server
|
||||
|
||||
async def __call__(self, request: Request) -> str:
|
||||
msg = await request.json()
|
||||
img_path: str = msg['img_path']
|
||||
pred = await self.texteller_server.predict.remote(img_path)
|
||||
return pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ckpt_dir = args.checkpoint_dir
|
||||
tknz_dir = args.tokenizer_dir
|
||||
|
||||
serve.start(http_options={"port": args.server_port}) # 启动一个Ray集群,端口号为9900
|
||||
texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
|
||||
ingress = Ingress.bind(texteller_server)
|
||||
|
||||
ingress_handle = serve.run(ingress, route_prefix="/predict")
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
9
src/start_web.sh
Executable file
9
src/start_web.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
set -exu
|
||||
|
||||
export CHECKPOINT_DIR=/home/lhy/code/TeXify/src/models/ocr_model/model_checkpoint
|
||||
export TOKENIZER_DIR=/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550K
|
||||
export USE_CUDA=False
|
||||
export NUM_BEAM=3
|
||||
|
||||
streamlit run web.py
|
||||
40
src/web.py
40
src/web.py
@@ -1,14 +1,25 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import requests
|
||||
import tempfile
|
||||
import streamlit as st
|
||||
|
||||
from PIL import Image
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
def post_image(server_url, img_rb):
|
||||
response = requests.post(server_url, files={'image': img_rb})
|
||||
return response.text
|
||||
|
||||
@st.cache_resource
|
||||
def get_model():
|
||||
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
|
||||
|
||||
@st.cache_resource
|
||||
def get_tokenizer():
|
||||
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
||||
|
||||
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# ============================ pages =============================== #
|
||||
# 使用 Markdown 和 HTML 将标题居中
|
||||
@@ -41,6 +52,12 @@ uploaded_file = st.file_uploader("",type=['jpg', 'png'])
|
||||
if uploaded_file:
|
||||
# 打开上传图片
|
||||
img = Image.open(uploaded_file)
|
||||
|
||||
# 使用tempfile创建临时目录
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
png_file_path = os.path.join(temp_dir, 'image.png')
|
||||
img.save(png_file_path, 'PNG')
|
||||
|
||||
# st.image(uploaded_file, caption=f"Input image ({img.height}✖️{img.width})")
|
||||
|
||||
# 将 BytesIO 对象转换为 Base64 编码
|
||||
@@ -78,10 +95,15 @@ if uploaded_file:
|
||||
# 预测
|
||||
with st.spinner("Predicting..."):
|
||||
# 预测结果
|
||||
server_url = 'http://localhost:8000/'
|
||||
server_url = 'http://localhost:9900/'
|
||||
uploaded_file.seek(0)
|
||||
TeXTeller_result = post_image(server_url, uploaded_file)
|
||||
TeXTeller_result = r"\begin{align*}" + '\n' + TeXTeller_result + '\n' + r'\end{align*}'
|
||||
TeXTeller_result = inference(
|
||||
model,
|
||||
tokenizer,
|
||||
[png_file_path],
|
||||
bool(os.environ['USE_CUDA']),
|
||||
int(os.environ['NUM_BEAM'])
|
||||
)[0]
|
||||
# tab1, tab2 = st.tabs(["✨TeXTeller✨", "pix2tex:gray[(9.6K⭐)️]"])
|
||||
tab1, tab2 = st.tabs(["🔥👁️", "pix2tex:gray[(9.6K⭐)️]"])
|
||||
# with st.container(border=True):
|
||||
@@ -91,4 +113,4 @@ if uploaded_file:
|
||||
st.code(TeXTeller_result, language='latex')
|
||||
st.success('Done!')
|
||||
|
||||
# ============================ pages =============================== #
|
||||
# ============================ pages =============================== #
|
||||
|
||||
Reference in New Issue
Block a user