eswardivi commited on
Commit
3592f5f
·
verified ·
1 Parent(s): 5a94d47

Updated to ModernBERT for similarity comparsion

Browse files
Files changed (1) hide show
  1. app.py +84 -69
app.py CHANGED
@@ -1,26 +1,14 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import (
4
- AutoModelForCausalLM,
5
  AutoTokenizer,
6
- TextIteratorStreamer,
7
  )
8
  import os
9
  from threading import Thread
10
  import spaces
11
  import time
12
 
13
- token = os.environ["HF_TOKEN"]
14
-
15
-
16
- model = AutoModelForCausalLM.from_pretrained(
17
- "microsoft/Phi-3-mini-4k-instruct", token=token,trust_remote_code=True
18
- )
19
- tok = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", token=token)
20
- terminators = [
21
- tok.eos_token_id,
22
- ]
23
-
24
  if torch.cuda.is_available():
25
  device = torch.device("cuda")
26
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
@@ -28,70 +16,97 @@ else:
28
  device = torch.device("cpu")
29
  print("Using CPU")
30
 
31
- model = model.to(device)
32
- # Dispatch Errors
33
-
34
 
35
- @spaces.GPU(duration=60)
36
- def chat(message, history, temperature,do_sample, max_tokens):
37
- chat = []
38
- for item in history:
39
- chat.append({"role": "user", "content": item[0]})
40
- if item[1] is not None:
41
- chat.append({"role": "assistant", "content": item[1]})
42
- chat.append({"role": "user", "content": message})
43
- messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
44
- model_inputs = tok([messages], return_tensors="pt").to(device)
45
- streamer = TextIteratorStreamer(
46
- tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
47
  )
48
- generate_kwargs = dict(
49
- model_inputs,
50
- streamer=streamer,
51
- max_new_tokens=max_tokens,
52
- do_sample=True,
53
- temperature=temperature,
54
- eos_token_id=terminators,
55
  )
56
-
57
- if temperature == 0:
58
- generate_kwargs['do_sample'] = False
59
-
60
- t = Thread(target=model.generate, kwargs=generate_kwargs)
61
- t.start()
62
 
63
- partial_text = ""
64
- for new_text in streamer:
65
- partial_text += new_text
66
- yield partial_text
67
 
68
 
69
- yield partial_text
 
70
 
 
 
71
 
72
- demo = gr.ChatInterface(
73
- fn=chat,
74
- examples=[["Write me a poem about Machine Learning."]],
75
- # multimodal=False,
76
- additional_inputs_accordion=gr.Accordion(
77
- label="⚙️ Parameters", open=False, render=False
78
- ),
79
- additional_inputs=[
80
- gr.Slider(
81
- minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ),
83
- gr.Checkbox(label="Sampling",value=True),
84
- gr.Slider(
85
- minimum=128,
86
- maximum=4096,
87
- step=1,
88
- value=512,
89
- label="Max new tokens",
90
- render=False,
 
 
 
91
  ),
92
  ],
93
- stop_btn="Stop Generation",
94
- title="Chat With LLMs",
95
- description="Now Running [microsoft/Phi-3-mini-4k-instruct](https://huggingface.com/microsoft/Phi-3-mini-4k-instruct)"
96
- )
97
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import (
4
+ AutoModel,
5
  AutoTokenizer,
 
6
  )
7
  import os
8
  from threading import Thread
9
  import spaces
10
  import time
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  if torch.cuda.is_available():
13
  device = torch.device("cuda")
14
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
 
16
  device = torch.device("cpu")
17
  print("Using CPU")
18
 
 
 
 
19
 
20
+ def mean_pooling(model_output, attention_mask):
21
+ token_embeddings = model_output[0]
22
+ input_mask_expanded = (
23
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
 
 
 
 
 
 
 
 
24
  )
25
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
26
+ input_mask_expanded.sum(1), min=1e-9
 
 
 
 
 
27
  )
 
 
 
 
 
 
28
 
29
+
30
+ def cls_pooling(model_output):
31
+ return model_output[0][:, 0]
 
32
 
33
 
34
+ @spaces.GPU
35
+ def get_embedding(text, use_mean_pooling, model_id):
36
 
37
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
38
+ model = AutoModel.from_pretrained(model_id, torch_dtype=torch.float16)
39
 
40
+ model = model.to(device)
41
+ inputs = tokenizer(
42
+ text, return_tensors="pt", padding=True, truncation=True, max_length=512
43
+ )
44
+ inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
45
+ with torch.no_grad():
46
+ model_output = model(**inputs)
47
+ if use_mean_pooling:
48
+ return mean_pooling(model_output, inputs["attention_mask"])
49
+ return cls_pooling(model_output)
50
+
51
+
52
+ def get_similarity(text1, text2, pooling_method, model_id):
53
+ use_mean_pooling = pooling_method == "Use Mean Pooling"
54
+ embedding1 = get_embedding(text1, use_mean_pooling, model_id)
55
+ embedding2 = get_embedding(text2, use_mean_pooling, model_id)
56
+ return torch.nn.functional.cosine_similarity(embedding1, embedding2).item()
57
+
58
+
59
+ gr.Interface(
60
+ get_similarity,
61
+ [
62
+ gr.Textbox(lines=7, label="Text 1"),
63
+ gr.Textbox(lines=7, label="Text 2"),
64
+ gr.Dropdown(
65
+ choices=["Use Mean Pooling", "Use CLS"],
66
+ value="Use Mean Pooling",
67
+ label="Pooling Method",
68
+ info="Mean Pooling: Averages all token embeddings (better for semantic similarity)\nCLS Pooling: Uses only the [CLS] token embedding (faster, might miss context)",
69
  ),
70
+ gr.Dropdown(
71
+ choices=[
72
+ "tasksource/ModernBERT-base-embed",
73
+ "tasksource/ModernBERT-base-nli",
74
+ "joe32140/ModernBERT-large-msmarco",
75
+ "answerdotai/ModernBERT-large",
76
+ "answerdotai/ModernBERT-base",
77
+ ],
78
+ value="answerdotai/ModernBERT-large",
79
+ label="Model",
80
+ info="Choose between the variants of ModernBERT \nMight take a few seconds to load the model",
81
  ),
82
  ],
83
+ gr.Textbox(label="Similarity"),
84
+ title="ModernBERT Similarity Demo",
85
+ description="Compute the similarity between two texts using ModernBERT. Choose between different pooling strategies for embedding generation.",
86
+ examples=[
87
+ [
88
+ "The quick brown fox jumps over the lazy dog",
89
+ "A swift brown fox leaps above a sleeping canine",
90
+ "Use Mean Pooling",
91
+ "answerdotai/ModernBERT-large"
92
+ ],
93
+ [
94
+ "I love programming in Python",
95
+ "I hate coding with Python",
96
+ "Use Mean Pooling",
97
+ "answerdotai/ModernBERT-large"
98
+ ],
99
+ [
100
+ "The weather is beautiful today",
101
+ "Machine learning models are improving rapidly",
102
+ "Use Mean Pooling",
103
+ "answerdotai/ModernBERT-large"
104
+ ],
105
+ [
106
+ "def calculate_sum(a, b):\n return a + b",
107
+ "def add_numbers(x, y):\n result = x + y\n return result",
108
+ "Use Mean Pooling",
109
+ "answerdotai/ModernBERT-large"
110
+ ]
111
+ ]
112
+ ).launch(share=True)