Sambhavnoobcoder commited on
Commit
5bdaeed
·
1 Parent(s): 3e95584

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +177 -0
api.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ import urllib.request
5
+ from pathlib import Path
6
+ from tempfile import NamedTemporaryFile
7
+
8
+ import fitz
9
+ import numpy as np
10
+ import openai
11
+ import tensorflow_hub as hub
12
+ from fastapi import UploadFile
13
+ from lcserve import serving
14
+ from sklearn.neighbors import NearestNeighbors
15
+
16
+
17
+ recommender = None
18
+
19
+
20
+ def download_pdf(url, output_path):
21
+ urllib.request.urlretrieve(url, output_path)
22
+
23
+
24
+ def preprocess(text):
25
+ text = text.replace('\n', ' ')
26
+ text = re.sub('\s+', ' ', text)
27
+ return text
28
+
29
+
30
+ def pdf_to_text(path, start_page=1, end_page=None):
31
+ doc = fitz.open(path)
32
+ total_pages = doc.page_count
33
+
34
+ if end_page is None:
35
+ end_page = total_pages
36
+
37
+ text_list = []
38
+
39
+ for i in range(start_page - 1, end_page):
40
+ text = doc.load_page(i).get_text("text")
41
+ text = preprocess(text)
42
+ text_list.append(text)
43
+
44
+ doc.close()
45
+ return text_list
46
+
47
+
48
+ def text_to_chunks(texts, word_length=150, start_page=1):
49
+ text_toks = [t.split(' ') for t in texts]
50
+ chunks = []
51
+
52
+ for idx, words in enumerate(text_toks):
53
+ for i in range(0, len(words), word_length):
54
+ chunk = words[i : i + word_length]
55
+ if (
56
+ (i + word_length) > len(words)
57
+ and (len(chunk) < word_length)
58
+ and (len(text_toks) != (idx + 1))
59
+ ):
60
+ text_toks[idx + 1] = chunk + text_toks[idx + 1]
61
+ continue
62
+ chunk = ' '.join(chunk).strip()
63
+ chunk = f'[Page no. {idx+start_page}]' + ' ' + '"' + chunk + '"'
64
+ chunks.append(chunk)
65
+ return chunks
66
+
67
+
68
+ class SemanticSearch:
69
+ def __init__(self):
70
+ self.use = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
71
+ self.fitted = False
72
+
73
+ def fit(self, data, batch=1000, n_neighbors=5):
74
+ self.data = data
75
+ self.embeddings = self.get_text_embedding(data, batch=batch)
76
+ n_neighbors = min(n_neighbors, len(self.embeddings))
77
+ self.nn = NearestNeighbors(n_neighbors=n_neighbors)
78
+ self.nn.fit(self.embeddings)
79
+ self.fitted = True
80
+
81
+ def __call__(self, text, return_data=True):
82
+ inp_emb = self.use([text])
83
+ neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
84
+
85
+ if return_data:
86
+ return [self.data[i] for i in neighbors]
87
+ else:
88
+ return neighbors
89
+
90
+ def get_text_embedding(self, texts, batch=1000):
91
+ embeddings = []
92
+ for i in range(0, len(texts), batch):
93
+ text_batch = texts[i : (i + batch)]
94
+ emb_batch = self.use(text_batch)
95
+ embeddings.append(emb_batch)
96
+ embeddings = np.vstack(embeddings)
97
+ return embeddings
98
+
99
+
100
+ def load_recommender(path, start_page=1):
101
+ global recommender
102
+ if recommender is None:
103
+ recommender = SemanticSearch()
104
+
105
+ texts = pdf_to_text(path, start_page=start_page)
106
+ chunks = text_to_chunks(texts, start_page=start_page)
107
+ recommender.fit(chunks)
108
+ return 'Corpus Loaded.'
109
+
110
+
111
+ def generate_text(openAI_key, prompt, engine="text-davinci-003"):
112
+ openai.api_key = openAI_key
113
+ try:
114
+ completions = openai.Completion.create(
115
+ engine=engine,
116
+ prompt=prompt,
117
+ max_tokens=512,
118
+ n=1,
119
+ stop=None,
120
+ temperature=0.7,
121
+ )
122
+ message = completions.choices[0].text
123
+ except Exception as e:
124
+ message = f'API Error: {str(e)}'
125
+ return message
126
+
127
+
128
+ def generate_answer(question, openAI_key):
129
+ topn_chunks = recommender(question)
130
+ prompt = ""
131
+ prompt += 'search results:\n\n'
132
+ for c in topn_chunks:
133
+ prompt += c + '\n\n'
134
+
135
+ prompt += (
136
+ "Instructions: Compose a comprehensive reply to the query using the search results given. "
137
+ "Cite each reference using [ Page Number] notation (every result has this number at the beginning). "
138
+ "Citation should be done at the end of each sentence. If the search results mention multiple subjects "
139
+ "with the same name, create separate answers for each. Only include information found in the results and "
140
+ "don't add any additional information. Make sure the answer is correct and don't output false content. "
141
+ "If the text does not relate to the query, simply state 'Text Not Found in PDF'. Ignore outlier "
142
+ "search results which has nothing to do with the question. Only answer what is asked. The "
143
+ "answer should be short and concise. Answer step-by-step. \n\nQuery: {question}\nAnswer: "
144
+ )
145
+
146
+ prompt += f"Query: {question}\nAnswer:"
147
+ answer = generate_text(openAI_key, prompt, "text-davinci-003")
148
+ return answer
149
+
150
+
151
+ def load_openai_key() -> str:
152
+ key = os.environ.get("OPENAI_API_KEY")
153
+ if key is None:
154
+ raise ValueError(
155
+ "[ERROR]: Please pass your OPENAI_API_KEY. Get your key here : https://platform.openai.com/account/api-keys"
156
+ )
157
+ return key
158
+
159
+
160
+ @serving
161
+ def ask_url(url: str, question: str):
162
+ download_pdf(url, 'corpus.pdf')
163
+ load_recommender('corpus.pdf')
164
+ openAI_key = load_openai_key()
165
+ return generate_answer(question, openAI_key)
166
+
167
+
168
+ @serving
169
+ async def ask_file(file: UploadFile, question: str) -> str:
170
+ suffix = Path(file.filename).suffix
171
+ with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
172
+ shutil.copyfileobj(file.file, tmp)
173
+ tmp_path = Path(tmp.name)
174
+
175
+ load_recommender(str(tmp_path))
176
+ openAI_key = load_openai_key()
177
+ return generate_answer(question, openAI_key)