Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import time | |
import threading | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
model_loading = False | |
model_loaded = False | |
loading_error = None | |
def load_model(): | |
"""Load the model and tokenizer""" | |
global model, tokenizer, model_loading, model_loaded, loading_error | |
model_loading = True | |
loading_error = None | |
try: | |
model_name = "UnarineLeo/nllb-en-ve-finetuned" | |
print(f"Loading model: {model_name}") | |
# Try loading with different configurations | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
except Exception as e1: | |
print(f"First attempt failed: {e1}") | |
# Fallback: try without optimizations | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Test if model works | |
test_input = tokenizer("Hello", return_tensors="pt") | |
with torch.no_grad(): | |
_ = model.generate(**test_input, max_length=10) | |
model_loaded = True | |
model_loading = False | |
print("Model loaded successfully!") | |
return True | |
except Exception as e: | |
loading_error = str(e) | |
model_loading = False | |
model_loaded = False | |
print(f"Error loading model: {e}") | |
return False | |
def get_model_status(): | |
"""Get current model loading status""" | |
if model_loaded: | |
return "β Model loaded and ready" | |
elif model_loading: | |
return "β³ Model is loading, please wait..." | |
elif loading_error: | |
return f"β Model loading failed: {loading_error}" | |
else: | |
return "β³ Initializing model..." | |
def translate_text(text, max_length=512, num_beams=5): | |
""" | |
Translate English text to Venda using the fine-tuned NLLB model | |
Args: | |
text (str): Input English text | |
max_length (int): Maximum length of translation | |
num_beams (int): Number of beams for beam search | |
Returns: | |
tuple: (translated_text, status_message) | |
""" | |
global model, tokenizer, model_loaded, model_loading | |
if not text.strip(): | |
return "", "Please enter some text to translate." | |
if not model_loaded: | |
if model_loading: | |
return "", "β³ Model is still loading, please wait a moment and try again." | |
else: | |
return "", f"β Model not available. {loading_error if loading_error else 'Please refresh the page.'}" | |
try: | |
# Language codes as used in training | |
source_lang = "eng_Latn" | |
target_lang = "ven_Latn" | |
# Format input exactly like in training: "eng_Latn: {text}" | |
formatted_input = f"{source_lang}: {text}" | |
# Set source language for tokenizer | |
if hasattr(tokenizer, 'src_lang'): | |
tokenizer.src_lang = source_lang | |
# Tokenize input | |
inputs = tokenizer( | |
formatted_input, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=128 # Match training max_length | |
) | |
# Generate translation | |
start_time = time.time() | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
max_length=max_length, | |
num_beams=num_beams, | |
early_stopping=True, | |
do_sample=False, | |
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id | |
) | |
# Decode translation | |
raw_translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
# Clean up translation - remove language prefixes if present | |
translation = raw_translation | |
# Remove source language prefix if it appears in output | |
if translation.startswith(f"{source_lang}:"): | |
translation = translation[len(f"{source_lang}:"):].strip() | |
# Remove target language prefix if it appears in output | |
if translation.startswith(f"{target_lang}:"): | |
translation = translation[len(f"{target_lang}:"):].strip() | |
# Remove original input if it appears at the start | |
if translation.lower().startswith(text.lower()): | |
translation = translation[len(text):].strip() | |
# Remove any remaining colons or prefixes at the start | |
translation = translation.lstrip(': ') | |
end_time = time.time() | |
processing_time = round(end_time - start_time, 2) | |
if translation and translation != formatted_input: | |
status = f"β Translation completed in {processing_time} seconds" | |
else: | |
status = "β οΈ Translation completed but result may be incomplete" | |
if not translation: | |
translation = "[No translation generated]" | |
return translation, status | |
except Exception as e: | |
error_msg = f"β Translation error: {str(e)}" | |
print(f"Translation error: {e}") | |
import traceback | |
print(f"Full traceback: {traceback.format_exc()}") | |
return "", error_msg | |
def translate_batch(text_list): | |
""" | |
Translate multiple lines of text | |
Args: | |
text_list (str): Multi-line text input | |
Returns: | |
tuple: (translated_text, status_message) | |
""" | |
if not text_list.strip(): | |
return "", "Please enter some text to translate." | |
lines = [line.strip() for line in text_list.split('\n') if line.strip()] | |
if not lines: | |
return "", "No valid text lines found." | |
try: | |
translations = [] | |
total_time = 0 | |
for i, line in enumerate(lines): | |
translation, status = translate_text(line) | |
if translation: | |
translations.append(f"{i+1}. EN: {line}") | |
translations.append(f" VE: {translation}") | |
translations.append("") | |
if translations: | |
result = "\n".join(translations) | |
status_msg = f"β Successfully translated {len(lines)} lines" | |
return result, status_msg | |
else: | |
return "", "β No translations generated" | |
except Exception as e: | |
return "", f"β Batch translation error: {str(e)}" | |
# Start loading model in background thread | |
print("Initializing model...") | |
loading_thread = threading.Thread(target=load_model) | |
loading_thread.daemon = True | |
loading_thread.start() | |
# Create Gradio interface | |
with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π English to Venda Translator | |
This app translates English text to Venda (Tshivenda) using the NLLB model. | |
Venda is a Bantu language spoken primarily in South Africa and Zimbabwe. | |
**Model:** `UnarineLeo/nllb_eng_ven_terms` | |
""") | |
# Model status indicator | |
status_indicator = gr.Textbox( | |
value=get_model_status(), | |
label="Model Status", | |
interactive=False, | |
max_lines=1 | |
) | |
# Auto-refresh status every 3 seconds while loading | |
def update_status(): | |
return get_model_status() | |
# Set up periodic status updates | |
demo.load(fn=update_status, outputs=status_indicator) | |
with gr.Tab("Single Translation"): | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="English Text", | |
placeholder="Enter English text to translate...", | |
lines=4, | |
max_lines=10 | |
) | |
with gr.Row(): | |
max_length_slider = gr.Slider( | |
minimum=50, | |
maximum=1000, | |
value=512, | |
step=50, | |
label="Max Translation Length" | |
) | |
num_beams_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Number of Beams (Quality vs Speed)" | |
) | |
translate_btn = gr.Button("π Translate", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Venda Translation", | |
lines=4, | |
max_lines=10, | |
interactive=False | |
) | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=1 | |
) | |
# Examples based on statistical terminology the model was trained on | |
gr.Examples( | |
examples=[ | |
["Hello, how are you?"], | |
["Good morning, everyone."], | |
["Thank you for your help."], | |
["What is your name?"], | |
["I am learning Venda."], | |
["Welcome to our school."], | |
["The weather is beautiful today."], | |
["Can you help me please?"] | |
], | |
inputs=[input_text], | |
label="Try these statistical terms (model was trained on statistical terminology):" | |
) | |
with gr.Tab("Batch Translation"): | |
with gr.Row(): | |
with gr.Column(): | |
batch_input = gr.Textbox( | |
label="Multiple English Sentences", | |
placeholder="Enter multiple English sentences, one per line...", | |
lines=8, | |
max_lines=15 | |
) | |
batch_translate_btn = gr.Button("π Translate All", variant="primary") | |
with gr.Column(): | |
batch_output = gr.Textbox( | |
label="Batch Translations", | |
lines=8, | |
max_lines=15, | |
interactive=False | |
) | |
batch_status = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=1 | |
) | |
with gr.Tab("About"): | |
gr.Markdown(""" | |
## About This Translator | |
This application uses a fine-tuned NLLB (No Language Left Behind) model specifically trained for English to Venda translation. | |
### Features: | |
- **Single Translation**: Translate individual sentences or paragraphs | |
- **Batch Translation**: Translate multiple sentences at once | |
- **Adjustable Parameters**: Control translation quality and length | |
- **Examples**: Try pre-loaded example sentences | |
### About Venda (Tshivenda): | |
- Spoken by approximately 1.2 million people | |
- Official language of South Africa | |
- Also spoken in Zimbabwe | |
- Part of the Bantu language family | |
### Usage Tips: | |
- Keep sentences reasonably short for best results | |
- The model works best with common, everyday language | |
- Higher beam numbers generally produce better quality but slower translations | |
### Technical Details: | |
- **Model**: UnarineLeo/nllb_eng_ven_terms | |
- **Architecture**: NLLB (No Language Left Behind) | |
- **Language Codes**: eng_Latn β ven_Latn | |
""") | |
# Event handlers | |
translate_btn.click( | |
fn=translate_text, | |
inputs=[input_text, max_length_slider, num_beams_slider], | |
outputs=[output_text, status_text] | |
) | |
batch_translate_btn.click( | |
fn=translate_batch, | |
inputs=[batch_input], | |
outputs=[batch_output, batch_status] | |
) | |
# Auto-translate on example selection | |
input_text.submit( | |
fn=translate_text, | |
inputs=[input_text, max_length_slider, num_beams_slider], | |
outputs=[output_text, status_text] | |
) | |
# Refresh status button | |
refresh_btn = gr.Button("π Refresh Status", size="sm") | |
refresh_btn.click( | |
fn=update_status, | |
outputs=[status_indicator] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
debug=True, | |
show_error=True | |
) |