ashwath-vaithina-ibm commited on
Commit
e08c0b3
·
verified ·
1 Parent(s): 6cc4d08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -52
app.py CHANGED
@@ -26,19 +26,20 @@ __license__ = "Apache 2.0"
26
  __version__ = "0.0.1"
27
 
28
 
29
- from flask import Flask, request, jsonify, render_template
30
  from flask_cors import CORS, cross_origin
31
  from flask_restful import Resource, Api, reqparse
32
  import control.recommendation_handler as recommendation_handler
33
- from helpers import get_credentials, authenticate_api, save_model
34
  import config as cfg
 
35
  import logging
36
  import uuid
37
  import json
38
  import os
39
- import requests
40
 
41
- app = Flask(__name__, static_folder='static')
42
 
43
  # configure logging
44
  logging.basicConfig(
@@ -66,42 +67,60 @@ def index():
66
  @app.route("/recommend", methods=['GET'])
67
  @cross_origin()
68
  def recommend():
69
- user_ip = request.remote_addr
70
- hf_token, hf_url = get_credentials.get_credentials()
71
- api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
72
  prompt_json = recommendation_handler.populate_json()
73
  args = request.args
 
74
  prompt = args.get("prompt")
75
- recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json,
76
- api_url, headers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
78
  logger.info(f'RECOMMEND ROUTE - request: {prompt} response: {recommendation_json}')
 
79
  return recommendation_json
80
 
81
  @app.route("/get_thresholds", methods=['GET'])
82
  @cross_origin()
83
  def get_thresholds():
84
- hf_token, hf_url = get_credentials.get_credentials()
85
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
86
  prompt_json = recommendation_handler.populate_json()
87
- model_id = 'sentence-transformers/all-minilm-l6-v2'
88
  args = request.args
89
- #print("args list = ", args)
90
  prompt = args.get("prompt")
91
- thresholds_json = recommendation_handler.get_thresholds(prompt, prompt_json, api_url,
92
- headers, model_id)
93
  return thresholds_json
94
 
95
  @app.route("/recommend_local", methods=['GET'])
96
  @cross_origin()
97
  def recommend_local():
98
- model_id, model_path = save_model.save_model()
99
- prompt_json = recommendation_handler.populate_json()
100
  args = request.args
101
  print("args list = ", args)
102
  prompt = args.get("prompt")
103
- local_recommendation_json = recommendation_handler.recommend_local(prompt, prompt_json,
104
- model_id, model_path)
 
 
 
 
 
 
105
  return local_recommendation_json
106
 
107
  @app.route("/log", methods=['POST'])
@@ -127,51 +146,27 @@ def log():
127
  @cross_origin()
128
  def demo_inference():
129
  args = request.args
130
- # model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
131
- model_id = args.get('model_id', default="meta-llama/Llama-4-Scout-17B-16E-Instruct")
132
- temperature = 0.5
133
- max_new_tokens = 1000
134
 
135
- hf_token, hf_url = get_credentials.get_credentials()
 
 
 
136
 
137
  prompt = args.get('prompt')
138
 
139
- API_URL = "https://router.huggingface.co/together/v1/chat/completions"
140
- headers = {
141
- "Authorization": f"Bearer {hf_token}",
142
- }
143
-
144
- response = requests.post(
145
- API_URL,
146
- headers=headers,
147
- json={
148
- "messages": [
149
- {
150
- "role": "user",
151
- "content": [
152
- {
153
- "type": "text",
154
- "text": prompt
155
- },
156
- ]
157
- }
158
- ],
159
- "model": model_id,
160
- 'temperature': temperature,
161
- 'max_new_tokens': max_new_tokens,
162
- }
163
- )
164
  try:
165
- response = response.json()["choices"][0]["message"]
166
  response.update({
 
167
  'model_id': model_id,
168
  'temperature': temperature,
169
  'max_new_tokens': max_new_tokens,
170
  })
 
171
  return response
172
  except:
173
- return response.text, response.status_code
174
 
175
  if __name__=='__main__':
176
- debug_mode = os.getenv('FLASK_DEBUG', 'True').lower() in ['true', '1', 't']
177
- app.run(host='0.0.0.0', port='7860', debug=debug_mode)
 
26
  __version__ = "0.0.1"
27
 
28
 
29
+ from flask import Flask, request, jsonify
30
  from flask_cors import CORS, cross_origin
31
  from flask_restful import Resource, Api, reqparse
32
  import control.recommendation_handler as recommendation_handler
33
+ from helpers import get_credentials, authenticate_api, save_model, inference
34
  import config as cfg
35
+ import requests
36
  import logging
37
  import uuid
38
  import json
39
  import os
40
+ import pickle
41
 
42
+ app = Flask(__name__)
43
 
44
  # configure logging
45
  logging.basicConfig(
 
67
  @app.route("/recommend", methods=['GET'])
68
  @cross_origin()
69
  def recommend():
70
+ model_id, _ =save_model.save_model()
 
 
71
  prompt_json = recommendation_handler.populate_json()
72
  args = request.args
73
+ print("args list = ", args)
74
  prompt = args.get("prompt")
75
+
76
+ umap_model_file = './models/umap/sentence-transformers/all-MiniLM-L6-v2/umap.pkl'
77
+ with open(umap_model_file, 'rb') as f:
78
+ umap_model = pickle.load(f)
79
+
80
+ # Embeddings from HF API
81
+ # hf_token, hf_url = get_credentials.get_hf_credentials()
82
+ # api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
83
+ # api_url = f'https://router.huggingface.co/hf-inference/models/{model_id}/pipeline/feature-extraction'
84
+ # embedding_fn = recommendation_handler.get_embedding_func(inference='huggingface', model_id=model_id, api_url= api_url, headers = headers)
85
+
86
+ # Embeddings from local inference
87
+ embedding_fn = recommendation_handler.get_embedding_func(inference='local', model_id=model_id)
88
+
89
+ recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json, embedding_fn, umap_model=umap_model)
90
+
91
+ user_ip = request.remote_addr
92
  logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
93
  logger.info(f'RECOMMEND ROUTE - request: {prompt} response: {recommendation_json}')
94
+
95
  return recommendation_json
96
 
97
  @app.route("/get_thresholds", methods=['GET'])
98
  @cross_origin()
99
  def get_thresholds():
100
+ hf_token, hf_url = get_credentials.get_hf_credentials()
101
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
102
  prompt_json = recommendation_handler.populate_json()
 
103
  args = request.args
 
104
  prompt = args.get("prompt")
105
+ thresholds_json = recommendation_handler.get_thresholds(prompt, prompt_json, api_url, headers)
 
106
  return thresholds_json
107
 
108
  @app.route("/recommend_local", methods=['GET'])
109
  @cross_origin()
110
  def recommend_local():
111
+ model_id, _ = save_model.save_model()
112
+ prompt_json, _ = recommendation_handler.populate_json()
113
  args = request.args
114
  print("args list = ", args)
115
  prompt = args.get("prompt")
116
+
117
+ umap_model_file = './models/umap/sentence-transformers/all-MiniLM-L6-v2/umap.pkl'
118
+ with open(umap_model_file, 'rb') as f:
119
+ umap_model = pickle.load(f)
120
+
121
+ embedding_fn = recommendation_handler.get_embedding_func(inference='local', model_id=model_id)
122
+
123
+ local_recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json, embedding_fn, umap_model=umap_model)
124
  return local_recommendation_json
125
 
126
  @app.route("/log", methods=['POST'])
 
146
  @cross_origin()
147
  def demo_inference():
148
  args = request.args
 
 
 
 
149
 
150
+ inference_provider = args.get('inference_provider', default='replicate')
151
+ model_id = args.get('model_id', default="ibm-granite/granite-3.3-8b-instruct")
152
+ temperature = args.get('temperature', default=0.5)
153
+ max_new_tokens = args.get('max_new_tokens', default=1000)
154
 
155
  prompt = args.get('prompt')
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  try:
158
+ response = inference.INFERENCE_HANDLER[inference_provider](prompt, model_id, temperature, max_new_tokens)
159
  response.update({
160
+ 'inference_provider': inference_provider,
161
  'model_id': model_id,
162
  'temperature': temperature,
163
  'max_new_tokens': max_new_tokens,
164
  })
165
+
166
  return response
167
  except:
168
+ return "Model Inference failed.", 500
169
 
170
  if __name__=='__main__':
171
+ debug_mode = os.getenv('FLASK_DEBUG', 'False').lower() in ['true', '1', 't']
172
+ app.run(host='0.0.0.0', port='8080', debug=debug_mode)