import os import io import json import torch from PIL import Image from google.cloud import vision from transformers import AutoTokenizer, AutoModelForTokenClassification # ====== Set Google Credential ====== os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "eastern-entity-450514-u2-c9949243357a.json" # ====== Load Vision API and Model ====== client = vision.ImageAnnotatorClient() model_name = "bashyaldhiraj2067/100epoch_test_march19" tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base") model = AutoModelForTokenClassification.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # ====== Labels ====== labels = [ "B-O", "I-O", "B-CITIZENSHIP_CERTIFICATE_NO", "I-CITIZENSHIP_CERTIFICATE_NO", "B-FULL_NAME", "I-FULL_NAME", "B-GENDER", "I-GENDER", "B-BIRTH_YEAR", "B-BIRTH_MONTH", "B-BIRTH_DAY", "I-BIRTH_MONTH", "B-DISTRICT", "B-MUNCIPALITY", "I-MUNCIPALITY", "B-WARD_NO", "I-WARD_NO", "B-FATHERS_NAME", "I-FATHERS_NAME", "B-MOTHERS_NAME", "I-MOTHERS_NAME", "I-BIRTH_YEAR", "I-DISTRICT", "I-BIRTH_DAY" ] # ====== Normalize Bounding Boxes ====== def normalized_boxes(bbox, width, height): x_min, y_min, x_max, y_max = bbox return [ int((x_min / width) * 1000), int((y_min / height) * 1000), int((x_max / width) * 1000), int((y_max / height) * 1000), ] # ====== OCR ====== def extract_text_and_boxes(image): width, height = image.size img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') content = img_byte_arr.getvalue() image_vision = vision.Image(content=content) response = client.document_text_detection(image=image_vision) words, box_list = [], [] for page in response.full_text_annotation.pages: for block in page.blocks: for paragraph in block.paragraphs: for word in paragraph.words: word_text = ''.join([s.text for s in word.symbols]) words.append(word_text) x_min = min([v.x for v in word.bounding_box.vertices]) y_min = min([v.y for v in word.bounding_box.vertices]) x_max = max([v.x for v in word.bounding_box.vertices]) y_max = max([v.y for v in word.bounding_box.vertices]) box = normalized_boxes([x_min, y_min, x_max, y_max], width, height) box_list.append(box) return words, box_list # ====== Inference & Entity Extraction ====== def predict_text(image): words, norm_boxes = extract_text_and_boxes(image) if not words: return json.dumps({"error": "No text detected"}, ensure_ascii=False) encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt") encoding = {k: v.to(device).long() if k == "input_ids" else v.to(device) for k, v in encoding.items()} token_boxes = [] for word, box in zip(words, norm_boxes): word_tokens = tokenizer.tokenize(word) token_boxes.extend([box] * len(word_tokens)) cls_box = [0, 0, 0, 0] sep_box = [0, 0, 0, 0] token_boxes = [cls_box] + token_boxes[:510] + [sep_box] encoding["bbox"] = torch.tensor([token_boxes[:len(encoding["input_ids"][0])]]).to(device).long() with torch.no_grad(): outputs = model(**encoding) predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) label_output = [labels[idx] for idx in predictions] # ====== Parse Entities ====== # ====== Parse Entities ====== entity_map = { "CITIZENSHIP_CERTIFICATE_NO": "", "FULL_NAME": "", "GENDER": "", "DISTRICT": "", "MUNCIPALITY": "", "BIRTH_YEAR": "", "BIRTH_MONTH": "", "BIRTH_DAY": "", "WARD_NO": "", "FATHERS_NAME": "", "MOTHERS_NAME": "" } current_label = None for token, label in zip(tokens, label_output): if token in ["[CLS]", "[SEP]"]: continue token = token.replace("▁", "").strip() if not token: continue if label.startswith("B-"): ent_type = label[2:] if ent_type in entity_map: entity_map[ent_type] += (" " if entity_map[ent_type] else "") + token current_label = ent_type else: current_label = None elif label.startswith("I-"): ent_type = label[2:] if current_label == ent_type and ent_type in entity_map: entity_map[ent_type] += " " + token else: current_label = None else: current_label = None # Strip extra spaces for k in entity_map: entity_map[k] = entity_map[k].strip() return json.dumps(entity_map, ensure_ascii=False, indent=2) # ====== Gradio Interface ====== import gradio as gr gr.Interface(fn=predict_text, inputs=gr.Image(type="pil"), outputs="text").launch()