完成了web,ray server,重构了代码

This commit is contained in:
三洋三洋
2024-02-08 13:48:34 +00:00
parent 07c4c3dc01
commit 04b99b8451
20 changed files with 245 additions and 57 deletions

16
src/client_demo.py Normal file
View File

@@ -0,0 +1,16 @@
import requests
# 服务的 URL
url = "http://127.0.0.1:9900/predict"
# 替换成你要预测的图像的路径
img_path = "/home/lhy/code/TeXify/src/7.png"
# 构造请求数据
data = {"img_path": img_path}
# 发送 POST 请求
response = requests.post(url, json=data)
# 打印响应
print(response.text)

View File

@@ -0,0 +1,40 @@
import os
import argparse
from pathlib import Path
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-img',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'-cuda',
default=False,
action='store_true',
help='use cuda or not'
)
args = parser.parse_args([
'-img', './models/ocr_model/test_img/1.png',
'-cuda'
])
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
model = TexTeller.from_pretrained('./models/ocr_model/model_checkpoint')
tokenizer = TexTeller.get_tokenizer('./models/tokenizer/roberta-tokenizer-550K')
# base = '/home/lhy/code/TeXify/src/models/ocr_model/test_img'
# img_path = [base + f'/{i}.png' for i in range(7, 12)]
img_path = [args.img]
res = inference(model, tokenizer, img_path, args.cuda)
print(res[0])

View File

@@ -1,34 +0,0 @@
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from PIL import Image
from typing import List
from .model.TexTeller import TexTeller
from .utils.transforms import inference_transform
from .utils.helpers import convert2rgb
from ...globals import MAX_TOKEN_SIZE
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=3,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res
if __name__ == '__main__':
inference()

View File

@@ -1,6 +1,7 @@
from PIL import Image from PIL import Image
from pathlib import Path
from ....globals import ( from models.globals import (
VOCAB_SIZE, VOCAB_SIZE,
OCR_IMG_SIZE, OCR_IMG_SIZE,
OCR_IMG_CHANNELS, OCR_IMG_CHANNELS,
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
@classmethod @classmethod
def from_pretrained(cls, model_path: str): def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path) model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod @classmethod
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast: def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path) tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
if __name__ == "__main__": if __name__ == "__main__":
# texteller = TexTeller() # texteller = TexTeller()
from ..inference import inference from ..utils.inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')

View File

@@ -11,7 +11,7 @@ from .training_args import CONFIG
from ..model.TexTeller import TexTeller from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
@@ -82,10 +82,10 @@ if __name__ == '__main__':
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller() # model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000') model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
enable_train = True enable_train = False
enable_evaluate = False enable_evaluate = True
if enable_train: if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate: if enable_evaluate:

View File

@@ -4,12 +4,14 @@ from typing import List
from PIL import Image from PIL import Image
def convert2rgb(image_paths: List[str]) -> List[Image.Image]: def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
# 输出的np.ndarray的格式为[H, W, C]通道在第三维通道的排列顺序为RGB
processed_images = [] processed_images = []
for path in image_paths: for path in image_paths:
# 读取图片 # 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED) image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None: if image is None:
print(f"Image at {path} could not be read.") print(f"Image at {path} could not be read.")
continue continue
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
# 如果是 BGR (3通道), 转换为 RGB # 如果是 BGR (3通道), 转换为 RGB
elif channels == 3: elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image)) processed_images.append(image)
return processed_images return processed_images

View File

