diff --git a/src/models/resizer/inference.py b/src/models/resizer/inference.py index 92b1ba3..a9c6eeb 100644 --- a/src/models/resizer/inference.py +++ b/src/models/resizer/inference.py @@ -11,36 +11,6 @@ from .utils import preprocess_fn from munch import Munch -def load_resizer(): - model = Resizer.from_pretrained('/home/lhy/code/TeXify/src/models/resizer/train/res_wo_sigmoid_train_result_v2/checkpoint-96000') - model.eval() - return model - - -def load_teller(): - arguments = Munch( - { - 'config': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex/config.yaml', - 'checkpoint': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex_v1/pix2tex_v1_e30_step4265.pth', - 'no_cuda': False, - 'no_resize': False - } -) - ... - - -def inference_v2(img: Image): - # img = img.convert('RGB') if img.format == 'PNG' else img - # processed_img = preprocess_fn({"pixel_values": [img]}) - - # resizer = load_resizer(resizer_path) - # inpu = torch.stack(processed_img['pixel_values']) - # pred_size = resizer(inpu) - - # teller = load_teller(teller_path) - ... - - def inference(args): img = Image.open(args.image) img = img.convert('RGB') if img.format == 'PNG' else img @@ -50,7 +20,7 @@ def inference(args): model = Resizer.from_pretrained(ckt_path) model.eval() inpu = torch.stack(processed_img['pixel_values']) - pred = model(inpu) + pred = model(inpu) * 1.25 print(pred) ...