ashwath-vaithina-ibm commited on
Commit
aefc33c
·
verified ·
1 Parent(s): d8689f7

Upload inference.py

Browse files
Files changed (1) hide show
  1. helpers/inference.py +66 -0
helpers/inference.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from helpers import get_credentials
2
+ import requests
3
+
4
+ def hf_inference(prompt, model_id, temperature, max_new_tokens):
5
+
6
+ hf_token, _ = get_credentials.get_hf_credentials()
7
+
8
+ API_URL = "https://router.huggingface.co/together/v1/chat/completions"
9
+ headers = {
10
+ "Authorization": f"Bearer {hf_token}",
11
+ }
12
+
13
+ response = requests.post(
14
+ API_URL,
15
+ headers=headers,
16
+ json={
17
+ "messages": [
18
+ {
19
+ "role": "user",
20
+ "content": [
21
+ {
22
+ "type": "text",
23
+ "text": prompt
24
+ },
25
+ ]
26
+ }
27
+ ],
28
+ "model": model_id,
29
+ 'temperature': temperature,
30
+ 'max_new_tokens': max_new_tokens,
31
+ }
32
+ )
33
+
34
+ return response.json()["choices"][0]["message"]
35
+
36
+ def replicate_inference(prompt, model_id, temperature, max_new_tokens):
37
+
38
+ repl_token = get_credentials.get_replicate_credentials()
39
+
40
+ API_URL = f"https://api.replicate.com/v1/models/{model_id}/predictions"
41
+ headers = {
42
+ "Authorization": f"Bearer {repl_token}",
43
+ "Content-Type": "application/json",
44
+ "Prefer": "wait"
45
+ }
46
+
47
+ response = requests.post(
48
+ API_URL,
49
+ headers=headers,
50
+ json={
51
+ "input": {
52
+ "prompt": prompt,
53
+ "temperature": temperature,
54
+ "max_tokens": max_new_tokens,
55
+ }
56
+ }
57
+ )
58
+
59
+ return {
60
+ "content": "".join(response.json()['output'])
61
+ }
62
+
63
+ INFERENCE_HANDLER = {
64
+ 'huggingface': hf_inference,
65
+ 'replicate': replicate_inference
66
+ }