@@ -0,0 +1,39 @@
import torch
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: List[str],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -7,7 +7,7 @@ from torchvision.transforms import v2
from typing import List, Union from typing import List, Union
from PIL import Image from PIL import Image
from ....globals import ( from ...globals import (
OCR_IMG_CHANNELS, OCR_IMG_CHANNELS,
OCR_IMG_SIZE, OCR_IMG_SIZE,
OCR_FIX_SIZE, OCR_FIX_SIZE,

View File

@@ -10,7 +10,7 @@ from transformers import (
from ..utils import preprocess_fn from ..utils import preprocess_fn
from ..model.Resizer import Resizer from ..model.Resizer import Resizer
from ....globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE from ...globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
def train(): def train():

View File

@@ -2,7 +2,7 @@ import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from PIL import Image, ImageChops from PIL import Image, ImageChops
from ....globals import ( from ...globals import (
IMAGE_MEAN, IMAGE_STD, IMAGE_MEAN, IMAGE_STD,
LABEL_RATIO, LABEL_RATIO,
RESIZER_IMG_SIZE, RESIZER_IMG_SIZE,

View File

@@ -1,6 +1,6 @@
from datasets import load_dataset from datasets import load_dataset
from ...ocr_model.model.TexTeller import TexTeller from ...ocr_model.model.TexTeller import TexTeller
from ....globals import VOCAB_SIZE from ...globals import VOCAB_SIZE
if __name__ == '__main__': if __name__ == '__main__':

91
src/server.py Normal file
View File

@@ -0,0 +1,91 @@
import argparse
import time
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
parser = argparse.ArgumentParser()
parser.add_argument('-ckpt', '--checkpoint_dir', type=str, required=True)
parser.add_argument('-tknz', '--tokenizer_dir', type=str, required=True)
parser.add_argument('-port', '--server_port', type=int, default=8000)
parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--use_cuda', action='store_true', default=False)
parser.add_argument('--num_beam', type=int, default=1)
# args = parser.parse_args()
args = parser.parse_args([
'--checkpoint_dir', '/home/lhy/code/TeXify/src/models/ocr_model/model_checkpoint',
'--tokenizer_dir', '/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550K',
'--server_port', '9900',
'--num_replicas', '1',
'--ncpu_per_replica', '1.0',
'--ngpu_per_replica', '0.0',
# '--use_cuda',
'--num_beam', '1'
])
if args.ngpu_per_replica > 0 and not args.use_cuda:
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica
}
)
class TexTellerServer:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
use_cuda: bool = False,
num_beam: int = 1
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.use_cuda = use_cuda
self.num_beam = num_beam
self.model = self.model.to('cuda') if use_cuda else self.model
def predict(self, image_path: str) -> str:
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0]
@serve.deployment()
class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server
async def __call__(self, request: Request) -> str:
msg = await request.json()
img_path: str = msg['img_path']
pred = await self.texteller_server.predict.remote(img_path)
return pred
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port}) # 启动一个Ray集群端口号为9900
texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
ingress = Ingress.bind(texteller_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
while True:
time.sleep(1)

9
src/start_web.sh Executable file
View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -exu
export CHECKPOINT_DIR=/home/lhy/code/TeXify/src/models/ocr_model/model_checkpoint
export TOKENIZER_DIR=/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550K
export USE_CUDA=False
export NUM_BEAM=3
streamlit run web.py

View File

@@ -1,14 +1,25 @@
import streamlit as st import os
import io import io
import base64 import base64
import requests import tempfile
import streamlit as st
from PIL import Image from PIL import Image
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
def post_image(server_url, img_rb):
response = requests.post(server_url, files={'image': img_rb})
return response.text
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
model = get_model()
tokenizer = get_tokenizer()
# ============================ pages =============================== # # ============================ pages =============================== #
# 使用 Markdown 和 HTML 将标题居中 # 使用 Markdown 和 HTML 将标题居中
@@ -41,6 +52,12 @@ uploaded_file = st.file_uploader("",type=['jpg', 'png'])
if uploaded_file: if uploaded_file:
# 打开上传图片 # 打开上传图片
img = Image.open(uploaded_file) img = Image.open(uploaded_file)
# 使用tempfile创建临时目录
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
# st.image(uploaded_file, caption=f"Input image ({img.height}✖️{img.width})") # st.image(uploaded_file, caption=f"Input image ({img.height}✖️{img.width})")
# 将 BytesIO 对象转换为 Base64 编码 # 将 BytesIO 对象转换为 Base64 编码
@@ -78,10 +95,15 @@ if uploaded_file:
# 预测 # 预测
with st.spinner("Predicting..."): with st.spinner("Predicting..."):
# 预测结果 # 预测结果
server_url = 'http://localhost:8000/' server_url = 'http://localhost:9900/'
uploaded_file.seek(0) uploaded_file.seek(0)
TeXTeller_result = post_image(server_url, uploaded_file) TeXTeller_result = inference(
TeXTeller_result = r"\begin{align*}" + '\n' + TeXTeller_result + '\n' + r'\end{align*}' model,
tokenizer,
[png_file_path],
bool(os.environ['USE_CUDA']),
int(os.environ['NUM_BEAM'])
)[0]
# tab1, tab2 = st.tabs(["✨TeXTeller✨", "pix2tex:gray[(9.6K⭐)]"]) # tab1, tab2 = st.tabs(["✨TeXTeller✨", "pix2tex:gray[(9.6K⭐)]"])
tab1, tab2 = st.tabs(["🔥👁️", "pix2tex:gray[(9.6K⭐)]"]) tab1, tab2 = st.tabs(["🔥👁️", "pix2tex:gray[(9.6K⭐)]"])
# with st.container(border=True): # with st.container(border=True):
@@ -91,4 +113,4 @@ if uploaded_file:
st.code(TeXTeller_result, language='latex') st.code(TeXTeller_result, language='latex')
st.success('Done!') st.success('Done!')
# ============================ pages =============================== # # ============================ pages =============================== #