CitiNP / app.py
bashyaldhiraj2067's picture
Upload folder using huggingface_hub
1c50fdc verified
import os
import io
import json
import torch
from PIL import Image
from google.cloud import vision
from transformers import AutoTokenizer, AutoModelForTokenClassification
# ====== Set Google Credential ======
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "eastern-entity-450514-u2-c9949243357a.json"
# ====== Load Vision API and Model ======
client = vision.ImageAnnotatorClient()
model_name = "bashyaldhiraj2067/100epoch_test_march19"
tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base")
model = AutoModelForTokenClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# ====== Labels ======
labels = [
"B-O", "I-O",
"B-CITIZENSHIP_CERTIFICATE_NO", "I-CITIZENSHIP_CERTIFICATE_NO",
"B-FULL_NAME", "I-FULL_NAME",
"B-GENDER", "I-GENDER",
"B-BIRTH_YEAR", "B-BIRTH_MONTH", "B-BIRTH_DAY",
"I-BIRTH_MONTH", "B-DISTRICT", "B-MUNCIPALITY", "I-MUNCIPALITY",
"B-WARD_NO", "I-WARD_NO",
"B-FATHERS_NAME", "I-FATHERS_NAME",
"B-MOTHERS_NAME", "I-MOTHERS_NAME",
"I-BIRTH_YEAR", "I-DISTRICT", "I-BIRTH_DAY"
]
# ====== Normalize Bounding Boxes ======
def normalized_boxes(bbox, width, height):
x_min, y_min, x_max, y_max = bbox
return [
int((x_min / width) * 1000),
int((y_min / height) * 1000),
int((x_max / width) * 1000),
int((y_max / height) * 1000),
]
# ====== OCR ======
def extract_text_and_boxes(image):
width, height = image.size
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
content = img_byte_arr.getvalue()
image_vision = vision.Image(content=content)
response = client.document_text_detection(image=image_vision)
words, box_list = [], []
for page in response.full_text_annotation.pages:
for block in page.blocks:
for paragraph in block.paragraphs:
for word in paragraph.words:
word_text = ''.join([s.text for s in word.symbols])
words.append(word_text)
x_min = min([v.x for v in word.bounding_box.vertices])
y_min = min([v.y for v in word.bounding_box.vertices])
x_max = max([v.x for v in word.bounding_box.vertices])
y_max = max([v.y for v in word.bounding_box.vertices])
box = normalized_boxes([x_min, y_min, x_max, y_max], width, height)
box_list.append(box)
return words, box_list
# ====== Inference & Entity Extraction ======
def predict_text(image):
words, norm_boxes = extract_text_and_boxes(image)
if not words:
return json.dumps({"error": "No text detected"}, ensure_ascii=False)
encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt")
encoding = {k: v.to(device).long() if k == "input_ids" else v.to(device) for k, v in encoding.items()}
token_boxes = []
for word, box in zip(words, norm_boxes):
word_tokens = tokenizer.tokenize(word)
token_boxes.extend([box] * len(word_tokens))
cls_box = [0, 0, 0, 0]
sep_box = [0, 0, 0, 0]
token_boxes = [cls_box] + token_boxes[:510] + [sep_box]
encoding["bbox"] = torch.tensor([token_boxes[:len(encoding["input_ids"][0])]]).to(device).long()
with torch.no_grad():
outputs = model(**encoding)
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
label_output = [labels[idx] for idx in predictions]
# ====== Parse Entities ======
# ====== Parse Entities ======
entity_map = {
"CITIZENSHIP_CERTIFICATE_NO": "",
"FULL_NAME": "",
"GENDER": "",
"DISTRICT": "",
"MUNCIPALITY": "",
"BIRTH_YEAR": "",
"BIRTH_MONTH": "",
"BIRTH_DAY": "",
"WARD_NO": "",
"FATHERS_NAME": "",
"MOTHERS_NAME": ""
}
current_label = None
for token, label in zip(tokens, label_output):
if token in ["[CLS]", "[SEP]"]:
continue
token = token.replace("▁", "").strip()
if not token:
continue
if label.startswith("B-"):
ent_type = label[2:]
if ent_type in entity_map:
entity_map[ent_type] += (" " if entity_map[ent_type] else "") + token
current_label = ent_type
else:
current_label = None
elif label.startswith("I-"):
ent_type = label[2:]
if current_label == ent_type and ent_type in entity_map:
entity_map[ent_type] += " " + token
else:
current_label = None
else:
current_label = None
# Strip extra spaces
for k in entity_map:
entity_map[k] = entity_map[k].strip()
return json.dumps(entity_map, ensure_ascii=False, indent=2)
# ====== Gradio Interface ======
import gradio as gr
gr.Interface(fn=predict_text, inputs=gr.Image(type="pil"), outputs="text").launch()