File size: 5,119 Bytes
877b2f6
 
 
 
 
 
 
 
 
1c50fdc
877b2f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import io
import json
import torch
from PIL import Image
from google.cloud import vision
from transformers import AutoTokenizer, AutoModelForTokenClassification

# ====== Set Google Credential ======
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "eastern-entity-450514-u2-c9949243357a.json"

# ====== Load Vision API and Model ======
client = vision.ImageAnnotatorClient()
model_name = "bashyaldhiraj2067/100epoch_test_march19"
tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base")
model = AutoModelForTokenClassification.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# ====== Labels ======
labels = [
    "B-O", "I-O",
    "B-CITIZENSHIP_CERTIFICATE_NO", "I-CITIZENSHIP_CERTIFICATE_NO",
    "B-FULL_NAME", "I-FULL_NAME",
    "B-GENDER", "I-GENDER",
    "B-BIRTH_YEAR", "B-BIRTH_MONTH", "B-BIRTH_DAY",
    "I-BIRTH_MONTH", "B-DISTRICT", "B-MUNCIPALITY", "I-MUNCIPALITY",
    "B-WARD_NO", "I-WARD_NO",
    "B-FATHERS_NAME", "I-FATHERS_NAME",
    "B-MOTHERS_NAME", "I-MOTHERS_NAME",
    "I-BIRTH_YEAR", "I-DISTRICT", "I-BIRTH_DAY"
]

# ====== Normalize Bounding Boxes ======
def normalized_boxes(bbox, width, height):
    x_min, y_min, x_max, y_max = bbox
    return [
        int((x_min / width) * 1000),
        int((y_min / height) * 1000),
        int((x_max / width) * 1000),
        int((y_max / height) * 1000),
    ]

# ====== OCR ======
def extract_text_and_boxes(image):
    width, height = image.size
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    content = img_byte_arr.getvalue()

    image_vision = vision.Image(content=content)
    response = client.document_text_detection(image=image_vision)

    words, box_list = [], []

    for page in response.full_text_annotation.pages:
        for block in page.blocks:
            for paragraph in block.paragraphs:
                for word in paragraph.words:
                    word_text = ''.join([s.text for s in word.symbols])
                    words.append(word_text)

                    x_min = min([v.x for v in word.bounding_box.vertices])
                    y_min = min([v.y for v in word.bounding_box.vertices])
                    x_max = max([v.x for v in word.bounding_box.vertices])
                    y_max = max([v.y for v in word.bounding_box.vertices])

                    box = normalized_boxes([x_min, y_min, x_max, y_max], width, height)
                    box_list.append(box)

    return words, box_list

# ====== Inference & Entity Extraction ======
def predict_text(image):
    words, norm_boxes = extract_text_and_boxes(image)
    if not words:
        return json.dumps({"error": "No text detected"}, ensure_ascii=False)

    encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt")
    encoding = {k: v.to(device).long() if k == "input_ids" else v.to(device) for k, v in encoding.items()}

    token_boxes = []
    for word, box in zip(words, norm_boxes):
        word_tokens = tokenizer.tokenize(word)
        token_boxes.extend([box] * len(word_tokens))

    cls_box = [0, 0, 0, 0]
    sep_box = [0, 0, 0, 0]
    token_boxes = [cls_box] + token_boxes[:510] + [sep_box]
    encoding["bbox"] = torch.tensor([token_boxes[:len(encoding["input_ids"][0])]]).to(device).long()

    with torch.no_grad():
        outputs = model(**encoding)

    predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
    tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
    label_output = [labels[idx] for idx in predictions]

    # ====== Parse Entities ======
    # ====== Parse Entities ======
    entity_map = {
        "CITIZENSHIP_CERTIFICATE_NO": "",
        "FULL_NAME": "",
        "GENDER": "",
        "DISTRICT": "",
        "MUNCIPALITY": "",
        "BIRTH_YEAR": "",
        "BIRTH_MONTH": "",
        "BIRTH_DAY": "",
        "WARD_NO": "",
        "FATHERS_NAME": "",
        "MOTHERS_NAME": ""
    }

    current_label = None
    for token, label in zip(tokens, label_output):
        if token in ["[CLS]", "[SEP]"]:
            continue

        token = token.replace("▁", "").strip()
        if not token:
            continue

        if label.startswith("B-"):
            ent_type = label[2:]
            if ent_type in entity_map:
                entity_map[ent_type] += (" " if entity_map[ent_type] else "") + token
                current_label = ent_type
            else:
                current_label = None
        elif label.startswith("I-"):
            ent_type = label[2:]
            if current_label == ent_type and ent_type in entity_map:
                entity_map[ent_type] += " " + token
            else:
                current_label = None
        else:
            current_label = None

    # Strip extra spaces
    for k in entity_map:
        entity_map[k] = entity_map[k].strip()

    return json.dumps(entity_map, ensure_ascii=False, indent=2)


# ====== Gradio Interface ======
import gradio as gr
gr.Interface(fn=predict_text, inputs=gr.Image(type="pil"), outputs="text").launch()