From 0f619b1812bdad2694f7e4e8c44547c051353ddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Tue, 23 Jan 2024 06:07:09 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=BA=86resiezer=E4=B8=ADinf?= =?UTF-8?q?erence.py=E9=87=8C=E9=9D=A2=E6=97=A0=E7=94=A8=E7=9A=84=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/resizer/inference.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/src/models/resizer/inference.py b/src/models/resizer/inference.py index 92b1ba3..a9c6eeb 100644 --- a/src/models/resizer/inference.py +++ b/src/models/resizer/inference.py @@ -11,36 +11,6 @@ from .utils import preprocess_fn from munch import Munch -def load_resizer(): - model = Resizer.from_pretrained('/home/lhy/code/TeXify/src/models/resizer/train/res_wo_sigmoid_train_result_v2/checkpoint-96000') - model.eval() - return model - - -def load_teller(): - arguments = Munch( - { - 'config': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex/config.yaml', - 'checkpoint': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex_v1/pix2tex_v1_e30_step4265.pth', - 'no_cuda': False, - 'no_resize': False - } -) - ... - - -def inference_v2(img: Image): - # img = img.convert('RGB') if img.format == 'PNG' else img - # processed_img = preprocess_fn({"pixel_values": [img]}) - - # resizer = load_resizer(resizer_path) - # inpu = torch.stack(processed_img['pixel_values']) - # pred_size = resizer(inpu) - - # teller = load_teller(teller_path) - ... - - def inference(args): img = Image.open(args.image) img = img.convert('RGB') if img.format == 'PNG' else img @@ -50,7 +20,7 @@ def inference(args): model = Resizer.from_pretrained(ckt_path) model.eval() inpu = torch.stack(processed_img['pixel_values']) - pred = model(inpu) + pred = model(inpu) * 1.25 print(pred) ...