删除了resiezer中inference.py里面无用的代码
This commit is contained in:
@@ -11,36 +11,6 @@ from .utils import preprocess_fn
|
|||||||
from munch import Munch
|
from munch import Munch
|
||||||
|
|
||||||
|
|
||||||
def load_resizer():
|
|
||||||
model = Resizer.from_pretrained('/home/lhy/code/TeXify/src/models/resizer/train/res_wo_sigmoid_train_result_v2/checkpoint-96000')
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_teller():
|
|
||||||
arguments = Munch(
|
|
||||||
{
|
|
||||||
'config': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex/config.yaml',
|
|
||||||
'checkpoint': '/home/lhy/code/LaTeX-OCR/pix2tex/model/checkpoints/pix2tex_v1/pix2tex_v1_e30_step4265.pth',
|
|
||||||
'no_cuda': False,
|
|
||||||
'no_resize': False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def inference_v2(img: Image):
|
|
||||||
# img = img.convert('RGB') if img.format == 'PNG' else img
|
|
||||||
# processed_img = preprocess_fn({"pixel_values": [img]})
|
|
||||||
|
|
||||||
# resizer = load_resizer(resizer_path)
|
|
||||||
# inpu = torch.stack(processed_img['pixel_values'])
|
|
||||||
# pred_size = resizer(inpu)
|
|
||||||
|
|
||||||
# teller = load_teller(teller_path)
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def inference(args):
|
def inference(args):
|
||||||
img = Image.open(args.image)
|
img = Image.open(args.image)
|
||||||
img = img.convert('RGB') if img.format == 'PNG' else img
|
img = img.convert('RGB') if img.format == 'PNG' else img
|
||||||
@@ -50,7 +20,7 @@ def inference(args):
|
|||||||
model = Resizer.from_pretrained(ckt_path)
|
model = Resizer.from_pretrained(ckt_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
inpu = torch.stack(processed_img['pixel_values'])
|
inpu = torch.stack(processed_img['pixel_values'])
|
||||||
pred = model(inpu)
|
pred = model(inpu) * 1.25
|
||||||
print(pred)
|
print(pred)
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|||||||
Reference in New Issue
Block a user