Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import json | |
import time | |
import asyncio | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image, ImageOps | |
import cv2 | |
from transformers import ( | |
Qwen2VLForConditionalGeneration, | |
Qwen2_5_VLForConditionalGeneration, | |
VisionEncoderDecoderModel, | |
AutoModelForVision2Seq, | |
AutoProcessor, | |
TextIteratorStreamer, | |
) | |
from transformers.image_utils import load_image | |
from transformers.generation import GenerationConfig | |
from docling_core.types.doc import DoclingDocument, DocTagsDocument | |
import re | |
import ast | |
import html | |
# Constants for text generation | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load Nanonets-OCR-s | |
MODEL_ID_M = "nanonets/Nanonets-OCR-s" | |
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_M, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to(device).eval() | |
# Load ByteDance's Dolphin | |
MODEL_ID_K = "ByteDance/Dolphin" | |
processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True) | |
model_k = VisionEncoderDecoderModel.from_pretrained( | |
MODEL_ID_K, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to(device).eval() | |
# Load SmolDocling-256M-preview | |
MODEL_ID_X = "ds4sd/SmolDocling-256M-preview" | |
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
model_x = AutoModelForVision2Seq.from_pretrained( | |
MODEL_ID_X, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to(device).eval() | |
# Load MonkeyOCR | |
MODEL_ID_G = "echo840/MonkeyOCR" | |
SUBFOLDER = "Recognition" | |
processor_g = AutoProcessor.from_pretrained( | |
MODEL_ID_G, | |
trust_remote_code=True, | |
subfolder=SUBFOLDER | |
) | |
model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_G, | |
trust_remote_code=True, | |
subfolder=SUBFOLDER, | |
torch_dtype=torch.float16 | |
).to(device).eval() | |
# Preprocessing functions for SmolDocling-256M | |
def add_random_padding(image, min_percent=0.1, max_percent=0.10): | |
"""Add random padding to an image based on its size.""" | |
image = image.convert("RGB") | |
width, height = image.size | |
pad_w = int(width * random.uniform(min_percent, max_percent)) | |
pad_h = int(height * random.uniform(min_percent, max_percent)) | |
corner_pixel = image.getpixel((0, 0)) | |
padded = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) | |
return padded | |
def normalize_values(text, target_max=500): | |
"""Normalize numerical lists in text to a target maximum.""" | |
def norm_list(vals): | |
m = max(vals) if vals else 1 | |
return [round(v / m * target_max) for v in vals] | |
def repl(m): | |
lst = ast.literal_eval(m.group(0)) | |
return "".join(f"<loc_{n}>" for n in norm_list(lst)) | |
return re.sub(r"\[([\d\.\s,]+)\]", repl, text) | |
def downsample_video(video_path): | |
"""Extract 10 evenly spaced frames (with timestamps) from a video.""" | |
cap = cv2.VideoCapture(video_path) | |
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frames, indices = [], np.linspace(0, total - 1, 10, dtype=int) | |
for idx in indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) | |
ok, img = cap.read() | |
if not ok: | |
continue | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
frames.append((Image.fromarray(img), round(idx / fps, 2))) | |
cap.release() | |
return frames | |
# Dolphin-specific inference | |
def model_chat(prompt, image): | |
proc = processor_k | |
mdl = model_k | |
device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
# encode image | |
inputs = proc(image, return_tensors="pt").to(device_str).pixel_values.half() | |
# encode prompt | |
pi = proc.tokenizer(f"<s>{prompt} <Answer/>", add_special_tokens=False, return_tensors="pt").to(device_str) | |
# build generation config | |
gen_cfg = GenerationConfig.from_model_config(mdl.config) | |
gen_cfg.max_length = 4096 | |
gen_cfg.min_length = 1 | |
gen_cfg.use_cache = True | |
gen_cfg.bad_words_ids = [[proc.tokenizer.unk_token_id]] | |
gen_cfg.num_beams = 1 | |
gen_cfg.do_sample = False | |
gen_cfg.repetition_penalty = 1.1 | |
out = mdl.generate( | |
pixel_values=inputs, | |
decoder_input_ids=pi.input_ids, | |
decoder_attention_mask=pi.attention_mask, | |
generation_config=gen_cfg, | |
return_dict_in_generate=True, | |
) | |
seq = proc.tokenizer.batch_decode(out.sequences, skip_special_tokens=False)[0] | |
return seq.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip() | |
def process_elements(layout_result, image): | |
try: | |
elements = ast.literal_eval(layout_result) | |
except: | |
elements = [] | |
results, order = [], 0 | |
for bbox, label in elements: | |
x1, y1, x2, y2 = map(int, bbox) | |
crop = image.crop((x1, y1, x2, y2)) | |
if crop.width == 0 or crop.height == 0: | |
continue | |
if label == "text": | |
txt = model_chat("Read text in the image.", crop) | |
elif label == "table": | |
txt = model_chat("Parse the table in the image.", crop) | |
else: | |
txt = "[Figure]" | |
results.append({ | |
"label": label, | |
"bbox": [x1, y1, x2, y2], | |
"text": txt.strip(), | |
"reading_order": order | |
}) | |
order += 1 | |
return results | |
def generate_markdown(recog): | |
md = "" | |
for el in sorted(recog, key=lambda x: x["reading_order"]): | |
if el["label"] == "text": | |
md += el["text"] + "\n\n" | |
elif el["label"] == "table": | |
md += f"**Table:**\n{el['text']}\n\n" | |
else: | |
md += el["text"] + "\n\n" | |
return md.strip() | |
def process_image_with_dolphin(image): | |
layout = model_chat("Parse the reading order of this document.", image) | |
elems = process_elements(layout, image) | |
return generate_markdown(elems) | |
def generate_image(model_name: str, text: str, image: Image.Image, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2): | |
if model_name == "ByteDance-s-Dolphin": | |
if image is None: | |
yield "Please upload an image." | |
else: | |
yield process_image_with_dolphin(image) | |
return | |
if model_name == "Nanonets-OCR-s": | |
proc, mdl = processor_m, model_m | |
elif model_name == "SmolDocling-256M-preview": | |
proc, mdl = processor_x, model_x | |
elif model_name == "MonkeyOCR-Recognition": | |
proc, mdl = processor_g, model_g | |
else: | |
yield "Invalid model selected." | |
return | |
if image is None: | |
yield "Please upload an image." | |
return | |
imgs = [image] | |
if model_name == "SmolDocling-256M-preview": | |
if any(tok in text for tok in ["OTSL", "code"]): | |
imgs = [add_random_padding(img) for img in imgs] | |
if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]): | |
text = normalize_values(text, target_max=500) | |
messages = [ | |
{"role":"user", | |
"content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}] | |
} | |
] | |
prompt = proc.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device) | |
gen_cfg = GenerationConfig.from_model_config(mdl.config) | |
gen_cfg.max_new_tokens = max_new_tokens | |
gen_cfg.temperature = temperature | |
gen_cfg.top_p = top_p | |
gen_cfg.top_k = top_k | |
gen_cfg.repetition_penalty = repetition_penalty | |
gen_cfg.use_cache = True | |
streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = { | |
**inputs, | |
"streamer": streamer, | |
"generation_config": gen_cfg, | |
} | |
thread = Thread(target=mdl.generate, kwargs=gen_kwargs) | |
thread.start() | |
buffer = "" | |
full_output = "" | |
for new_text in streamer: | |
full_output += new_text | |
buffer += new_text.replace("<|im_end|>", "") | |
yield buffer | |
if model_name == "SmolDocling-256M-preview": | |
cleaned = full_output.replace("<end_of_utterance>", "").strip() | |
if any(tag in cleaned for tag in ["<doctag>","<otsl>","<code>","<chart>","<formula>"]): | |
if "<chart>" in cleaned: | |
cleaned = cleaned.replace("<chart>","<otsl>").replace("</chart>","</otsl>") | |
cleaned = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned) | |
tags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned], imgs) | |
doc = DoclingDocument.load_from_doctags(tags_doc, document_name="Document") | |
yield f"**MD Output:**\n\n{doc.export_to_markdown()}" | |
else: | |
yield cleaned | |
def generate_video(model_name: str, text: str, video_path: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2): | |
if model_name == "ByteDance-s-Dolphin": | |
if not video_path: | |
yield "Please upload a video." | |
return | |
md_list = [] | |
for frame, _ in downsample_video(video_path): | |
md_list.append(process_image_with_dolphin(frame)) | |
yield "\n\n".join(md_list) | |
return | |
if model_name == "Nanonets-OCR-s": | |
proc, mdl = processor_m, model_m | |
elif model_name == "SmolDocling-256M-preview": | |
proc, mdl = processor_x, model_x | |
elif model_name == "MonkeyOCR-Recognition": | |
proc, mdl = processor_g, model_g | |
else: | |
yield "Invalid model selected." | |
return | |
if not video_path: | |
yield "Please upload a video." | |
return | |
frames = [f for f, _ in downsample_video(video_path)] | |
imgs = frames | |
if model_name == "SmolDocling-256M-preview": | |
if any(tok in text for tok in ["OTSL", "code"]): | |
imgs = [add_random_padding(img) for img in imgs] | |
if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]): | |
pm.text.normalize_values(text, target_max=500) | |
messages = [ | |
{"role":"user", | |
"content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}] | |
} | |
] | |
prompt = proc.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device) | |
gen_cfg = GenerationConfig.from_model_config(mdl.config) | |
gen_cfg.max_new_tokens = max_new_tokens | |
gen_cfg.temperature = temperature | |
gen_cfg.top_p = top_p | |
gen_cfg.top_k = top_k | |
gen_cfg.repetition_penalty = repetition_penalty | |
gen_cfg.use_cache = True | |
streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = { | |
**inputs, | |
"streamer": streamer, | |
"generation_config": gen_cfg, | |
} | |
thread = Thread(target=mdl.generate, kwargs=gen_kwargs) | |
thread.start() | |
buff = "" | |
full = "" | |
for nt in streamer: | |
full += nt | |
buff += nt.replace("<|im_end|>", "") | |
yield buff | |
# Gradio UI | |
image_examples = [ | |
["Convert this page to docling", "images/1.png"], | |
["OCR the image", "images/2.jpg"], | |
["Convert this page to docling", "images/3.png"], | |
] | |
video_examples = [ | |
["Explain the ad in detail", "example/1.mp4"], | |
["Identify the main actions in the coca cola ad...", "example/2.mp4"] | |
] | |
css = """ | |
.submit-btn { | |
background-color: #2980b9 !important; | |
color: white !important; | |
} | |
.submit-btn:hover { | |
background-color: #3498db !important; | |
} | |
""" | |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("Image Inference"): | |
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") | |
image_upload = gr.Image(type="pil", label="Image") | |
image_submit = gr.Button("Submit", elem_classes="submit-btn") | |
gr.Examples( | |
examples=image_examples, | |
inputs=[image_query, image_upload] | |
) | |
with gr.TabItem("Video Inference"): | |
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") | |
video_upload = gr.Video(label="Video") | |
video_submit = gr.Button("Submit", elem_classes="submit-btn") | |
gr.Examples( | |
examples=video_examples, | |
inputs=[video_query, video_upload] | |
) | |
with gr.Accordion("Advanced options", open=False): | |
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) | |
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) | |
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) | |
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) | |
with gr.Column(): | |
output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2) | |
model_choice = gr.Radio( | |
choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"], | |
label="Select Model", | |
value="Nanonets-OCR-s" | |
) | |
image_submit.click( | |
fn=generate_image, | |
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
outputs=output | |
) | |
video_submit.click( | |
fn=generate_video, | |
inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True) |