Initial Commit
This commit is contained in:
@@ -32,6 +32,10 @@ OCR_IMG_CHANNELS = 1 # 灰度图
|
|||||||
# ocr模型训练数据集的最长token数
|
# ocr模型训练数据集的最长token数
|
||||||
MAX_TOKEN_SIZE = 600
|
MAX_TOKEN_SIZE = 600
|
||||||
|
|
||||||
|
# ocr模型训练时随机缩放的比例
|
||||||
|
MAX_RESIZE_RATIO = 1.15
|
||||||
|
MIN_RESIZE_RATIO = 0.75
|
||||||
|
|
||||||
# ============================================================================= #
|
# ============================================================================= #
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,43 +8,10 @@ from typing import List
|
|||||||
|
|
||||||
from .model.TexTeller import TexTeller
|
from .model.TexTeller import TexTeller
|
||||||
from .utils.transforms import inference_transform
|
from .utils.transforms import inference_transform
|
||||||
|
from .utils.helpers import convert2rgb
|
||||||
from ...globals import MAX_TOKEN_SIZE
|
from ...globals import MAX_TOKEN_SIZE
|
||||||
|
|
||||||
|
|
||||||
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
|
||||||
processed_images = []
|
|
||||||
|
|
||||||
for path in image_paths:
|
|
||||||
# 读取图片
|
|
||||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
if image is None:
|
|
||||||
print(f"Image at {path} could not be read.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查图片是否使用 uint16 类型
|
|
||||||
if image.dtype == np.uint16:
|
|
||||||
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
|
|
||||||
|
|
||||||
# 获取图片通道数
|
|
||||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
|
||||||
|
|
||||||
# 如果是 RGBA (4通道), 转换为 RGB
|
|
||||||
if channels == 4:
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
|
||||||
|
|
||||||
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
|
|
||||||
elif channels == 1:
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
||||||
|
|
||||||
# 如果是 BGR (3通道), 转换为 RGB
|
|
||||||
elif channels == 3:
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
processed_images.append(Image.fromarray(image))
|
|
||||||
|
|
||||||
return processed_images
|
|
||||||
|
|
||||||
|
|
||||||
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
|
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||||
imgs = convert2rgb(imgs_path)
|
imgs = convert2rgb(imgs_path)
|
||||||
imgs = inference_transform(imgs)
|
imgs = inference_transform(imgs)
|
||||||
|
|||||||
Binary file not shown.
@@ -1,13 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
|
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
|
||||||
|
|
||||||
from .training_args import CONFIG
|
from .training_args import CONFIG
|
||||||
from ..model.TexTeller import TexTeller
|
from ..model.TexTeller import TexTeller
|
||||||
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess
|
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||||
from ..utils.metrics import bleu_metric
|
from ..utils.metrics import bleu_metric
|
||||||
from ....globals import MAX_TOKEN_SIZE
|
from ....globals import MAX_TOKEN_SIZE
|
||||||
|
|
||||||
@@ -38,6 +40,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
)
|
)
|
||||||
|
# eval_config['use_cpu'] = True
|
||||||
eval_config['output_dir'] = 'debug_dir'
|
eval_config['output_dir'] = 'debug_dir'
|
||||||
eval_config['predict_with_generate'] = True
|
eval_config['predict_with_generate'] = True
|
||||||
eval_config['predict_with_generate'] = True
|
eval_config['predict_with_generate'] = True
|
||||||
@@ -52,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|||||||
model,
|
model,
|
||||||
seq2seq_config,
|
seq2seq_config,
|
||||||
|
|
||||||
eval_dataset=eval_dataset.select(range(16)),
|
eval_dataset=eval_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||||
@@ -70,23 +73,27 @@ if __name__ == '__main__':
|
|||||||
os.chdir(script_dirpath)
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
|
||||||
|
# dataset = load_dataset(
|
||||||
|
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
# 'cleaned_formulas'
|
||||||
|
# )['train']
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
'cleaned_formulas'
|
'cleaned_formulas'
|
||||||
)['train']
|
)['train'].select(range(1000))
|
||||||
|
|
||||||
pause = dataset[0]['image']
|
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False)
|
||||||
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False)
|
||||||
|
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||||
|
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
# model = TexTeller()
|
# model = TexTeller()
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-80500')
|
||||||
|
|
||||||
enable_train = False
|
enable_train = False
|
||||||
enable_evaluate = True
|
enable_evaluate = True
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
import numpy as np
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling
|
from transformers import DataCollatorForLanguageModeling
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from ...ocr_model.model.TexTeller import TexTeller
|
from ..model.TexTeller import TexTeller
|
||||||
from .transforms import train_transform
|
from .transforms import train_transform
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +21,7 @@ def left_move(x: torch.Tensor, pad_val):
|
|||||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||||
assert tokenizer is not None, 'tokenizer should not be None'
|
assert tokenizer is not None, 'tokenizer should not be None'
|
||||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||||
tokenized_formula['pixel_values'] = samples['image']
|
tokenized_formula['pixel_values'] = [np.array(sample) for sample in samples['image']]
|
||||||
return tokenized_formula
|
return tokenized_formula
|
||||||
|
|
||||||
|
|
||||||
@@ -36,14 +38,13 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
|||||||
|
|
||||||
# 左移labels和decoder_attention_mask
|
# 左移labels和decoder_attention_mask
|
||||||
batch['labels'] = left_move(batch['labels'], -100)
|
batch['labels'] = left_move(batch['labels'], -100)
|
||||||
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
|
||||||
|
|
||||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
processed_img = train_transform(samples['pixel_values'])
|
processed_img = train_transform(samples['pixel_values'])
|
||||||
samples['pixel_values'] = processed_img
|
samples['pixel_values'] = processed_img
|
||||||
return samples
|
return samples
|
||||||
@@ -63,7 +64,7 @@ if __name__ == '__main__':
|
|||||||
tokenized_formula = tokenized_formula.to_dict()
|
tokenized_formula = tokenized_formula.to_dict()
|
||||||
# tokenized_formula['pixel_values'] = dataset['image']
|
# tokenized_formula['pixel_values'] = dataset['image']
|
||||||
# tokenized_formula = dataset.from_dict(tokenized_formula)
|
# tokenized_formula = dataset.from_dict(tokenized_formula)
|
||||||
tokenized_dataset = tokenized_formula.with_transform(img_preprocess)
|
tokenized_dataset = tokenized_formula.with_transform(img_transform_fn)
|
||||||
|
|
||||||
dataset_dict = tokenized_dataset[:]
|
dataset_dict = tokenized_dataset[:]
|
||||||
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]
|
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]
|
||||||
37
src/models/ocr_model/utils/helpers.py
Normal file
37
src/models/ocr_model/utils/helpers.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||||
|
processed_images = []
|
||||||
|
|
||||||
|
for path in image_paths:
|
||||||
|
# 读取图片
|
||||||
|
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if image is None:
|
||||||
|
print(f"Image at {path} could not be read.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查图片是否使用 uint16 类型
|
||||||
|
if image.dtype == np.uint16:
|
||||||
|
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
|
||||||
|
|
||||||
|
# 获取图片通道数
|
||||||
|
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||||
|
|
||||||
|
# 如果是 RGBA (4通道), 转换为 RGB
|
||||||
|
if channels == 4:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||||
|
|
||||||
|
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
|
||||||
|
elif channels == 1:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||||
|
|
||||||
|
# 如果是 BGR (3通道), 转换为 RGB
|
||||||
|
elif channels == 3:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
processed_images.append(Image.fromarray(image))
|
||||||
|
|
||||||
|
return processed_images
|
||||||
@@ -1,67 +1,155 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from PIL import ImageChops, Image
|
from PIL import ImageChops, Image
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
|
from ....globals import (
|
||||||
|
OCR_IMG_CHANNELS,
|
||||||
|
OCR_IMG_SIZE,
|
||||||
|
OCR_FIX_SIZE,
|
||||||
|
IMAGE_MEAN, IMAGE_STD,
|
||||||
|
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def trim_white_border(image: Image.Image):
|
general_transform_pipeline = v2.Compose([
|
||||||
if image.mode == 'RGB':
|
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||||
bg_color = (255, 255, 255)
|
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
||||||
elif image.mode == 'L':
|
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
||||||
bg_color = 255
|
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
||||||
else:
|
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||||
raise ValueError("Only support RGB or L mode")
|
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
||||||
# 创建一个与图片一样大小的白色背景
|
|
||||||
bg = Image.new(image.mode, image.size, bg_color)
|
v2.Resize( # 固定resize到一个正方形上
|
||||||
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
|
size=OCR_IMG_SIZE - 1, # size必须小于max_size
|
||||||
diff = ImageChops.difference(image, bg)
|
interpolation=v2.InterpolationMode.BICUBIC,
|
||||||
# 这一步增强差异图像中的对比度,使非背景区域更加明显。这对确定边界框有帮助,但参数的选择可能需要根据具体图像进行调整。
|
max_size=OCR_IMG_SIZE,
|
||||||
diff = ImageChops.add(diff, diff, 2.0, -100)
|
antialias=True
|
||||||
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
|
),
|
||||||
bbox = diff.getbbox()
|
|
||||||
return image.crop(bbox) if bbox else image
|
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||||
|
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||||
|
|
||||||
|
# v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
# image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
|
||||||
images = [trim_white_border(image) for image in images]
|
|
||||||
transforms = v2.Compose([
|
|
||||||
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
|
||||||
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
|
||||||
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
|
||||||
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
|
||||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
|
||||||
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
|
||||||
|
|
||||||
v2.Resize( # 固定resize到一个正方形上
|
# 检查images中的第一个元素是否是嵌套的列表结构
|
||||||
size=OCR_IMG_SIZE - 1, # size必须小于max_size
|
if isinstance(image, list):
|
||||||
interpolation=v2.InterpolationMode.BICUBIC,
|
image = np.array(image, dtype=np.uint8)
|
||||||
max_size=OCR_IMG_SIZE,
|
|
||||||
antialias=True
|
|
||||||
),
|
|
||||||
|
|
||||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
# 检查图像是否为RGB格式,同时检查通道维是不是在第三维上
|
||||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||||
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||||
|
|
||||||
# v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用)
|
# 检查图片是否使用 uint8 类型
|
||||||
])
|
if image.dtype != np.uint8:
|
||||||
|
raise ValueError(f"Image should stored in uint8")
|
||||||
|
|
||||||
images = transforms(images) # imgs: List[PIL.Image.Image]
|
# 创建与原图像同样大小的纯白背景图像
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
bg = np.full((h, w, 3), 255, dtype=np.uint8)
|
||||||
|
|
||||||
|
# 计算差异
|
||||||
|
diff = cv2.absdiff(image, bg)
|
||||||
|
|
||||||
|
# 只要差值大于1,就全部转化为255
|
||||||
|
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
|
||||||
|
|
||||||
|
# 把差值转灰度图
|
||||||
|
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
|
||||||
|
# 计算图像中非零像素点的最小外接矩阵
|
||||||
|
x, y, w, h = cv2.boundingRect(gray_diff)
|
||||||
|
|
||||||
|
# 裁剪图像
|
||||||
|
trimmed_image = image[y:y+h, x:x+w]
|
||||||
|
|
||||||
|
return trimmed_image
|
||||||
|
|
||||||
|
|
||||||
|
def padding(images: List[torch.Tensor], required_size: int):
|
||||||
images = [
|
images = [
|
||||||
v2.functional.pad(
|
v2.functional.pad(
|
||||||
img,
|
img,
|
||||||
padding=[0, 0, OCR_IMG_SIZE - img.shape[2], OCR_IMG_SIZE - img.shape[1]]
|
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||||
)
|
)
|
||||||
for img in images
|
for img in images
|
||||||
]
|
]
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def random_resize(
|
||||||
|
images: Union[List[np.ndarray], List[List[List[List]]]],
|
||||||
|
minr: float,
|
||||||
|
maxr: float
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
# np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||||
|
|
||||||
|
# 检查images中的第一个元素是否是嵌套的列表结构
|
||||||
|
if isinstance(images[0], list):
|
||||||
|
# 将嵌套的列表结构转换为np.ndarray
|
||||||
|
images = [np.array(img, dtype=np.uint8) for img in images]
|
||||||
|
|
||||||
|
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||||
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||||
|
|
||||||
|
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||||
|
return [
|
||||||
|
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
|
||||||
|
for img, r in zip(images, ratios)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
|
# 裁剪掉白边
|
||||||
|
images = [trim_white_border(image) for image in images]
|
||||||
|
# general transform pipeline
|
||||||
|
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
||||||
|
# padding to fixed size
|
||||||
|
images = padding(images, OCR_IMG_SIZE)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||||
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
|
||||||
|
# random resize first
|
||||||
|
# images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||||
|
return general_transform(images)
|
||||||
|
|
||||||
|
|
||||||
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
return train_transform(images)
|
|
||||||
|
return general_transform(images)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from pathlib import Path
|
||||||
|
from .helpers import convert2rgb
|
||||||
|
base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model')
|
||||||
|
imgs_path = [
|
||||||
|
base_dir / '1.jpg',
|
||||||
|
base_dir / '2.jpg',
|
||||||
|
base_dir / '3.jpg',
|
||||||
|
base_dir / '4.jpg',
|
||||||
|
base_dir / '5.jpg',
|
||||||
|
base_dir / '6.jpg',
|
||||||
|
base_dir / '7.jpg',
|
||||||
|
]
|
||||||
|
imgs_path = [str(img_path) for img_path in imgs_path]
|
||||||
|
imgs = convert2rgb(imgs_path)
|
||||||
|
# res = train_transform(imgs)
|
||||||
|
# res = [v2.functional.to_pil_image(img) for img in res]
|
||||||
|
res = random_resize(imgs, 0.5, 1.5)
|
||||||
|
pause = 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user