Initial commit

This commit is contained in:
三洋三洋
2024-01-15 05:48:36 +00:00
commit 126026cb48
12 changed files with 428 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python3
import os
import argparse
import torch
from pathlib import Path
from PIL import Image
from .model.Resizer import Resizer
from .utils import preprocess_fn
from munch import Munch
def load_resizer():
model = Resizer.from_pretrained('/home/lhy/code/TeXify/src/models/resizer/train/res_wo_sigmoid_train_result_v2/checkpoint-96000')
model.eval()
return model
def load_teller():
arguments = Munch(
{
'config': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex/config.yaml',
'checkpoint': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex_v1/pix2tex_v1_e30_step4265.pth',
'no_cuda': False,
'no_resize': False
}
)
...
def inference_v2(img: Image):
# img = img.convert('RGB') if img.format == 'PNG' else img
# processed_img = preprocess_fn({"pixel_values": [img]})
# resizer = load_resizer(resizer_path)
# inpu = torch.stack(processed_img['pixel_values'])
# pred_size = resizer(inpu)
# teller = load_teller(teller_path)
...
def inference(args):
img = Image.open(args.image)
img = img.convert('RGB') if img.format == 'PNG' else img
processed_img = preprocess_fn({"pixel_values": [img]})
ckt_path = Path(args.checkpoint).resolve()
model = Resizer.from_pretrained(ckt_path)
model.eval()
inpu = torch.stack(processed_img['pixel_values'])
pred = model(inpu)
print(pred)
...
if __name__ == "__main__":
cur_dirpath = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
parser = argparse.ArgumentParser()
parser.add_argument('-img', '--image', type=str, required=True)
parser.add_argument('-ckt', '--checkpoint', type=str, required=True)
args = parser.parse_args([
'-img', '/home/lhy/code/TeXify/src/models/resizer/foo5_140h.jpg',
'-ckt', '/home/lhy/code/TeXify/src/models/resizer/train/train_result_pred_height_v5'
])
inference(args)
os.chdir(cur_dirpath)