ZivK commited on
Commit
1e6d148
·
1 Parent(s): 3e0c41e

Changed to a chat app

Browse files
Files changed (1) hide show
  1. app.py +44 -17
app.py CHANGED
@@ -4,43 +4,70 @@ import gradio as gr
4
  from model import SmolLM
5
  from huggingface_hub import hf_hub_download
6
 
 
 
7
  hf_token = os.environ.get("HF_TOKEN")
8
  repo_id = "ZivK/smollm2-end-of-sentence"
9
  model_options = {
10
  "Word-level Model": "word_model.ckpt",
11
  "Token-level Model": "token_model.ckpt"
12
  }
 
13
  models = {}
14
  for model_name, filename in model_options.items():
15
  print(f"Loading {model_name} ...")
16
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
17
- models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path)
18
  models[model_name].eval()
19
 
20
 
21
  def classify_sentence(sentence, model_choice):
22
  model = models[model_choice]
23
- inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
24
  with torch.no_grad():
25
  logits = model(inputs)
26
  confidence = torch.sigmoid(logits).item() * 100
27
- confidence_to_display = confidence if confidence > 50.0 else 100 - confidence
28
- label = "Complete" if confidence > 50.0 else "Incomplete"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- return f"{label} Sentence\nConfidence: {confidence_to_display:.2f}"
 
 
 
 
31
 
 
 
 
 
32
 
33
- # Create the Gradio interface
34
- interface = gr.Interface(
35
- fn=classify_sentence,
36
- inputs=[
37
- gr.Textbox(lines=1, placeholder="Enter your sentence here..."),
38
- gr.Dropdown(choices=list(model_options.keys()), label="Select Model")
39
- ],
40
- outputs="text",
41
- title="Complete Sentence Classifier",
42
- description="## Enter a sentence to determine if it's complete or if it might be cut off"
43
- )
44
 
45
  # Launch the demo
46
- interface.launch()
 
4
  from model import SmolLM
5
  from huggingface_hub import hf_hub_download
6
 
7
+
8
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
9
  hf_token = os.environ.get("HF_TOKEN")
10
  repo_id = "ZivK/smollm2-end-of-sentence"
11
  model_options = {
12
  "Word-level Model": "word_model.ckpt",
13
  "Token-level Model": "token_model.ckpt"
14
  }
15
+ label_map = {0: "Incomplete", 1: "Complete"}
16
  models = {}
17
  for model_name, filename in model_options.items():
18
  print(f"Loading {model_name} ...")
19
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
20
+ models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path).to(device)
21
  models[model_name].eval()
22
 
23
 
24
  def classify_sentence(sentence, model_choice):
25
  model = models[model_choice]
26
+ inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(device)
27
  with torch.no_grad():
28
  logits = model(inputs)
29
  confidence = torch.sigmoid(logits).item() * 100
30
+ predicted_class = 1 if confidence > 50.0 else 0
31
+ return label_map[predicted_class], confidence
32
+
33
+
34
+ def chatbot_reply(history, user_input, model_choice):
35
+ classification, confidence = classify_sentence(user_input, model_choice)
36
+
37
+ if classification == "Incomplete":
38
+ bot_message = "It looks like you may have stopped mid-sentence. Please finish your thought! Confidence: " + \
39
+ f"{(100.0-confidence):.2f}"
40
+ else:
41
+ bot_message = f"Thank you for sharing a complete sentence! Confidence: {confidence:.2f}"
42
+
43
+ # Append the user message and bot response to the conversation history
44
+ history.append((user_input, bot_message))
45
+ return history, ""
46
+
47
+
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown(
50
+ "## Sentence Completeness Chatbot\nType a message and see if the model thinks it’s complete or incomplete!")
51
+
52
+ # 3. Create a stateful Chatbot plus an input textbox
53
+ chatbot = gr.Chatbot(label="Chat with Me!")
54
+ state = gr.State([]) # This will store the conversation history
55
 
56
+ with gr.Row():
57
+ user_input = gr.Textbox(show_label=False, placeholder="Type your sentence here...")
58
+ submit_btn = gr.Button("Submit")
59
+ with gr.Row():
60
+ model_input = gr.Dropdown(choices=list(model_options.keys()), label="Select Model")
61
 
62
+ # 4. Bind the chatbot function
63
+ submit_btn.click(fn=chatbot_reply,
64
+ inputs=[state, user_input, model_input],
65
+ outputs=[chatbot, user_input])
66
 
67
+ # We also want pressing Enter to do the same as clicking submit
68
+ user_input.submit(fn=chatbot_reply,
69
+ inputs=[state, user_input, model_input],
70
+ outputs=[chatbot, user_input])
 
 
 
 
 
 
 
71
 
72
  # Launch the demo
73
+ demo.launch()