File size: 5,110 Bytes
b383fb3
6eb22ad
b383fb3
94e54fd
 
b383fb3
 
 
 
dceaa5a
6eb22ad
 
 
dceaa5a
94e54fd
 
 
 
 
 
b383fb3
 
 
94e54fd
 
 
 
 
 
 
 
 
6eb22ad
b383fb3
 
 
 
 
 
 
 
 
 
 
6eb22ad
 
b383fb3
 
 
 
 
 
94e54fd
 
 
 
 
 
 
b383fb3
 
 
 
94e54fd
6eb22ad
 
94e54fd
b383fb3
 
6eb22ad
 
94e54fd
 
 
 
 
 
6eb22ad
 
94e54fd
6eb22ad
 
94e54fd
b383fb3
 
 
 
 
94e54fd
 
 
 
6eb22ad
 
 
 
 
94e54fd
 
 
 
 
 
6eb22ad
 
b383fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
94e54fd
 
 
 
b383fb3
6eb22ad
 
b383fb3
6eb22ad
b383fb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import datetime
import torch
import torch.nn.functional as F

# Load FLAN-T5 for Legal Q&A
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Define optimizer for FLAN-T5 model
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Translation Models (English <-> Hindi)
translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")

translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en")

# Translation Function
def translate(text, src_lang, tgt_lang):
    if src_lang == "English" and tgt_lang == "Hindi":
        inputs = tokenizer_en_hi(text, return_tensors="pt", padding=True, truncation=True)
        outputs = translator_en_hi.generate(**inputs)
        return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True)
    elif src_lang == "Hindi" and tgt_lang == "English":
        inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True)
        outputs = translator_hi_en.generate(**inputs)
        return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
    else:
        return "Translation for this pair not supported yet!"

# Generate Complaint Template
def generate_complaint(issue):
    date = datetime.datetime.now().strftime("%d-%m-%Y")
    template = f"""
[Your Name]  
[Your Address]  
{date}  
To Whom It May Concern,  
**Subject: Complaint Regarding {issue}**  
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:  
- Issue: {issue}  
- Evidence: [Provide Evidence]  
I kindly request you to take appropriate action as per the legal guidelines.  
Yours sincerely,  
[Your Name]  
"""
    return template.strip()

# Self-Critical Sequence Training (SCST) for RL
def compute_loss(logits, labels):
    log_probs = F.log_softmax(logits, dim=-1)
    gathered_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    loss = -gathered_log_probs.mean()
    return loss

def handle_legal_query(query, language):
    if language != "English":
        query = translate(query, language, "English")

    inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)

    # Generate output
    outputs = model.generate(**inputs, max_length=150)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Simple reward function (reward if response mentions legal terms)
    reward = 1.0 if "law" in response.lower() or "legal" in response.lower() else -1.0

    # Compute SCST Loss
    labels = inputs['input_ids']
    logits = model(**inputs).logits
    loss = compute_loss(logits, labels)

    # Update model weights based on reward signal
    optimizer.zero_grad()  # Reset gradients
    loss = loss * torch.tensor(reward, dtype=torch.float)
    loss.backward()  # Backpropagation
    optimizer.step()  # Update model weights

    if language != "English":
        response = translate(response, "English", language)

    return response

# Generate Email
def generate_email(issue):
    template = f"""
Subject: Complaint Regarding {issue}
Dear Sir/Madam,
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
- Issue: {issue}
- Evidence: [Provide Evidence]
I kindly request you to take appropriate action as per the legal guidelines.
Yours sincerely,  
[Your Name]
"""
    return template.strip()

# Gradio Interface
with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
    gr.Markdown("# AI Legal Assistant for Disabilities \n### Ask legal questions and generate complaints")

    with gr.Row():
        query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
        lang = gr.Dropdown(["English", "Hindi"], label="Language", value="English")

    with gr.Row():
        submit_btn = gr.Button("Get Legal Advice")
        output = gr.Textbox(label="Legal Advice", placeholder="Legal advice will appear here")

    with gr.Row():
        issue = gr.Textbox(label="Describe your issue", placeholder="Facing discrimination at work...")
        generate_btn = gr.Button("Generate Complaint")
        complaint_output = gr.Textbox(label="Generated Complaint", placeholder="Complaint template will appear here")

    with gr.Row():
        email_btn = gr.Button("Generate Email")
        email_output = gr.Textbox(label="Generated Email", placeholder="Generated email will appear here")

    submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
    generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
    email_btn.click(generate_email, inputs=issue, outputs=email_output)

# Launch the app
app.launch()