Bhavibond commited on
Commit
dceaa5a
·
verified ·
1 Parent(s): 25f3d38

remove ppo training for now

Browse files
Files changed (1) hide show
  1. app.py +28 -57
app.py CHANGED
@@ -1,52 +1,36 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper, set_seed
3
  import datetime
4
- import torch
5
- from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, create_reference_model
6
 
7
  # Load FLAN-T5 for Legal Q&A
8
  model_name = "google/flan-t5-small"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name)
11
-
12
- # Create a reference model for PPO
13
- ref_model = create_reference_model(model)
14
-
15
- # PPO Configuration
16
- config = PPOConfig(
17
- batch_size=1,
18
- learning_rate=1e-5,
19
- mini_batch_size=1,
20
- steps=1 # Minimal epochs
21
- )
22
-
23
- # Create PPO Trainer
24
- ppo_trainer = PPOTrainer(
25
- config=config,
26
- model=model,
27
- ref_model=ref_model,
28
- tokenizer=tokenizer
29
- )
30
-
31
- # Translation Models (English <-> Hindi)
32
- translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
33
- tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
34
-
35
- translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
36
- tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
37
 
38
  # Translation Function
39
  def translate(text, src_lang, tgt_lang):
40
- if src_lang == "English" and tgt_lang == "Hindi":
41
- inputs = tokenizer_en_hi(text, return_tensors="pt", padding=True, truncation=True)
42
- outputs = translator_en_hi.generate(**inputs)
43
- return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True)
44
- elif src_lang == "Hindi" and tgt_lang == "English":
45
- inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True)
46
- outputs = translator_hi_en.generate(**inputs)
47
- return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
48
- else:
49
  return "Translation for this pair not supported yet!"
 
 
 
 
50
 
51
  # Generate Complaint Template
52
  def generate_complaint(issue):
@@ -55,47 +39,34 @@ def generate_complaint(issue):
55
  [Your Name]
56
  [Your Address]
57
  {date}
58
-
59
  To Whom It May Concern,
60
-
61
  **Subject: Complaint Regarding {issue}**
62
-
63
  I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
64
-
65
  - Issue: {issue}
66
  - Evidence: [Provide Evidence]
67
-
68
  I kindly request you to take appropriate action as per the legal guidelines.
69
-
70
  Yours sincerely,
71
  [Your Name]
72
  """
73
  return template.strip()
74
 
75
- # Handle Legal Q&A with PPO
76
  def handle_legal_query(query, language):
77
  if language != "English":
78
  query = translate(query, language, "English")
79
 
80
  # Tokenize input
81
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
82
 
83
- # Logits processing using Top-K and Top-P sampling (replacement for top_k_top_p_filtering)
84
  logits_processor = LogitsProcessorList([
85
- TopKLogitsWarper(50), # Top-k sampling with k=50
86
- TopPLogitsWarper(0.95) # Top-p nucleus sampling with p=0.95
87
  ])
88
 
89
  # Generate Response
90
  outputs = model.generate(**inputs, max_length=150, logits_processor=logits_processor)
91
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
92
 
93
- # Reward Signal for PPO (basic reward)
94
- reward = torch.tensor([1.0]) if "legal" in response.lower() else torch.tensor([-1.0])
95
-
96
- # PPO Step (Reinforcement Learning)
97
- ppo_trainer.step([query], [outputs], [reward])
98
-
99
  if language != "English":
100
  response = translate(response, "English", language)
101
 
@@ -121,5 +92,5 @@ with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
121
  submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
122
  generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
123
 
124
- # Launch the app on Hugging Face free tier
125
  app.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, LogitsProcessorList, TopKLogitsWarper
3
  import datetime
 
 
4
 
5
  # Load FLAN-T5 for Legal Q&A
6
  model_name = "google/flan-t5-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+
10
+ # Translation Models (Dynamically Loaded)
11
+ translation_models = {
12
+ "en-hi": ("Helsinki-NLP/opus-mt-en-hi", "Helsinki-NLP/opus-mt-en-hi"),
13
+ "hi-en": ("Helsinki-NLP/opus-mt-hi-en", "Helsinki-NLP/opus-mt-hi-en")
14
+ }
15
+
16
+ def load_translation_model(src_lang, tgt_lang):
17
+ pair = f"{src_lang[:2]}-{tgt_lang[:2]}"
18
+ if pair in translation_models:
19
+ model_name, tokenizer_name = translation_models[pair]
20
+ trans_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+ trans_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
22
+ return trans_model, trans_tokenizer
23
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Translation Function
26
  def translate(text, src_lang, tgt_lang):
27
+ trans_model, trans_tokenizer = load_translation_model(src_lang, tgt_lang)
28
+ if trans_model is None:
 
 
 
 
 
 
 
29
  return "Translation for this pair not supported yet!"
30
+
31
+ inputs = trans_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
32
+ outputs = trans_model.generate(**inputs, max_length=256)
33
+ return trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
34
 
35
  # Generate Complaint Template
36
  def generate_complaint(issue):
 
39
  [Your Name]
40
  [Your Address]
41
  {date}
 
42
  To Whom It May Concern,
 
43
  **Subject: Complaint Regarding {issue}**
 
44
  I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
 
45
  - Issue: {issue}
46
  - Evidence: [Provide Evidence]
 
47
  I kindly request you to take appropriate action as per the legal guidelines.
 
48
  Yours sincerely,
49
  [Your Name]
50
  """
51
  return template.strip()
52
 
53
+ # Handle Legal Q&A
54
  def handle_legal_query(query, language):
55
  if language != "English":
56
  query = translate(query, language, "English")
57
 
58
  # Tokenize input
59
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=256)
60
 
61
+ # Logits processing using Top-K sampling
62
  logits_processor = LogitsProcessorList([
63
+ TopKLogitsWarper(50) # Use Top-K only
 
64
  ])
65
 
66
  # Generate Response
67
  outputs = model.generate(**inputs, max_length=150, logits_processor=logits_processor)
68
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
 
 
 
 
 
 
 
70
  if language != "English":
71
  response = translate(response, "English", language)
72
 
 
92
  submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
93
  generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
94
 
95
+ # Launch the app
96
  app.launch()