om4r932 commited on
Commit
1538533
·
1 Parent(s): 546fbbe

Security added for query reqs via LLM

Browse files
Files changed (2) hide show
  1. app.py +13 -6
  2. schemas.py +3 -0
app.py CHANGED
@@ -294,6 +294,8 @@ def download_tdocs(req: DownloadRequest):
294
  async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
295
  documents = req.documents
296
  n_docs = len(documents)
 
 
297
 
298
  async def process_document(doc):
299
  doc_id = doc.document
@@ -309,7 +311,7 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
309
  async with limiter_mapping[model_used]:
310
  resp_ai = await llm_router.acompletion(
311
  model=model_used,
312
- messages=[{"role":"user","content": f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"}],
313
  response_format=RequirementsResponse
314
  )
315
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
@@ -320,7 +322,7 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
320
  async with limiter_mapping[model_used]:
321
  resp_ai = await llm_router.acompletion(
322
  model=model_used,
323
- messages=[{"role":"user","content": f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"}],
324
  response_format=RequirementsResponse
325
  )
326
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
@@ -357,14 +359,19 @@ def find_requirements_from_problem_description(req: ReqSearchRequest):
357
  requirements = req.requirements
358
  query = req.query
359
 
360
- requirements_text = "\n".join([f"[Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
361
 
362
  print("Called the LLM")
363
  resp_ai = llm_router.completion(
364
  model="gemini-v2",
365
- messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of objects each with document ID, context, and requirement for the most relevant requirements that reference or best cover the problem."}],
366
- response_format=ReqSearchResponse
367
  )
368
  print("Answered")
 
369
 
370
- return ReqSearchResponse.model_validate_json(resp_ai.choices[0].message.content)
 
 
 
 
 
294
  async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
295
  documents = req.documents
296
  n_docs = len(documents)
297
+ def prompt(doc_id, full):
298
+ return f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"
299
 
300
  async def process_document(doc):
301
  doc_id = doc.document
 
311
  async with limiter_mapping[model_used]:
312
  resp_ai = await llm_router.acompletion(
313
  model=model_used,
314
+ messages=[{"role":"user","content": prompt(doc_id, full)}],
315
  response_format=RequirementsResponse
316
  )
317
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
 
322
  async with limiter_mapping[model_used]:
323
  resp_ai = await llm_router.acompletion(
324
  model=model_used,
325
+ messages=[{"role":"user","content": prompt(doc_id, full)}],
326
  response_format=RequirementsResponse
327
  )
328
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
 
359
  requirements = req.requirements
360
  query = req.query
361
 
362
+ requirements_text = "\n".join([f"[Selection ID: {x} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for x, r in enumerate(requirements)])
363
 
364
  print("Called the LLM")
365
  resp_ai = llm_router.completion(
366
  model="gemini-v2",
367
+ messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
368
+ response_format=ReqSearchLLMResponse
369
  )
370
  print("Answered")
371
+ print(resp_ai.choices[0].message.content)
372
 
373
+ out_llm = ReqSearchLLMResponse.model_validate_json(resp_ai.choices[0].message.content).selected
374
+ if max(out_llm) > len(out_llm) - 1:
375
+ raise HTTPException(status_code=500, detail="LLM error : Generated a wrong index, please try again.")
376
+
377
+ return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])
schemas.py CHANGED
@@ -37,6 +37,9 @@ class SingleRequirement(BaseModel):
37
  context: str
38
  requirement: str
39
 
 
 
 
40
  class ReqSearchRequest(BaseModel):
41
  query: str
42
  requirements: List[SingleRequirement]
 
37
  context: str
38
  requirement: str
39
 
40
+ class ReqSearchLLMResponse(BaseModel):
41
+ selected: List[int]
42
+
43
  class ReqSearchRequest(BaseModel):
44
  query: str
45
  requirements: List[SingleRequirement]