Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,31 +2,71 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
import time
|
|
|
5 |
|
6 |
# Global variables for model and tokenizer
|
7 |
model = None
|
8 |
tokenizer = None
|
|
|
|
|
|
|
9 |
|
10 |
def load_model():
|
11 |
"""Load the model and tokenizer"""
|
12 |
-
global model, tokenizer
|
|
|
|
|
|
|
13 |
|
14 |
try:
|
15 |
model_name = "UnarineLeo/nllb_eng_ven_terms"
|
16 |
print(f"Loading model: {model_name}")
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
|
|
|
|
21 |
print("Model loaded successfully!")
|
22 |
return True
|
|
|
23 |
except Exception as e:
|
|
|
|
|
|
|
24 |
print(f"Error loading model: {e}")
|
25 |
return False
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def translate_text(text, max_length=512, num_beams=5):
|
28 |
"""
|
29 |
-
Translate English text to Venda
|
30 |
|
31 |
Args:
|
32 |
text (str): Input English text
|
@@ -36,45 +76,88 @@ def translate_text(text, max_length=512, num_beams=5):
|
|
36 |
Returns:
|
37 |
tuple: (translated_text, status_message)
|
38 |
"""
|
39 |
-
global model, tokenizer
|
40 |
|
41 |
if not text.strip():
|
42 |
return "", "Please enter some text to translate."
|
43 |
|
44 |
-
if
|
45 |
-
|
|
|
|
|
|
|
46 |
|
47 |
try:
|
48 |
-
#
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Tokenize input
|
52 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# Generate translation
|
55 |
start_time = time.time()
|
56 |
with torch.no_grad():
|
57 |
generated_tokens = model.generate(
|
58 |
**inputs,
|
59 |
-
forced_bos_token_id=tokenizer.lang_code_to_id["ven_Latn"],
|
60 |
max_length=max_length,
|
61 |
num_beams=num_beams,
|
62 |
early_stopping=True,
|
63 |
-
do_sample=False
|
|
|
64 |
)
|
65 |
|
66 |
# Decode translation
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
end_time = time.time()
|
70 |
processing_time = round(end_time - start_time, 2)
|
71 |
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
return translation, status
|
75 |
|
76 |
except Exception as e:
|
77 |
error_msg = f"β Translation error: {str(e)}"
|
|
|
|
|
|
|
78 |
return "", error_msg
|
79 |
|
80 |
def translate_batch(text_list):
|
@@ -116,9 +199,11 @@ def translate_batch(text_list):
|
|
116 |
except Exception as e:
|
117 |
return "", f"β Batch translation error: {str(e)}"
|
118 |
|
119 |
-
#
|
120 |
print("Initializing model...")
|
121 |
-
|
|
|
|
|
122 |
|
123 |
# Create Gradio interface
|
124 |
with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo:
|
@@ -132,6 +217,21 @@ with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as d
|
|
132 |
**Model:** `UnarineLeo/nllb_eng_ven_terms`
|
133 |
""")
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
with gr.Tab("Single Translation"):
|
136 |
with gr.Row():
|
137 |
with gr.Column():
|
@@ -175,20 +275,22 @@ with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as d
|
|
175 |
lines=1
|
176 |
)
|
177 |
|
178 |
-
# Examples
|
179 |
gr.Examples(
|
180 |
examples=[
|
181 |
-
["
|
182 |
-
["
|
183 |
-
["
|
184 |
-
["
|
185 |
-
["
|
186 |
-
["
|
187 |
-
["
|
188 |
-
["
|
|
|
|
|
189 |
],
|
190 |
inputs=[input_text],
|
191 |
-
label="Try these
|
192 |
)
|
193 |
|
194 |
with gr.Tab("Batch Translation"):
|
@@ -263,6 +365,13 @@ with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as d
|
|
263 |
inputs=[input_text, max_length_slider, num_beams_slider],
|
264 |
outputs=[output_text, status_text]
|
265 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
# Launch the app
|
268 |
if __name__ == "__main__":
|
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
import time
|
5 |
+
import threading
|
6 |
|
7 |
# Global variables for model and tokenizer
|
8 |
model = None
|
9 |
tokenizer = None
|
10 |
+
model_loading = False
|
11 |
+
model_loaded = False
|
12 |
+
loading_error = None
|
13 |
|
14 |
def load_model():
|
15 |
"""Load the model and tokenizer"""
|
16 |
+
global model, tokenizer, model_loading, model_loaded, loading_error
|
17 |
+
|
18 |
+
model_loading = True
|
19 |
+
loading_error = None
|
20 |
|
21 |
try:
|
22 |
model_name = "UnarineLeo/nllb_eng_ven_terms"
|
23 |
print(f"Loading model: {model_name}")
|
24 |
|
25 |
+
# Try loading with different configurations
|
26 |
+
try:
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
28 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
29 |
+
model_name,
|
30 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
31 |
+
device_map="auto" if torch.cuda.is_available() else None
|
32 |
+
)
|
33 |
+
except Exception as e1:
|
34 |
+
print(f"First attempt failed: {e1}")
|
35 |
+
# Fallback: try without optimizations
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
37 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
38 |
+
|
39 |
+
# Test if model works
|
40 |
+
test_input = tokenizer("Hello", return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
_ = model.generate(**test_input, max_length=10)
|
43 |
|
44 |
+
model_loaded = True
|
45 |
+
model_loading = False
|
46 |
print("Model loaded successfully!")
|
47 |
return True
|
48 |
+
|
49 |
except Exception as e:
|
50 |
+
loading_error = str(e)
|
51 |
+
model_loading = False
|
52 |
+
model_loaded = False
|
53 |
print(f"Error loading model: {e}")
|
54 |
return False
|
55 |
|
56 |
+
def get_model_status():
|
57 |
+
"""Get current model loading status"""
|
58 |
+
if model_loaded:
|
59 |
+
return "β
Model loaded and ready"
|
60 |
+
elif model_loading:
|
61 |
+
return "β³ Model is loading, please wait..."
|
62 |
+
elif loading_error:
|
63 |
+
return f"β Model loading failed: {loading_error}"
|
64 |
+
else:
|
65 |
+
return "β³ Initializing model..."
|
66 |
+
|
67 |
def translate_text(text, max_length=512, num_beams=5):
|
68 |
"""
|
69 |
+
Translate English text to Venda using the fine-tuned NLLB model
|
70 |
|
71 |
Args:
|
72 |
text (str): Input English text
|
|
|
76 |
Returns:
|
77 |
tuple: (translated_text, status_message)
|
78 |
"""
|
79 |
+
global model, tokenizer, model_loaded, model_loading
|
80 |
|
81 |
if not text.strip():
|
82 |
return "", "Please enter some text to translate."
|
83 |
|
84 |
+
if not model_loaded:
|
85 |
+
if model_loading:
|
86 |
+
return "", "β³ Model is still loading, please wait a moment and try again."
|
87 |
+
else:
|
88 |
+
return "", f"β Model not available. {loading_error if loading_error else 'Please refresh the page.'}"
|
89 |
|
90 |
try:
|
91 |
+
# Language codes as used in training
|
92 |
+
source_lang = "eng_Latn"
|
93 |
+
target_lang = "ven_Latn"
|
94 |
+
|
95 |
+
# Format input exactly like in training: "eng_Latn: {text}"
|
96 |
+
formatted_input = f"{source_lang}: {text}"
|
97 |
+
|
98 |
+
# Set source language for tokenizer
|
99 |
+
if hasattr(tokenizer, 'src_lang'):
|
100 |
+
tokenizer.src_lang = source_lang
|
101 |
|
102 |
# Tokenize input
|
103 |
+
inputs = tokenizer(
|
104 |
+
formatted_input,
|
105 |
+
return_tensors="pt",
|
106 |
+
padding=True,
|
107 |
+
truncation=True,
|
108 |
+
max_length=128 # Match training max_length
|
109 |
+
)
|
110 |
|
111 |
# Generate translation
|
112 |
start_time = time.time()
|
113 |
with torch.no_grad():
|
114 |
generated_tokens = model.generate(
|
115 |
**inputs,
|
|
|
116 |
max_length=max_length,
|
117 |
num_beams=num_beams,
|
118 |
early_stopping=True,
|
119 |
+
do_sample=False,
|
120 |
+
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id
|
121 |
)
|
122 |
|
123 |
# Decode translation
|
124 |
+
raw_translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
125 |
+
|
126 |
+
# Clean up translation - remove language prefixes if present
|
127 |
+
translation = raw_translation
|
128 |
+
|
129 |
+
# Remove source language prefix if it appears in output
|
130 |
+
if translation.startswith(f"{source_lang}:"):
|
131 |
+
translation = translation[len(f"{source_lang}:"):].strip()
|
132 |
+
|
133 |
+
# Remove target language prefix if it appears in output
|
134 |
+
if translation.startswith(f"{target_lang}:"):
|
135 |
+
translation = translation[len(f"{target_lang}:"):].strip()
|
136 |
+
|
137 |
+
# Remove original input if it appears at the start
|
138 |
+
if translation.lower().startswith(text.lower()):
|
139 |
+
translation = translation[len(text):].strip()
|
140 |
+
|
141 |
+
# Remove any remaining colons or prefixes at the start
|
142 |
+
translation = translation.lstrip(': ')
|
143 |
|
144 |
end_time = time.time()
|
145 |
processing_time = round(end_time - start_time, 2)
|
146 |
|
147 |
+
if translation and translation != formatted_input:
|
148 |
+
status = f"β
Translation completed in {processing_time} seconds"
|
149 |
+
else:
|
150 |
+
status = "β οΈ Translation completed but result may be incomplete"
|
151 |
+
if not translation:
|
152 |
+
translation = "[No translation generated]"
|
153 |
|
154 |
return translation, status
|
155 |
|
156 |
except Exception as e:
|
157 |
error_msg = f"β Translation error: {str(e)}"
|
158 |
+
print(f"Translation error: {e}")
|
159 |
+
import traceback
|
160 |
+
print(f"Full traceback: {traceback.format_exc()}")
|
161 |
return "", error_msg
|
162 |
|
163 |
def translate_batch(text_list):
|
|
|
199 |
except Exception as e:
|
200 |
return "", f"β Batch translation error: {str(e)}"
|
201 |
|
202 |
+
# Start loading model in background thread
|
203 |
print("Initializing model...")
|
204 |
+
loading_thread = threading.Thread(target=load_model)
|
205 |
+
loading_thread.daemon = True
|
206 |
+
loading_thread.start()
|
207 |
|
208 |
# Create Gradio interface
|
209 |
with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo:
|
|
|
217 |
**Model:** `UnarineLeo/nllb_eng_ven_terms`
|
218 |
""")
|
219 |
|
220 |
+
# Model status indicator
|
221 |
+
status_indicator = gr.Textbox(
|
222 |
+
value=get_model_status(),
|
223 |
+
label="Model Status",
|
224 |
+
interactive=False,
|
225 |
+
max_lines=1
|
226 |
+
)
|
227 |
+
|
228 |
+
# Auto-refresh status every 3 seconds while loading
|
229 |
+
def update_status():
|
230 |
+
return get_model_status()
|
231 |
+
|
232 |
+
# Set up periodic status updates
|
233 |
+
demo.load(lambda: get_model_status(), outputs=status_indicator, every=3)
|
234 |
+
|
235 |
with gr.Tab("Single Translation"):
|
236 |
with gr.Row():
|
237 |
with gr.Column():
|
|
|
275 |
lines=1
|
276 |
)
|
277 |
|
278 |
+
# Examples based on statistical terminology the model was trained on
|
279 |
gr.Examples(
|
280 |
examples=[
|
281 |
+
["Area planted for grain"],
|
282 |
+
["Population census"],
|
283 |
+
["Economic growth rate"],
|
284 |
+
["Statistical survey"],
|
285 |
+
["Data collection"],
|
286 |
+
["Income distribution"],
|
287 |
+
["Employment rate"],
|
288 |
+
["Agricultural production"],
|
289 |
+
["Household size"],
|
290 |
+
["Rural development"]
|
291 |
],
|
292 |
inputs=[input_text],
|
293 |
+
label="Try these statistical terms (model was trained on statistical terminology):"
|
294 |
)
|
295 |
|
296 |
with gr.Tab("Batch Translation"):
|
|
|
365 |
inputs=[input_text, max_length_slider, num_beams_slider],
|
366 |
outputs=[output_text, status_text]
|
367 |
)
|
368 |
+
|
369 |
+
# Refresh status button
|
370 |
+
refresh_btn = gr.Button("π Refresh Status", size="sm")
|
371 |
+
refresh_btn.click(
|
372 |
+
fn=update_status,
|
373 |
+
outputs=[status_indicator]
|
374 |
+
)
|
375 |
|
376 |
# Launch the app
|
377 |
if __name__ == "__main__":
|