UnarineLeo's picture
Update app.py
3eefee4 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
import threading
# Global variables for model and tokenizer
model = None
tokenizer = None
model_loading = False
model_loaded = False
loading_error = None
def load_model():
"""Load the model and tokenizer"""
global model, tokenizer, model_loading, model_loaded, loading_error
model_loading = True
loading_error = None
try:
model_name = "UnarineLeo/nllb-en-ve-finetuned"
print(f"Loading model: {model_name}")
# Try loading with different configurations
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
except Exception as e1:
print(f"First attempt failed: {e1}")
# Fallback: try without optimizations
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Test if model works
test_input = tokenizer("Hello", return_tensors="pt")
with torch.no_grad():
_ = model.generate(**test_input, max_length=10)
model_loaded = True
model_loading = False
print("Model loaded successfully!")
return True
except Exception as e:
loading_error = str(e)
model_loading = False
model_loaded = False
print(f"Error loading model: {e}")
return False
def get_model_status():
"""Get current model loading status"""
if model_loaded:
return "βœ… Model loaded and ready"
elif model_loading:
return "⏳ Model is loading, please wait..."
elif loading_error:
return f"❌ Model loading failed: {loading_error}"
else:
return "⏳ Initializing model..."
def translate_text(text, max_length=512, num_beams=5):
"""
Translate English text to Venda using the fine-tuned NLLB model
Args:
text (str): Input English text
max_length (int): Maximum length of translation
num_beams (int): Number of beams for beam search
Returns:
tuple: (translated_text, status_message)
"""
global model, tokenizer, model_loaded, model_loading
if not text.strip():
return "", "Please enter some text to translate."
if not model_loaded:
if model_loading:
return "", "⏳ Model is still loading, please wait a moment and try again."
else:
return "", f"❌ Model not available. {loading_error if loading_error else 'Please refresh the page.'}"
try:
# Language codes as used in training
source_lang = "eng_Latn"
target_lang = "ven_Latn"
# Format input exactly like in training: "eng_Latn: {text}"
formatted_input = f"{source_lang}: {text}"
# Set source language for tokenizer
if hasattr(tokenizer, 'src_lang'):
tokenizer.src_lang = source_lang
# Tokenize input
inputs = tokenizer(
formatted_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128 # Match training max_length
)
# Generate translation
start_time = time.time()
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
do_sample=False,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id
)
# Decode translation
raw_translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Clean up translation - remove language prefixes if present
translation = raw_translation
# Remove source language prefix if it appears in output
if translation.startswith(f"{source_lang}:"):
translation = translation[len(f"{source_lang}:"):].strip()
# Remove target language prefix if it appears in output
if translation.startswith(f"{target_lang}:"):
translation = translation[len(f"{target_lang}:"):].strip()
# Remove original input if it appears at the start
if translation.lower().startswith(text.lower()):
translation = translation[len(text):].strip()
# Remove any remaining colons or prefixes at the start
translation = translation.lstrip(': ')
end_time = time.time()
processing_time = round(end_time - start_time, 2)
if translation and translation != formatted_input:
status = f"βœ… Translation completed in {processing_time} seconds"
else:
status = "⚠️ Translation completed but result may be incomplete"
if not translation:
translation = "[No translation generated]"
return translation, status
except Exception as e:
error_msg = f"❌ Translation error: {str(e)}"
print(f"Translation error: {e}")
import traceback
print(f"Full traceback: {traceback.format_exc()}")
return "", error_msg
def translate_batch(text_list):
"""
Translate multiple lines of text
Args:
text_list (str): Multi-line text input
Returns:
tuple: (translated_text, status_message)
"""
if not text_list.strip():
return "", "Please enter some text to translate."
lines = [line.strip() for line in text_list.split('\n') if line.strip()]
if not lines:
return "", "No valid text lines found."
try:
translations = []
total_time = 0
for i, line in enumerate(lines):
translation, status = translate_text(line)
if translation:
translations.append(f"{i+1}. EN: {line}")
translations.append(f" VE: {translation}")
translations.append("")
if translations:
result = "\n".join(translations)
status_msg = f"βœ… Successfully translated {len(lines)} lines"
return result, status_msg
else:
return "", "❌ No translations generated"
except Exception as e:
return "", f"❌ Batch translation error: {str(e)}"
# Start loading model in background thread
print("Initializing model...")
loading_thread = threading.Thread(target=load_model)
loading_thread.daemon = True
loading_thread.start()
# Create Gradio interface
with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌍 English to Venda Translator
This app translates English text to Venda (Tshivenda) using the NLLB model.
Venda is a Bantu language spoken primarily in South Africa and Zimbabwe.
**Model:** `UnarineLeo/nllb_eng_ven_terms`
""")
# Model status indicator
status_indicator = gr.Textbox(
value=get_model_status(),
label="Model Status",
interactive=False,
max_lines=1
)
# Auto-refresh status every 3 seconds while loading
def update_status():
return get_model_status()
# Set up periodic status updates
demo.load(fn=update_status, outputs=status_indicator)
with gr.Tab("Single Translation"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="English Text",
placeholder="Enter English text to translate...",
lines=4,
max_lines=10
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=50,
maximum=1000,
value=512,
step=50,
label="Max Translation Length"
)
num_beams_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Beams (Quality vs Speed)"
)
translate_btn = gr.Button("πŸ”„ Translate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Venda Translation",
lines=4,
max_lines=10,
interactive=False
)
status_text = gr.Textbox(
label="Status",
interactive=False,
lines=1
)
# Examples based on statistical terminology the model was trained on
gr.Examples(
examples=[
["Hello, how are you?"],
["Good morning, everyone."],
["Thank you for your help."],
["What is your name?"],
["I am learning Venda."],
["Welcome to our school."],
["The weather is beautiful today."],
["Can you help me please?"]
],
inputs=[input_text],
label="Try these statistical terms (model was trained on statistical terminology):"
)
with gr.Tab("Batch Translation"):
with gr.Row():
with gr.Column():
batch_input = gr.Textbox(
label="Multiple English Sentences",
placeholder="Enter multiple English sentences, one per line...",
lines=8,
max_lines=15
)
batch_translate_btn = gr.Button("πŸ”„ Translate All", variant="primary")
with gr.Column():
batch_output = gr.Textbox(
label="Batch Translations",
lines=8,
max_lines=15,
interactive=False
)
batch_status = gr.Textbox(
label="Status",
interactive=False,
lines=1
)
with gr.Tab("About"):
gr.Markdown("""
## About This Translator
This application uses a fine-tuned NLLB (No Language Left Behind) model specifically trained for English to Venda translation.
### Features:
- **Single Translation**: Translate individual sentences or paragraphs
- **Batch Translation**: Translate multiple sentences at once
- **Adjustable Parameters**: Control translation quality and length
- **Examples**: Try pre-loaded example sentences
### About Venda (Tshivenda):
- Spoken by approximately 1.2 million people
- Official language of South Africa
- Also spoken in Zimbabwe
- Part of the Bantu language family
### Usage Tips:
- Keep sentences reasonably short for best results
- The model works best with common, everyday language
- Higher beam numbers generally produce better quality but slower translations
### Technical Details:
- **Model**: UnarineLeo/nllb_eng_ven_terms
- **Architecture**: NLLB (No Language Left Behind)
- **Language Codes**: eng_Latn β†’ ven_Latn
""")
# Event handlers
translate_btn.click(
fn=translate_text,
inputs=[input_text, max_length_slider, num_beams_slider],
outputs=[output_text, status_text]
)
batch_translate_btn.click(
fn=translate_batch,
inputs=[batch_input],
outputs=[batch_output, batch_status]
)
# Auto-translate on example selection
input_text.submit(
fn=translate_text,
inputs=[input_text, max_length_slider, num_beams_slider],
outputs=[output_text, status_text]
)
# Refresh status button
refresh_btn = gr.Button("πŸ”„ Refresh Status", size="sm")
refresh_btn.click(
fn=update_status,
outputs=[status_indicator]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
share=True,
debug=True,
show_error=True
)