Security added for query reqs via LLM
Browse files- app.py +13 -6
- schemas.py +3 -0
app.py
CHANGED
@@ -294,6 +294,8 @@ def download_tdocs(req: DownloadRequest):
|
|
294 |
async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
295 |
documents = req.documents
|
296 |
n_docs = len(documents)
|
|
|
|
|
297 |
|
298 |
async def process_document(doc):
|
299 |
doc_id = doc.document
|
@@ -309,7 +311,7 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
|
309 |
async with limiter_mapping[model_used]:
|
310 |
resp_ai = await llm_router.acompletion(
|
311 |
model=model_used,
|
312 |
-
messages=[{"role":"user","content":
|
313 |
response_format=RequirementsResponse
|
314 |
)
|
315 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
@@ -320,7 +322,7 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
|
320 |
async with limiter_mapping[model_used]:
|
321 |
resp_ai = await llm_router.acompletion(
|
322 |
model=model_used,
|
323 |
-
messages=[{"role":"user","content":
|
324 |
response_format=RequirementsResponse
|
325 |
)
|
326 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
@@ -357,14 +359,19 @@ def find_requirements_from_problem_description(req: ReqSearchRequest):
|
|
357 |
requirements = req.requirements
|
358 |
query = req.query
|
359 |
|
360 |
-
requirements_text = "\n".join([f"[Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
|
361 |
|
362 |
print("Called the LLM")
|
363 |
resp_ai = llm_router.completion(
|
364 |
model="gemini-v2",
|
365 |
-
messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of
|
366 |
-
response_format=
|
367 |
)
|
368 |
print("Answered")
|
|
|
369 |
|
370 |
-
|
|
|
|
|
|
|
|
|
|
294 |
async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
295 |
documents = req.documents
|
296 |
n_docs = len(documents)
|
297 |
+
def prompt(doc_id, full):
|
298 |
+
return f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"
|
299 |
|
300 |
async def process_document(doc):
|
301 |
doc_id = doc.document
|
|
|
311 |
async with limiter_mapping[model_used]:
|
312 |
resp_ai = await llm_router.acompletion(
|
313 |
model=model_used,
|
314 |
+
messages=[{"role":"user","content": prompt(doc_id, full)}],
|
315 |
response_format=RequirementsResponse
|
316 |
)
|
317 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
|
|
322 |
async with limiter_mapping[model_used]:
|
323 |
resp_ai = await llm_router.acompletion(
|
324 |
model=model_used,
|
325 |
+
messages=[{"role":"user","content": prompt(doc_id, full)}],
|
326 |
response_format=RequirementsResponse
|
327 |
)
|
328 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
|
|
359 |
requirements = req.requirements
|
360 |
query = req.query
|
361 |
|
362 |
+
requirements_text = "\n".join([f"[Selection ID: {x} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for x, r in enumerate(requirements)])
|
363 |
|
364 |
print("Called the LLM")
|
365 |
resp_ai = llm_router.completion(
|
366 |
model="gemini-v2",
|
367 |
+
messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
|
368 |
+
response_format=ReqSearchLLMResponse
|
369 |
)
|
370 |
print("Answered")
|
371 |
+
print(resp_ai.choices[0].message.content)
|
372 |
|
373 |
+
out_llm = ReqSearchLLMResponse.model_validate_json(resp_ai.choices[0].message.content).selected
|
374 |
+
if max(out_llm) > len(out_llm) - 1:
|
375 |
+
raise HTTPException(status_code=500, detail="LLM error : Generated a wrong index, please try again.")
|
376 |
+
|
377 |
+
return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])
|
schemas.py
CHANGED
@@ -37,6 +37,9 @@ class SingleRequirement(BaseModel):
|
|
37 |
context: str
|
38 |
requirement: str
|
39 |
|
|
|
|
|
|
|
40 |
class ReqSearchRequest(BaseModel):
|
41 |
query: str
|
42 |
requirements: List[SingleRequirement]
|
|
|
37 |
context: str
|
38 |
requirement: str
|
39 |
|
40 |
+
class ReqSearchLLMResponse(BaseModel):
|
41 |
+
selected: List[int]
|
42 |
+
|
43 |
class ReqSearchRequest(BaseModel):
|
44 |
query: str
|
45 |
requirements: List[SingleRequirement]
|