AshwinSankar's picture
Update app.py
3f2ad6a verified
import os
import torch
import spaces
import psycopg2
import gradio as gr
from threading import Thread
from collections.abc import Iterator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gc
# Constants
MAX_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Language lists
INDIC_LANGUAGES = [
"Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati",
"Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili",
"Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri",
"Manipuri", "Bodo", "English", "Sanskrit"
]
SARVAM_LANGUAGES = INDIC_LANGUAGES
# Model configurations with optimizations
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEVICE_MAP = "cuda:0" if torch.cuda.is_available() else "cpu"
indictrans_model = AutoModelForCausalLM.from_pretrained(
"ai4bharat/IndicTrans3-beta",
torch_dtype=TORCH_DTYPE,
device_map=DEVICE_MAP,
token=HF_TOKEN,
low_cpu_mem_usage=True,
trust_remote_code=True
)
sarvam_model = AutoModelForCausalLM.from_pretrained(
"sarvamai/sarvam-translate",
torch_dtype=TORCH_DTYPE,
device_map=DEVICE_MAP,
token=HF_TOKEN,
low_cpu_mem_usage=True,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/IndicTrans3-beta",
trust_remote_code=True
)
def format_message_for_translation(message, target_lang):
return f"Translate the following text to {target_lang}: {message}"
def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type):
try:
if not rating:
gr.Warning("Please select a rating before submitting feedback.", duration=5)
return None
if not feedback_text or feedback_text.strip() == "":
gr.Warning("Please provide some feedback before submitting.", duration=5)
return None
if not chat_history:
gr.Warning("Please provide the input text before submitting feedback.", duration=5)
return None
if len(chat_history[0]) < 2:
gr.Warning("Please translate the input text before submitting feedback.", duration=5)
return None
conn = psycopg2.connect(
host=os.getenv("DB_HOST"),
database=os.getenv("DB_NAME"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
port=os.getenv("DB_PORT"),
)
cursor = conn.cursor()
insert_query = """
INSERT INTO feedback
(tgt_lang, rating, feedback_txt, chat_history, model_type)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type))
conn.commit()
cursor.close()
conn.close()
gr.Info("Thank you for your feedback! ๐Ÿ™", duration=5)
except Exception as e:
print(f"Database error: {e}")
gr.Error("An error occurred while storing feedback. Please try again later.", duration=5)
def store_output(tgt_lang, input_text, output_text, model_type):
try:
conn = psycopg2.connect(
host=os.getenv("DB_HOST"),
database=os.getenv("DB_NAME"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
port=os.getenv("DB_PORT"),
)
cursor = conn.cursor()
insert_query = """
INSERT INTO translation
(input_txt, output_txt, tgt_lang, model_type)
VALUES (%s, %s, %s, %s)
"""
cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type))
conn.commit()
cursor.close()
conn.close()
except Exception as e:
print(f"Database error: {e}")
@spaces.GPU
def translate_message(
message: str,
chat_history: list[dict],
target_language: str = "Hindi",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
model_type: str = "indictrans"
) -> Iterator[str]:
if model_type == "indictrans":
model = indictrans_model
elif model_type == "sarvam":
model = sarvam_model
if model is None or tokenizer is None:
yield "Error: Model failed to load. Please try again."
return
conversation = []
translation_request = format_message_for_translation(message, target_language)
conversation.append({"role": "user", "content": translation_request})
try:
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt", add_generation_prompt=True
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
"use_cache": True, # Enable KV cache
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Clean up
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
store_output(target_language, message, "".join(outputs), model_type)
except Exception as e:
yield f"Translation error: {str(e)}"
# Enhanced CSS with beautiful styling
css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
* {
font-family: 'Inter', sans-serif;
box-sizing: border-box;
}
.gradio-container {
background: #1a1a1a !important;
color: #e0e0e0;
min-height: 100vh;
}
.main-container {
background: #2a2a2a;
border-radius: 12px;
padding: 1.5rem;
margin: 1rem;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
.title-container {
text-align: center;
margin-bottom: 1.5rem;
padding: 1rem;
color: #a0a0ff;
}
.model-tab {
background: #3333a0;
border: none;
border-radius: 8px;
color: #ffffff;
font-weight: 500;
padding: 0.75rem 1.5rem;
transition: all 0.2s ease;
}
.model-tab:hover {
background: #4444b0;
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
}
.language-dropdown {
background: #333333;
border: 1px solid #444444;
border-radius: 8px;
padding: 0.5rem;
font-size: 14px;
color: #e0e0e0;
transition: all 0.2s ease;
}
.language-dropdown:focus {
border-color: #6666ff;
box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2);
}
.chat-container {
background: #222222;
border-radius: 8px;
padding: 1rem;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
margin: 1rem 0;
}
.message-input {
background: #333333;
border: 1px solid #444444;
border-radius: 8px;
padding: 0.75rem;
font-size: 14px;
color: #e0e0e0;
transition: all 0.2s ease;
}
.message-input:focus {
border-color: #6666ff;
box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2);
}
.translate-btn {
background: #3333a0;
border: none;
border-radius: 8px;
color: #ffffff;
font-weight: 500;
padding: 0.75rem 1.5rem;
font-size: 14px;
cursor: pointer;
transition: all 0.2s ease;
}
.translate-btn:hover {
background: #4444b0;
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
}
.examples-container {
background: #2a2a2a;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
}
.feedback-section {
background: #2a2a2a;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
border: none;
}
.advanced-options {
background: #2a2a2a;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
}
.slider-container .gr-slider {
background: #444444;
color: #e0e0e0;
}
.rating-container {
display: flex;
gap: 0.5rem;
justify-content: center;
margin: 0.5rem 0;
}
.feedback-btn {
background: #3333a0;
border: none;
border-radius: 8px;
color: #ffffff;
font-weight: 500;
padding: 0.5rem 1rem;
cursor: pointer;
transition: all 0.2s ease;
}
.feedback-btn:hover {
background: #4444b0;
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
}
.stats-card {
background: #333333;
border-radius: 8px;
padding: 0.75rem;
text-align: center;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
margin: 0.5rem;
color: #e0e0e0;
}
.model-info {
background: #3333a0;
color: #ffffff;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
}
.animate-pulse {
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
@keyframes pulse {
0%, 100% {
opacity: 1;
}
50% {
opacity: 0.5;
}
}
.loading-spinner {
border: 3px solid #444444;
border-top: 3px solid #6666ff;
border-radius: 50%;
width: 30px;
height: 30px;
animation: spin 1.5s linear infinite;
margin: 0 auto;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
"""
# Model descriptions
INDICTRANS_DESCRIPTION = """
<div class="model-info">
<h3>๐ŸŒŸ IndicTrans3-Beta</h3>
<p><strong>Latest SOTA translation model from AI4Bharat</strong></p>
<ul>
<li>โœ… Supports <strong>22 Indic languages</strong></li>
<li>โœ… Document-level machine translation</li>
<li>โœ… Optimized for real-world applications</li>
<li>โœ… Enhanced with KV caching for faster inference</li>
</ul>
</div>
"""
SARVAM_DESCRIPTION = """
<div class="model-info">
<h3>๐Ÿš€ Sarvam Translate</h3>
<p><strong>Advanced multilingual translation model</strong></p>
<ul>
<li>โœ… Supports <strong>22 Indic languages</strong></li>
<li>โœ… High-quality translations</li>
<li>โœ… Document-level machine translation</li>
<li>โœ… Optimized for real-world applications</li>
<li>โœ… Optimized for production use</li>
<li>โœ… Enhanced with KV caching for faster inference</li>
</ul>
</div>
"""
def create_chatbot_interface(model_type, languages, description):
with gr.Column(elem_classes="main-container"):
gr.Markdown(description)
target_language = gr.Dropdown(
languages,
value=languages[0],
label="๐ŸŒ Select Target Language",
elem_classes="language-dropdown",
)
chatbot = gr.Chatbot(
height=500,
elem_classes="chat-container",
show_copy_button=True,
avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"],
bubble_full_width=False,
show_label=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="โœ๏ธ Enter text to translate...",
show_label=False,
container=False,
scale=9,
elem_classes="message-input",
)
submit_btn = gr.Button(
"๐Ÿ”„ Translate",
scale=1,
elem_classes="translate-btn"
)
# Examples section
if model_type == "indictrans":
examples_data = [
"The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India.",
"Kumbh Mela, the world's largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks.",
"India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, are deeply rooted in tradition and storytelling.",
"Ayurveda, India's ancient medical system, emphasizes a holistic approach to health by balancing mind, body, and spirit.",
"Diwali, the festival of lights, symbolizes the victory of light over darkness and good over evil."
]
else:
examples_data = [
"Hello, how are you today?",
"I love learning new languages and cultures.",
"Technology is transforming the way we communicate.",
"The weather is beautiful today.",
"Thank you for your help and support."
]
with gr.Accordion("๐Ÿ“š Example Texts", open=False, elem_classes="examples-container"):
gr.Examples(
examples=examples_data,
inputs=msg,
label="Click on any example to try:"
)
# Feedback section
with gr.Accordion("๐Ÿ’ญ Provide Feedback", open=False, elem_classes="feedback-section"):
gr.Markdown("### ๐Ÿ“ Rate Translation & Share Feedback")
gr.Markdown("Help us improve translation quality with your valuable feedback!")
with gr.Row():
rating = gr.Radio(
["1", "2", "3", "4", "5"],
label="๐Ÿ† Translation Quality Rating",
value=None
)
feedback_text = gr.Textbox(
placeholder="๐Ÿ’ฌ Share your thoughts about the translation quality, accuracy, or suggestions for improvement...",
label="๐Ÿ“ Your Feedback",
lines=3,
)
feedback_submit = gr.Button(
"๐Ÿ“ค Submit Feedback",
elem_classes="feedback-btn"
)
# Advanced options
with gr.Accordion("โš™๏ธ Advanced Settings", open=False, elem_classes="advanced-options"):
gr.Markdown("### ๐Ÿ”ง Fine-tune Translation Parameters")
with gr.Row():
max_new_tokens = gr.Slider(
label="๐Ÿ“ Max New Tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
elem_classes="slider-container"
)
temperature = gr.Slider(
label="๐ŸŒก๏ธ Temperature",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.1,
elem_classes="slider-container"
)
with gr.Row():
top_p = gr.Slider(
label="๐ŸŽฏ Top-p (Nucleus Sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
elem_classes="slider-container"
)
top_k = gr.Slider(
label="๐Ÿ” Top-k",
minimum=1,
maximum=100,
step=1,
value=50,
elem_classes="slider-container"
)
repetition_penalty = gr.Slider(
label="๐Ÿ”„ Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
elem_classes="slider-container"
)
return (chatbot, msg, submit_btn, target_language, rating, feedback_text,
feedback_submit, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
def user(user_message, history, target_lang):
return "", history + [[user_message, None]]
def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type):
user_message = history[-1][0]
history[-1][1] = ""
for chunk in translate_message(
user_message, history[:-1], target_lang, max_tokens,
temp, top_p_val, top_k_val, rep_penalty, model_type
):
history[-1][1] = chunk
yield history
# Main Gradio interface
with gr.Blocks(css=css, title="๐ŸŒ Advanced Multilingual Translation Hub", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div class="title-container">
<h1>๐ŸŒ Advanced Multilingual Translation Hub</h1>
<p style="font-size: 18px; margin-top: 10px;">
Experience state-of-the-art translation with multiple AI models
</p>
</div>
""",
elem_classes="title-container"
)
# Statistics cards
with gr.Row():
gr.Markdown(
'<div class="stats-card"><h3>๐ŸŽฏ</h3><p><strong>22+</strong><br>Languages</p></div>',
elem_classes="stats-card"
)
gr.Markdown(
'<div class="stats-card"><h3>๐Ÿš€</h3><p><strong>2</strong><br>AI Models</p></div>',
elem_classes="stats-card"
)
gr.Markdown(
'<div class="stats-card"><h3>โšก</h3><p><strong>Optimized</strong><br>Performance</p></div>',
elem_classes="stats-card"
)
gr.Markdown(
'<div class="stats-card"><h3>๐Ÿ”’</h3><p><strong>Secure</strong><br>Processing</p></div>',
elem_classes="stats-card"
)
with gr.Tabs(elem_classes="model-tab") as tabs:
with gr.TabItem("๐Ÿ‡ฎ๐Ÿ‡ณ IndicTrans3-Beta", elem_id="indictrans-tab"):
indictrans_components = create_chatbot_interface("indictrans", INDIC_LANGUAGES, INDICTRANS_DESCRIPTION)
with gr.TabItem("๐ŸŒ Sarvam Translate", elem_id="sarvam-tab"):
sarvam_components = create_chatbot_interface("sarvam", SARVAM_LANGUAGES, SARVAM_DESCRIPTION)
# Event handlers for IndicTrans
(indictrans_chatbot, indictrans_msg, indictrans_submit, indictrans_lang,
indictrans_rating, indictrans_feedback, indictrans_feedback_submit,
indictrans_max_tokens, indictrans_temp, indictrans_top_p,
indictrans_top_k, indictrans_rep_penalty) = indictrans_components
indictrans_msg.submit(
user, [indictrans_msg, indictrans_chatbot, indictrans_lang],
[indictrans_msg, indictrans_chatbot], queue=False
).then(
lambda *args: bot(*args, "indictrans"),
[indictrans_chatbot, indictrans_lang, indictrans_max_tokens,
indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
indictrans_chatbot,
)
indictrans_submit.click(
user, [indictrans_msg, indictrans_chatbot, indictrans_lang],
[indictrans_msg, indictrans_chatbot], queue=False
).then(
lambda *args: bot(*args, "indictrans"),
[indictrans_chatbot, indictrans_lang, indictrans_max_tokens,
indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
indictrans_chatbot,
)
indictrans_feedback_submit.click(
lambda *args: store_feedback(*args, "indictrans"),
inputs=[indictrans_rating, indictrans_feedback, indictrans_chatbot, indictrans_lang],
)
# Event handlers for Sarvam
(sarvam_chatbot, sarvam_msg, sarvam_submit, sarvam_lang,
sarvam_rating, sarvam_feedback, sarvam_feedback_submit,
sarvam_max_tokens, sarvam_temp, sarvam_top_p,
sarvam_top_k, sarvam_rep_penalty) = sarvam_components
sarvam_msg.submit(
user, [sarvam_msg, sarvam_chatbot, sarvam_lang],
[sarvam_msg, sarvam_chatbot], queue=False
).then(
lambda *args: bot(*args, "sarvam"),
[sarvam_chatbot, sarvam_lang, sarvam_max_tokens,
sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
sarvam_chatbot,
)
sarvam_submit.click(
user, [sarvam_msg, sarvam_chatbot, sarvam_lang],
[sarvam_msg, sarvam_chatbot], queue=False
).then(
lambda *args: bot(*args, "sarvam"),
[sarvam_chatbot, sarvam_lang, sarvam_max_tokens,
sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
sarvam_chatbot,
)
sarvam_feedback_submit.click(
lambda *args: store_feedback(*args, "sarvam"),
inputs=[sarvam_rating, sarvam_feedback, sarvam_chatbot, sarvam_lang],
)
# Footer
gr.Markdown(
"""
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;">
<p>๐Ÿš€ <strong>Powered by AI4Bharat & Sarvam AI</strong> |
Built with โค๏ธ using Gradio |
๐Ÿ”ง <strong>Optimized with KV Caching & Advanced Memory Management</strong></p>
</div>
"""
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
)