JJS0321 commited on
Commit
7a6934f
·
1 Parent(s): e2c8f63
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
5
+ import torch
6
+ import traceback
7
+
8
+ # 1) Load pretrained Donut model and processor
9
+ MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
10
+ processor = DonutProcessor.from_pretrained(MODEL_NAME)
11
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
12
+
13
+ # 2) Set device and move model
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model.to(device)
16
+
17
+ # 3) Inference function with debugging
18
+ def ocr_donut(image):
19
+ try:
20
+ if image is None:
21
+ return {"error": "No image provided."}
22
+
23
+ # Prepare prompt and inputs
24
+ task_prompt = "<s_cord-v2>"
25
+ decoder_input_ids = processor.tokenizer(
26
+ task_prompt,
27
+ add_special_tokens=False,
28
+ return_tensors="pt"
29
+ ).input_ids.to(device)
30
+
31
+ # Convert to tensor
32
+ pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
33
+
34
+ # Generate outputs
35
+ outputs = model.generate(
36
+ pixel_values,
37
+ decoder_input_ids=decoder_input_ids,
38
+ max_length=model.config.decoder.max_position_embeddings,
39
+ pad_token_id=processor.tokenizer.pad_token_id,
40
+ eos_token_id=processor.tokenizer.eos_token_id,
41
+ use_cache=True,
42
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
43
+ return_dict_in_generate=True,
44
+ )
45
+
46
+ # Decode and clean up
47
+ sequence = processor.batch_decode(outputs.sequences)[0]
48
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
49
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
50
+ json_output = processor.token2json(sequence)
51
+
52
+ return {"result": json_output}
53
+
54
+ except Exception:
55
+ tb = traceback.format_exc()
56
+ print(tb)
57
+ return {"error": tb}
58
+
59
+ # 4) Build Gradio interface
60
+ demo = gr.Interface(
61
+ fn=ocr_donut,
62
+ inputs=gr.Image(type="pil", label="Upload Document Image"),
63
+ outputs=gr.JSON(label="Output"),
64
+ title="Donut OCR Gradio App",
65
+ description="Upload a document image and get structured JSON output. Errors will be shown for debugging."
66
+ )
67
+
68
+ # 5) Launch for Spaces
69
+ demo.launch(
70
+ server_name="0.0.0.0",
71
+ server_port=int(os.environ.get("PORT", 7860)),
72
+ debug=True
73
+ )