Bhavibond commited on
Commit
6eb22ad
·
verified ·
1 Parent(s): 9bf77ba

Added optimizer.zero_grad(), loss.backward(), and optimizer.step() properly

Browse files
Files changed (1) hide show
  1. app.py +25 -34
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
3
  import datetime
4
  import torch
5
  import torch.nn.functional as F
@@ -8,7 +8,9 @@ import torch.nn.functional as F
8
  model_name = "google/flan-t5-small"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
- optimizer = AdamW(model.parameters(), lr=5e-5)
 
 
12
 
13
  # Translation Models (English <-> Hindi)
14
  translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
@@ -28,7 +30,7 @@ def translate(text, src_lang, tgt_lang):
28
  outputs = translator_hi_en.generate(**inputs)
29
  return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
30
  else:
31
- return "Translation for this pair is not supported yet!"
32
 
33
  # Generate Complaint Template
34
  def generate_complaint(issue):
@@ -38,15 +40,11 @@ def generate_complaint(issue):
38
  [Your Address]
39
  {date}
40
  To Whom It May Concern,
41
-
42
  **Subject: Complaint Regarding {issue}**
43
-
44
  I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
45
- - **Issue:** {issue}
46
- - **Evidence:** [Provide Evidence]
47
-
48
  I kindly request you to take appropriate action as per the legal guidelines.
49
-
50
  Yours sincerely,
51
  [Your Name]
52
  """
@@ -59,31 +57,29 @@ def compute_loss(logits, labels):
59
  loss = -gathered_log_probs.mean()
60
  return loss
61
 
62
- # Legal Query Handling with Reinforcement Learning
63
  def handle_legal_query(query, language):
64
  if language != "English":
65
  query = translate(query, language, "English")
66
 
67
  inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
 
 
68
  outputs = model.generate(**inputs, max_length=150)
69
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
- # Simple reward function - reward if response mentions key legal terms
72
- reward_terms = ["law", "legal", "rights", "disability", "act", "compliance"]
73
- reward = 1.0 if any(term in response.lower() for term in reward_terms) else -0.5
74
 
75
  # Compute SCST Loss
76
  labels = inputs['input_ids']
77
  logits = model(**inputs).logits
78
  loss = compute_loss(logits, labels)
79
 
80
- # Adjust loss based on reward
 
81
  loss = loss * torch.tensor(reward, dtype=torch.float)
82
-
83
- # Backpropagation with optimizer
84
- optimizer.zero_grad()
85
- loss.backward()
86
- optimizer.step()
87
 
88
  if language != "English":
89
  response = translate(response, "English", language)
@@ -94,23 +90,19 @@ def handle_legal_query(query, language):
94
  def generate_email(issue):
95
  template = f"""
96
  Subject: Complaint Regarding {issue}
97
-
98
- Dear Sir/Madam,
99
-
100
- I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
101
- - **Issue:** {issue}
102
- - **Evidence:** [Provide Evidence]
103
-
104
- I kindly request you to take appropriate action as per the legal guidelines.
105
-
106
  Yours sincerely,
107
  [Your Name]
108
  """
109
  return template.strip()
110
 
111
  # Gradio Interface
112
- with gr.Blocks(css=".container {width: 100%; max-width: 800px;}") as app:
113
- gr.Markdown("# AI Legal Assistant for Disabilities\n### Get Legal Advice and Generate Complaint Templates Instantly")
114
 
115
  with gr.Row():
116
  query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
@@ -129,10 +121,9 @@ with gr.Blocks(css=".container {width: 100%; max-width: 800px;}") as app:
129
  email_btn = gr.Button("Generate Email")
130
  email_output = gr.Textbox(label="Generated Email", placeholder="Generated email will appear here")
131
 
132
- # Connect functions to buttons
133
  submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
134
- generate_btn.click(generate_complaint, inputs=[issue], outputs=complaint_output)
135
- email_btn.click(generate_email, inputs=[issue], outputs=email_output)
136
 
137
- # Launch app
138
  app.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import datetime
4
  import torch
5
  import torch.nn.functional as F
 
8
  model_name = "google/flan-t5-small"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
+
12
+ # Define optimizer for FLAN-T5 model
13
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
14
 
15
  # Translation Models (English <-> Hindi)
16
  translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
 
30
  outputs = translator_hi_en.generate(**inputs)
31
  return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
32
  else:
33
+ return "Translation for this pair not supported yet!"
34
 
35
  # Generate Complaint Template
36
  def generate_complaint(issue):
 
40
  [Your Address]
41
  {date}
42
  To Whom It May Concern,
 
43
  **Subject: Complaint Regarding {issue}**
 
44
  I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
45
+ - Issue: {issue}
46
+ - Evidence: [Provide Evidence]
 
47
  I kindly request you to take appropriate action as per the legal guidelines.
 
48
  Yours sincerely,
49
  [Your Name]
50
  """
 
57
  loss = -gathered_log_probs.mean()
58
  return loss
59
 
 
60
  def handle_legal_query(query, language):
61
  if language != "English":
62
  query = translate(query, language, "English")
63
 
64
  inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
65
+
66
+ # Generate output
67
  outputs = model.generate(**inputs, max_length=150)
68
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
 
70
+ # Simple reward function (reward if response mentions legal terms)
71
+ reward = 1.0 if "law" in response.lower() or "legal" in response.lower() else -1.0
 
72
 
73
  # Compute SCST Loss
74
  labels = inputs['input_ids']
75
  logits = model(**inputs).logits
76
  loss = compute_loss(logits, labels)
77
 
78
+ # Update model weights based on reward signal
79
+ optimizer.zero_grad() # Reset gradients
80
  loss = loss * torch.tensor(reward, dtype=torch.float)
81
+ loss.backward() # Backpropagation
82
+ optimizer.step() # Update model weights
 
 
 
83
 
84
  if language != "English":
85
  response = translate(response, "English", language)
 
90
  def generate_email(issue):
91
  template = f"""
92
  Subject: Complaint Regarding {issue}
93
+ Dear Sir/Madam,
94
+ I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
95
+ - Issue: {issue}
96
+ - Evidence: [Provide Evidence]
97
+ I kindly request you to take appropriate action as per the legal guidelines.
 
 
 
 
98
  Yours sincerely,
99
  [Your Name]
100
  """
101
  return template.strip()
102
 
103
  # Gradio Interface
104
+ with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
105
+ gr.Markdown("# AI Legal Assistant for Disabilities \n### Ask legal questions and generate complaints")
106
 
107
  with gr.Row():
108
  query = gr.Textbox(label="Ask your legal question", placeholder="What are my rights as a disabled person?")
 
121
  email_btn = gr.Button("Generate Email")
122
  email_output = gr.Textbox(label="Generated Email", placeholder="Generated email will appear here")
123
 
 
124
  submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
125
+ generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
126
+ email_btn.click(generate_email, inputs=issue, outputs=email_output)
127
 
128
+ # Launch the app
129
  app.launch()