From 6bd68ad3b7b3ecf807ba6caeeb31b01ddeff15e7 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 2 Apr 2025 03:23:27 +0000 Subject: [PATCH] [feat] Support n-gram stop criteria --- texteller/models/ocr_model/utils/inference.py | 71 ++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/texteller/models/ocr_model/utils/inference.py b/texteller/models/ocr_model/utils/inference.py index d07100b..e00ea12 100644 --- a/texteller/models/ocr_model/utils/inference.py +++ b/texteller/models/ocr_model/utils/inference.py @@ -1,7 +1,7 @@ import torch import numpy as np -from transformers import RobertaTokenizerFast, GenerationConfig +from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria from typing import List, Union from .transforms import inference_transform @@ -10,6 +10,67 @@ from ..model.TexTeller import TexTeller from ...globals import MAX_TOKEN_SIZE +class EfficientDetectRepeatingNgramCriteria(StoppingCriteria): + """ + Stops generation efficiently if any n-gram repeats. + + This criteria maintains a set of encountered n-grams. + At each step, it checks if the *latest* n-gram is already in the set. + If yes, it stops generation. If no, it adds the n-gram to the set. + """ + + def __init__(self, n: int): + """ + Args: + n (int): The size of the n-gram to check for repetition. + """ + if n <= 0: + raise ValueError("n-gram size 'n' must be positive.") + self.n = n + # Stores tuples of token IDs representing seen n-grams + self.seen_ngrams = set() + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): + Prediction scores. + + Return: + `bool`: `True` if generation should stop, `False` otherwise. + """ + batch_size, seq_length = input_ids.shape + + # Need at least n tokens to form the first n-gram + if seq_length < self.n: + return False + + # --- Efficient Check --- + # Consider only the first sequence in the batch for simplicity + if batch_size > 1: + # If handling batch_size > 1, you'd need a list of sets, one per batch item. + # Or decide on a stopping policy (e.g., stop if *any* sequence repeats). + # For now, we'll focus on the first sequence. + pass # No warning needed every step, maybe once in __init__ if needed. + + sequence = input_ids[0] # Get the first sequence + + # Get the latest n-gram (the one ending at the last token) + last_ngram_tensor = sequence[-self.n :] + # Convert to a hashable tuple for set storage and lookup + last_ngram_tuple = tuple(last_ngram_tensor.tolist()) + + # Check if this n-gram has been seen before *at any prior step* + if last_ngram_tuple in self.seen_ngrams: + return True # Stop generation + else: + # It's a new n-gram, add it to the set and continue + self.seen_ngrams.add(last_ngram_tuple) + return False # Continue generation + + def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, @@ -43,7 +104,13 @@ def inference( pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, + # no_repeat_ngram_size=10, ) - pred = model.generate(pixel_values.to(model.device), generation_config=generate_config) + pred = model.generate( + pixel_values.to(model.device), + generation_config=generate_config, + # stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)], + ) + res = tokenizer.batch_decode(pred, skip_special_tokens=True) return res