Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import datetime | |
import torch | |
import torch.nn.functional as F | |
# Load FLAN-T5 for Legal Q&A | |
model_name = "google/flan-t5-small" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Define optimizer for FLAN-T5 model | |
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) | |
# Translation Models (English <-> Hindi) | |
translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi") | |
tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi") | |
translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en") | |
tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en") | |
# Translation Function | |
def translate(text, src_lang, tgt_lang): | |
if src_lang == "English" and tgt_lang == "Hindi": | |
inputs = tokenizer_en_hi(text, return_tensors="pt", padding=True, truncation=True) | |
outputs = translator_en_hi.generate(**inputs) | |
return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True) | |
elif src_lang == "Hindi" and tgt_lang == "English": | |
inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True) | |
outputs = translator_hi_en.generate(**inputs) | |
return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True) | |
else: | |
return "Translation for this pair not supported yet!" | |
# Generate Complaint Template | |
def generate_complaint(issue): | |
date = datetime.datetime.now().strftime("%d-%m-%Y") | |
template = f""" | |
[Your Name] | |
[Your Address] | |
{date} | |
To Whom It May Concern, | |
**Subject: Complaint Regarding {issue}** | |
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows: | |
- Issue: {issue} | |
- Evidence: [Provide Evidence] | |
I kindly request you to take appropriate action as per the legal guidelines. | |
Yours sincerely, | |
[Your Name] | |
""" | |
return template.strip() | |
# Self-Critical Sequence Training (SCST) for RL | |
def compute_loss(logits, labels): | |
log_probs = F.log_softmax(logits, dim=-1) | |
gathered_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) | |
loss = -gathered_log_probs.mean() | |
return loss | |
def handle_legal_query(query, language): | |
if language != "English": | |
query = translate(query, language, "English") | |
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True) | |
# Generate output | |
outputs = model.generate(**inputs, max_length=150) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Simple reward function (reward if response mentions legal terms) | |
reward = 1.0 if "law" in response.lower() or "legal" in response.lower() else -1.0 | |
# Compute SCST Loss | |
labels = inputs['input_ids'] | |
logits = model(**inputs).logits | |
loss = compute_loss(logits, labels) | |
# Update model weights based on reward signal | |
optimizer.zero_grad() # Reset gradients | |
loss = loss * torch.tensor(reward, dtype=torch.float) | |
loss.backward() # Backpropagation | |
optimizer.step() # Update model weights | |
if language != "English": | |
response = translate(response, "English", language) | |
return response | |
# Generate Email | |
def generate_email(issue): | |
template = f""" | |
Subject: Complaint Regarding {issue} | |
Dear Sir/Madam, | |
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows: | |
- Issue: {issue} | |
- Evidence: [Provide Evidence] | |
I kindly request you to take appropriate action as per the legal guidelines. | |
Yours sincerely, | |
[Your Name] | |
""" | |
return template.strip() | |
# Gradio Interface | |
with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app: | |
gr.Markdown("# AI Legal Assistant for Disabilities \n### Ask legal questions and generate complaints") | |
with gr.Row(): | |
query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?") | |
lang = gr.Dropdown(["English", "Hindi"], label="Language", value="English") | |
with gr.Row(): | |
submit_btn = gr.Button("Get Legal Advice") | |
output = gr.Textbox(label="Legal Advice", placeholder="Legal advice will appear here") | |
with gr.Row(): | |
issue = gr.Textbox(label="Describe your issue", placeholder="Facing discrimination at work...") | |
generate_btn = gr.Button("Generate Complaint") | |
complaint_output = gr.Textbox(label="Generated Complaint", placeholder="Complaint template will appear here") | |
with gr.Row(): | |
email_btn = gr.Button("Generate Email") | |
email_output = gr.Textbox(label="Generated Email", placeholder="Generated email will appear here") | |
submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output) | |
generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output) | |
email_btn.click(generate_email, inputs=issue, outputs=email_output) | |
# Launch the app | |
app.launch() |