Merge branch 'pre_release' into dev

This commit is contained in:
三洋三洋
2024-04-17 10:30:09 +00:00
19 changed files with 1843 additions and 208 deletions

View File

@@ -13,8 +13,8 @@ from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
use_cuda: bool,
imgs_path: Union[List[str], List[np.ndarray]],
inf_mode: str = 'cpu',
num_beams: int = 1,
) -> List[str]:
model.eval()
@@ -26,9 +26,8 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
model = model.to(inf_mode)
pixel_values = pixel_values.to(inf_mode)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,