File size: 8,393 Bytes
f57cf41
 
 
 
6583fc2
 
f57cf41
1ea570a
6583fc2
f57cf41
6583fc2
f57cf41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6583fc2
f57cf41
 
 
 
 
 
6583fc2
 
f57cf41
 
 
6583fc2
1ea570a
f57cf41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6583fc2
f57cf41
6583fc2
f57cf41
 
6583fc2
f57cf41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6583fc2
 
3d8470e
 
 
f57cf41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d8470e
f57cf41
3d8470e
 
f57cf41
 
 
 
 
 
 
3d8470e
f57cf41
 
3d8470e
f57cf41
 
 
3d8470e
f57cf41
 
6583fc2
3d8470e
f57cf41
 
 
 
 
 
 
 
6583fc2
f57cf41
 
3d8470e
f57cf41
78c8fb4
f57cf41
 
 
78c8fb4
f57cf41
 
 
1ea570a
6583fc2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import tempfile
import uuid
import concurrent.futures
from typing import List, Tuple

import fitz  # PyMuPDF for PDF operations
import torch
import gradio as gr
import spaces                # HuggingFace Spaces helper (ZeroGPU)
import easyocr
import warnings

# Suppress benign CuDNN LSTM warning
warnings.filterwarnings("ignore", "RNN module weights are not part")

# ----------------------------------------------------------------------
# Configuration constants
# ----------------------------------------------------------------------
SUPPORTED_FILE_TYPES = [
    ".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif"
]
LANGUAGES = ["en", "nl", "de", "fr", "es", "it", "pt", "ru", "zh_cn", "ja", "ar"]
# Cap parallel OCR threads to avoid GPU OOM
OCR_THREADS = min(int(os.getenv("OCR_THREADS", "2")), 2)

# ----------------------------------------------------------------------
# EasyOCR reader cache
# ----------------------------------------------------------------------
_READERS = {}

def get_reader(lang_codes: Tuple[str, ...]):
    """
    Lazily initialize or retrieve an EasyOCR Reader for the given languages.
    Uses spaces.is_gpu_enabled() to decide whether to run on GPU or CPU.
    """
    key = tuple(sorted(lang_codes))
    if key not in _READERS:
        gpu_flag = spaces.is_gpu_enabled()
        _READERS[key] = easyocr.Reader(list(key), gpu=gpu_flag)
        print(f"[Init] EasyOCR reader for {key} (GPU={'yes' if gpu_flag else 'no'})")
    return _READERS[key]

# ----------------------------------------------------------------------
# OCR helpers
# ----------------------------------------------------------------------
@spaces.GPU(duration=600)
def run_ocr_pages(pdf_path: str, page_ids: List[int], lang_codes: Tuple[str, ...]) -> List[Tuple[int, str]]:
    """
    OCR the specified pages of a PDF.
    Runs only when GPU is allocated (ZeroGPU); falls back to CPU if unavailable.
    Processes pages in parallel threads, with per-page error handling.
    """
    reader = get_reader(lang_codes)
    results = []

    with fitz.open(pdf_path) as doc:
        def ocr_page(idx: int) -> Tuple[int, str]:
            try:
                page = doc[idx - 1]
                # Adaptive resolution: up to ~300dpi on normal pages
                scale = 2 if max(page.rect.width, page.rect.height) <= 600 else 1.5
                pix = page.get_pixmap(matrix=fitz.Matrix(scale, scale))
                img_path = os.path.join(tempfile.gettempdir(), f"ocr_{uuid.uuid4().hex}.png")
                pix.save(img_path)

                # Single-language => detail mode with confidence filtering
                if len(lang_codes) == 1:
                    items = reader.readtext(img_path, detail=1)
                    lines = [t for _, t, conf in items if conf > 0.2]
                else:
                    lines = reader.readtext(img_path, detail=0)

                os.remove(img_path)
                return idx, "\n".join(lines)
            except Exception as e:
                # Emit a warning instead of halting the entire batch
                msg = f"⚠️ OCR error on page {idx}: {e}"
                print(msg)
                return idx, msg

        # Cap threadpool size to avoid overloading GPU
        workers = min(OCR_THREADS, len(page_ids))
        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool:
            futures = {pool.submit(ocr_page, pid): pid for pid in page_ids}
            for fut in concurrent.futures.as_completed(futures):
                results.append(fut.result())

    return results

