Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
import torch | |
import numpy as np | |
# Initialize models | |
try: | |
# Text Generation with TinyLlama | |
generator_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
tokenizer = AutoTokenizer.from_pretrained(generator_name) | |
generator_model = AutoModelForCausalLM.from_pretrained( | |
generator_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Text Summarization | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
# Sentiment Analysis | |
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") | |
# Question Answering | |
qa_model = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
# Translation (English to multiple languages) | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-en-ROMANCE") | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
def generate_text(prompt, max_length=100, temperature=0.7): | |
"""Generate text based on a prompt using TinyLlama""" | |
try: | |
# Format the prompt for chat | |
formatted_prompt = f"<human>: {prompt}\n<assistant>:" | |
# Generate text | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(generator_model.device) | |
outputs = generator_model.generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.2, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and clean up the response | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the original prompt and clean up | |
response = generated_text.split("<assistant>:")[-1].strip() | |
return response | |
except Exception as e: | |
return f"Error in text generation: {str(e)}" | |
def summarize_text(text, max_length=130, min_length=30): | |
"""Summarize long text""" | |
try: | |
summary = summarizer( | |
text, | |
max_length=max_length, | |
min_length=min_length, | |
do_sample=False | |
) | |
return summary[0]['summary_text'] | |
except Exception as e: | |
return f"Error in summarization: {str(e)}" | |
def analyze_sentiment(text): | |
"""Analyze sentiment of text""" | |
try: | |
result = sentiment_analyzer(text) | |
return { | |
"Sentiment": result[0]['label'], | |
"Confidence": f"{result[0]['score']:.2%}" | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def answer_question(context, question): | |
"""Answer questions based on context""" | |
try: | |
result = qa_model( | |
question=question, | |
context=context | |
) | |
return { | |
"Answer": result['answer'], | |
"Confidence": f"{result['score']:.2%}" | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def translate_text(text, target_lang): | |
"""Translate text to target language""" | |
try: | |
translation = translator( | |
text, | |
src_lang="en", | |
tgt_lang=target_lang | |
) | |
return translation[0]['translation_text'] | |
except Exception as e: | |
return f"Error in translation: {str(e)}" | |
# Create the Gradio interface | |
with gr.Blocks(title="Advanced NLP") as demo: | |
gr.Markdown(""" | |
# 🤖 Advanced NLP | |
## Multi-task Language Model Application | |
This application demonstrates various Natural Language Processing capabilities: | |
- Text Generation (TinyLlama) | |
- Text Summarization (BART) | |
- Sentiment Analysis (DistilBERT) | |
- Question Answering | |
- Multi-language Translation | |
Try out different tasks using the options below! | |
""") | |
with gr.Tab("Text Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
gen_input = gr.Textbox(label="Enter your prompt", lines=3) | |
gen_length = gr.Slider(minimum=10, maximum=200, value=100, step=10, label="Maximum Length") | |
gen_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") | |
gen_button = gr.Button("Generate") | |
with gr.Column(): | |
gen_output = gr.Textbox(label="Generated Text", lines=5) | |
with gr.Tab("Text Summarization"): | |
with gr.Row(): | |
with gr.Column(): | |
sum_input = gr.Textbox(label="Enter text to summarize", lines=8) | |
sum_max_length = gr.Slider(minimum=50, maximum=200, value=130, step=10, label="Maximum Summary Length") | |
sum_min_length = gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Summary Length") | |
sum_button = gr.Button("Summarize") | |
with gr.Column(): | |
sum_output = gr.Textbox(label="Summary", lines=4) | |
with gr.Tab("Sentiment Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
sent_input = gr.Textbox(label="Enter text for sentiment analysis", lines=3) | |
sent_button = gr.Button("Analyze Sentiment") | |
with gr.Column(): | |
sent_output = gr.JSON(label="Sentiment Analysis Results") | |
with gr.Tab("Question Answering"): | |
with gr.Row(): | |
with gr.Column(): | |
qa_context = gr.Textbox(label="Enter the context", lines=6) | |
qa_question = gr.Textbox(label="Enter your question", lines=2) | |
qa_button = gr.Button("Get Answer") | |
with gr.Column(): | |
qa_output = gr.JSON(label="Answer") | |
with gr.Tab("Translation"): | |
with gr.Row(): | |
with gr.Column(): | |
trans_input = gr.Textbox(label="Enter text to translate (English)", lines=3) | |
trans_lang = gr.Dropdown( | |
choices=["es", "fr", "it", "pt", "ro"], | |
value="es", | |
label="Target Language" | |
) | |
trans_button = gr.Button("Translate") | |
with gr.Column(): | |
trans_output = gr.Textbox(label="Translated Text", lines=3) | |
# Set up event handlers | |
gen_button.click( | |
fn=generate_text, | |
inputs=[gen_input, gen_length, gen_temp], | |
outputs=gen_output | |
) | |
sum_button.click( | |
fn=summarize_text, | |
inputs=[sum_input, sum_max_length, sum_min_length], | |
outputs=sum_output | |
) | |
sent_button.click( | |
fn=analyze_sentiment, | |
inputs=sent_input, | |
outputs=sent_output | |
) | |
qa_button.click( | |
fn=answer_question, | |
inputs=[qa_context, qa_question], | |
outputs=qa_output | |
) | |
trans_button.click( | |
fn=translate_text, | |
inputs=[trans_input, trans_lang], | |
outputs=trans_output | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |