|
from fastapi import APIRouter, Depends, HTTPException |
|
from litellm.router import Router |
|
from dependencies import get_llm_router |
|
from schemas import ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
@router.post("/get_reqs_from_query", response_model=ReqSearchResponse) |
|
def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)): |
|
"""Finds the requirements that adress a given problem description from an extracted list""" |
|
|
|
requirements = req.requirements |
|
query = req.query |
|
|
|
requirements_text = "\n".join( |
|
[f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements]) |
|
print("Called the LLM") |
|
resp_ai = llm_router.completion( |
|
model="gemini-v2", |
|
messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}], |
|
response_format=ReqSearchLLMResponse |
|
) |
|
print("Answered") |
|
print(resp_ai.choices[0].message.content) |
|
|
|
out_llm = ReqSearchLLMResponse.model_validate_json( |
|
resp_ai.choices[0].message.content).selected |
|
|
|
if max(out_llm) > len(requirements) - 1: |
|
raise HTTPException( |
|
status_code=500, detail="LLM error : Generated a wrong index, please try again.") |
|
|
|
return ReqSearchResponse(requirements=[requirements[i] for i in out_llm]) |
|
|