Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from huggingface_hub import InferenceClient | |
import os | |
import fitz # PyMuPDF for PDF processing | |
from PIL import Image | |
import pytesseract | |
# Initialize Cerebras client for Llama 4 | |
cerebras_client = InferenceClient( | |
"meta-llama/Llama-4-Scout-17B-16E-Instruct", | |
provider="cerebras", | |
token=os.getenv("HF_TOKEN"), | |
) | |
# Global variables for models and tokenizers | |
en_es_tokenizer = None | |
en_es_model = None | |
es_en_tokenizer = None | |
es_en_model = None | |
def translate_en_to_es(text): | |
global en_es_tokenizer, en_es_model | |
# Initialize EN->ES model if needed | |
if en_es_tokenizer is None or en_es_model is None: | |
en_es_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="spa_Latn") | |
en_es_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
torch_dtype=torch.float16 | |
).cuda() | |
# Translate | |
inputs = en_es_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to("cuda") | |
with torch.no_grad(): | |
outputs = en_es_model.generate( | |
**inputs, | |
forced_bos_token_id=en_es_tokenizer.convert_tokens_to_ids("spa_Latn"), | |
max_length=512, | |
num_beams=5, | |
early_stopping=True | |
) | |
translation = en_es_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translation | |
def translate_es_to_en(text): | |
global es_en_tokenizer, es_en_model | |
# Initialize ES->EN model if needed | |
if es_en_tokenizer is None or es_en_model is None: | |
es_en_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="spa_Latn", tgt_lang="eng_Latn") | |
es_en_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
torch_dtype=torch.float16 | |
).cuda() | |
# Translate | |
inputs = es_en_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to("cuda") | |
with torch.no_grad(): | |
outputs = es_en_model.generate( | |
**inputs, | |
forced_bos_token_id=es_en_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
max_length=512, | |
num_beams=5, | |
early_stopping=True | |
) | |
translation = es_en_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translation | |
def extract_text_from_pdf(file_path): | |
"""Extract text from PDF file""" | |
try: | |
doc = fitz.open(file_path) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
doc.close() | |
return text | |
except Exception as e: | |
return f"Error extracting text from PDF: {str(e)}" | |
def extract_text_from_image(file_path): | |
"""Extract text from image using OCR""" | |
try: | |
image = Image.open(file_path) | |
text = pytesseract.image_to_string(image) | |
return text | |
except Exception as e: | |
return f"Error extracting text from image: {str(e)}" | |
def process_uploaded_file(file): | |
"""Process uploaded file and extract text""" | |
if file is None: | |
return "No file uploaded" | |
file_path = file.name | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.pdf': | |
return extract_text_from_pdf(file_path) | |
elif file_extension in ['.png', '.jpg', '.jpeg', '.tiff', '.bmp']: | |
return extract_text_from_image(file_path) | |
else: | |
return "Unsupported file format. Please upload PDF or image files." | |
def refine_with_llama(original_text, translation, direction, region="Mexico", formality="neutral"): | |
if direction == "en_to_es": | |
refine_prompt = f"""You are an expert Spanish translator specializing in {region} Spanish. Refine the following translation and explain your changes: | |
Original English: {original_text} | |
Initial Spanish translation: {translation} | |
Region: {region} | |
Formality level: {formality} | |
Requirements: | |
1. Use {region} Spanish vocabulary and expressions | |
2. Adjust for {formality} formality level | |
3. Fix any contextual errors or awkward phrasing | |
4. Preserve idiomatic expressions appropriately for {region} Spanish | |
Respond in this format: | |
TRANSLATION: [your refined translation] | |
EXPLANATION: [Brief explanation of changes made and why this version fits {formality} {region} Spanish better]""" | |
else: | |
refine_prompt = f"""You are an expert English translator. Refine the following translation and explain your changes: | |
Original Spanish: {original_text} | |
Initial English translation: {translation} | |
Formality level: {formality} | |
Requirements: | |
1. Use natural English expressions | |
2. Adjust for {formality} formality level | |
3. Fix any contextual errors or awkward phrasing | |
4. Preserve meaning while making it sound natural | |
Respond in this format: | |
TRANSLATION: [your refined translation] | |
EXPLANATION: [Brief explanation of changes made and why this version fits {formality} English better]""" | |
try: | |
response = cerebras_client.chat_completion( | |
messages=[{"role": "user", "content": refine_prompt}], | |
max_tokens=512, | |
temperature=0.3 | |
) | |
# Parse response to extract translation and explanation | |
content = response.choices[0].message.content.strip() | |
if "TRANSLATION:" in content and "EXPLANATION:" in content: | |
translation_part = content.split("TRANSLATION:")[1].split("EXPLANATION:")[0].strip() | |
explanation_part = content.split("EXPLANATION:")[1].strip() | |
return translation_part, explanation_part | |
else: | |
return content, "Explanation not available in expected format" | |
except Exception as e: | |
return f"Refinement error: {str(e)}", "" | |
def complete_translation(text, direction, region, formality): | |
if not text.strip(): | |
return "", "", "" | |
try: | |
# Step 1: Initial translation | |
if direction == "English to Spanish": | |
initial_translation = translate_en_to_es(text) | |
refined_translation, explanation = refine_with_llama(text, initial_translation, "en_to_es", region, formality) | |
else: # Spanish to English | |
initial_translation = translate_es_to_en(text) | |
refined_translation, explanation = refine_with_llama(text, initial_translation, "es_to_en", region, formality) | |
return initial_translation, refined_translation, explanation | |
except Exception as e: | |
return f"Error: {str(e)}", "", "" | |
def translate_from_file(file, direction, region, formality): | |
# Extract text from uploaded file | |
extracted_text = process_uploaded_file(file) | |
if "Error" in extracted_text or "No file" in extracted_text: | |
return extracted_text, "", "", "" | |
# Translate extracted text | |
initial_translation, refined_translation, explanation = complete_translation(extracted_text, direction, region, formality) | |
return extracted_text, initial_translation, refined_translation, explanation | |
# Create Gradio interface | |
with gr.Blocks(title="Document Translation with Regional Spanish") as demo: | |
gr.Markdown("# Document Translation with Regional Spanish") | |
gr.Markdown("Upload PDFs or images for OCR, or type text directly. Powered by NLLB-200 + Llama 4 with regional variants") | |
with gr.Tabs(): | |
# Text Translation Tab | |
with gr.TabItem("Text Translation"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_text = gr.Textbox( | |
label="Text to Translate", | |
placeholder="Enter text in English or Spanish...", | |
lines=6 | |
) | |
with gr.Row(): | |
direction = gr.Dropdown( | |
choices=["English to Spanish", "Spanish to English"], | |
value="English to Spanish", | |
label="Translation Direction" | |
) | |
with gr.Row(): | |
region = gr.Dropdown( | |
choices=["Mexico", "Spain", "Argentina", "Colombia", "Peru", "General"], | |
value="Mexico", | |
label="Spanish Variant" | |
) | |
formality = gr.Dropdown( | |
choices=["informal", "neutral", "formal"], | |
value="neutral", | |
label="Formality Level" | |
) | |
translate_btn = gr.Button("Translate", variant="primary", size="lg") | |
with gr.Column(scale=2): | |
initial_output = gr.Textbox( | |
label="Initial Translation (NLLB-200)", | |
lines=2, | |
interactive=False | |
) | |
refined_output = gr.Textbox( | |
label="Refined Translation (Llama 4)", | |
lines=2, | |
interactive=False | |
) | |
explanation_output = gr.Textbox( | |
label="Explanation of Changes", | |
lines=4, | |
interactive=False | |
) | |
# Document Upload Tab | |
with gr.TabItem("Document Translation"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
file_input = gr.File( | |
label="Upload PDF or Image", | |
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"] | |
) | |
with gr.Row(): | |
doc_direction = gr.Dropdown( | |
choices=["English to Spanish", "Spanish to English"], | |
value="English to Spanish", | |
label="Translation Direction" | |
) | |
with gr.Row(): | |
doc_region = gr.Dropdown( | |
choices=["Mexico", "Spain", "Argentina", "Colombia", "Peru", "General"], | |
value="Mexico", | |
label="Spanish Variant" | |
) | |
doc_formality = gr.Dropdown( | |
choices=["informal", "neutral", "formal"], | |
value="neutral", | |
label="Formality Level" | |
) | |
translate_doc_btn = gr.Button("Extract & Translate", variant="primary", size="lg") | |
with gr.Column(scale=2): | |
extracted_text = gr.Textbox( | |
label="Extracted Text", | |
lines=3, | |
interactive=False | |
) | |
doc_initial = gr.Textbox( | |
label="Initial Translation (NLLB-200)", | |
lines=3, | |
interactive=False | |
) | |
doc_refined = gr.Textbox( | |
label="Refined Translation (Llama 4)", | |
lines=3, | |
interactive=False | |
) | |
doc_explanation = gr.Textbox( | |
label="Explanation of Changes", | |
lines=3, | |
interactive=False | |
) | |
# Connect functions | |
translate_btn.click( | |
fn=complete_translation, | |
inputs=[input_text, direction, region, formality], | |
outputs=[initial_output, refined_output, explanation_output] | |
) | |
input_text.submit( | |
fn=complete_translation, | |
inputs=[input_text, direction, region, formality], | |
outputs=[initial_output, refined_output, explanation_output] | |
) | |
translate_doc_btn.click( | |
fn=translate_from_file, | |
inputs=[file_input, doc_direction, doc_region, doc_formality], | |
outputs=[extracted_text, doc_initial, doc_refined, doc_explanation] | |
) | |
if __name__ == "__main__": | |
demo.launch() |