def run_ocr_image(image_path: str, lang_codes: Tuple[str, ...]) -> str:
    """
    OCR a single image file.
    Mirrors run_ocr_pages' logic but for one-shot image inputs.
    """
    reader = get_reader(lang_codes)
    try:
        if len(lang_codes) == 1:
            items = reader.readtext(image_path, detail=1)
            lines = [t for _, t, conf in items if conf > 0.2]
        else:
            lines = reader.readtext(image_path, detail=0)
        return "\n".join(lines)
    except Exception as e:
        msg = f"⚠️ OCR error on image: {e}"
        print(msg)
        return msg

# ----------------------------------------------------------------------
# Streamed output helper
# ----------------------------------------------------------------------
def emit_chunk(chunk: str, combined: str, tmp_file) -> Tuple[str, None]:
    """
    Append 'chunk' to the in-memory combined text and the temp file,
    then return the updated combined text for streaming.
    """
    combined += chunk
    tmp_file.write(chunk.encode("utf-8"))
    return combined, None

# ----------------------------------------------------------------------
# Main extraction pipeline
# ----------------------------------------------------------------------
def pipeline(upload, langs, mode):
    """
    Handles PDF or image uploads, emits native and OCR text incrementally,
    and provides a downloadable .txt at the end.
    """
    if upload is None:
        raise gr.Error("Please upload a file.")
    # File-size guard (200MB)
    if os.path.getsize(upload.name) > 200 * 1024 * 1024:
        raise gr.Error("File larger than 200 MB; please split it.")

    # Prepare languages and temp output
    langs = langs if isinstance(langs, list) else [langs]
    lang_tuple = tuple(langs)
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
    combined = ""

    ext = os.path.splitext(upload.name)[1].lower()

    # PDF flow
    if ext == ".pdf":
        # Streaming progress bar
        progress = gr.Progress(track_tqdm=False)
        with fitz.open(upload.name) as doc:
            total_pages = doc.page_count

        # Phase 1: Native-text extraction & OCR scheduling
        ocr_pages = []
        with fitz.open(upload.name) as doc:
            for i, page in enumerate(doc, start=1):
                text = page.get_text("text") if mode in ("native", "auto") else ""
                if text.strip():
                    chunk = f"--- Page {i} (native) ---\n{text}\n"
                    combined, _ = emit_chunk(chunk, combined, tmp)
                    yield combined, None
                else:
                    if mode in ("ocr", "auto"):
                        ocr_pages.append(i)
                progress(i / total_pages)

        # Phase 2: OCR pass on scheduled pages
        if ocr_pages:
            ocr_results = run_ocr_pages(upload.name, ocr_pages, lang_tuple)
            for idx, txt in sorted(ocr_results, key=lambda x: x[0]):
                chunk = f"--- Page {idx} (OCR) ---\n{txt}\n"
                combined, _ = emit_chunk(chunk, combined, tmp)
                yield combined, None

    # Image flow
    else:
        txt = run_ocr_image(upload.name, lang_tuple)
        chunk = f"--- Image OCR ---\n{txt}\n"
        combined, _ = emit_chunk(chunk, combined, tmp)
        yield combined, None

    tmp.close()
    # Final step: offer download link
    yield combined or "⚠️ No text detected.", tmp.name

# ----------------------------------------------------------------------
# Gradio UI (Blocks + streaming)
# ----------------------------------------------------------------------
theme = gr.themes.Base(primary_hue="purple")
with gr.Blocks(theme=theme, title="ZeroGPU OCR PDF & Image Extractor") as demo:
    gr.Markdown("## 📚 ZeroGPU Multilingual OCR Extractor")
    with gr.Row():
        with gr.Column(scale=1):
            file_in = gr.File(label="Upload PDF or image",
                              file_types=SUPPORTED_FILE_TYPES)
            lang_in = gr.Dropdown(LANGUAGES, multiselect=True, value=["en"],
                                  label="OCR language(s)")
            mode_in = gr.Radio(["native", "ocr", "auto"], value="auto",
                               label="Mode",
                               info="native=text · ocr=image · auto=mix")
            btn = gr.Button("Extract", variant="primary")
        with gr.Column(scale=2):
            out_txt = gr.Textbox(label="Extracted Text", lines=18,
                                 show_copy_button=True)
            dl = gr.File(label="Download .txt")

    # Use a list for outputs to match Gradio API
    btn.click(
        fn=pipeline,
        inputs=[file_in, lang_in, mode_in],
        outputs=[out_txt, dl]
    )
    demo.queue()

if __name__ == "__main__":
    demo.launch()