File size: 10,721 Bytes
f6bffda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import asyncio
import json
import logging
from fastapi import APIRouter, Depends, HTTPException
from httpx import AsyncClient
from jinja2 import Environment
from litellm.router import Router
from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates
from typing import Awaitable, Callable, TypeVar
from schemas import _RefinedSolutionModel, _ReqGroupingCategory, _ReqGroupingOutput, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionSearchResponse, SolutionSearchV2Request, TechnologyData
# Router for solution generation and critique
router = APIRouter(tags=["solution generation and critique"])
# ============== utilities =======================
T = TypeVar("T")
A = TypeVar("A")
async def retry_until(
func: Callable[[A], Awaitable[T]],
arg: A,
predicate: Callable[[T], bool],
max_retries: int,
) -> T:
"""Retries the given async function until the passed in validation predicate returns true."""
last_value = await func(arg)
for _ in range(max_retries):
if predicate(last_value):
return last_value
last_value = await func(arg)
return last_value
# =================================================
@router.post("/search_solutions_gemini/v2", response_model=SolutionSearchResponse)
async def search_solutions(params: SolutionSearchV2Request, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionSearchResponse:
"""Searches solutions solving the given grouping params and respecting the user constraints using Gemini and grounded on google search"""
logging.info(f"Searching solutions for categories: {params}")
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
# ================== generate the solution with web grounding
req_prompt = await prompt_env.get_template("search_solution_v2.txt").render_async(**{
"category": cat.model_dump(),
"user_constraints": params.user_constraints,
})
# generate the completion in non-structured mode.
# the googleSearch tool enables grounding gemini with google search
# this also forces gemini to perform a tool call
req_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": req_prompt}
], tools=[{"googleSearch": {}}], tool_choice="required")
# ==================== structure the solution as a json ===================================
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
"solution": req_completion.choices[0].message.content,
"response_schema": _SearchedSolutionModel.model_json_schema()
})
structured_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": structured_prompt}
], response_format=_SearchedSolutionModel)
solution_model = _SearchedSolutionModel.model_validate_json(
structured_completion.choices[0].message.content)
# ======================== build the final solution object ================================
sources_metadata = []
# extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated
try:
sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]}
for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']])
except KeyError as ke:
pass
final_sol = SolutionModel(
Context="",
Requirements=[
cat.requirements[i].requirement for i in solution_model.requirement_ids
],
Problem_Description=solution_model.problem_description,
Solution_Description=solution_model.solution_description,
References=sources_metadata,
Category_Id=cat.id,
)
return final_sol
solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
logging.info(solutions)
final_solutions = [
sol for sol in solutions if not isinstance(sol, Exception)]
return SolutionSearchResponse(solutions=final_solutions)
@router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse:
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
async def __criticize_single(solution: SolutionModel):
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
"solutions": [solution.model_dump()],
"response_schema": _SolutionCriticismOutput.model_json_schema()
})
req_completion = await llm_router.acompletion(
model="gemini-v2",
messages=[{"role": "user", "content": req_prompt}],
response_format=_SolutionCriticismOutput
)
criticism_out = _SolutionCriticismOutput.model_validate_json(
req_completion.choices[0].message.content
)
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
return CritiqueResponse(critiques=critiques)
# =================================================================== Refine solution ====================================
@router.post("/refine_solutions", response_model=SolutionSearchResponse)
async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionSearchResponse:
"""Refines the previously critiqued solutions."""
async def __refine_solution(crit: SolutionCriticism):
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
"solution": crit.solution.model_dump(),
"criticism": crit.criticism,
"response_schema": _RefinedSolutionModel.model_json_schema(),
})
req_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": req_prompt}
], response_format=_RefinedSolutionModel)
req_model = _RefinedSolutionModel.model_validate_json(
req_completion.choices[0].message.content)
# copy previous solution model
refined_solution = crit.solution.model_copy(deep=True)
refined_solution.Problem_Description = req_model.problem_description
refined_solution.Solution_Description = req_model.solution_description
return refined_solution
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)
return SolutionSearchResponse(solutions=refined_solutions)
# =============================================================== Search solutions =========================================
@router.post("/search_solutions_if")
async def search_solutions_if(req: SolutionSearchV2Request, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionSearchResponse:
async def _search_solution_inner(cat: ReqGroupingCategory):
# process requirements into insight finder format
fmt_completion = await llm_router.acompletion("gemini-v2", messages=[
{
"role": "user",
"content": await prompt_env.get_template("if/format_requirements.txt").render_async(**{
"category": cat.model_dump(),
"response_schema": InsightFinderConstraintsList.model_json_schema()
})
}], response_format=InsightFinderConstraintsList)
fmt_model = InsightFinderConstraintsList.model_validate_json(
fmt_completion.choices[0].message.content)
# translate from a structured output to a dict for insights finder
formatted_constraints = {'constraints': {
cons.title: cons.description for cons in fmt_model.constraints}}
# fetch technologies from insight finder
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints))
technologies = TechnologyData.model_validate(technologies_req.json())
# =============================================================== synthesize solution using LLM =========================================
format_solution = await llm_router.acompletion("gemini-v2", messages=[{
"role": "user",
"content": await prompt_env.get_template("if/synthesize_solution.txt").render_async(**{
"category": cat.model_dump(),
"technologies": technologies.model_dump()["technologies"],
"user_constraints": req.user_constraints,
"response_schema": _SearchedSolutionModel.model_json_schema()
})}
], response_format=_SearchedSolutionModel)
format_solution_model = _SearchedSolutionModel.model_validate_json(
format_solution.choices[0].message.content)
final_solution = SolutionModel(
Context="",
Requirements=[
cat.requirements[i].requirement for i in format_solution_model.requirement_ids
],
Problem_Description=format_solution_model.problem_description,
Solution_Description=format_solution_model.solution_description,
References=[],
Category_Id=cat.id,
)
# ========================================================================================================================================
return final_solution
tasks = await asyncio.gather(*[_search_solution_inner(cat) for cat in req.categories], return_exceptions=True)
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)]
return SolutionSearchResponse(solutions=final_solutions)
|