import os import torch import spaces import psycopg2 import gradio as gr from threading import Thread from collections.abc import Iterator from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import gc # Constants MAX_MAX_NEW_TOKENS = 4096 MAX_INPUT_TOKEN_LENGTH = 4096 DEFAULT_MAX_NEW_TOKENS = 2048 HF_TOKEN = os.environ.get("HF_TOKEN", "") # Language lists INDIC_LANGUAGES = [ "Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati", "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", "Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri", "Manipuri", "Bodo", "English", "Sanskrit" ] SARVAM_LANGUAGES = INDIC_LANGUAGES # Model configurations with optimizations TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 DEVICE_MAP = "auto" if torch.cuda.is_available() else None indictrans_model = AutoModelForCausalLM.from_pretrained( "ai4bharat/IndicTrans3-beta", torch_dtype=TORCH_DTYPE, device_map=DEVICE_MAP, token=HF_TOKEN, low_cpu_mem_usage=True, trust_remote_code=True ) sarvam_model = AutoModelForCausalLM.from_pretrained( "sarvamai/sarvam-translate", torch_dtype=TORCH_DTYPE, device_map=DEVICE_MAP, token=HF_TOKEN, low_cpu_mem_usage=True, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/IndicTrans3-beta", trust_remote_code=True ) def format_message_for_translation(message, target_lang): return f"Translate the following text to {target_lang}: {message}" def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type): try: if not rating: gr.Warning("Please select a rating before submitting feedback.", duration=5) return None if not feedback_text or feedback_text.strip() == "": gr.Warning("Please provide some feedback before submitting.", duration=5) return None if not chat_history: gr.Warning("Please provide the input text before submitting feedback.", duration=5) return None if len(chat_history[0]) < 2: gr.Warning("Please translate the input text before submitting feedback.", duration=5) return None conn = psycopg2.connect( host=os.getenv("DB_HOST"), database=os.getenv("DB_NAME"), user=os.getenv("DB_USER"), password=os.getenv("DB_PASSWORD"), port=os.getenv("DB_PORT"), ) cursor = conn.cursor() insert_query = """ INSERT INTO feedback (tgt_lang, rating, feedback_txt, chat_history, model_type) VALUES (%s, %s, %s, %s, %s) """ cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type)) conn.commit() cursor.close() conn.close() gr.Info("Thank you for your feedback! 🙏", duration=5) except Exception as e: print(f"Database error: {e}") gr.Error("An error occurred while storing feedback. Please try again later.", duration=5) def store_output(tgt_lang, input_text, output_text, model_type): try: conn = psycopg2.connect( host=os.getenv("DB_HOST"), database=os.getenv("DB_NAME"), user=os.getenv("DB_USER"), password=os.getenv("DB_PASSWORD"), port=os.getenv("DB_PORT"), ) cursor = conn.cursor() insert_query = """ INSERT INTO translation (input_txt, output_txt, tgt_lang, model_type) VALUES (%s, %s, %s, %s) """ cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type)) conn.commit() cursor.close() conn.close() except Exception as e: print(f"Database error: {e}") @spaces.GPU def translate_message( message: str, chat_history: list[dict], target_language: str = "Hindi", max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, model_type: str = "indictrans" ) -> Iterator[str]: if model_type == "indictrans": model = indictrans_model elif model_type == "sarvam": model = sarvam_model if model is None or tokenizer is None: yield "Error: Model failed to load. Please try again." return conversation = [] translation_request = format_message_for_translation(message, target_language) conversation.append({"role": "user", "content": translation_request}) try: input_ids = tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True ) if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, "use_cache": True, # Enable KV cache "pad_token_id": tokenizer.eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Clean up if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() store_output(target_language, message, "".join(outputs), model_type) except Exception as e: yield f"Translation error: {str(e)}" # Enhanced CSS with beautiful styling css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); * { font-family: 'Inter', sans-serif; box-sizing: border-box; } .gradio-container { background: #1a1a1a !important; color: #e0e0e0; min-height: 100vh; } .main-container { background: #2a2a2a; border-radius: 12px; padding: 1.5rem; margin: 1rem; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); } .title-container { text-align: center; margin-bottom: 1.5rem; padding: 1rem; color: #a0a0ff; } .model-tab { background: #3333a0; border: none; border-radius: 8px; color: #ffffff; font-weight: 500; padding: 0.75rem 1.5rem; transition: all 0.2s ease; } .model-tab:hover { background: #4444b0; transform: translateY(-1px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); } .language-dropdown { background: #333333; border: 1px solid #444444; border-radius: 8px; padding: 0.5rem; font-size: 14px; color: #e0e0e0; transition: all 0.2s ease; } .language-dropdown:focus { border-color: #6666ff; box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2); } .chat-container { background: #222222; border-radius: 8px; padding: 1rem; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); margin: 1rem 0; } .message-input { background: #333333; border: 1px solid #444444; border-radius: 8px; padding: 0.75rem; font-size: 14px; color: #e0e0e0; transition: all 0.2s ease; } .message-input:focus { border-color: #6666ff; box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2); } .translate-btn { background: #3333a0; border: none; border-radius: 8px; color: #ffffff; font-weight: 500; padding: 0.75rem 1.5rem; font-size: 14px; cursor: pointer; transition: all 0.2s ease; } .translate-btn:hover { background: #4444b0; transform: translateY(-1px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); } .examples-container { background: #2a2a2a; border-radius: 8px; padding: 1rem; margin: 1rem 0; } .feedback-section { background: #2a2a2a; border-radius: 8px; padding: 1rem; margin: 1rem 0; border: none; } .advanced-options { background: #2a2a2a; border-radius: 8px; padding: 1rem; margin: 1rem 0; } .slider-container .gr-slider { background: #444444; color: #e0e0e0; } .rating-container { display: flex; gap: 0.5rem; justify-content: center; margin: 0.5rem 0; } .feedback-btn { background: #3333a0; border: none; border-radius: 8px; color: #ffffff; font-weight: 500; padding: 0.5rem 1rem; cursor: pointer; transition: all 0.2s ease; } .feedback-btn:hover { background: #4444b0; transform: translateY(-1px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); } .stats-card { background: #333333; border-radius: 8px; padding: 0.75rem; text-align: center; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); margin: 0.5rem; color: #e0e0e0; } .model-info { background: #3333a0; color: #ffffff; border-radius: 8px; padding: 1rem; margin: 1rem 0; } .animate-pulse { animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite; } @keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } } .loading-spinner { border: 3px solid #444444; border-top: 3px solid #6666ff; border-radius: 50%; width: 30px; height: 30px; animation: spin 1.5s linear infinite; margin: 0 auto; } @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } """ # Model descriptions INDICTRANS_DESCRIPTION = """
Latest SOTA translation model from AI4Bharat
Advanced multilingual translation model
Experience state-of-the-art translation with multiple AI models
22+
Languages
2
AI Models
Optimized
Performance
Secure
Processing
🚀 Powered by AI4Bharat & Sarvam AI | Built with ❤️ using Gradio | 🔧 Optimized with KV Caching & Advanced Memory Management