UnarineLeo commited on
Commit
21b365a
Β·
verified Β·
1 Parent(s): 9c42d8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -26
app.py CHANGED
@@ -2,31 +2,71 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import time
 
5
 
6
  # Global variables for model and tokenizer
7
  model = None
8
  tokenizer = None
 
 
 
9
 
10
  def load_model():
11
  """Load the model and tokenizer"""
12
- global model, tokenizer
 
 
 
13
 
14
  try:
15
  model_name = "UnarineLeo/nllb_eng_ven_terms"
16
  print(f"Loading model: {model_name}")
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
21
  print("Model loaded successfully!")
22
  return True
 
23
  except Exception as e:
 
 
 
24
  print(f"Error loading model: {e}")
25
  return False
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  def translate_text(text, max_length=512, num_beams=5):
28
  """
29
- Translate English text to Venda
30
 
31
  Args:
32
  text (str): Input English text
@@ -36,45 +76,88 @@ def translate_text(text, max_length=512, num_beams=5):
36
  Returns:
37
  tuple: (translated_text, status_message)
38
  """
39
- global model, tokenizer
40
 
41
  if not text.strip():
42
  return "", "Please enter some text to translate."
43
 
44
- if model is None or tokenizer is None:
45
- return "", "Model not loaded. Please wait while the model loads."
 
 
 
46
 
47
  try:
48
- # Set source language
49
- tokenizer.src_lang = "eng_Latn"
 
 
 
 
 
 
 
 
50
 
51
  # Tokenize input
52
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
 
53
 
54
  # Generate translation
55
  start_time = time.time()
56
  with torch.no_grad():
57
  generated_tokens = model.generate(
58
  **inputs,
59
- forced_bos_token_id=tokenizer.lang_code_to_id["ven_Latn"],
60
  max_length=max_length,
61
  num_beams=num_beams,
62
  early_stopping=True,
63
- do_sample=False
 
64
  )
65
 
66
  # Decode translation
67
- translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  end_time = time.time()
70
  processing_time = round(end_time - start_time, 2)
71
 
72
- status = f"βœ… Translation completed in {processing_time} seconds"
 
 
 
 
 
73
 
74
  return translation, status
75
 
76
  except Exception as e:
77
  error_msg = f"❌ Translation error: {str(e)}"
 
 
 
78
  return "", error_msg
79
 
80
  def translate_batch(text_list):
@@ -116,9 +199,11 @@ def translate_batch(text_list):
116
  except Exception as e:
117
  return "", f"❌ Batch translation error: {str(e)}"
118
 
119
- # Load model on startup
120
  print("Initializing model...")
121
- model_loaded = load_model()
 
 
122
 
123
  # Create Gradio interface
124
  with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo:
@@ -132,6 +217,21 @@ with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as d
132
  **Model:** `UnarineLeo/nllb_eng_ven_terms`
133
  """)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  with gr.Tab("Single Translation"):
136
  with gr.Row():
137
  with gr.Column():
@@ -175,20 +275,22 @@ with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as d
175
  lines=1
176
  )
177
 
178
- # Examples
179
  gr.Examples(
180
  examples=[
181
- ["Hello, how are you?"],
182
- ["Good morning, everyone."],
183
- ["Thank you for your help."],
184
- ["What is your name?"],
185
- ["I am learning Venda."],
186
- ["Welcome to our school."],
187
- ["The weather is beautiful today."],
188
- ["Can you help me please?"]
 
 
189
  ],
190
  inputs=[input_text],
191
- label="Try these examples:"
192
  )
193
 
194
  with gr.Tab("Batch Translation"):
@@ -263,6 +365,13 @@ with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as d
263
  inputs=[input_text, max_length_slider, num_beams_slider],
264
  outputs=[output_text, status_text]
265
  )
 
 
 
 
 
 
 
266
 
267
  # Launch the app
268
  if __name__ == "__main__":
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import time
5
+ import threading
6
 
7
  # Global variables for model and tokenizer
8
  model = None
9
  tokenizer = None
10
+ model_loading = False
11
+ model_loaded = False
12
+ loading_error = None
13
 
14
  def load_model():
15
  """Load the model and tokenizer"""
16
+ global model, tokenizer, model_loading, model_loaded, loading_error
17
+
18
+ model_loading = True
19
+ loading_error = None
20
 
21
  try:
22
  model_name = "UnarineLeo/nllb_eng_ven_terms"
23
  print(f"Loading model: {model_name}")
24
 
25
+ # Try loading with different configurations
26
+ try:
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ model = AutoModelForSeq2SeqLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
+ device_map="auto" if torch.cuda.is_available() else None
32
+ )
33
+ except Exception as e1:
34
+ print(f"First attempt failed: {e1}")
35
+ # Fallback: try without optimizations
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
38
+
39
+ # Test if model works
40
+ test_input = tokenizer("Hello", return_tensors="pt")
41
+ with torch.no_grad():
42
+ _ = model.generate(**test_input, max_length=10)
43
 
44
+ model_loaded = True
45
+ model_loading = False
46
  print("Model loaded successfully!")
47
  return True
48
+
49
  except Exception as e:
50
+ loading_error = str(e)
51
+ model_loading = False
52
+ model_loaded = False
53
  print(f"Error loading model: {e}")
54
  return False
55
 
56
+ def get_model_status():
57
+ """Get current model loading status"""
58
+ if model_loaded:
59
+ return "βœ… Model loaded and ready"
60
+ elif model_loading:
61
+ return "⏳ Model is loading, please wait..."
62
+ elif loading_error:
63
+ return f"❌ Model loading failed: {loading_error}"
64
+ else:
65
+ return "⏳ Initializing model..."
66
+
67
  def translate_text(text, max_length=512, num_beams=5):
68
  """
