2024-01-15 05:48:36 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
import os
|
|
|
|
|
import argparse
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from .model.Resizer import Resizer
|
|
|
|
|
from .utils import preprocess_fn
|
|
|
|
|
|
|
|
|
|
from munch import Munch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(args):
|
|
|
|
|
img = Image.open(args.image)
|
|
|
|
|
img = img.convert('RGB') if img.format == 'PNG' else img
|
|
|
|
|
processed_img = preprocess_fn({"pixel_values": [img]})
|
|
|
|
|
|
|
|
|
|
ckt_path = Path(args.checkpoint).resolve()
|
|
|
|
|
model = Resizer.from_pretrained(ckt_path)
|
|
|
|
|
model.eval()
|
|
|
|
|
inpu = torch.stack(processed_img['pixel_values'])
|
2024-01-23 06:07:09 +00:00
|
|
|
pred = model(inpu) * 1.25
|
2024-01-15 05:48:36 +00:00
|
|
|
print(pred)
|
|
|
|
|
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
cur_dirpath = os.getcwd()
|
|
|
|
|
script_dirpath = Path(__file__).resolve().parent
|
|
|
|
|
os.chdir(script_dirpath)
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('-img', '--image', type=str, required=True)
|
|
|
|
|
parser.add_argument('-ckt', '--checkpoint', type=str, required=True)
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args([
|
|
|
|
|
'-img', '/home/lhy/code/TeXify/src/models/resizer/foo5_140h.jpg',
|
|
|
|
|
'-ckt', '/home/lhy/code/TeXify/src/models/resizer/train/train_result_pred_height_v5'
|
|
|
|
|
])
|
|
|
|
|
inference(args)
|
|
|
|
|
|
|
|
|
|
os.chdir(cur_dirpath)
|