2024-02-11 08:06:50 +00:00
|
|
|
import torch
|
2024-02-27 07:13:36 +00:00
|
|
|
import numpy as np
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2025-04-02 03:23:27 +00:00
|
|
|
from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria
|
2024-02-27 07:13:36 +00:00
|
|
|
from typing import List, Union
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-04-21 00:05:14 +08:00
|
|
|
from .transforms import inference_transform
|
|
|
|
|
from .helpers import convert2rgb
|
|
|
|
|
from ..model.TexTeller import TexTeller
|
|
|
|
|
from ...globals import MAX_TOKEN_SIZE
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
|
2025-04-02 03:23:27 +00:00
|
|
|
class EfficientDetectRepeatingNgramCriteria(StoppingCriteria):
|
|
|
|
|
"""
|
|
|
|
|
Stops generation efficiently if any n-gram repeats.
|
|
|
|
|
|
|
|
|
|
This criteria maintains a set of encountered n-grams.
|
|
|
|
|
At each step, it checks if the *latest* n-gram is already in the set.
|
|
|
|
|
If yes, it stops generation. If no, it adds the n-gram to the set.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, n: int):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
n (int): The size of the n-gram to check for repetition.
|
|
|
|
|
"""
|
|
|
|
|
if n <= 0:
|
|
|
|
|
raise ValueError("n-gram size 'n' must be positive.")
|
|
|
|
|
self.n = n
|
|
|
|
|
# Stores tuples of token IDs representing seen n-grams
|
|
|
|
|
self.seen_ngrams = set()
|
|
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
|
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
|
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
|
|
|
|
Prediction scores.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
`bool`: `True` if generation should stop, `False` otherwise.
|
|
|
|
|
"""
|
|
|
|
|
batch_size, seq_length = input_ids.shape
|
|
|
|
|
|
|
|
|
|
# Need at least n tokens to form the first n-gram
|
|
|
|
|
if seq_length < self.n:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# --- Efficient Check ---
|
|
|
|
|
# Consider only the first sequence in the batch for simplicity
|
|
|
|
|
if batch_size > 1:
|
|
|
|
|
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
|
|
|
|
|
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
|
|
|
|
|
# For now, we'll focus on the first sequence.
|
|
|
|
|
pass # No warning needed every step, maybe once in __init__ if needed.
|
|
|
|
|
|
|
|
|
|
sequence = input_ids[0] # Get the first sequence
|
|
|
|
|
|
|
|
|
|
# Get the latest n-gram (the one ending at the last token)
|
|
|
|
|
last_ngram_tensor = sequence[-self.n :]
|
|
|
|
|
# Convert to a hashable tuple for set storage and lookup
|
|
|
|
|
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
|
|
|
|
|
|
|
|
|
|
# Check if this n-gram has been seen before *at any prior step*
|
|
|
|
|
if last_ngram_tuple in self.seen_ngrams:
|
|
|
|
|
return True # Stop generation
|
|
|
|
|
else:
|
|
|
|
|
# It's a new n-gram, add it to the set and continue
|
|
|
|
|
self.seen_ngrams.add(last_ngram_tuple)
|
|
|
|
|
return False # Continue generation
|
|
|
|
|
|
|
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
def inference(
|
2025-02-28 19:56:49 +08:00
|
|
|
model: TexTeller,
|
2024-02-11 08:06:50 +00:00
|
|
|
tokenizer: RobertaTokenizerFast,
|
2025-02-28 19:56:49 +08:00
|
|
|
imgs: Union[List[str], List[np.ndarray]],
|
2024-04-21 00:05:14 +08:00
|
|
|
accelerator: str = 'cpu',
|
2024-02-11 08:06:50 +00:00
|
|
|
num_beams: int = 1,
|
2025-02-28 19:56:49 +08:00
|
|
|
max_tokens=None,
|
2024-02-11 08:06:50 +00:00
|
|
|
) -> List[str]:
|
2024-04-21 00:05:14 +08:00
|
|
|
if imgs == []:
|
|
|
|
|
return []
|
2024-06-22 21:51:51 +08:00
|
|
|
if hasattr(model, 'eval'):
|
|
|
|
|
# not onnx session, turn model.eval()
|
|
|
|
|
model.eval()
|
2024-04-05 07:25:06 +00:00
|
|
|
if isinstance(imgs[0], str):
|
2025-02-28 19:56:49 +08:00
|
|
|
imgs = convert2rgb(imgs)
|
2024-02-27 07:13:36 +00:00
|
|
|
else: # already numpy array(rgb format)
|
2024-04-05 07:25:06 +00:00
|
|
|
assert isinstance(imgs[0], np.ndarray)
|
2025-02-28 19:56:49 +08:00
|
|
|
imgs = imgs
|
2024-02-11 08:06:50 +00:00
|
|
|
imgs = inference_transform(imgs)
|
|
|
|
|
pixel_values = torch.stack(imgs)
|
|
|
|
|
|
2024-06-22 21:51:51 +08:00
|
|
|
if hasattr(model, 'eval'):
|
|
|
|
|
# not onnx session, move weights to device
|
|
|
|
|
model = model.to(accelerator)
|
2024-04-21 00:05:14 +08:00
|
|
|
pixel_values = pixel_values.to(accelerator)
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
generate_config = GenerationConfig(
|
2024-04-21 00:05:14 +08:00
|
|
|
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
|
2024-02-11 08:06:50 +00:00
|
|
|
num_beams=num_beams,
|
|
|
|
|
do_sample=False,
|
|
|
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
bos_token_id=tokenizer.bos_token_id,
|
2025-04-02 03:23:27 +00:00
|
|
|
# no_repeat_ngram_size=10,
|
|
|
|
|
)
|
|
|
|
|
pred = model.generate(
|
|
|
|
|
pixel_values.to(model.device),
|
|
|
|
|
generation_config=generate_config,
|
|
|
|
|
# stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)],
|
2024-02-11 08:06:50 +00:00
|
|
|
)
|
2025-04-02 03:23:27 +00:00
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
|
|
|
|
return res
|