Spaces:
Sleeping
Sleeping
File size: 5,110 Bytes
b383fb3 6eb22ad b383fb3 94e54fd b383fb3 dceaa5a 6eb22ad dceaa5a 94e54fd b383fb3 94e54fd 6eb22ad b383fb3 6eb22ad b383fb3 94e54fd b383fb3 94e54fd 6eb22ad 94e54fd b383fb3 6eb22ad 94e54fd 6eb22ad 94e54fd 6eb22ad 94e54fd b383fb3 94e54fd 6eb22ad 94e54fd 6eb22ad b383fb3 94e54fd b383fb3 6eb22ad b383fb3 6eb22ad b383fb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import datetime
import torch
import torch.nn.functional as F
# Load FLAN-T5 for Legal Q&A
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Define optimizer for FLAN-T5 model
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# Translation Models (English <-> Hindi)
translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
# Translation Function
def translate(text, src_lang, tgt_lang):
if src_lang == "English" and tgt_lang == "Hindi":
inputs = tokenizer_en_hi(text, return_tensors="pt", padding=True, truncation=True)
outputs = translator_en_hi.generate(**inputs)
return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True)
elif src_lang == "Hindi" and tgt_lang == "English":
inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True)
outputs = translator_hi_en.generate(**inputs)
return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
else:
return "Translation for this pair not supported yet!"
# Generate Complaint Template
def generate_complaint(issue):
date = datetime.datetime.now().strftime("%d-%m-%Y")
template = f"""
[Your Name]
[Your Address]
{date}
To Whom It May Concern,
**Subject: Complaint Regarding {issue}**
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
- Issue: {issue}
- Evidence: [Provide Evidence]
I kindly request you to take appropriate action as per the legal guidelines.
Yours sincerely,
[Your Name]
"""
return template.strip()
# Self-Critical Sequence Training (SCST) for RL
def compute_loss(logits, labels):
log_probs = F.log_softmax(logits, dim=-1)
gathered_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
loss = -gathered_log_probs.mean()
return loss
def handle_legal_query(query, language):
if language != "English":
query = translate(query, language, "English")
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
# Generate output
outputs = model.generate(**inputs, max_length=150)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Simple reward function (reward if response mentions legal terms)
reward = 1.0 if "law" in response.lower() or "legal" in response.lower() else -1.0
# Compute SCST Loss
labels = inputs['input_ids']
logits = model(**inputs).logits
loss = compute_loss(logits, labels)
# Update model weights based on reward signal
optimizer.zero_grad() # Reset gradients
loss = loss * torch.tensor(reward, dtype=torch.float)
loss.backward() # Backpropagation
optimizer.step() # Update model weights
if language != "English":
response = translate(response, "English", language)
return response
# Generate Email
def generate_email(issue):
template = f"""
Subject: Complaint Regarding {issue}
Dear Sir/Madam,
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
- Issue: {issue}
- Evidence: [Provide Evidence]
I kindly request you to take appropriate action as per the legal guidelines.
Yours sincerely,
[Your Name]
"""
return template.strip()
# Gradio Interface
with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
gr.Markdown("# AI Legal Assistant for Disabilities \n### Ask legal questions and generate complaints")
with gr.Row():
query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
lang = gr.Dropdown(["English", "Hindi"], label="Language", value="English")
with gr.Row():
submit_btn = gr.Button("Get Legal Advice")
output = gr.Textbox(label="Legal Advice", placeholder="Legal advice will appear here")
with gr.Row():
issue = gr.Textbox(label="Describe your issue", placeholder="Facing discrimination at work...")
generate_btn = gr.Button("Generate Complaint")
complaint_output = gr.Textbox(label="Generated Complaint", placeholder="Complaint template will appear here")
with gr.Row():
email_btn = gr.Button("Generate Email")
email_output = gr.Textbox(label="Generated Email", placeholder="Generated email will appear here")
submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
email_btn.click(generate_email, inputs=issue, outputs=email_output)
# Launch the app
app.launch() |