Doc-VLMs-OCR / app.py
prithivMLmods's picture
Update app.py
c97526c verified
raw
history blame
14.8 kB
import os
import random
import uuid
import json
import time
import asyncio
from threading import Thread
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageOps
import cv2
from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
VisionEncoderDecoderModel,
AutoModelForVision2Seq,
AutoProcessor,
TextIteratorStreamer,
)
from transformers.image_utils import load_image
from transformers.generation import GenerationConfig
from docling_core.types.doc import DoclingDocument, DocTagsDocument
import re
import ast
import html
# Constants for text generation
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Nanonets-OCR-s
MODEL_ID_M = "nanonets/Nanonets-OCR-s"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load ByteDance's Dolphin
MODEL_ID_K = "ByteDance/Dolphin"
processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
model_k = VisionEncoderDecoderModel.from_pretrained(
MODEL_ID_K,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load SmolDocling-256M-preview
MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
model_x = AutoModelForVision2Seq.from_pretrained(
MODEL_ID_X,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load MonkeyOCR
MODEL_ID_G = "echo840/MonkeyOCR"
SUBFOLDER = "Recognition"
processor_g = AutoProcessor.from_pretrained(
MODEL_ID_G,
trust_remote_code=True,
subfolder=SUBFOLDER
)
model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_G,
trust_remote_code=True,
subfolder=SUBFOLDER,
torch_dtype=torch.float16
).to(device).eval()
# Preprocessing functions for SmolDocling-256M
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
"""Add random padding to an image based on its size."""
image = image.convert("RGB")
width, height = image.size
pad_w = int(width * random.uniform(min_percent, max_percent))
pad_h = int(height * random.uniform(min_percent, max_percent))
corner_pixel = image.getpixel((0, 0))
padded = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
return padded
def normalize_values(text, target_max=500):
"""Normalize numerical lists in text to a target maximum."""
def norm_list(vals):
m = max(vals) if vals else 1
return [round(v / m * target_max) for v in vals]
def repl(m):
lst = ast.literal_eval(m.group(0))
return "".join(f"<loc_{n}>" for n in norm_list(lst))
return re.sub(r"\[([\d\.\s,]+)\]", repl, text)
def downsample_video(video_path):
"""Extract 10 evenly spaced frames (with timestamps) from a video."""
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
frames, indices = [], np.linspace(0, total - 1, 10, dtype=int)
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ok, img = cap.read()
if not ok:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
frames.append((Image.fromarray(img), round(idx / fps, 2)))
cap.release()
return frames
# Dolphin-specific inference
def model_chat(prompt, image):
proc = processor_k
mdl = model_k
device_str = "cuda" if torch.cuda.is_available() else "cpu"
# encode image
inputs = proc(image, return_tensors="pt").to(device_str).pixel_values.half()
# encode prompt
pi = proc.tokenizer(f"<s>{prompt} <Answer/>", add_special_tokens=False, return_tensors="pt").to(device_str)
# build generation config
gen_cfg = GenerationConfig.from_model_config(mdl.config)
gen_cfg.max_length = 4096
gen_cfg.min_length = 1
gen_cfg.use_cache = True
gen_cfg.bad_words_ids = [[proc.tokenizer.unk_token_id]]
gen_cfg.num_beams = 1
gen_cfg.do_sample = False
gen_cfg.repetition_penalty = 1.1
out = mdl.generate(
pixel_values=inputs,
decoder_input_ids=pi.input_ids,
decoder_attention_mask=pi.attention_mask,
generation_config=gen_cfg,
return_dict_in_generate=True,
)
seq = proc.tokenizer.batch_decode(out.sequences, skip_special_tokens=False)[0]
return seq.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
def process_elements(layout_result, image):
try:
elements = ast.literal_eval(layout_result)
except:
elements = []
results, order = [], 0
for bbox, label in elements:
x1, y1, x2, y2 = map(int, bbox)
crop = image.crop((x1, y1, x2, y2))
if crop.width == 0 or crop.height == 0:
continue
if label == "text":
txt = model_chat("Read text in the image.", crop)
elif label == "table":
txt = model_chat("Parse the table in the image.", crop)
else:
txt = "[Figure]"
results.append({
"label": label,
"bbox": [x1, y1, x2, y2],
"text": txt.strip(),
"reading_order": order
})
order += 1
return results
def generate_markdown(recog):
md = ""
for el in sorted(recog, key=lambda x: x["reading_order"]):
if el["label"] == "text":
md += el["text"] + "\n\n"
elif el["label"] == "table":
md += f"**Table:**\n{el['text']}\n\n"
else:
md += el["text"] + "\n\n"
return md.strip()
def process_image_with_dolphin(image):
layout = model_chat("Parse the reading order of this document.", image)
elems = process_elements(layout, image)
return generate_markdown(elems)
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
if model_name == "ByteDance-s-Dolphin":
if image is None:
yield "Please upload an image."
else:
yield process_image_with_dolphin(image)
return
if model_name == "Nanonets-OCR-s":
proc, mdl = processor_m, model_m
elif model_name == "SmolDocling-256M-preview":
proc, mdl = processor_x, model_x
elif model_name == "MonkeyOCR-Recognition":
proc, mdl = processor_g, model_g
else:
yield "Invalid model selected."
return
if image is None:
yield "Please upload an image."
return
imgs = [image]
if model_name == "SmolDocling-256M-preview":
if any(tok in text for tok in ["OTSL", "code"]):
imgs = [add_random_padding(img) for img in imgs]
if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]):
text = normalize_values(text, target_max=500)
messages = [
{"role":"user",
"content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}]
}
]
prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device)
gen_cfg = GenerationConfig.from_model_config(mdl.config)
gen_cfg.max_new_tokens = max_new_tokens
gen_cfg.temperature = temperature
gen_cfg.top_p = top_p
gen_cfg.top_k = top_k
gen_cfg.repetition_penalty = repetition_penalty
gen_cfg.use_cache = True
streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
**inputs,
"streamer": streamer,
"generation_config": gen_cfg,
}
thread = Thread(target=mdl.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
full_output = ""
for new_text in streamer:
full_output += new_text
buffer += new_text.replace("<|im_end|>", "")
yield buffer
if model_name == "SmolDocling-256M-preview":
cleaned = full_output.replace("<end_of_utterance>", "").strip()
if any(tag in cleaned for tag in ["<doctag>","<otsl>","<code>","<chart>","<formula>"]):
if "<chart>" in cleaned:
cleaned = cleaned.replace("<chart>","<otsl>").replace("</chart>","</otsl>")
cleaned = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned)
tags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned], imgs)
doc = DoclingDocument.load_from_doctags(tags_doc, document_name="Document")
yield f"**MD Output:**\n\n{doc.export_to_markdown()}"
else:
yield cleaned
@spaces.GPU
def generate_video(model_name: str, text: str, video_path: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
if model_name == "ByteDance-s-Dolphin":
if not video_path:
yield "Please upload a video."
return
md_list = []
for frame, _ in downsample_video(video_path):
md_list.append(process_image_with_dolphin(frame))
yield "\n\n".join(md_list)
return
if model_name == "Nanonets-OCR-s":
proc, mdl = processor_m, model_m
elif model_name == "SmolDocling-256M-preview":
proc, mdl = processor_x, model_x
elif model_name == "MonkeyOCR-Recognition":
proc, mdl = processor_g, model_g
else:
yield "Invalid model selected."
return
if not video_path:
yield "Please upload a video."
return
frames = [f for f, _ in downsample_video(video_path)]
imgs = frames
if model_name == "SmolDocling-256M-preview":
if any(tok in text for tok in ["OTSL", "code"]):
imgs = [add_random_padding(img) for img in imgs]
if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]):
pm.text.normalize_values(text, target_max=500)
messages = [
{"role":"user",
"content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}]
}
]
prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device)
gen_cfg = GenerationConfig.from_model_config(mdl.config)
gen_cfg.max_new_tokens = max_new_tokens
gen_cfg.temperature = temperature
gen_cfg.top_p = top_p
gen_cfg.top_k = top_k
gen_cfg.repetition_penalty = repetition_penalty
gen_cfg.use_cache = True
streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
**inputs,
"streamer": streamer,
"generation_config": gen_cfg,
}
thread = Thread(target=mdl.generate, kwargs=gen_kwargs)
thread.start()
buff = ""
full = ""
for nt in streamer:
full += nt
buff += nt.replace("<|im_end|>", "")
yield buff
# Gradio UI
image_examples = [
["Convert this page to docling", "images/1.png"],
["OCR the image", "images/2.jpg"],
["Convert this page to docling", "images/3.png"],
]
video_examples = [
["Explain the ad in detail", "example/1.mp4"],
["Identify the main actions in the coca cola ad...", "example/2.mp4"]
]
css = """
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
with gr.Row():
with gr.Column():
with gr.Tabs():
with gr.TabItem("Image Inference"):
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
image_upload = gr.Image(type="pil", label="Image")
image_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=image_examples,
inputs=[image_query, image_upload]
)
with gr.TabItem("Video Inference"):
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
video_upload = gr.Video(label="Video")
video_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=video_examples,
inputs=[video_query, video_upload]
)
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
with gr.Column():
output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
model_choice = gr.Radio(
choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"],
label="Select Model",
value="Nanonets-OCR-s"
)
image_submit.click(
fn=generate_image,
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=output
)
video_submit.click(
fn=generate_video,
inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=output
)
if __name__ == "__main__":
demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)