Bhavibond commited on
Commit
94e54fd
·
verified ·
1 Parent(s): 6adcae6

Use SCST RLAI and check

Browse files
Files changed (1) hide show
  1. app.py +69 -34
app.py CHANGED
@@ -1,36 +1,33 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, LogitsProcessorList, TopKLogitsWarper
3
  import datetime
 
 
4
 
5
  # Load FLAN-T5 for Legal Q&A
6
  model_name = "google/flan-t5-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
- # Translation Models (Dynamically Loaded)
11
- translation_models = {
12
- "en-hi": ("Helsinki-NLP/opus-mt-en-hi", "Helsinki-NLP/opus-mt-en-hi"),
13
- "hi-en": ("Helsinki-NLP/opus-mt-hi-en", "Helsinki-NLP/opus-mt-hi-en")
14
- }
15
-
16
- def load_translation_model(src_lang, tgt_lang):
17
- pair = f"{src_lang[:2]}-{tgt_lang[:2]}"
18
- if pair in translation_models:
19
- model_name, tokenizer_name = translation_models[pair]
20
- trans_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
- trans_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
22
- return trans_model, trans_tokenizer
23
- return None, None
24
 
25
  # Translation Function
26
  def translate(text, src_lang, tgt_lang):
27
- trans_model, trans_tokenizer = load_translation_model(src_lang, tgt_lang)
28
- if trans_model is None:
 
 
 
 
 
 
 
29
  return "Translation for this pair not supported yet!"
30
-
31
- inputs = trans_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
32
- outputs = trans_model.generate(**inputs, max_length=256)
33
- return trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
34
 
35
  # Generate Complaint Template
36
  def generate_complaint(issue):
@@ -50,31 +47,65 @@ Yours sincerely,
50
  """
51
  return template.strip()
52
 
53
- # Handle Legal Q&A
 
 
 
 
 
 
54
  def handle_legal_query(query, language):
55
  if language != "English":
56
  query = translate(query, language, "English")
57
 
58
- # Tokenize input
59
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=256)
60
-
61
- # Logits processing using Top-K sampling
62
- logits_processor = LogitsProcessorList([
63
- TopKLogitsWarper(50) # Use Top-K only
64
- ])
65
 
66
- # Generate Response
67
- outputs = model.generate(**inputs, max_length=150, logits_processor=logits_processor)
68
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if language != "English":
71
  response = translate(response, "English", language)
72
 
73
  return response
74
 
75
- # Define Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
77
- gr.Markdown("# AI Legal Assistant\n### Ask legal questions and generate complaints")
 
78
 
79
  with gr.Row():
80
  query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
@@ -89,8 +120,12 @@ with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
89
  generate_btn = gr.Button("Generate Complaint")
90
  complaint_output = gr.Textbox(label="Generated Complaint", placeholder="Complaint template will appear here")
91
 
 
 
 
 
92
  submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
93
  generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
 
94
 
95
- # Launch the app
96
  app.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import datetime
4
+ import torch
5
+ import torch.nn.functional as F
6
 
7
  # Load FLAN-T5 for Legal Q&A
8
  model_name = "google/flan-t5-small"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
 
12
+ # Translation Models (English <-> Hindi)
13
+ translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
14
+ tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
15
+
16
+ translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
17
+ tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
 
 
 
 
 
 
 
 
18
 
19
  # Translation Function
20
  def translate(text, src_lang, tgt_lang):
21
+ if src_lang == "English" and tgt_lang == "Hindi":
22
+ inputs = tokenizer_en_hi(text, return_tensors="pt", padding=True, truncation=True)
23
+ outputs = translator_en_hi.generate(**inputs)
24
+ return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True)
25
+ elif src_lang == "Hindi" and tgt_lang == "English":
26
+ inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True)
27
+ outputs = translator_hi_en.generate(**inputs)
28
+ return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
29
+ else:
30
  return "Translation for this pair not supported yet!"
 
 
 
 
31
 
32
  # Generate Complaint Template
33
  def generate_complaint(issue):
 
47
  """
48
  return template.strip()
49
 
50
+ # Self-Critical Sequence Training (SCST) for RL
51
+ def compute_loss(logits, labels):
52
+ log_probs = F.log_softmax(logits, dim=-1)
53
+ gathered_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
54
+ loss = -gathered_log_probs.mean()
55
+ return loss
56
+
57
  def handle_legal_query(query, language):
58
  if language != "English":
59
  query = translate(query, language, "English")
60
 
61
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
 
62
 
63
+ # Generate output
64
+ outputs = model.generate(**inputs, max_length=150)
65
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
 
67
+ # Simple reward function (reward if response mentions legal terms)
68
+ reward = 1.0 if "law" in response.lower() or "legal" in response.lower() else -1.0
69
+
70
+ # Compute SCST Loss
71
+ labels = inputs['input_ids']
72
+ logits = model(**inputs).logits
73
+ loss = compute_loss(logits, labels)
74
+
75
+ # Update model weights based on reward signal
76
+ loss = loss * torch.tensor(reward, dtype=torch.float)
77
+ loss.backward()
78
+ model.optimizer.step()
79
+ model.zero_grad()
80
+
81
  if language != "English":
82
  response = translate(response, "English", language)
83
 
84
  return response
85
 
86
+ # Generate Email
87
+ def generate_email(issue):
88
+ template = f"""
89
+ Subject: Complaint Regarding {issue}
90
+
91
+ Dear Sir/Madam,
92
+
93
+ I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
94
+
95
+ - Issue: {issue}
96
+ - Evidence: [Provide Evidence]
97
+
98
+ I kindly request you to take appropriate action as per the legal guidelines.
99
+
100
+ Yours sincerely,
101
+ [Your Name]
102
+ """
103
+ return template.strip()
104
+
105
+ # Gradio Interface
106
  with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
107
+ gr.Markdown("# AI Legal Assistant for disabilities
108
+ ### Ask legal questions and generate complaints")
109
 
110
  with gr.Row():
111
  query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
 
120
  generate_btn = gr.Button("Generate Complaint")
121
  complaint_output = gr.Textbox(label="Generated Complaint", placeholder="Complaint template will appear here")
122
 
123
+ with gr.Row():
124
+ email_btn = gr.Button("Generate Email")
125
+ email_output = gr.Textbox(label="Generated Email", placeholder="Generated email will appear here")
126
+
127
  submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
128
  generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
129
+ email_btn.click(generate_email, inputs=issue, outputs=email_output)
130
 
 
131
  app.launch()