ashwath-vaithina-ibm commited on
Commit
d3ff0d7
·
verified ·
1 Parent(s): b23eb51

Update control/recommendation_handler.py

Browse files
Files changed (1) hide show
  1. control/recommendation_handler.py +129 -274
control/recommendation_handler.py CHANGED
@@ -29,17 +29,10 @@ import requests
29
  import json
30
  import math
31
  import re
32
- import warnings
33
  import pandas as pd
34
  import numpy as np
35
  from sklearn.metrics.pairwise import cosine_similarity
36
  import os
37
- #os.environ['TRANSFORMERS_CACHE'] ="./models/allmini/cache"
38
- import os.path
39
- from sentence_transformers import SentenceTransformer
40
- from umap import UMAP
41
- import tensorflow as tf
42
- from umap.parametric_umap import ParametricUMAP, load_ParametricUMAP
43
  from sentence_transformers import SentenceTransformer
44
 
45
  def populate_json(json_file_path = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json',
@@ -64,45 +57,31 @@ def populate_json(json_file_path = './prompt-sentences-main/prompt_sentences-all
64
  json_file = json_file_path
65
  if(os.path.isfile(existing_json_populated_file_path)):
66
  json_file = existing_json_populated_file_path
67
- try:
68
- prompt_json = json.load(open(json_file))
69
- json_error = None
70
- return prompt_json, json_error
71
- except Exception as e:
72
- json_error = e
73
- print(f'Error when loading sentences json file: {json_error}')
74
- prompt_json = None
75
- return prompt_json, json_error
76
-
77
- def query(texts, api_url, headers):
78
- """
79
- Function that requests embeddings for a given sentence.
80
-
81
- Args:
82
- texts: The sentence or entered prompt text.
83
- api_url: API url for HF request.
84
- headers: Content headers for HF request.
85
-
86
- Returns:
87
- A json with the sentence embeddings.
88
-
89
- Raises:
90
- Warning: Warns about sentences that have more
91
- than 256 words.
92
- """
93
- for t in texts:
94
- n_words = len(re.split(r"\s+", t))
95
- if(n_words > 256):
96
- # warning in case of prompts longer than 256 words
97
- warnings.warn("Warning: Sentence provided is longer than 256 words. Model all-MiniLM-L6-v2 expects sentences up to 256 words.")
98
- warnings.warn("Word count:{}".format(n_words))
99
- if('sentence-transformers/all-MiniLM-L6-v2' in api_url):
100
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
101
- out = model.encode(texts).tolist()
102
  else:
103
- response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}})
104
- out = response.json()
105
- return out
106
 
107
  def split_into_sentences(prompt):
108
  """
@@ -123,27 +102,6 @@ def split_into_sentences(prompt):
123
  sentences = re.split(r'(?<=[.!?]) +', prompt)
124
  return sentences
125
 
126
-
127
- def get_similarity(embedding1, embedding2):
128
- """
129
- Function that returns cosine similarity between
130
- two embeddings.
131
-
132
- Args:
133
- embedding1: first embedding.
134
- embedding2: second embedding.
135
-
136
- Returns:
137
- The similarity value.
138
-
139
- Raises:
140
- Nothing.
141
- """
142
- v1 = np.array( embedding1 ).reshape( 1, -1 )
143
- v2 = np.array( embedding2 ).reshape( 1, -1 )
144
- similarity = cosine_similarity( v1, v2 )
145
- return similarity[0, 0]
146
-
147
  def get_distance(embedding1, embedding2):
148
  """
149
  Function that returns euclidean distance between
@@ -181,17 +139,24 @@ def sort_by_similarity(e):
181
  """
182
  return e['similarity']
183
 
184
- def recommend_prompt(prompt, prompt_json, api_url, headers, add_lower_threshold = 0.3,
185
- add_upper_threshold = 0.5, remove_lower_threshold = 0.1,
186
- remove_upper_threshold = 0.5, model_id = 'sentence-transformers/all-minilm-l6-v2'):
 
 
 
 
 
 
 
187
  """
188
  Function that recommends prompts additions or removals.
189
 
190
  Args:
191
  prompt: The entered prompt text.
192
  prompt_json: Json file populated with embeddings.
193
- api_url: API url for HF request.
194
- headers: Content headers for HF request.
195
  add_lower_threshold: Lower threshold for sentence addition,
196
  the default value is 0.3.
197
  add_upper_threshold: Upper threshold for sentence addition,