69
+ Translate English text to Venda using the fine-tuned NLLB model
70
 
71
  Args:
72
  text (str): Input English text
 
76
  Returns:
77
  tuple: (translated_text, status_message)
78
  """
79
+ global model, tokenizer, model_loaded, model_loading
80
 
81
  if not text.strip():
82
  return "", "Please enter some text to translate."
83
 
84
+ if not model_loaded:
85
+ if model_loading:
86
+ return "", "⏳ Model is still loading, please wait a moment and try again."
87
+ else:
88
+ return "", f"❌ Model not available. {loading_error if loading_error else 'Please refresh the page.'}"
89
 
90
  try:
91
+ # Language codes as used in training
92
+ source_lang = "eng_Latn"
93
+ target_lang = "ven_Latn"
94
+
95
+ # Format input exactly like in training: "eng_Latn: {text}"
96
+ formatted_input = f"{source_lang}: {text}"
97
+
98
+ # Set source language for tokenizer
99
+ if hasattr(tokenizer, 'src_lang'):
100
+ tokenizer.src_lang = source_lang
101
 
102
  # Tokenize input
103
+ inputs = tokenizer(
104
+ formatted_input,
105
+ return_tensors="pt",
106
+ padding=True,
107
+ truncation=True,
108
+ max_length=128 # Match training max_length
109
+ )
110
 
111
  # Generate translation
112
  start_time = time.time()
113
  with torch.no_grad():
114
  generated_tokens = model.generate(
115
  **inputs,
 
116
  max_length=max_length,
117
  num_beams=num_beams,
118
  early_stopping=True,
119
+ do_sample=False,
120
+ pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id
121
  )
122
 
123
  # Decode translation
124
+ raw_translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
125
+
126
+ # Clean up translation - remove language prefixes if present
127
+ translation = raw_translation
128
+
129
+ # Remove source language prefix if it appears in output
130
+ if translation.startswith(f"{source_lang}:"):
131
+ translation = translation[len(f"{source_lang}:"):].strip()
132
+
133
+ # Remove target language prefix if it appears in output
134
+ if translation.startswith(f"{target_lang}:"):
135
+ translation = translation[len(f"{target_lang}:"):].strip()
136
+
137
+ # Remove original input if it appears at the start
138
+ if translation.lower().startswith(text.lower()):
139
+ translation = translation[len(text):].strip()
140
+
141
+ # Remove any remaining colons or prefixes at the start
142
+ translation = translation.lstrip(': ')
143
 
144
  end_time = time.time()
145
  processing_time = round(end_time - start_time, 2)
146
 
147
+ if translation and translation != formatted_input:
148
+ status = f"βœ… Translation completed in {processing_time} seconds"
149
+ else:
150
+ status = "⚠️ Translation completed but result may be incomplete"
151
+ if not translation:
152
+ translation = "[No translation generated]"
153
 
154
  return translation, status
155
 
156
  except Exception as e:
157
  error_msg = f"❌ Translation error: {str(e)}"
158
+ print(f"Translation error: {e}")
159
+ import traceback
160
+ print(f"Full traceback: {traceback.format_exc()}")
161
  return "", error_msg
162
 
163
  def translate_batch(text_list):
 
199
  except Exception as e:
200
  return "", f"❌ Batch translation error: {str(e)}"
201
 
202
+ # Start loading model in background thread
203
  print("Initializing model...")
204
+ loading_thread = threading.Thread(target=load_model)
205
+ loading_thread.daemon = True
206
+ loading_thread.start()
207
 
208
  # Create Gradio interface
209
  with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo:
 
217
  **Model:** `UnarineLeo/nllb_eng_ven_terms`
218
  """)
219
 
220
+ # Model status indicator
221
+ status_indicator = gr.Textbox(
222
+ value=get_model_status(),
223
+ label="Model Status",
224
+ interactive=False,
225
+ max_lines=1
226
+ )
227
+
228
+ # Auto-refresh status every 3 seconds while loading
229
+ def update_status():
230
+ return get_model_status()
231
+
232
+ # Set up periodic status updates
233
+ demo.load(lambda: get_model_status(), outputs=status_indicator, every=3)
234
+
235
  with gr.Tab("Single Translation"):
236
  with gr.Row():
237
  with gr.Column():
 
275
  lines=1
276
  )
277
 
278
+ # Examples based on statistical terminology the model was trained on
279
  gr.Examples(
280
  examples=[
281
+ ["Area planted for grain"],
282
+ ["Population census"],
283
+ ["Economic growth rate"],
284
+ ["Statistical survey"],
285
+ ["Data collection"],
286
+ ["Income distribution"],
287
+ ["Employment rate"],
288
+ ["Agricultural production"],
289
+ ["Household size"],
290
+ ["Rural development"]
291
  ],
292
  inputs=[input_text],
293
+ label="Try these statistical terms (model was trained on statistical terminology):"
294
  )
295
 
296
  with gr.Tab("Batch Translation"):
 
365
  inputs=[input_text, max_length_slider, num_beams_slider],
366
  outputs=[output_text, status_text]
367
  )
368
+
369
+ # Refresh status button
370
+ refresh_btn = gr.Button("πŸ”„ Refresh Status", size="sm")
371
+ refresh_btn.click(
372
+ fn=update_status,
373
+ outputs=[status_indicator]
374
+ )
375
 
376
  # Launch the app
377
  if __name__ == "__main__":