Bhavibond commited on
Commit
b383fb3
·
verified ·
1 Parent(s): 1e5bb95

AI legal assistant for disabilities

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import datetime
4
+ import torch
5
+ from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, create_reference_model, set_seed
6
+ from transformers import pipeline
7
+
8
+ # Load FLAN-T5 for Legal Q&A
9
+ model_name = "google/flan-t5-small"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name)
12
+
13
+ # Create a reference model for PPO
14
+ ref_model = create_reference_model(model)
15
+
16
+ # PPO Configuration
17
+ config = PPOConfig(
18
+ batch_size=1,
19
+ learning_rate=1e-5,
20
+ mini_batch_size=1,
21
+ ppo_epochs=1 # Minimal epochs
22
+ )
23
+
24
+ # Create PPO Trainer
25
+ ppo_trainer = PPOTrainer(
26
+ config=config,
27
+ model=model,
28
+ ref_model=ref_model,
29
+ tokenizer=tokenizer
30
+ )
31
+
32
+ # Translation Models (English ↔ Hindi)
33
+ translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
34
+ tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
35
+
36
+ translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
37
+ tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
38
+
39
+ # Translation Function
40
+ def translate(text, src_lang, tgt_lang):
41
+ if src_lang == "English" and tgt_lang == "Hindi":
42
+ inputs = tokenizer_en_hi(text, return_tensors="pt", padding=True, truncation=True)
43
+ outputs = translator_en_hi.generate(**inputs)
44
+ return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True)
45
+ elif src_lang == "Hindi" and tgt_lang == "English":
46
+ inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True)
47
+ outputs = translator_hi_en.generate(**inputs)
48
+ return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
49
+ else:
50
+ return "Translation for this pair not supported yet!"
51
+
52
+ # Generate Complaint Template
53
+ def generate_complaint(issue):
54
+ date = datetime.datetime.now().strftime("%d-%m-%Y")
55
+ template = f"""
56
+ [Your Name]
57
+ [Your Address]
58
+ {date}
59
+
60
+ To Whom It May Concern,
61
+
62
+ **Subject: Complaint Regarding {issue}**
63
+
64
+ I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
65
+
66
+ - Issue: {issue}
67
+ - Evidence: [Provide Evidence]
68
+
69
+ I kindly request you to take appropriate action as per the legal guidelines.
70
+
71
+ Yours sincerely,
72
+ [Your Name]
73
+ """
74
+ return template.strip()
75
+
76
+ # Handle Legal Q&A with PPO
77
+ def handle_legal_query(query, language):
78
+ if language != "English":
79
+ query = translate(query, language, "English")
80
+
81
+ # Tokenize input
82
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
83
+
84
+ # Generate Response
85
+ outputs = model.generate(**inputs, max_length=150)
86
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
+
88
+ # Reward Signal for PPO (basic reward)
89
+ reward = torch.tensor([1.0]) if "legal" in response.lower() else torch.tensor([-1.0])
90
+
91
+ # PPO Step (Reinforcement Learning)
92
+ ppo_trainer.step([query], [outputs], [reward])
93
+
94
+ if language != "English":
95
+ response = translate(response, "English", language)
96
+
97
+ return response
98
+
99
+ # Define Gradio Interface
100
+ with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
101
+ gr.Markdown("# AI Legal Assistant\n### Ask legal questions and generate complaints")
102
+
103
+ with gr.Row():
104
+ query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
105
+ lang = gr.Dropdown(["English", "Hindi"], label="Language", value="English")
106
+
107
+ with gr.Row():
108
+ submit_btn = gr.Button("Get Legal Advice")
109
+ output = gr.Textbox(label="Legal Advice", placeholder="Legal advice will appear here")
110
+
111
+ with gr.Row():
112
+ issue = gr.Textbox(label="Describe your issue", placeholder="Facing discrimination at work...")
113
+ generate_btn = gr.Button("Generate Complaint")
114
+ complaint_output = gr.Textbox(label="Generated Complaint", placeholder="Complaint template will appear here")
115
+
116
+ submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
117
+ generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
118
+
119
+ # Launch the app on Hugging Face free tier
120
+ app.launch()