diff --git a/texteller/api/detection/detect.py b/texteller/api/detection/detect.py index 184116d..c0331c1 100644 --- a/texteller/api/detection/detect.py +++ b/texteller/api/detection/detect.py @@ -28,6 +28,27 @@ _config = { def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]: + """ + Detect LaTeX formulas in an image and classify them as isolated or embedded. + + This function uses an ONNX model to detect LaTeX formulas in images. The model + identifies two types of LaTeX formulas: + - 'isolated': Standalone LaTeX formulas (typically displayed equations) + - 'embedding': Inline LaTeX formulas embedded within text + + Args: + img_path: Path to the input image file + predictor: ONNX InferenceSession model for LaTeX detection + + Returns: + List of Bbox objects representing the detected LaTeX formulas with their + positions, classifications, and confidence scores + + Example: + >>> from texteller.api import load_latexdet_model, latex_detect + >>> model = load_latexdet_model() + >>> bboxes = latex_detect("path/to/image.png", model) + """ transforms = Compose(_config["preprocess"]) inputs = transforms(img_path) inputs_name = [var.name for var in predictor.get_inputs()] diff --git a/texteller/api/inference.py b/texteller/api/inference.py index 68ef8be..aea312b 100644 --- a/texteller/api/inference.py +++ b/texteller/api/inference.py @@ -61,14 +61,14 @@ def img2latex( Returns: List of LaTeX or KaTeX strings corresponding to each input image - Example usage: + Example: >>> import torch >>> from texteller import load_model, load_tokenizer, img2latex - + >>> >>> model = load_model(model_path=None, use_onnx=False) >>> tokenizer = load_tokenizer(tokenizer_path=None) >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + >>> >>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex") """ assert isinstance(images, list) @@ -132,7 +132,47 @@ def paragraph2md( num_beams=1, ) -> str: """ - Input a mixed image of formula text and output str (in markdown syntax) + Convert an image containing both text and mathematical formulas to markdown format. + + This function processes a mixed-content image by: + 1. Detecting mathematical formulas using a latex detection model + 2. Masking detected formula areas and detecting text regions using OCR + 3. Recognizing text in the detected regions + 4. Converting formula regions to LaTeX using the latex recognition model + 5. Combining all detected elements into a properly formatted markdown string + + Args: + img_path: Path to the input image containing text and formulas + latexdet_model: ONNX InferenceSession for LaTeX formula detection + textdet_model: OCR text detector model + textrec_model: OCR text recognition model + latexrec_model: TexTeller model for LaTeX formula recognition + tokenizer: Tokenizer for the LaTeX recognition model + device: The torch device to use (defaults to available GPU or CPU) + num_beams: Number of beams for beam search during LaTeX generation + + Returns: + Markdown formatted string containing the recognized text and formulas + + Example: + >>> from texteller import load_latexdet_model, load_textdet_model, load_textrec_model, load_tokenizer, paragraph2md + >>> + >>> # Load all required models + >>> latexdet_model = load_latexdet_model() + >>> textdet_model = load_textdet_model() + >>> textrec_model = load_textrec_model() + >>> latexrec_model = load_model() + >>> tokenizer = load_tokenizer() + >>> + >>> # Convert image to markdown + >>> markdown_text = paragraph2md( + ... img_path="path/to/mixed_content_image.jpg", + ... latexdet_model=latexdet_model, + ... textdet_model=textdet_model, + ... textrec_model=textrec_model, + ... latexrec_model=latexrec_model, + ... tokenizer=tokenizer, + ... ) """ img = cv2.imread(img_path) corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] diff --git a/texteller/api/katex.py b/texteller/api/katex.py index cc32b0e..83eefdf 100644 --- a/texteller/api/katex.py +++ b/texteller/api/katex.py @@ -17,6 +17,20 @@ def _rm_dollar_surr(content): def to_katex(formula: str) -> str: + """ + Convert LaTeX formula to KaTeX-compatible format. + + This function processes a LaTeX formula string and converts it to a format + that is compatible with KaTeX rendering. It removes unsupported commands + and structures, simplifies LaTeX environments, and optimizes the formula + for web display. + + Args: + formula: LaTeX formula string to convert + + Returns: + KaTeX-compatible formula string + """ res = formula # remove mbox surrounding res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'') diff --git a/texteller/api/load.py b/texteller/api/load.py index de454bd..0fe880b 100644 --- a/texteller/api/load.py +++ b/texteller/api/load.py @@ -17,14 +17,65 @@ _logger = get_logger(__name__) def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel: + """ + Load the TexTeller model for LaTeX recognition. + + This function loads the main TexTeller model, which is responsible for + converting images to LaTeX. It can load either the standard PyTorch model + or the optimized ONNX version. + + Args: + model_dir: Directory containing the model files. If None, uses the default model. + use_onnx: Whether to load the ONNX version of the model for faster inference. + Requires the 'optimum' package and ONNX Runtime. + + Returns: + Loaded TexTeller model instance + + Example: + >>> from texteller import load_model + >>> + >>> model = load_model(use_onnx=True) + """ return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx) def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast: + """ + Load the tokenizer for the TexTeller model. + + This function loads the tokenizer used by the TexTeller model for + encoding and decoding LaTeX sequences. + + Args: + tokenizer_dir: Directory containing the tokenizer files. If None, uses the default tokenizer. + + Returns: + RobertaTokenizerFast instance + + Example: + >>> from texteller import load_tokenizer + >>> + >>> tokenizer = load_tokenizer() + """ return TexTeller.get_tokenizer(tokenizer_dir) def load_latexdet_model() -> InferenceSession: + """ + Load the LaTeX detection model. + + This function loads the model responsible for detecting LaTeX formulas in images. + The model is implemented as an ONNX InferenceSession for optimal performance. + + Returns: + ONNX InferenceSession for LaTeX detection + + Example: + >>> from texteller import load_latexdet_model + >>> + >>> detector = load_latexdet_model() + """ fpath = _maybe_download(LATEX_DET_MODEL_URL) return InferenceSession( resolve_path(fpath), @@ -33,6 +84,20 @@ def load_latexdet_model() -> InferenceSession: def load_textrec_model() -> predict_rec.TextRecognizer: + """ + Load the text recognition model. + + This function loads the model responsible for recognizing regular text in images. + It's based on PaddleOCR's text recognition model. + + Returns: + PaddleOCR TextRecognizer instance + + Example: + >>> from texteller import load_textrec_model + >>> + >>> text_recognizer = load_textrec_model() + """ fpath = _maybe_download(TEXT_REC_MODEL_URL) paddleocr_args = parse_args() paddleocr_args.use_onnx = True @@ -43,6 +108,20 @@ def load_textrec_model() -> predict_rec.TextRecognizer: def load_textdet_model() -> predict_det.TextDetector: + """ + Load the text detection model. + + This function loads the model responsible for detecting text regions in images. + It's based on PaddleOCR's text detection model. + + Returns: + PaddleOCR TextDetector instance + + Example: + >>> from texteller import load_textdet_model + >>> + >>> text_detector = load_textdet_model() + """ fpath = _maybe_download(TEXT_DET_MODEL_URL) paddleocr_args = parse_args() paddleocr_args.use_onnx = True @@ -53,6 +132,17 @@ def load_textdet_model() -> predict_det.TextDetector: def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path: + """ + Download a file if it doesn't already exist. + + Args: + url: URL to download from + dirpath: Directory to save the file in. If None, uses the default cache directory. + force: Whether to force download even if the file already exists + + Returns: + Path to the downloaded file + """ if dirpath is None: dirpath = Globals().cache_dir mkdir(dirpath)