@@ -200,7 +165,8 @@ def recommend_prompt(prompt, prompt_json, api_url, headers, add_lower_threshold
200
  the default value is 0.3.
201
  remove_upper_threshold: Upper threshold for sentence removal,
202
  the default value is 0.5.
203
- model_id: Id of the model, the default value is all-minilm-l6-v2 movel.
 
204
 
205
  Returns:
206
  Prompt values to add or remove.
@@ -208,22 +174,9 @@ def recommend_prompt(prompt, prompt_json, api_url, headers, add_lower_threshold
208
  Raises:
209
  Nothing.
210
  """
211
- if(model_id == 'baai/bge-large-en-v1.5' ):
212
- json_file = './prompt-sentences-main/prompt_sentences-bge-large-en-v1.5.json'
213
- umap_folder = './models/umap/BAAI/bge-large-en-v1.5/'
214
- elif(model_id == 'intfloat/multilingual-e5-large'):
215
- json_file = './prompt-sentences-main/prompt_sentences-multilingual-e5-large.json'
216
- umap_folder = './models/umap/intfloat/multilingual-e5-large/'
217
- else: # fall back to all-minilm as default
218
- json_file = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json'
219
- umap_folder = './models/umap/sentence-transformers/all-MiniLM-L6-v2/'
220
-
221
- # Loading the encoder and config separately due to a bug
222
- encoder = tf.keras.models.load_model( umap_folder )
223
- with open( f"{umap_folder}umap_config.json", "r" ) as f:
224
- config = json.load( f )
225
- umap_model = ParametricUMAP( encoder=encoder, **config )
226
- prompt_json = json.load( open( json_file ) )
227
 
228
  # Output initialization
229
  out, out['input'], out['add'], out['remove'] = {}, {}, {}, {}
@@ -236,63 +189,85 @@ def recommend_prompt(prompt, prompt_json, api_url, headers, add_lower_threshold
236
 
237
  # Recommendation of values to add to the current prompt
238
  # Using only the last sentence for the add recommendation
239
- input_embedding = query(input_sentences[-1], api_url, headers)
240
- for v in prompt_json['positive_values']:
241
- # Dealing with values without prompts and makinig sure they have the same dimensions
242
- if(len(v['centroid']) == len(input_embedding)):
243
- if(get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(v['centroid'])) > add_lower_threshold):
244
- closer_prompt = -1
245
- for p in v['prompts']:
246
- d_prompt = get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(p['embedding']))
247
- # The sentence_threshold is being used as a ceiling meaning that for high similarities the sentence/value might already be presente in the prompt
248
- # So, we don't want to recommend adding something that is already there
249
- if(d_prompt > closer_prompt and d_prompt > add_lower_threshold and d_prompt < add_upper_threshold):
250
- closer_prompt = d_prompt
251
- items_to_add.append({
252
- 'value': v['label'],
253
- 'prompt': p['text'],
254
- 'similarity': d_prompt,
255
- 'x': p['x'],
256
- 'y': p['y']})
257
- out['add'] = items_to_add
258
 
259
- # Recommendation of values to remove from the current prompt
260
- i = 0
 
261
 
262
- # Recommendation of values to remove from the current prompt
263
- for sentence in input_sentences:
264
- input_embedding = query(sentence, api_url, headers) # remote
265
- # Obtaining XY coords for input sentences from a parametric UMAP model
266
- if(len(prompt_json['negative_values'][0]['centroid']) == len(input_embedding) and sentence != ''):
267
- embeddings_umap = umap_model.transform(tf.expand_dims(pd.DataFrame(input_embedding), axis=0))
268
- input_items.append({
269
- 'sentence': sentence,
270
- 'x': str(embeddings_umap[0][0]),
271
- 'y': str(embeddings_umap[0][1])
272
- })
273
 
274
- for v in prompt_json['negative_values']:
275
  # Dealing with values without prompts and makinig sure they have the same dimensions
276
- if(len(v['centroid']) == len(input_embedding)):
277
- if(get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(v['centroid'])) > remove_lower_threshold):
278
- closer_prompt = -1
279
- for p in v['prompts']:
280
- d_prompt = get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(p['embedding']))
281
- # A more restrict threshold is used here to prevent false positives
282
- # The sentence_threshold is being used to indicate that there must be a sentence in the prompt that is similiar to one of our adversarial prompts
283
- # So, yes, we want to recommend the removal of something adversarial we've found
284
- if(d_prompt > closer_prompt and d_prompt > remove_upper_threshold):
285
- closer_prompt = d_prompt
286
- items_to_remove.append({
287
- 'value': v['label'],
288
- 'sentence': sentence,
289
- 'sentence_index': i,
290
- 'closest_harmful_sentence': p['text'],
291
- 'similarity': d_prompt,
292
- 'x': p['x'],
293
- 'y': p['y']})
294
- out['remove'] = items_to_remove
295
- i += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  out['input'] = input_items
298
 
@@ -315,14 +290,19 @@ def recommend_prompt(prompt, prompt_json, api_url, headers, add_lower_threshold
315
  out['remove'] = out['remove'][0:5]
316
  return out
317
 
318
- def get_thresholds(prompts, prompt_json, api_url, headers, model_id = 'sentence-transformers/all-minilm-l6-v2'):
 
 
 
 
319
  """
320
  Function that recommends thresholds given an array of prompts.
321
 
322
  Args:
323
  prompts: The array with samples of prompts to be used in the system.
324
  prompt_json: Sentences to be forwarded to the recommendation endpoint.
325
- model_id: Id of the model, the default value is all-minilm-l6-v2 model.
 
326
 
327
  Returns:
328
  A map with thresholds for the sample prompts and the informed model.
@@ -330,14 +310,15 @@ def get_thresholds(prompts, prompt_json, api_url, headers, model_id = 'sentence-
330
  Raises:
331
  Nothing.
332
  """
333
- # Array limits for retrieving the thresholds
334
- # if( len( prompts ) < 10 or len( prompts ) > 30 ):
335
- # return -1
 
336
  add_similarities = []
337
  remove_similarities = []
338
 
339
  for p_id, p in enumerate(prompts):
340
- out = recommend_prompt(p, prompt_json, api_url, headers, 0, 1, 0, 0, model_id) # Wider possible range
341
 
342
  for r in out['add']:
343
  add_similarities.append(r['similarity'])
@@ -353,130 +334,4 @@ def get_thresholds(prompts, prompt_json, api_url, headers, model_id = 'sentence-
353
  thresholds['remove_lower_threshold'] = round(remove_similarities_df.describe([.1]).loc['10%', 'similarity'], 1)
354
  thresholds['remove_higher_threshold'] = round(remove_similarities_df.describe([.9]).loc['90%', 'similarity'], 1)
355
 
356
- return thresholds
357
-
358
- def recommend_local(prompt, prompt_json, model_id, model_path = './models/all-MiniLM-L6-v2/', add_lower_threshold = 0.3,
359
- add_upper_threshold = 0.5, remove_lower_threshold = 0.1,
360
- remove_upper_threshold = 0.5):
361
- """
362
- Function that recommends prompts additions or removals
363
- using a local model.
364
-
365
- Args:
366
- prompt: The entered prompt text.
367
- prompt_json: Json file populated with embeddings.
368
- model_id: Id of the local model.
369
- model_path: Path to the local model.
370
-
371
- Returns:
372
- Prompt values to add or remove.
373
-
374
- Raises:
375
- Nothing.
376
- """
377
- if(model_id == 'baai/bge-large-en-v1.5' ):
378
- json_file = './prompt-sentences-main/prompt_sentences-bge-large-en-v1.5.json'
379
- umap_folder = './models/umap/BAAI/bge-large-en-v1.5/'
380
- elif(model_id == 'intfloat/multilingual-e5-large'):
381
- json_file = './prompt-sentences-main/prompt_sentences-multilingual-e5-large.json'
382
- umap_folder = './models/umap/intfloat/multilingual-e5-large/'
383
- else: # fall back to all-minilm as default
384
- json_file = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json'
385
- umap_folder = './models/umap/sentence-transformers/all-MiniLM-L6-v2/'
386
-
387
- # Loading the encoder and config separately due to a bug
388
- encoder = tf.keras.models.load_model( umap_folder )
389
- with open( f"{umap_folder}umap_config.json", "r" ) as f:
390
- config = json.load( f )
391
- umap_model = ParametricUMAP( encoder=encoder, **config )
392
- prompt_json = json.load( open( json_file ) )
393
-
394
- # Output initialization
395
- out, out['input'], out['add'], out['remove'] = {}, {}, {}, {}
396
- input_items, items_to_add, items_to_remove = [], [], []
397
-
398
- # Spliting prompt into sentences
399
- input_sentences = split_into_sentences(prompt)
400
-
401
- # Recommendation of values to add to the current prompt
402
- # Using only the last sentence for the add recommendation
403
- model = SentenceTransformer(model_path)
404
- input_embedding = model.encode(input_sentences[-1])
405
-
406
- for v in prompt_json['positive_values']:
407
- # Dealing with values without prompts and makinig sure they have the same dimensions
408
- if(len(v['centroid']) == len(input_embedding)):
409
- if(get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(v['centroid'])) > add_lower_threshold):
410
- closer_prompt = -1
411
- for p in v['prompts']:
412
- d_prompt = get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(p['embedding']))
413
- # The sentence_threshold is being used as a ceiling meaning that for high similarities the sentence/value might already be presente in the prompt
414
- # So, we don't want to recommend adding something that is already there
415
- if(d_prompt > closer_prompt and d_prompt > add_lower_threshold and d_prompt < add_upper_threshold):
416
- closer_prompt = d_prompt
417
- items_to_add.append({
418
- 'value': v['label'],
419
- 'prompt': p['text'],
420
- 'similarity': d_prompt,
421
- 'x': p['x'],
422
- 'y': p['y']})
423
- out['add'] = items_to_add
424
-
425
- # Recommendation of values to remove from the current prompt
426
- i = 0
427
-
428
- # Recommendation of values to remove from the current prompt
429
- for sentence in input_sentences:
430
- input_embedding = model.encode(sentence) # local
431
- # Obtaining XY coords for input sentences from a parametric UMAP model
432
- if(len(prompt_json['negative_values'][0]['centroid']) == len(input_embedding) and sentence != ''):
433
- embeddings_umap = umap_model.transform(tf.expand_dims(pd.DataFrame(input_embedding), axis=0))
434
- input_items.append({
435
- 'sentence': sentence,
436
- 'x': str(embeddings_umap[0][0]),
437
- 'y': str(embeddings_umap[0][1])
438
- })
439
-
440
- for v in prompt_json['negative_values']:
441
- # Dealing with values without prompts and makinig sure they have the same dimensions
442
- if(len(v['centroid']) == len(input_embedding)):
443
- if(get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(v['centroid'])) > remove_lower_threshold):
444
- closer_prompt = -1
445
- for p in v['prompts']:
446
- d_prompt = get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(p['embedding']))
447
- # A more restrict threshold is used here to prevent false positives
448
- # The sentence_threhold is being used to indicate that there must be a sentence in the prompt that is similiar to one of our adversarial prompts
449
- # So, yes, we want to recommend the revolval of something adversarial we've found
450
- if(d_prompt > closer_prompt and d_prompt > remove_upper_threshold):
451
- closer_prompt = d_prompt
452
- items_to_remove.append({
453
- 'value': v['label'],
454
- 'sentence': sentence,
455
- 'sentence_index': i,
456
- 'closest_harmful_sentence': p['text'],
457
- 'similarity': d_prompt,
458
- 'x': p['x'],
459
- 'y': p['y']})
460
- out['remove'] = items_to_remove
461
- i += 1
462
-
463
- out['input'] = input_items
464
-
465
- out['add'] = sorted(out['add'], key=sort_by_similarity, reverse=True)
466
- values_map = {}
467
- for item in out['add'][:]:
468
- if(item['value'] in values_map):
469
- out['add'].remove(item)
470
- else:
471
- values_map[item['value']] = item['similarity']
472
- out['add'] = out['add'][0:5]
473
-
474
- out['remove'] = sorted(out['remove'], key=sort_by_similarity, reverse=True)
475
- values_map = {}
476
- for item in out['remove'][:]:
477
- if(item['value'] in values_map):
478
- out['remove'].remove(item)
479
- else:
480
- values_map[item['value']] = item['similarity']
481
- out['remove'] = out['remove'][0:5]
482
- return out
 
29
  import json
30
  import math
31
  import re
 
32
  import pandas as pd
33
  import numpy as np
34
  from sklearn.metrics.pairwise import cosine_similarity
35
  import os
 
 
 
 
 
 
36
  from sentence_transformers import SentenceTransformer
37
 
38
  def populate_json(json_file_path = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json',
 
57
  json_file = json_file_path
58
  if(os.path.isfile(existing_json_populated_file_path)):
59
  json_file = existing_json_populated_file_path
60
+ prompt_json = json.load(open(json_file))
61
+ return prompt_json
62
+
63
+ def get_embedding_func(inference = 'huggingface', **kwargs):
64
+ if inference == 'local':
65
+ if 'model_id' not in kwargs:
66
+ raise TypeError("Missing required argument: model_id")
67
+ model = SentenceTransformer(kwargs['model_id'])
68
+
69
+ def embedding_fn(texts):
70
+ return model.encode(texts).tolist()
71
+
72
+ elif inference == 'huggingface':
73
+ if 'api_url' not in kwargs:
74
+ raise TypeError("Missing required argument: api_url")
75
+ if 'headers' not in kwargs:
76
+ raise TypeError("Missing required argument: headers")
77
+
78
+ def embedding_fn(texts):
79
+ response = requests.post(kwargs['api_url'], headers=kwargs['headers'], json={"inputs": texts, "options":{"wait_for_model":True}})
80
+ return response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  else:
82
+ raise ValueError(f"Inference type {inference} is not supported. Please choose one of ['local', 'huggingface'].")
83
+
84
+ return embedding_fn
85
 
86
  def split_into_sentences(prompt):
87
  """
 
102
  sentences = re.split(r'(?<=[.!?]) +', prompt)
103
  return sentences
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def get_distance(embedding1, embedding2):
106
  """
107
  Function that returns euclidean distance between
 
139
  """
140
  return e['similarity']
141
 
142
+ def recommend_prompt(
143
+ prompt,
144
+ prompt_json,
145
+ embedding_fn = None,
146
+ add_lower_threshold = 0.3,
147
+ add_upper_threshold = 0.5,
148
+ remove_lower_threshold = 0.1,
149
+ remove_upper_threshold = 0.5,
150
+ umap_model = None
151
+ ):
152
  """
153
  Function that recommends prompts additions or removals.
154
 
155
  Args:
156
  prompt: The entered prompt text.
157
  prompt_json: Json file populated with embeddings.
158
+ embedding_fn: Embedding function to convert prompt sentences into embeddings.
159
+ If None, uses all-MiniLM-L6-v2 run locally.
160
  add_lower_threshold: Lower threshold for sentence addition,
161
  the default value is 0.3.
162
  add_upper_threshold: Upper threshold for sentence addition,
 
165
  the default value is 0.3.
166
  remove_upper_threshold: Upper threshold for sentence removal,
167
  the default value is 0.5.
168
+ umap_model: Umap model used for visualization.
169
+ If None, the projected embeddings of input sentences will not be returned.
170
 
171
  Returns:
172
  Prompt values to add or remove.
 
174
  Raises:
175
  Nothing.
176
  """
177
+ if embedding_fn is None:
178
+ # Use all-MiniLM-L6-v2 locally by default
179
+ embedding_fn = get_embedding_func('local', model_id='sentence-transformers/all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # Output initialization
182
  out, out['input'], out['add'], out['remove'] = {}, {}, {}, {}
 
189
 
190
  # Recommendation of values to add to the current prompt
191
  # Using only the last sentence for the add recommendation
192
+ input_embedding = embedding_fn(input_sentences[-1])
193
+ input_embedding = np.array(input_embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ sentence_embeddings = np.array(
196
+ [v['centroid'] for v in prompt_json['positive_values']]
197
+ )
198
 
199
+ similarities_positive_sent = cosine_similarity(np.expand_dims(input_embedding, axis=0), sentence_embeddings)[0, :]
 
 
 
 
 
 
 
 
 
 
200
 
201
+ for value_idx, v in enumerate(prompt_json['positive_values']):
202
  # Dealing with values without prompts and makinig sure they have the same dimensions
203
+ if(len(v['centroid']) != len(input_embedding)):
204
+ continue
205
+
206
+ if(similarities_positive_sent[value_idx] < add_lower_threshold):
207
+ continue
208
+
209
+ value_sents_similarity = cosine_similarity(
210
+ np.expand_dims(input_embedding, axis=0),
211
+ np.array([p['embedding'] for p in v['prompts']])
212
+ )[0, :]
213
+ closer_prompt_idxs = np.nonzero((add_lower_threshold < value_sents_similarity) & (value_sents_similarity < add_upper_threshold))[0]
214
+
215
+ for idx in closer_prompt_idxs:
216
+ items_to_add.append({
217
+ 'value': v['label'],
218
+ 'prompt': v['prompts'][idx]['text'],
219
+ 'similarity': value_sents_similarity[idx],
220
+ 'x': v['prompts'][idx]['x'],
221
+ 'y': v['prompts'][idx]['y']
222
+ })
223
+ out['add'] = items_to_add
224
+
225
+ inp_sentence_embeddings = np.array([embedding_fn(sent) for sent in input_sentences])
226
+ pairwise_similarities = cosine_similarity(
227
+ inp_sentence_embeddings,
228
+ np.array([v['centroid'] for v in prompt_json['negative_values']])
229
+ )
230
+
231
+ # Recommendation of values to remove from the current prompt
232
+ for sent_idx, sentence in enumerate(input_sentences):
233
+ input_embedding = inp_sentence_embeddings[sent_idx]
234
+ if umap_model:
235
+ # Obtaining XY coords for input sentences from a parametric UMAP model
236
+ if(len(prompt_json['negative_values'][0]['centroid']) == len(input_embedding) and sentence != ''):
237
+ embeddings_umap = umap_model.transform(np.expand_dims(pd.DataFrame(input_embedding).squeeze(), axis=0))
238
+ input_items.append({
239
+ 'sentence': sentence,
240
+ 'x': str(embeddings_umap[0][0]),
241
+ 'y': str(embeddings_umap[0][1])
242
+ })
243
+
244
+ for value_idx, v in enumerate(prompt_json['negative_values']):
245
+ # Dealing with values without prompts and making sure they have the same dimensions
246
+ if(len(v['centroid']) != len(input_embedding)):
247
+ continue
248
+ if(pairwise_similarities[sent_idx][value_idx] < remove_lower_threshold):
249
+ continue
250
+
251
+ # A more restrict threshold is used here to prevent false positives
252
+ # The sentence_threshold is being used to indicate that there must be a sentence in the prompt that is similiar to one of our adversarial prompts
253
+ # So, yes, we want to recommend the removal of something adversarial we've found
254
+ value_sents_similarity = cosine_similarity(
255
+ np.expand_dims(input_embedding, axis=0),
256
+ np.array([p['embedding'] for p in v['prompts']])
257
+ )[0, :]
258
+ closer_prompt_idxs = np.nonzero(value_sents_similarity > remove_upper_threshold)[0]
259
+
260
+ for idx in closer_prompt_idxs:
261
+ items_to_remove.append({
262
+ 'value': v['label'],
263
+ 'sentence': sentence,
264
+ 'sentence_index': sent_idx,
265
+ 'closest_harmful_sentence': v['prompts'][idx]['text'],
266
+ 'similarity': value_sents_similarity[idx],
267
+ 'x': v['prompts'][idx]['x'],
268
+ 'y': v['prompts'][idx]['y']
269
+ })
270
+ out['remove'] = items_to_remove
271
 
272
  out['input'] = input_items
273
 
 
290
  out['remove'] = out['remove'][0:5]
291
  return out
292
 
293
+ def get_thresholds(
294
+ prompts,
295
+ prompt_json,
296
+ embedding_fn = None,
297
+ ):
298
  """
299
  Function that recommends thresholds given an array of prompts.
300
 
301
  Args:
302
  prompts: The array with samples of prompts to be used in the system.
303
  prompt_json: Sentences to be forwarded to the recommendation endpoint.
304
+ embedding_fn: Embedding function to convert prompt sentences into embeddings.
305
+ If None, uses all-MiniLM-L6-v2 run locally.
306
 
307
  Returns:
308
  A map with thresholds for the sample prompts and the informed model.
 
310
  Raises:
311
  Nothing.
312
  """
313
+
314
+ if embedding_fn is None:
315
+ embedding_fn = get_embedding_func('local', model_id='sentence-transformers/all-MiniLM-L6-v2')
316
+
317
  add_similarities = []
318
  remove_similarities = []
319
 
320
  for p_id, p in enumerate(prompts):
321
+ out = recommend_prompt(p, prompt_json, embedding_fn, 0, 1, 0, 0, None) # Wider possible range
322
 
323
  for r in out['add']:
324
  add_similarities.append(r['similarity'])
 
334
  thresholds['remove_lower_threshold'] = round(remove_similarities_df.describe([.1]).loc['10%', 'similarity'], 1)
335
  thresholds['remove_higher_threshold'] = round(remove_similarities_df.describe([.9]).loc['90%', 'similarity'], 1)
336
 
337
+ return thresholds