ashwath-vaithina-ibm commited on
Commit
2b1ac28
·
verified ·
1 Parent(s): f40280b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -35
app.py CHANGED
@@ -30,7 +30,7 @@ from flask import Flask, request, jsonify
30
  from flask_cors import CORS, cross_origin
31
  from flask_restful import Resource, Api, reqparse
32
  import control.recommendation_handler as recommendation_handler
33
- from helpers import get_credentials, authenticate_api, save_model
34
  import config as cfg
35
  import requests
36
  import logging
@@ -61,18 +61,17 @@ FRONT_LOG_FILE = 'front_log.json'
61
  def index():
62
  user_ip = request.remote_addr
63
  logger.info(f'USER {user_ip} - ID {id} - started the app')
64
- return app.send_static_file('demo/index.html')
65
 
66
  @app.route("/recommend", methods=['GET'])
67
  @cross_origin()
68
  def recommend():
69
  user_ip = request.remote_addr
70
- hf_token, hf_url = get_credentials.get_credentials()
71
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
72
  prompt_json = recommendation_handler.populate_json()
73
  args = request.args
74
  prompt = args.get("prompt")
75
- print(prompt)
76
  recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json,
77
  api_url, headers)
78
  logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
@@ -82,7 +81,7 @@ def recommend():
82
  @app.route("/get_thresholds", methods=['GET'])
83
  @cross_origin()
84
  def get_thresholds():
85
- hf_token, hf_url = get_credentials.get_credentials()
86
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
87
  prompt_json = recommendation_handler.populate_json()
88
  model_id = 'sentence-transformers/all-minilm-l6-v2'
@@ -129,49 +128,25 @@ def log():
129
  def demo_inference():
130
  args = request.args
131
 
132
- model_id = args.get('model_id', default="meta-llama/Llama-4-Scout-17B-16E-Instruct")
 
133
  temperature = args.get('temperature', default=0.5)
134
  max_new_tokens = args.get('max_new_tokens', default=1000)
135
 
136
- hf_token, _ = get_credentials.get_credentials()
137
-
138
  prompt = args.get('prompt')
139
 
140
- API_URL = "https://router.huggingface.co/together/v1/chat/completions"
141
- headers = {
142
- "Authorization": f"Bearer {hf_token}",
143
- }
144
-
145
- response = requests.post(
146
- API_URL,
147
- headers=headers,
148
- json={
149
- "messages": [
150
- {
151
- "role": "user",
152
- "content": [
153
- {
154
- "type": "text",
155
- "text": prompt
156
- },
157
- ]
158
- }
159
- ],
160
- "model": model_id,
161
- 'temperature': temperature,
162
- 'max_new_tokens': max_new_tokens,
163
- }
164
- )
165
  try:
166
- response = response.json()["choices"][0]["message"]
167
  response.update({
 
168
  'model_id': model_id,
169
  'temperature': temperature,
170
  'max_new_tokens': max_new_tokens,
171
  })
 
172
  return response
173
  except:
174
- return response.text, response.status_code
175
 
176
  if __name__=='__main__':
177
  debug_mode = os.getenv('FLASK_DEBUG', 'True').lower() in ['true', '1', 't']
 
30
  from flask_cors import CORS, cross_origin
31
  from flask_restful import Resource, Api, reqparse
32
  import control.recommendation_handler as recommendation_handler
33
+ from helpers import get_credentials, authenticate_api, save_model, inference
34
  import config as cfg
35
  import requests
36
  import logging
 
61
  def index():
62
  user_ip = request.remote_addr
63
  logger.info(f'USER {user_ip} - ID {id} - started the app')
64
+ return "Ready!"
65
 
66
  @app.route("/recommend", methods=['GET'])
67
  @cross_origin()
68
  def recommend():
69
  user_ip = request.remote_addr
70
+ hf_token, hf_url = get_credentials.get_hf_credentials()
71
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
72
  prompt_json = recommendation_handler.populate_json()
73
  args = request.args
74
  prompt = args.get("prompt")
 
75
  recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json,
76
  api_url, headers)
77
  logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
 
81
  @app.route("/get_thresholds", methods=['GET'])
82
  @cross_origin()
83
  def get_thresholds():
84
+ hf_token, hf_url = get_credentials.get_hf_credentials()
85
  api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
86
  prompt_json = recommendation_handler.populate_json()
87
  model_id = 'sentence-transformers/all-minilm-l6-v2'
 
128
  def demo_inference():
129
  args = request.args
130
 
131
+ inference_provider = args.get('inference_provider', default='replicate')
132
+ model_id = args.get('model_id', default="ibm-granite/granite-3.3-8b-instruct")
133
  temperature = args.get('temperature', default=0.5)
134
  max_new_tokens = args.get('max_new_tokens', default=1000)
135
 
 
 
136
  prompt = args.get('prompt')
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
+ response = inference.INFERENCE_HANDLER[inference_provider](prompt, model_id, temperature, max_new_tokens)
140
  response.update({
141
+ 'inference_provider': inference_provider,
142
  'model_id': model_id,
143
  'temperature': temperature,
144
  'max_new_tokens': max_new_tokens,
145
  })
146
+
147
  return response
148
  except:
149
+ return "Model Inference failed.", 500
150
 
151
  if __name__=='__main__':
152
  debug_mode = os.getenv('FLASK_DEBUG', 'True').lower() in ['true', '1', 't']