Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
# Gradio app for Dhivehi typo correction | |
import difflib | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
import gradio as gr | |
import spaces | |
# Available models | |
MODEL_OPTIONS_TYPO = { | |
"A3 Model": "alakxender/t5-dhivehi-typo-corrector-asr", | |
"XS Model": "alakxender/dhivehi-quick-spell-check-t5" | |
} | |
# Function to load model and tokenizer | |
def load_model(model_choice): | |
print("Loading model and tokenizer...") | |
try: | |
selected_model = MODEL_OPTIONS_TYPO[model_choice] | |
tokenizer = AutoTokenizer.from_pretrained(selected_model) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForSeq2SeqLM.from_pretrained(selected_model) | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
print(f"Model loaded successfully on {device}") | |
return model, tokenizer, device | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None, None, None | |
# Function to correct typos (reverted to single output) | |
def correct_typo(text, model, tokenizer, device): | |
if not text.strip(): | |
#return "Please enter some text." | |
raise gr.Error("Please enter some text💥!", duration=5) | |
if len(text.strip()) > 1024: | |
#return "Shorter the better." | |
raise gr.Error("Shorter the better💥!", duration=5) | |
try: | |
# Prepare input with prefix | |
input_text = "fix: " + text | |
# Tokenize input | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=128, truncation=True) | |
inputs = inputs.to(device) | |
# Generate output | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs.get("attention_mask", None), | |
max_length=128, | |
num_beams=4, | |
early_stopping=True | |
) | |
# Decode the output | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return corrected_text | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Initialize model and tokenizer | |
model, tokenizer, device = load_model("A3 Model") | |
if model is None: | |
print("Failed to load model. Please check your model and tokenizer paths.") | |
# Function to highlight differences between original and corrected text | |
def highlight_differences(original, corrected): | |
d = difflib.Differ() | |
orig_words = original.split() | |
corr_words = corrected.split() | |
diff = list(d.compare(orig_words, corr_words)) | |
html_parts = [] | |
i = 0 | |
while i < len(diff): | |
if diff[i].startswith(' '): # Unchanged | |
html_parts.append(f'<span>{diff[i][2:]}</span>') | |
elif diff[i].startswith('- '): # Removed | |
if i + 1 < len(diff) and diff[i + 1].startswith('+ '): | |
# Changed word - show correction | |
old_word = diff[i][2:] | |
new_word = diff[i + 1][2:] | |
html_parts.append(f'<span style="background-color: #fff3cd">{old_word}</span>→<span style="background-color: #d4edda">{new_word}</span>') | |
i += 1 | |
else: | |
# Removed word | |
html_parts.append(f'<span style="background-color: #f8d7da">{diff[i][2:]}</span>') | |
elif diff[i].startswith('+ '): # Added | |
html_parts.append(f'<span style="background-color: #d4edda">{diff[i][2:]}</span>') | |
i += 1 | |
return f'<div class="dhivehi-diff">{" ".join(html_parts)}</div>' | |
# Function to process the input for Gradio | |
def process_input(text,model_choice): | |
if model is None: | |
load_model(model_choice) | |
corrected = correct_typo(text, model, tokenizer, device) | |
highlighted = highlight_differences(text, corrected) | |
return corrected, highlighted | |
# Define CSS for Dhivehi font styling | |
css = """ | |
.textbox1 textarea { | |
font-size: 18px !important; | |
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
line-height: 1.8 !important; | |
direction: rtl !important; | |
} | |
.dhivehi-text { | |
font-size: 18px !important; | |
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
line-height: 1.8 !important; | |
direction: rtl !important; | |
text-align: right !important; | |
padding: 10px !important; | |
background: transparent !important; /* Make background transparent */ | |
border-radius: 4px !important; | |
color: #ffffff !important; /* White text for dark background */ | |
} | |
/* Style for the highlighted differences */ | |
.dhivehi-diff { | |
font-size: 18px !important; | |
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
line-height: 1.8 !important; | |
direction: rtl !important; | |
text-align: right !important; | |
padding: 15px !important; | |
background: transparent !important; /* Make background transparent */ | |
border: 1px solid rgba(255, 255, 255, 0.1) !important; /* Subtle border */ | |
border-radius: 4px !important; | |
margin-top: 10px !important; | |
color: #ffffff !important; /* White text for dark background */ | |
} | |
/* Ensure the highlighted spans have good contrast */ | |
.dhivehi-diff span { | |
padding: 2px 5px !important; | |
border-radius: 3px !important; | |
margin: 0 2px !important; | |
} | |
/* Original text (yellow background) */ | |
.dhivehi-diff span[style*="background-color: #fff3cd"] { | |
background-color: rgba(255, 243, 205, 0.2) !important; | |
color: #ffd700 !important; /* Golden yellow for visibility */ | |
border: 1px solid rgba(255, 243, 205, 0.3) !important; | |
} | |
/* Corrected text (green background) */ | |
.dhivehi-diff span[style*="background-color: #d4edda"] { | |
background-color: rgba(212, 237, 218, 0.2) !important; | |
color: #98ff98 !important; /* Light green for visibility */ | |
border: 1px solid rgba(212, 237, 218, 0.3) !important; | |
} | |
/* Removed text (red background) */ | |
.dhivehi-diff span[style*="background-color: #f8d7da"] { | |
background-color: rgba(248, 215, 218, 0.2) !important; | |
color: #ff6b6b !important; /* Light red for visibility */ | |
border: 1px solid rgba(248, 215, 218, 0.3) !important; | |
} | |
/* Arrow color */ | |
.dhivehi-diff span:contains('→') { | |
color: #ffffff !important; | |
} | |
""" |