删除了resiezer中inference.py里面无用的代码

This commit is contained in:
三洋三洋
2024-01-23 06:07:09 +00:00
parent 703ac7441c
commit 0f619b1812

View File

@@ -11,36 +11,6 @@ from .utils import preprocess_fn
from munch import Munch
def load_resizer():
model = Resizer.from_pretrained('/home/lhy/code/TeXify/src/models/resizer/train/res_wo_sigmoid_train_result_v2/checkpoint-96000')
model.eval()
return model
def load_teller():
arguments = Munch(
{
'config': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex/config.yaml',
'checkpoint': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex_v1/pix2tex_v1_e30_step4265.pth',
'no_cuda': False,
'no_resize': False
}
)
...
def inference_v2(img: Image):
# img = img.convert('RGB') if img.format == 'PNG' else img
# processed_img = preprocess_fn({"pixel_values": [img]})
# resizer = load_resizer(resizer_path)
# inpu = torch.stack(processed_img['pixel_values'])
# pred_size = resizer(inpu)
# teller = load_teller(teller_path)
...
def inference(args):
img = Image.open(args.image)
img = img.convert('RGB') if img.format == 'PNG' else img
@@ -50,7 +20,7 @@ def inference(args):
model = Resizer.from_pretrained(ckt_path)
model.eval()
inpu = torch.stack(processed_img['pixel_values'])
pred = model(inpu)
pred = model(inpu) * 1.25
print(pred)
...