Spaces:
Sleeping
Sleeping
remove ppo training for now
Browse files
app.py
CHANGED
@@ -1,52 +1,36 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, LogitsProcessorList, TopKLogitsWarper
|
3 |
import datetime
|
4 |
-
import torch
|
5 |
-
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, create_reference_model
|
6 |
|
7 |
# Load FLAN-T5 for Legal Q&A
|
8 |
model_name = "google/flan-t5-small"
|
9 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
10 |
-
model =
|
11 |
-
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
model=model,
|
27 |
-
ref_model=ref_model,
|
28 |
-
tokenizer=tokenizer
|
29 |
-
)
|
30 |
-
|
31 |
-
# Translation Models (English <-> Hindi)
|
32 |
-
translator_en_hi = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
|
33 |
-
tokenizer_en_hi = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
|
34 |
-
|
35 |
-
translator_hi_en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
|
36 |
-
tokenizer_hi_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-hi-en")
|
37 |
|
38 |
# Translation Function
|
39 |
def translate(text, src_lang, tgt_lang):
|
40 |
-
|
41 |
-
|
42 |
-
outputs = translator_en_hi.generate(**inputs)
|
43 |
-
return tokenizer_en_hi.decode(outputs[0], skip_special_tokens=True)
|
44 |
-
elif src_lang == "Hindi" and tgt_lang == "English":
|
45 |
-
inputs = tokenizer_hi_en(text, return_tensors="pt", padding=True, truncation=True)
|
46 |
-
outputs = translator_hi_en.generate(**inputs)
|
47 |
-
return tokenizer_hi_en.decode(outputs[0], skip_special_tokens=True)
|
48 |
-
else:
|
49 |
return "Translation for this pair not supported yet!"
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Generate Complaint Template
|
52 |
def generate_complaint(issue):
|
@@ -55,47 +39,34 @@ def generate_complaint(issue):
|
|
55 |
[Your Name]
|
56 |
[Your Address]
|
57 |
{date}
|
58 |
-
|
59 |
To Whom It May Concern,
|
60 |
-
|
61 |
**Subject: Complaint Regarding {issue}**
|
62 |
-
|
63 |
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
|
64 |
-
|
65 |
- Issue: {issue}
|
66 |
- Evidence: [Provide Evidence]
|
67 |
-
|
68 |
I kindly request you to take appropriate action as per the legal guidelines.
|
69 |
-
|
70 |
Yours sincerely,
|
71 |
[Your Name]
|
72 |
"""
|
73 |
return template.strip()
|
74 |
|
75 |
-
# Handle Legal Q&A
|
76 |
def handle_legal_query(query, language):
|
77 |
if language != "English":
|
78 |
query = translate(query, language, "English")
|
79 |
|
80 |
# Tokenize input
|
81 |
-
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
|
82 |
|
83 |
-
# Logits processing using Top-K
|
84 |
logits_processor = LogitsProcessorList([
|
85 |
-
TopKLogitsWarper(50)
|
86 |
-
TopPLogitsWarper(0.95) # Top-p nucleus sampling with p=0.95
|
87 |
])
|
88 |
|
89 |
# Generate Response
|
90 |
outputs = model.generate(**inputs, max_length=150, logits_processor=logits_processor)
|
91 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
92 |
|
93 |
-
# Reward Signal for PPO (basic reward)
|
94 |
-
reward = torch.tensor([1.0]) if "legal" in response.lower() else torch.tensor([-1.0])
|
95 |
-
|
96 |
-
# PPO Step (Reinforcement Learning)
|
97 |
-
ppo_trainer.step([query], [outputs], [reward])
|
98 |
-
|
99 |
if language != "English":
|
100 |
response = translate(response, "English", language)
|
101 |
|
@@ -121,5 +92,5 @@ with gr.Blocks(css=".container {width: 100%; max-width: 600px;}") as app:
|
|
121 |
submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
|
122 |
generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
|
123 |
|
124 |
-
# Launch the app
|
125 |
app.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, LogitsProcessorList, TopKLogitsWarper
|
3 |
import datetime
|
|
|
|
|
4 |
|
5 |
# Load FLAN-T5 for Legal Q&A
|
6 |
model_name = "google/flan-t5-small"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
9 |
+
|
10 |
+
# Translation Models (Dynamically Loaded)
|
11 |
+
translation_models = {
|
12 |
+
"en-hi": ("Helsinki-NLP/opus-mt-en-hi", "Helsinki-NLP/opus-mt-en-hi"),
|
13 |
+
"hi-en": ("Helsinki-NLP/opus-mt-hi-en", "Helsinki-NLP/opus-mt-hi-en")
|
14 |
+
}
|
15 |
+
|
16 |
+
def load_translation_model(src_lang, tgt_lang):
|
17 |
+
pair = f"{src_lang[:2]}-{tgt_lang[:2]}"
|
18 |
+
if pair in translation_models:
|
19 |
+
model_name, tokenizer_name = translation_models[pair]
|
20 |
+
trans_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
21 |
+
trans_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
22 |
+
return trans_model, trans_tokenizer
|
23 |
+
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Translation Function
|
26 |
def translate(text, src_lang, tgt_lang):
|
27 |
+
trans_model, trans_tokenizer = load_translation_model(src_lang, tgt_lang)
|
28 |
+
if trans_model is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return "Translation for this pair not supported yet!"
|
30 |
+
|
31 |
+
inputs = trans_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
32 |
+
outputs = trans_model.generate(**inputs, max_length=256)
|
33 |
+
return trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
34 |
|
35 |
# Generate Complaint Template
|
36 |
def generate_complaint(issue):
|
|
|
39 |
[Your Name]
|
40 |
[Your Address]
|
41 |
{date}
|
|
|
42 |
To Whom It May Concern,
|
|
|
43 |
**Subject: Complaint Regarding {issue}**
|
|
|
44 |
I am writing to formally lodge a complaint regarding {issue}. The incident occurred on [Date/Location]. The specific details are as follows:
|
|
|
45 |
- Issue: {issue}
|
46 |
- Evidence: [Provide Evidence]
|
|
|
47 |
I kindly request you to take appropriate action as per the legal guidelines.
|
|
|
48 |
Yours sincerely,
|
49 |
[Your Name]
|
50 |
"""
|
51 |
return template.strip()
|
52 |
|
53 |
+
# Handle Legal Q&A
|
54 |
def handle_legal_query(query, language):
|
55 |
if language != "English":
|
56 |
query = translate(query, language, "English")
|
57 |
|
58 |
# Tokenize input
|
59 |
+
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=256)
|
60 |
|
61 |
+
# Logits processing using Top-K sampling
|
62 |
logits_processor = LogitsProcessorList([
|
63 |
+
TopKLogitsWarper(50) # Use Top-K only
|
|
|
64 |
])
|
65 |
|
66 |
# Generate Response
|
67 |
outputs = model.generate(**inputs, max_length=150, logits_processor=logits_processor)
|
68 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
if language != "English":
|
71 |
response = translate(response, "English", language)
|
72 |
|
|
|
92 |
submit_btn.click(handle_legal_query, inputs=[query, lang], outputs=output)
|
93 |
generate_btn.click(generate_complaint, inputs=issue, outputs=complaint_output)
|
94 |
|
95 |
+
# Launch the app
|
96 |
app.launch()
|