ashwath-vaithina-ibm commited on
Commit
34b9503
·
verified ·
1 Parent(s): d3ff0d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -37,6 +37,7 @@ import logging
37
  import uuid
38
  import json
39
  import os
 
40
 
41
  app = Flask(__name__)
42
 
@@ -66,16 +67,31 @@ def index():
66
  @app.route("/recommend", methods=['GET'])
67
  @cross_origin()
68
  def recommend():
69
- user_ip = request.remote_addr
70
- hf_token, hf_url = get_credentials.get_hf_credentials()
71
- api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
72
  prompt_json = recommendation_handler.populate_json()
73
  args = request.args
 
74
  prompt = args.get("prompt")
75
- recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json,
76
- api_url, headers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
78
  logger.info(f'RECOMMEND ROUTE - request: {prompt} response: {recommendation_json}')
 
79
  return recommendation_json
80
 
81
  @app.route("/get_thresholds", methods=['GET'])
@@ -84,24 +100,27 @@ def get_thresholds():
84
  hf_token, hf_url = get_credentials.get_hf_credentials()
85
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
86
  prompt_json = recommendation_handler.populate_json()
87
- model_id = 'sentence-transformers/all-minilm-l6-v2'
88
  args = request.args
89
- #print("args list = ", args)
90
  prompt = args.get("prompt")
91
- thresholds_json = recommendation_handler.get_thresholds(prompt, prompt_json, api_url,
92
- headers, model_id)
93
  return thresholds_json
94
 
95
  @app.route("/recommend_local", methods=['GET'])
96
  @cross_origin()
97
  def recommend_local():
98
- model_id, model_path = save_model.save_model()
99
- prompt_json = recommendation_handler.populate_json()
100
  args = request.args
101
  print("args list = ", args)
102
  prompt = args.get("prompt")
103
- local_recommendation_json = recommendation_handler.recommend_local(prompt, prompt_json,
104
- model_id, model_path)
 
 
 
 
 
 
105
  return local_recommendation_json
106
 
107
  @app.route("/log", methods=['POST'])
@@ -149,5 +168,5 @@ def demo_inference():
149
  return "Model Inference failed.", 500
150
 
151
  if __name__=='__main__':
152
- debug_mode = os.getenv('FLASK_DEBUG', 'True').lower() in ['true', '1', 't']
153
  app.run(host='0.0.0.0', port='8080', debug=debug_mode)
 
37
  import uuid
38
  import json
39
  import os
40
+ import pickle
41
 
42
  app = Flask(__name__)
43
 
 
67
  @app.route("/recommend", methods=['GET'])
68
  @cross_origin()
69
  def recommend():
70
+ model_id, _ =save_model.save_model()
 
 
71
  prompt_json = recommendation_handler.populate_json()
72
  args = request.args
73
+ print("args list = ", args)
74
  prompt = args.get("prompt")
75
+
76
+ umap_model_file = './models/umap/sentence-transformers/all-MiniLM-L6-v2/umap.pkl'
77
+ with open(umap_model_file, 'rb') as f:
78
+ umap_model = pickle.load(f)
79
+
80
+ # Embeddings from HF API
81
+ # hf_token, hf_url = get_credentials.get_hf_credentials()
82
+ # api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
83
+ # api_url = f'https://router.huggingface.co/hf-inference/models/{model_id}/pipeline/feature-extraction'
84
+ # embedding_fn = recommendation_handler.get_embedding_func(inference='huggingface', model_id=model_id, api_url= api_url, headers = headers)
85
+
86
+ # Embeddings from local inference
87
+ embedding_fn = recommendation_handler.get_embedding_func(inference='local', model_id=model_id)
88
+
89
+ recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json, embedding_fn, umap_model=umap_model)
90
+
91
+ user_ip = request.remote_addr
92
  logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
93
  logger.info(f'RECOMMEND ROUTE - request: {prompt} response: {recommendation_json}')
94
+
95
  return recommendation_json
96
 
97
  @app.route("/get_thresholds", methods=['GET'])
 
100
  hf_token, hf_url = get_credentials.get_hf_credentials()
101
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
102
  prompt_json = recommendation_handler.populate_json()
 
103
  args = request.args
 
104
  prompt = args.get("prompt")
105
+ thresholds_json = recommendation_handler.get_thresholds(prompt, prompt_json, api_url, headers)
 
106
  return thresholds_json
107
 
108
  @app.route("/recommend_local", methods=['GET'])
109
  @cross_origin()
110
  def recommend_local():
111
+ model_id, _ = save_model.save_model()
112
+ prompt_json, _ = recommendation_handler.populate_json()
113
  args = request.args
114
  print("args list = ", args)
115
  prompt = args.get("prompt")
116
+
117
+ umap_model_file = './models/umap/sentence-transformers/all-MiniLM-L6-v2/umap.pkl'
118
+ with open(umap_model_file, 'rb') as f:
119
+ umap_model = pickle.load(f)
120
+
121
+ embedding_fn = recommendation_handler.get_embedding_func(inference='local', model_id=model_id)
122
+
123
+ local_recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json, embedding_fn, umap_model=umap_model)
124
  return local_recommendation_json
125
 
126
  @app.route("/log", methods=['POST'])
 
168
  return "Model Inference failed.", 500
169
 
170
  if __name__=='__main__':
171
+ debug_mode = os.getenv('FLASK_DEBUG', 'False').lower() in ['true', '1', 't']
172
  app.run(host='0.0.0.0', port='8080', debug=debug_mode)