Spaces:
Sleeping
Sleeping
import os | |
import io | |
import json | |
import torch | |
from PIL import Image | |
from google.cloud import vision | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
# ====== Set Google Credential ====== | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "eastern-entity-450514-u2-c9949243357a.json" | |
# ====== Load Vision API and Model ====== | |
client = vision.ImageAnnotatorClient() | |
model_name = "bashyaldhiraj2067/100epoch_test_march19" | |
tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base") | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
# ====== Labels ====== | |
labels = [ | |
"B-O", "I-O", | |
"B-CITIZENSHIP_CERTIFICATE_NO", "I-CITIZENSHIP_CERTIFICATE_NO", | |
"B-FULL_NAME", "I-FULL_NAME", | |
"B-GENDER", "I-GENDER", | |
"B-BIRTH_YEAR", "B-BIRTH_MONTH", "B-BIRTH_DAY", | |
"I-BIRTH_MONTH", "B-DISTRICT", "B-MUNCIPALITY", "I-MUNCIPALITY", | |
"B-WARD_NO", "I-WARD_NO", | |
"B-FATHERS_NAME", "I-FATHERS_NAME", | |
"B-MOTHERS_NAME", "I-MOTHERS_NAME", | |
"I-BIRTH_YEAR", "I-DISTRICT", "I-BIRTH_DAY" | |
] | |
# ====== Normalize Bounding Boxes ====== | |
def normalized_boxes(bbox, width, height): | |
x_min, y_min, x_max, y_max = bbox | |
return [ | |
int((x_min / width) * 1000), | |
int((y_min / height) * 1000), | |
int((x_max / width) * 1000), | |
int((y_max / height) * 1000), | |
] | |
# ====== OCR ====== | |
def extract_text_and_boxes(image): | |
width, height = image.size | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
content = img_byte_arr.getvalue() | |
image_vision = vision.Image(content=content) | |
response = client.document_text_detection(image=image_vision) | |
words, box_list = [], [] | |
for page in response.full_text_annotation.pages: | |
for block in page.blocks: | |
for paragraph in block.paragraphs: | |
for word in paragraph.words: | |
word_text = ''.join([s.text for s in word.symbols]) | |
words.append(word_text) | |
x_min = min([v.x for v in word.bounding_box.vertices]) | |
y_min = min([v.y for v in word.bounding_box.vertices]) | |
x_max = max([v.x for v in word.bounding_box.vertices]) | |
y_max = max([v.y for v in word.bounding_box.vertices]) | |
box = normalized_boxes([x_min, y_min, x_max, y_max], width, height) | |
box_list.append(box) | |
return words, box_list | |
# ====== Inference & Entity Extraction ====== | |
def predict_text(image): | |
words, norm_boxes = extract_text_and_boxes(image) | |
if not words: | |
return json.dumps({"error": "No text detected"}, ensure_ascii=False) | |
encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt") | |
encoding = {k: v.to(device).long() if k == "input_ids" else v.to(device) for k, v in encoding.items()} | |
token_boxes = [] | |
for word, box in zip(words, norm_boxes): | |
word_tokens = tokenizer.tokenize(word) | |
token_boxes.extend([box] * len(word_tokens)) | |
cls_box = [0, 0, 0, 0] | |
sep_box = [0, 0, 0, 0] | |
token_boxes = [cls_box] + token_boxes[:510] + [sep_box] | |
encoding["bbox"] = torch.tensor([token_boxes[:len(encoding["input_ids"][0])]]).to(device).long() | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() | |
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) | |
label_output = [labels[idx] for idx in predictions] | |
# ====== Parse Entities ====== | |
# ====== Parse Entities ====== | |
entity_map = { | |
"CITIZENSHIP_CERTIFICATE_NO": "", | |
"FULL_NAME": "", | |
"GENDER": "", | |
"DISTRICT": "", | |
"MUNCIPALITY": "", | |
"BIRTH_YEAR": "", | |
"BIRTH_MONTH": "", | |
"BIRTH_DAY": "", | |
"WARD_NO": "", | |
"FATHERS_NAME": "", | |
"MOTHERS_NAME": "" | |
} | |
current_label = None | |
for token, label in zip(tokens, label_output): | |
if token in ["[CLS]", "[SEP]"]: | |
continue | |
token = token.replace("▁", "").strip() | |
if not token: | |
continue | |
if label.startswith("B-"): | |
ent_type = label[2:] | |
if ent_type in entity_map: | |
entity_map[ent_type] += (" " if entity_map[ent_type] else "") + token | |
current_label = ent_type | |
else: | |
current_label = None | |
elif label.startswith("I-"): | |
ent_type = label[2:] | |
if current_label == ent_type and ent_type in entity_map: | |
entity_map[ent_type] += " " + token | |
else: | |
current_label = None | |
else: | |
current_label = None | |
# Strip extra spaces | |
for k in entity_map: | |
entity_map[k] = entity_map[k].strip() | |
return json.dumps(entity_map, ensure_ascii=False, indent=2) | |
# ====== Gradio Interface ====== | |
import gradio as gr | |
gr.Interface(fn=predict_text, inputs=gr.Image(type="pil"), outputs="text").launch() | |