Spaces:
Sleeping
Sleeping
File size: 5,119 Bytes
877b2f6 1c50fdc 877b2f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import io
import json
import torch
from PIL import Image
from google.cloud import vision
from transformers import AutoTokenizer, AutoModelForTokenClassification
# ====== Set Google Credential ======
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "eastern-entity-450514-u2-c9949243357a.json"
# ====== Load Vision API and Model ======
client = vision.ImageAnnotatorClient()
model_name = "bashyaldhiraj2067/100epoch_test_march19"
tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base")
model = AutoModelForTokenClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# ====== Labels ======
labels = [
"B-O", "I-O",
"B-CITIZENSHIP_CERTIFICATE_NO", "I-CITIZENSHIP_CERTIFICATE_NO",
"B-FULL_NAME", "I-FULL_NAME",
"B-GENDER", "I-GENDER",
"B-BIRTH_YEAR", "B-BIRTH_MONTH", "B-BIRTH_DAY",
"I-BIRTH_MONTH", "B-DISTRICT", "B-MUNCIPALITY", "I-MUNCIPALITY",
"B-WARD_NO", "I-WARD_NO",
"B-FATHERS_NAME", "I-FATHERS_NAME",
"B-MOTHERS_NAME", "I-MOTHERS_NAME",
"I-BIRTH_YEAR", "I-DISTRICT", "I-BIRTH_DAY"
]
# ====== Normalize Bounding Boxes ======
def normalized_boxes(bbox, width, height):
x_min, y_min, x_max, y_max = bbox
return [
int((x_min / width) * 1000),
int((y_min / height) * 1000),
int((x_max / width) * 1000),
int((y_max / height) * 1000),
]
# ====== OCR ======
def extract_text_and_boxes(image):
width, height = image.size
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
content = img_byte_arr.getvalue()
image_vision = vision.Image(content=content)
response = client.document_text_detection(image=image_vision)
words, box_list = [], []
for page in response.full_text_annotation.pages:
for block in page.blocks:
for paragraph in block.paragraphs:
for word in paragraph.words:
word_text = ''.join([s.text for s in word.symbols])
words.append(word_text)
x_min = min([v.x for v in word.bounding_box.vertices])
y_min = min([v.y for v in word.bounding_box.vertices])
x_max = max([v.x for v in word.bounding_box.vertices])
y_max = max([v.y for v in word.bounding_box.vertices])
box = normalized_boxes([x_min, y_min, x_max, y_max], width, height)
box_list.append(box)
return words, box_list
# ====== Inference & Entity Extraction ======
def predict_text(image):
words, norm_boxes = extract_text_and_boxes(image)
if not words:
return json.dumps({"error": "No text detected"}, ensure_ascii=False)
encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt")
encoding = {k: v.to(device).long() if k == "input_ids" else v.to(device) for k, v in encoding.items()}
token_boxes = []
for word, box in zip(words, norm_boxes):
word_tokens = tokenizer.tokenize(word)
token_boxes.extend([box] * len(word_tokens))
cls_box = [0, 0, 0, 0]
sep_box = [0, 0, 0, 0]
token_boxes = [cls_box] + token_boxes[:510] + [sep_box]
encoding["bbox"] = torch.tensor([token_boxes[:len(encoding["input_ids"][0])]]).to(device).long()
with torch.no_grad():
outputs = model(**encoding)
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
label_output = [labels[idx] for idx in predictions]
# ====== Parse Entities ======
# ====== Parse Entities ======
entity_map = {
"CITIZENSHIP_CERTIFICATE_NO": "",
"FULL_NAME": "",
"GENDER": "",
"DISTRICT": "",
"MUNCIPALITY": "",
"BIRTH_YEAR": "",
"BIRTH_MONTH": "",
"BIRTH_DAY": "",
"WARD_NO": "",
"FATHERS_NAME": "",
"MOTHERS_NAME": ""
}
current_label = None
for token, label in zip(tokens, label_output):
if token in ["[CLS]", "[SEP]"]:
continue
token = token.replace("▁", "").strip()
if not token:
continue
if label.startswith("B-"):
ent_type = label[2:]
if ent_type in entity_map:
entity_map[ent_type] += (" " if entity_map[ent_type] else "") + token
current_label = ent_type
else:
current_label = None
elif label.startswith("I-"):
ent_type = label[2:]
if current_label == ent_type and ent_type in entity_map:
entity_map[ent_type] += " " + token
else:
current_label = None
else:
current_label = None
# Strip extra spaces
for k in entity_map:
entity_map[k] = entity_map[k].strip()
return json.dumps(entity_map, ensure_ascii=False, indent=2)
# ====== Gradio Interface ======
import gradio as gr
gr.Interface(fn=predict_text, inputs=gr.Image(type="pil"), outputs="text").launch()
|