完成了web,ray server,重构了代码

This commit is contained in:
三洋三洋
2024-02-08 13:48:34 +00:00
parent 07c4c3dc01
commit 04b99b8451
20 changed files with 245 additions and 57 deletions

16
src/client_demo.py Normal file
View File

@@ -0,0 +1,16 @@
import requests
# 服务的 URL
url = "http://127.0.0.1:9900/predict"
# 替换成你要预测的图像的路径
img_path = "/home/lhy/code/TeXify/src/7.png"
# 构造请求数据
data = {"img_path": img_path}
# 发送 POST 请求
response = requests.post(url, json=data)
# 打印响应
print(response.text)

View File

@@ -0,0 +1,40 @@
import os
import argparse
from pathlib import Path
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-img',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'-cuda',
default=False,
action='store_true',
help='use cuda or not'
)
args = parser.parse_args([
'-img', './models/ocr_model/test_img/1.png',
'-cuda'
])
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
model = TexTeller.from_pretrained('./models/ocr_model/model_checkpoint')
tokenizer = TexTeller.get_tokenizer('./models/tokenizer/roberta-tokenizer-550K')
# base = '/home/lhy/code/TeXify/src/models/ocr_model/test_img'
# img_path = [base + f'/{i}.png' for i in range(7, 12)]
img_path = [args.img]
res = inference(model, tokenizer, img_path, args.cuda)
print(res[0])

View File

@@ -1,34 +0,0 @@
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from PIL import Image
from typing import List
from .model.TexTeller import TexTeller
from .utils.transforms import inference_transform
from .utils.helpers import convert2rgb
from ...globals import MAX_TOKEN_SIZE
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=3,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res
if __name__ == '__main__':
inference()

View File

@@ -1,6 +1,7 @@
from PIL import Image
from pathlib import Path
from ....globals import (
from models.globals import (
VOCAB_SIZE,
OCR_IMG_SIZE,
OCR_IMG_CHANNELS,
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
@classmethod
def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
if __name__ == "__main__":
# texteller = TexTeller()
from ..inference import inference
from ..utils.inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')

View File

@@ -11,7 +11,7 @@ from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
@@ -82,10 +82,10 @@ if __name__ == '__main__':
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
enable_train = True
enable_evaluate = False
enable_train = False
enable_evaluate = True
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate:

View File

@@ -4,12 +4,14 @@ from typing import List
from PIL import Image
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
# 输出的np.ndarray的格式为[H, W, C]通道在第三维通道的排列顺序为RGB
processed_images = []
for path in image_paths:
# 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
# 如果是 BGR (3通道), 转换为 RGB
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image))
processed_images.append(image)
return processed_images

View File

@@ -0,0 +1,39 @@
import torch
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: List[str],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -7,7 +7,7 @@ from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from ....globals import (
from ...globals import (
OCR_IMG_CHANNELS,
OCR_IMG_SIZE,
OCR_FIX_SIZE,

View File

@@ -10,7 +10,7 @@ from transformers import (
from ..utils import preprocess_fn
from ..model.Resizer import Resizer
from ....globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
from ...globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
def train():

View File

@@ -2,7 +2,7 @@ import torch
from torchvision.transforms import v2
from PIL import Image, ImageChops
from ....globals import (
from ...globals import (
IMAGE_MEAN, IMAGE_STD,
LABEL_RATIO,
RESIZER_IMG_SIZE,

View File

@@ -1,6 +1,6 @@
from datasets import load_dataset
from ...ocr_model.model.TexTeller import TexTeller
from ....globals import VOCAB_SIZE
from ...globals import VOCAB_SIZE
if __name__ == '__main__':

91
src/server.py Normal file
View File

@@ -0,0 +1,91 @@
import argparse
import time
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
parser = argparse.ArgumentParser()
parser.add_argument('-ckpt', '--checkpoint_dir', type=str, required=True)
parser.add_argument('-tknz', '--tokenizer_dir', type=str, required=True)
parser.add_argument('-port', '--server_port', type=int, default=8000)
parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--use_cuda', action='store_true', default=False)
parser.add_argument('--num_beam', type=int, default=1)
# args = parser.parse_args()
args = parser.parse_args([
'--checkpoint_dir', '/home/lhy/code/TeXify/src/models/ocr_model/model_checkpoint',
'--tokenizer_dir', '/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550K',
'--server_port', '9900',
'--num_replicas', '1',
'--ncpu_per_replica', '1.0',
'--ngpu_per_replica', '0.0',
# '--use_cuda',
'--num_beam', '1'
])
if args.ngpu_per_replica > 0 and not args.use_cuda:
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica
}
)
class TexTellerServer:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
use_cuda: bool = False,
num_beam: int = 1
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.use_cuda = use_cuda
self.num_beam = num_beam
self.model = self.model.to('cuda') if use_cuda else self.model
def predict(self, image_path: str) -> str:
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0]
@serve.deployment()
class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server
async def __call__(self, request: Request) -> str:
msg = await request.json()
img_path: str = msg['img_path']
pred = await self.texteller_server.predict.remote(img_path)
return pred
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port}) # 启动一个Ray集群端口号为9900
texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
ingress = Ingress.bind(texteller_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
while True:
time.sleep(1)

9
src/start_web.sh Executable file
View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -exu
export CHECKPOINT_DIR=/home/lhy/code/TeXify/src/models/ocr_model/model_checkpoint
export TOKENIZER_DIR=/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550K
export USE_CUDA=False
export NUM_BEAM=3
streamlit run web.py

View File

@@ -1,14 +1,25 @@
import streamlit as st
import os
import io
import base64
import requests
import tempfile
import streamlit as st
from PIL import Image
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
def post_image(server_url, img_rb):
response = requests.post(server_url, files={'image': img_rb})
return response.text
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
model = get_model()
tokenizer = get_tokenizer()
# ============================ pages =============================== #
# 使用 Markdown 和 HTML 将标题居中
@@ -41,6 +52,12 @@ uploaded_file = st.file_uploader("",type=['jpg', 'png'])
if uploaded_file:
# 打开上传图片
img = Image.open(uploaded_file)
# 使用tempfile创建临时目录
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
# st.image(uploaded_file, caption=f"Input image ({img.height}✖️{img.width})")
# 将 BytesIO 对象转换为 Base64 编码
@@ -78,10 +95,15 @@ if uploaded_file:
# 预测
with st.spinner("Predicting..."):
# 预测结果
server_url = 'http://localhost:8000/'
server_url = 'http://localhost:9900/'
uploaded_file.seek(0)
TeXTeller_result = post_image(server_url, uploaded_file)
TeXTeller_result = r"\begin{align*}" + '\n' + TeXTeller_result + '\n' + r'\end{align*}'
TeXTeller_result = inference(
model,
tokenizer,
[png_file_path],
bool(os.environ['USE_CUDA']),
int(os.environ['NUM_BEAM'])
)[0]
# tab1, tab2 = st.tabs(["✨TeXTeller✨", "pix2tex:gray[(9.6K⭐)]"])
tab1, tab2 = st.tabs(["🔥👁️", "pix2tex:gray[(9.6K⭐)]"])
# with st.container(border=True):
@@ -91,4 +113,4 @@ if uploaded_file:
st.code(TeXTeller_result, language='latex')
st.success('Done!')
# ============================ pages =============================== #
# ============================ pages =============================== #