naa142 commited on
Commit
896157d
Β·
verified Β·
1 Parent(s): 4982e19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -80
app.py CHANGED
@@ -1,80 +1,80 @@
1
- # app.py
2
-
3
- import streamlit as st
4
- import torch
5
- import torch.nn as nn
6
- from transformers import AutoTokenizer, AutoModel
7
-
8
- # βœ… Device
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
-
11
- # βœ… Load Tokenizer
12
- tokenizer = AutoTokenizer.from_pretrained("./final_deberta_model")
13
-
14
- # βœ… Define Model
15
- class ScoringModel(nn.Module):
16
- def __init__(self, base_model_path="./final_deberta_model", dropout_rate=0.242):
17
- super().__init__()
18
- self.base = AutoModel.from_pretrained(base_model_path)
19
- self.base.gradient_checkpointing_enable()
20
- self.dropout1 = nn.Dropout(dropout_rate)
21
- self.dropout2 = nn.Dropout(dropout_rate)
22
- self.dropout3 = nn.Dropout(dropout_rate)
23
- self.classifier = nn.Linear(self.base.config.hidden_size, 1)
24
-
25
- def forward(self, input_ids, attention_mask):
26
- hidden = self.base(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
27
- logits = (self.classifier(self.dropout1(hidden)) +
28
- self.classifier(self.dropout2(hidden)) +
29
- self.classifier(self.dropout3(hidden))) / 3
30
- return logits
31
-
32
- # βœ… Instantiate and Load
33
- model = ScoringModel()
34
- model.load_state_dict(torch.load("./final_deberta_model/scoring_model.pt", map_location=device))
35
- model.to(device)
36
- model.eval()
37
-
38
- # βœ… Prediction function
39
- def predict(prompt, response_a, response_b):
40
- model.eval()
41
- with torch.no_grad():
42
- text_a = f"Prompt: {prompt} [SEP] {response_a}"
43
- text_b = f"Prompt: {prompt} [SEP] {response_b}"
44
-
45
- encoded_a = tokenizer(text_a, return_tensors='pt', padding="max_length", truncation=True, max_length=186)
46
- encoded_b = tokenizer(text_b, return_tensors='pt', padding="max_length", truncation=True, max_length=186)
47
-
48
- inputs_a = {k: v.to(device) for k, v in encoded_a.items()}
49
- inputs_b = {k: v.to(device) for k, v in encoded_b.items()}
50
-
51
- score_a = model(**inputs_a).squeeze()
52
- score_b = model(**inputs_b).squeeze()
53
-
54
- prob_a = torch.sigmoid(score_a).item()
55
- prob_b = torch.sigmoid(score_b).item()
56
-
57
- return prob_a, prob_b
58
-
59
- # βœ… Streamlit App
60
- st.title("πŸ” Fine-Tuned DeBERTa-v3-small: Response Quality Evaluator")
61
-
62
- prompt = st.text_area("Enter your prompt:", height=100)
63
- response_a = st.text_area("Enter Response A:", height=100)
64
- response_b = st.text_area("Enter Response B:", height=100)
65
-
66
- if st.button("Predict Better Response"):
67
- if prompt and response_a and response_b:
68
- prob_a, prob_b = predict(prompt, response_a, response_b)
69
-
70
- st.write(f"πŸ”΅ **Response A Probability:** {prob_a:.4f}")
71
- st.write(f"🟠 **Response B Probability:** {prob_b:.4f}")
72
-
73
- if prob_b > prob_a:
74
- st.success("βœ… Model predicts: **Response B** is better!")
75
- else:
76
- st.success("βœ… Model predicts: **Response A** is better!")
77
- else:
78
- st.warning("⚠️ Please fill in all fields before predicting.")
79
-
80
-
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoModel, PreTrainedTokenizerFast
7
+
8
+ # βœ… Device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # βœ… Load Tokenizer
12
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file="final_deberta_model/tokenizer.json")
13
+
14
+ # βœ… Define Model
15
+ class ScoringModel(nn.Module):
16
+ def __init__(self, base_model_name="microsoft/deberta-v3-small", dropout_rate=0.242):
17
+ super().__init__()
18
+ self.base = AutoModel.from_pretrained(base_model_name)
19
+ self.base.gradient_checkpointing_enable()
20
+ self.dropout1 = nn.Dropout(dropout_rate)
21
+ self.dropout2 = nn.Dropout(dropout_rate)
22
+ self.dropout3 = nn.Dropout(dropout_rate)
23
+ self.classifier = nn.Linear(self.base.config.hidden_size, 1)
24
+
25
+ def forward(self, input_ids, attention_mask):
26
+ hidden = self.base(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
27
+ logits = (self.classifier(self.dropout1(hidden)) +
28
+ self.classifier(self.dropout2(hidden)) +
29
+ self.classifier(self.dropout3(hidden))) / 3
30
+ return logits
31
+
32
+ # βœ… Instantiate and Load
33
+ model = ScoringModel()
34
+ model.load_state_dict(torch.load("final_deberta_model/scoring_model.pt", map_location=device))
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ # βœ… Prediction function
39
+ def predict(prompt, response_a, response_b):
40
+ model.eval()
41
+ with torch.no_grad():
42
+ text_a = f"Prompt: {prompt} [SEP] {response_a}"
43
+ text_b = f"Prompt: {prompt} [SEP] {response_b}"
44
+
45
+ encoded_a = tokenizer(text_a, return_tensors='pt', padding="max_length", truncation=True, max_length=186)
46
+ encoded_b = tokenizer(text_b, return_tensors='pt', padding="max_length", truncation=True, max_length=186)
47
+
48
+ inputs_a = {k: v.to(device) for k, v in encoded_a.items()}
49
+ inputs_b = {k: v.to(device) for k, v in encoded_b.items()}
50
+
51
+ score_a = model(**inputs_a).squeeze()
52
+ score_b = model(**inputs_b).squeeze()
53
+
54
+ prob_a = torch.sigmoid(score_a).item()
55
+ prob_b = torch.sigmoid(score_b).item()
56
+
57
+ return prob_a, prob_b
58
+
59
+ # βœ… Streamlit App
60
+ st.title("πŸ” Fine-Tuned DeBERTa-v3-small: Response Quality Evaluator")
61
+
62
+ prompt = st.text_area("Enter your prompt:", height=100)
63
+ response_a = st.text_area("Enter Response A:", height=100)
64
+ response_b = st.text_area("Enter Response B:", height=100)
65
+
66
+ if st.button("Predict Better Response"):
67
+ if prompt and response_a and response_b:
68
+ prob_a, prob_b = predict(prompt, response_a, response_b)
69
+
70
+ st.write(f"πŸ”΅ **Response A Probability:** {prob_a:.4f}")
71
+ st.write(f"🟠 **Response B Probability:** {prob_b:.4f}")
72
+
73
+ if prob_b > prob_a:
74
+ st.success("βœ… Model predicts: **Response B** is better!")
75
+ else:
76
+ st.success("βœ… Model predicts: **Response A** is better!")
77
+ else:
78
+ st.warning("⚠️ Please fill in all fields before predicting.")
79
+
80
+