Files
TexTeller/src/models/resizer/inference.py
2024-01-23 06:07:09 +00:00

44 lines
1.1 KiB
Python

#!/usr/bin/env python3
import os
import argparse
import torch
from pathlib import Path
from PIL import Image
from .model.Resizer import Resizer
from .utils import preprocess_fn
from munch import Munch
def inference(args):
img = Image.open(args.image)
img = img.convert('RGB') if img.format == 'PNG' else img
processed_img = preprocess_fn({"pixel_values": [img]})
ckt_path = Path(args.checkpoint).resolve()
model = Resizer.from_pretrained(ckt_path)
model.eval()
inpu = torch.stack(processed_img['pixel_values'])
pred = model(inpu) * 1.25
print(pred)
...
if __name__ == "__main__":
cur_dirpath = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
parser = argparse.ArgumentParser()
parser.add_argument('-img', '--image', type=str, required=True)
parser.add_argument('-ckt', '--checkpoint', type=str, required=True)
args = parser.parse_args([
'-img', '/home/lhy/code/TeXify/src/models/resizer/foo5_140h.jpg',
'-ckt', '/home/lhy/code/TeXify/src/models/resizer/train/train_result_pred_height_v5'
])
inference(args)
os.chdir(cur_dirpath)