t5-ft-demo / typo_check.py
alakxender's picture
a
82b0ab8
#!/usr/bin/env python
# Gradio app for Dhivehi typo correction
import difflib
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import gradio as gr
import spaces
# Available models
MODEL_OPTIONS_TYPO = {
"A3 Model": "alakxender/t5-dhivehi-typo-corrector-asr",
"XS Model": "alakxender/dhivehi-quick-spell-check-t5"
}
# Function to load model and tokenizer
def load_model(model_choice):
print("Loading model and tokenizer...")
try:
selected_model = MODEL_OPTIONS_TYPO[model_choice]
tokenizer = AutoTokenizer.from_pretrained(selected_model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForSeq2SeqLM.from_pretrained(selected_model)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded successfully on {device}")
return model, tokenizer, device
except Exception as e:
print(f"Error loading model: {e}")
return None, None, None
# Function to correct typos (reverted to single output)
def correct_typo(text, model, tokenizer, device):
if not text.strip():
#return "Please enter some text."
raise gr.Error("Please enter some text💥!", duration=5)
if len(text.strip()) > 1024:
#return "Shorter the better."
raise gr.Error("Shorter the better💥!", duration=5)
try:
# Prepare input with prefix
input_text = "fix: " + text
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", max_length=128, truncation=True)
inputs = inputs.to(device)
# Generate output
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
max_length=128,
num_beams=4,
early_stopping=True
)
# Decode the output
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
except Exception as e:
return f"Error: {str(e)}"
# Initialize model and tokenizer
model, tokenizer, device = load_model("A3 Model")
if model is None:
print("Failed to load model. Please check your model and tokenizer paths.")
# Function to highlight differences between original and corrected text
def highlight_differences(original, corrected):
d = difflib.Differ()
orig_words = original.split()
corr_words = corrected.split()
diff = list(d.compare(orig_words, corr_words))
html_parts = []
i = 0
while i < len(diff):
if diff[i].startswith(' '): # Unchanged
html_parts.append(f'<span>{diff[i][2:]}</span>')
elif diff[i].startswith('- '): # Removed
if i + 1 < len(diff) and diff[i + 1].startswith('+ '):
# Changed word - show correction
old_word = diff[i][2:]
new_word = diff[i + 1][2:]
html_parts.append(f'<span style="background-color: #fff3cd">{old_word}</span>→<span style="background-color: #d4edda">{new_word}</span>')
i += 1
else:
# Removed word
html_parts.append(f'<span style="background-color: #f8d7da">{diff[i][2:]}</span>')
elif diff[i].startswith('+ '): # Added
html_parts.append(f'<span style="background-color: #d4edda">{diff[i][2:]}</span>')
i += 1
return f'<div class="dhivehi-diff">{" ".join(html_parts)}</div>'
# Function to process the input for Gradio
@spaces.GPU()
def process_input(text,model_choice):
if model is None:
load_model(model_choice)
corrected = correct_typo(text, model, tokenizer, device)
highlighted = highlight_differences(text, corrected)
return corrected, highlighted
# Define CSS for Dhivehi font styling
css = """
.textbox1 textarea {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
line-height: 1.8 !important;
direction: rtl !important;
}
.dhivehi-text {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
line-height: 1.8 !important;
direction: rtl !important;
text-align: right !important;
padding: 10px !important;
background: transparent !important; /* Make background transparent */
border-radius: 4px !important;
color: #ffffff !important; /* White text for dark background */
}
/* Style for the highlighted differences */
.dhivehi-diff {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
line-height: 1.8 !important;
direction: rtl !important;
text-align: right !important;
padding: 15px !important;
background: transparent !important; /* Make background transparent */
border: 1px solid rgba(255, 255, 255, 0.1) !important; /* Subtle border */
border-radius: 4px !important;
margin-top: 10px !important;
color: #ffffff !important; /* White text for dark background */
}
/* Ensure the highlighted spans have good contrast */
.dhivehi-diff span {
padding: 2px 5px !important;
border-radius: 3px !important;
margin: 0 2px !important;
}
/* Original text (yellow background) */
.dhivehi-diff span[style*="background-color: #fff3cd"] {
background-color: rgba(255, 243, 205, 0.2) !important;
color: #ffd700 !important; /* Golden yellow for visibility */
border: 1px solid rgba(255, 243, 205, 0.3) !important;
}
/* Corrected text (green background) */
.dhivehi-diff span[style*="background-color: #d4edda"] {
background-color: rgba(212, 237, 218, 0.2) !important;
color: #98ff98 !important; /* Light green for visibility */
border: 1px solid rgba(212, 237, 218, 0.3) !important;
}
/* Removed text (red background) */
.dhivehi-diff span[style*="background-color: #f8d7da"] {
background-color: rgba(248, 215, 218, 0.2) !important;
color: #ff6b6b !important; /* Light red for visibility */
border: 1px solid rgba(248, 215, 218, 0.3) !important;
}
/* Arrow color */
.dhivehi-diff span:contains('→') {
color: #ffffff !important;
}
"""