SlouchyBuffalo commited on
Commit
72fdbe7
·
verified ·
1 Parent(s): 446b071

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import torch
5
+ from huggingface_hub import InferenceClient
6
+ import os
7
+
8
+ # Initialize Cerebras client for Llama 4
9
+ cerebras_client = InferenceClient(
10
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct",
11
+ provider="cerebras",
12
+ token=os.getenv("HF_TOKEN"),
13
+ )
14
+
15
+ # Global variables for models and tokenizers
16
+ en_es_tokenizer = None
17
+ en_es_model = None
18
+ es_en_tokenizer = None
19
+ es_en_model = None
20
+
21
+ @spaces.GPU(duration=60)
22
+ def translate_en_to_es(text):
23
+ global en_es_tokenizer, en_es_model
24
+
25
+ # Initialize EN->ES model if needed
26
+ if en_es_tokenizer is None or en_es_model is None:
27
+ en_es_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="spa_Latn")
28
+ en_es_model = AutoModelForSeq2SeqLM.from_pretrained(
29
+ "facebook/nllb-200-distilled-600M",
30
+ torch_dtype=torch.float16
31
+ ).cuda()
32
+
33
+ # Translate
34
+ inputs = en_es_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to("cuda")
35
+ with torch.no_grad():
36
+ outputs = en_es_model.generate(
37
+ **inputs,
38
+ forced_bos_token_id=en_es_tokenizer.convert_tokens_to_ids("spa_Latn"),
39
+ max_length=512,
40
+ num_beams=5,
41
+ early_stopping=True
42
+ )
43
+
44
+ translation = en_es_tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return translation
46
+
47
+ @spaces.GPU(duration=60)
48
+ def translate_es_to_en(text):
49
+ global es_en_tokenizer, es_en_model
50
+
51
+ # Initialize ES->EN model if needed
52
+ if es_en_tokenizer is None or es_en_model is None:
53
+ es_en_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="spa_Latn", tgt_lang="eng_Latn")
54
+ es_en_model = AutoModelForSeq2SeqLM.from_pretrained(
55
+ "facebook/nllb-200-distilled-600M",
56
+ torch_dtype=torch.float16
57
+ ).cuda()
58
+
59
+ # Translate
60
+ inputs = es_en_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to("cuda")
61
+ with torch.no_grad():
62
+ outputs = es_en_model.generate(
63
+ **inputs,
64
+ forced_bos_token_id=es_en_tokenizer.convert_tokens_to_ids("eng_Latn"),
65
+ max_length=512,
66
+ num_beams=5,
67
+ early_stopping=True
68
+ )
69
+
70
+ translation = es_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+ return translation
72
+
73
+ def refine_with_llama(original_text, translation, direction, region="general", formality="neutral"):
74
+ if direction == "en_to_es":
75
+ refine_prompt = f"""You are an expert Spanish translator. Refine the following translation to address these common issues:
76
+
77
+ 1. Context and ambiguity resolution
78
+ 2. Cultural nuances and regional variations for {region}
79
+ 3. Tone and formality ({formality})
80
+ 4. Grammatical correctness
81
+ 5. Idiomatic expressions
82
+
83
+ Original English: {original_text}
84
+ Initial Spanish translation: {translation}
85
+ Region preference: {region}
86
+
87
+ Provide only the refined Spanish translation, nothing else."""
88
+ else:
89
+ refine_prompt = f"""You are an expert English translator. Refine the following translation to address these common issues:
90
+
91
+ 1. Context and ambiguity resolution
92
+ 2. Cultural nuances and natural English expressions
93
+ 3. Tone and formality ({formality})
94
+ 4. Grammatical correctness
95
+ 5. Idiomatic expressions
96
+
97
+ Original Spanish: {original_text}
98
+ Initial English translation: {translation}
99
+ Formality: {formality}
100
+
101
+ Provide only the refined English translation, nothing else."""
102
+
103
+ try:
104
+ response = cerebras_client.chat_completion(
105
+ messages=[{"role": "user", "content": refine_prompt}],
106
+ max_tokens=512,
107
+ temperature=0.3
108
+ )
109
+ return response.choices[0].message.content.strip()
110
+ except Exception as e:
111
+ return f"Refinement error: {str(e)}"
112
+
113
+ def complete_translation(text, direction, region, formality):
114
+ if not text.strip():
115
+ return "", ""
116
+
117
+ try:
118
+ # Step 1: Initial translation
119
+ if direction == "English to Spanish":
120
+ initial_translation = translate_en_to_es(text)
121
+ refined_translation = refine_with_llama(text, initial_translation, "en_to_es", region, formality)
122
+ else: # Spanish to English
123
+ initial_translation = translate_es_to_en(text)
124
+ refined_translation = refine_with_llama(text, initial_translation, "es_to_en", region, formality)
125
+
126
+ return initial_translation, refined_translation
127
+ except Exception as e:
128
+ return f"Error: {str(e)}", ""
129
+
130
+ # Create Gradio interface
131
+ with gr.Blocks(title="Bidirectional English-Spanish Translator") as demo:
132
+ gr.Markdown("# Bidirectional English-Spanish Translator")
133
+ gr.Markdown("Powered by NLLB-200 + Llama 4 via Cerebras for context-aware, culturally nuanced translations")
134
+
135
+ with gr.Row():
136
+ with gr.Column(scale=2):
137
+ input_text = gr.Textbox(
138
+ label="Text to Translate",
139
+ placeholder="Enter text in English or Spanish...",
140
+ lines=6
141
+ )
142
+
143
+ with gr.Row():
144
+ direction = gr.Dropdown(
145
+ choices=["English to Spanish", "Spanish to English"],
146
+ value="English to Spanish",
147
+ label="Translation Direction"
148
+ )
149
+
150
+ with gr.Row():
151
+ region = gr.Dropdown(
152
+ choices=["general", "Mexico", "Spain", "Argentina", "Colombia", "Peru"],
153
+ value="general",
154
+ label="Spanish Variant (for ES translations)"
155
+ )
156
+ formality = gr.Dropdown(
157
+ choices=["neutral", "formal", "informal"],
158
+ value="neutral",
159
+ label="Formality Level"
160
+ )
161
+
162
+ translate_btn = gr.Button("Translate", variant="primary", size="lg")
163
+
164
+ with gr.Column(scale=2):
165
+ with gr.Row():
166
+ initial_output = gr.Textbox(
167
+ label="Initial Translation (NLLB-200)",
168
+ lines=3,
169
+ interactive=False
170
+ )
171
+ refined_output = gr.Textbox(
172
+ label="Refined Translation (Llama 4)",
173
+ lines=3,
174
+ interactive=False
175
+ )
176
+
177
+ # Connect function
178
+ translate_btn.click(
179
+ fn=complete_translation,
180
+ inputs=[input_text, direction, region, formality],
181
+ outputs=[initial_output, refined_output]
182
+ )
183
+
184
+ input_text.submit(
185
+ fn=complete_translation,
186
+ inputs=[input_text, direction, region, formality],
187
+ outputs=[initial_output, refined_output]
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ demo.launch()