import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageOps import cv2 from transformers import ( Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, VisionEncoderDecoderModel, AutoModelForVision2Seq, AutoProcessor, TextIteratorStreamer, ) from transformers.image_utils import load_image from docling_core.types.doc import DoclingDocument, DocTagsDocument import re import ast import html # Constants for text generation MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load Nanonets-OCR-s MODEL_ID_M = "nanonets/Nanonets-OCR-s" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load ByteDance's Dolphin MODEL_ID_K = "ByteDance/Dolphin" processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True) model_k = VisionEncoderDecoderModel.from_pretrained( MODEL_ID_K, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load SmolDocling-256M-preview MODEL_ID_X = "ds4sd/SmolDocling-256M-preview" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = AutoModelForVision2Seq.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load MonkeyOCR MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() # Preprocessing functions for SmolDocling-256M def add_random_padding(image, min_percent=0.1, max_percent=0.10): """Add random padding to an image based on its size.""" image = image.convert("RGB") width, height = image.size pad_w_percent = random.uniform(min_percent, max_percent) pad_h_percent = random.uniform(min_percent, max_percent) pad_w = int(width * pad_w_percent) pad_h = int(height * pad_h_percent) corner_pixel = image.getpixel((0, 0)) # Top-left corner padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) return padded_image def normalize_values(text, target_max=500): """Normalize numerical values in text to a target maximum.""" def normalize_list(values): max_value = max(values) if values else 1 return [round((v / max_value) * target_max) for v in values] def process_match(match): num_list = ast.literal_eval(match.group(0)) normalized = normalize_list(num_list) return "".join([f"" for num in normalized]) pattern = r"\[([\d\.\s,]+)\]" normalized_text = re.sub(pattern, process_match, text) return normalized_text def downsample_video(video_path): """Downsample a video to evenly spaced frames, returning PIL images with timestamps.""" vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames # Dolphin-specific functions def model_chat(prompt, image): """Use Dolphin model for inference.""" processor = processor_k model = model_k device = "cuda" if torch.cuda.is_available() else "cpu" inputs = processor(image, return_tensors="pt").to(device) pixel_values = inputs.pixel_values.half() prompt_inputs = processor.tokenizer( f"{prompt} ", add_special_tokens=False, return_tensors="pt" ).to(device) outputs = model.generate( pixel_values=pixel_values, decoder_input_ids=prompt_inputs.input_ids, decoder_attention_mask=prompt_inputs.attention_mask, min_length=1, max_length=4096, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1 ) sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0] cleaned = sequence.replace(f"{prompt} ", "").replace("", "").replace("", "").strip() return cleaned def process_elements(layout_results, image): """Parse layout results and extract elements from the image.""" # Placeholder parsing logic based on expected Dolphin output # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]" try: elements = ast.literal_eval(layout_results) except: elements = [] # Fallback if parsing fails recognition_results = [] reading_order = 0 for bbox, label in elements: try: x1, y1, x2, y2 = map(int, bbox) cropped = image.crop((x1, y1, x2, y2)) if cropped.size[0] > 0 and cropped.size[1] > 0: if label == "text": text = model_chat("Read text in the image.", cropped) recognition_results.append({ "label": label, "bbox": [x1, y1, x2, y2], "text": text.strip(), "reading_order": reading_order }) elif label == "table": table_text = model_chat("Parse the table in the image.", cropped) recognition_results.append({ "label": label, "bbox": [x1, y1, x2, y2], "text": table_text.strip(), "reading_order": reading_order }) elif label == "figure": recognition_results.append({ "label": label, "bbox": [x1, y1, x2, y2], "text": "[Figure]", # Placeholder for figure content "reading_order": reading_order }) reading_order += 1 except Exception as e: print(f"Error processing element: {e}") continue return recognition_results def generate_markdown(recognition_results): """Generate markdown from extracted elements.""" markdown = "" for element in sorted(recognition_results, key=lambda x: x["reading_order"]): if element["label"] == "text": markdown += f"{element['text']}\n\n" elif element["label"] == "table": markdown += f"**Table:**\n{element['text']}\n\n" elif element["label"] == "figure": markdown += f"{element['text']}\n\n" return markdown.strip() def process_image_with_dolphin(image): """Process a single image with Dolphin model.""" layout_output = model_chat("Parse the reading order of this document.", image) elements = process_elements(layout_output, image) markdown_content = generate_markdown(elements) return markdown_content @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for image input using the selected model.""" if model_name == "ByteDance-s-Dolphin": if image is None: yield "Please upload an image." return markdown_content = process_image_with_dolphin(image) yield markdown_content else: # Existing logic for other models if model_name == "Nanonets-OCR-s": processor = processor_m model = model_m elif model_name == "MonkeyOCR-Recognition": processor = processor_g model = model_g elif model_name == "SmolDocling-256M-preview": processor = processor_x model = model_x else: yield "Invalid model selected." return if image is None: yield "Please upload an image." return images = [image] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += new_text.replace("<|im_end|>", "") yield buffer if model_name == "SmolDocling-256M-preview": cleaned_output = full_output.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield f"**MD Output:**\n\n{markdown_output}" else: yield cleaned_output @spaces.GPU def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for video input using the selected model.""" if model_name == "ByteDance-s-Dolphin": if video_path is None: yield "Please upload a video." return frames = downsample_video(video_path) markdown_contents = [] for frame, _ in frames: markdown_content = process_image_with_dolphin(frame) markdown_contents.append(markdown_content) combined_markdown = "\n\n".join(markdown_contents) yield combined_markdown else: # Existing logic for other models if model_name == "Nanonets-OCR-s": processor = processor_m model = model_m elif model_name == "MonkeyOCR-Recognition": processor = processor_g model = model_g elif model_name == "SmolDocling-256M-preview": processor = processor_x model = model_x else: yield "Invalid model selected." return if video_path is None: yield "Please upload a video." return frames = downsample_video(video_path) images = [frame for frame, _ in frames] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += new_text.replace("<|im_end|>", "") yield buffer if model_name == "SmolDocling-256M-preview": cleaned_output = full_output.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield f"**MD Output:**\n\n{markdown_output}" else: yield cleaned_output # Define examples for image and video inference image_examples = [ ["Convert this page to docling", "images/1.png"], ["OCR the image", "images/2.jpg"], ["Convert this page to docling", "images/3.png"], ] video_examples = [ ["Explain the ad in detail", "example/1.mp4"], ["Identify the main actions in the coca cola ad...", "example/2.mp4"] ] css = """ .submit-btn { background-color: #2980b9 !important; color: white !important; } .submit-btn:hover { background-color: #3498db !important; } """ # Create the Gradio Interface with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**") with gr.Row(): with gr.Column(): with gr.Tabs(): with gr.TabItem("Image Inference"): image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") image_upload = gr.Image(type="pil", label="Image") image_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=image_examples, inputs=[image_query, image_upload] ) with gr.TabItem("Video Inference"): video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") video_upload = gr.Video(label="Video") video_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=video_examples, inputs=[video_query, video_upload] ) with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) with gr.Column(): output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2) model_choice = gr.Radio( choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"], label="Select Model", value="Nanonets-OCR-s" ) image_submit.click( fn=generate_image, inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=output ) video_submit.click( fn=generate_video, inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=output ) if __name__ == "__main__": demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)