File size: 10,721 Bytes
f6bffda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import asyncio
import json
import logging
from fastapi import APIRouter, Depends, HTTPException
from httpx import AsyncClient
from jinja2 import Environment
from litellm.router import Router
from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates
from typing import Awaitable, Callable, TypeVar
from schemas import _RefinedSolutionModel, _ReqGroupingCategory, _ReqGroupingOutput, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionSearchResponse, SolutionSearchV2Request, TechnologyData

# Router for solution generation and critique
router = APIRouter(tags=["solution generation and critique"])


# ============== utilities =======================
T = TypeVar("T")
A = TypeVar("A")


async def retry_until(
    func: Callable[[A], Awaitable[T]],
    arg: A,
    predicate: Callable[[T], bool],
    max_retries: int,
) -> T:
    """Retries the given async function until the passed in validation predicate returns true."""
    last_value = await func(arg)
    for _ in range(max_retries):
        if predicate(last_value):
            return last_value
        last_value = await func(arg)
    return last_value

# =================================================


@router.post("/search_solutions_gemini/v2", response_model=SolutionSearchResponse)
async def search_solutions(params: SolutionSearchV2Request, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionSearchResponse:
    """Searches solutions solving the given grouping params and respecting the user constraints using Gemini and grounded on google search"""

    logging.info(f"Searching solutions for categories: {params}")

    async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
        # ================== generate the solution with web grounding
        req_prompt = await prompt_env.get_template("search_solution_v2.txt").render_async(**{
            "category": cat.model_dump(),
            "user_constraints": params.user_constraints,
        })

        # generate the completion in non-structured mode.
        # the googleSearch tool enables grounding gemini with google search
        # this also forces gemini to perform a tool call
        req_completion = await llm_router.acompletion(model="gemini-v2", messages=[
            {"role": "user", "content": req_prompt}
        ], tools=[{"googleSearch": {}}], tool_choice="required")

        # ==================== structure the solution as a json ===================================

        structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
            "solution": req_completion.choices[0].message.content,
            "response_schema": _SearchedSolutionModel.model_json_schema()
        })

        structured_completion = await llm_router.acompletion(model="gemini-v2", messages=[
            {"role": "user", "content": structured_prompt}
        ], response_format=_SearchedSolutionModel)
        solution_model = _SearchedSolutionModel.model_validate_json(
            structured_completion.choices[0].message.content)

        # ======================== build the final solution object ================================

        sources_metadata = []
        # extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated
        try:
            sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]}
                                     for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']])
        except KeyError as ke:
            pass

        final_sol = SolutionModel(
            Context="",
            Requirements=[
                cat.requirements[i].requirement for i in solution_model.requirement_ids
            ],
            Problem_Description=solution_model.problem_description,
            Solution_Description=solution_model.solution_description,
            References=sources_metadata,
            Category_Id=cat.id,
        )
        return final_sol

    solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
    logging.info(solutions)
    final_solutions = [
        sol for sol in solutions if not isinstance(sol, Exception)]

    return SolutionSearchResponse(solutions=final_solutions)


@router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse:
    """Criticize the challenges, weaknesses and limitations of the provided solutions."""

    async def __criticize_single(solution: SolutionModel):
        req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
            "solutions": [solution.model_dump()],
            "response_schema": _SolutionCriticismOutput.model_json_schema()
        })

        req_completion = await llm_router.acompletion(
            model="gemini-v2",
            messages=[{"role": "user", "content": req_prompt}],
            response_format=_SolutionCriticismOutput
        )

        criticism_out = _SolutionCriticismOutput.model_validate_json(
            req_completion.choices[0].message.content
        )

        return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])

    critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
    return CritiqueResponse(critiques=critiques)


# =================================================================== Refine solution ====================================

@router.post("/refine_solutions", response_model=SolutionSearchResponse)
async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionSearchResponse:
    """Refines the previously critiqued solutions."""

    async def __refine_solution(crit: SolutionCriticism):
        req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
            "solution": crit.solution.model_dump(),
            "criticism": crit.criticism,
            "response_schema": _RefinedSolutionModel.model_json_schema(),
        })

        req_completion = await llm_router.acompletion(model="gemini-v2", messages=[
            {"role": "user", "content": req_prompt}
        ], response_format=_RefinedSolutionModel)

        req_model = _RefinedSolutionModel.model_validate_json(
            req_completion.choices[0].message.content)

        # copy previous solution model
        refined_solution = crit.solution.model_copy(deep=True)
        refined_solution.Problem_Description = req_model.problem_description
        refined_solution.Solution_Description = req_model.solution_description

        return refined_solution

    refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)

    return SolutionSearchResponse(solutions=refined_solutions)

# =============================================================== Search solutions =========================================


@router.post("/search_solutions_if")
async def search_solutions_if(req: SolutionSearchV2Request, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionSearchResponse:

    async def _search_solution_inner(cat: ReqGroupingCategory):
        # process requirements into insight finder format
        fmt_completion = await llm_router.acompletion("gemini-v2", messages=[
            {
                "role": "user",
                "content": await prompt_env.get_template("if/format_requirements.txt").render_async(**{
                    "category": cat.model_dump(),
                    "response_schema": InsightFinderConstraintsList.model_json_schema()
                })
            }], response_format=InsightFinderConstraintsList)

        fmt_model = InsightFinderConstraintsList.model_validate_json(
            fmt_completion.choices[0].message.content)

        # translate from a structured output to a dict for insights finder
        formatted_constraints = {'constraints': {
            cons.title: cons.description for cons in fmt_model.constraints}}

        # fetch technologies from insight finder
        technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints))
        technologies = TechnologyData.model_validate(technologies_req.json())

        # =============================================================== synthesize solution using LLM =========================================

        format_solution = await llm_router.acompletion("gemini-v2", messages=[{
            "role": "user",
            "content": await prompt_env.get_template("if/synthesize_solution.txt").render_async(**{
                "category": cat.model_dump(),
                "technologies": technologies.model_dump()["technologies"],
                "user_constraints": req.user_constraints,
                "response_schema": _SearchedSolutionModel.model_json_schema()
            })}
        ], response_format=_SearchedSolutionModel)

        format_solution_model = _SearchedSolutionModel.model_validate_json(
            format_solution.choices[0].message.content)

        final_solution = SolutionModel(
            Context="",
            Requirements=[
                cat.requirements[i].requirement for i in format_solution_model.requirement_ids
            ],
            Problem_Description=format_solution_model.problem_description,
            Solution_Description=format_solution_model.solution_description,
            References=[],
            Category_Id=cat.id,
        )

        # ========================================================================================================================================

        return final_solution

    tasks = await asyncio.gather(*[_search_solution_inner(cat) for cat in req.categories], return_exceptions=True)
    final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)]

    return SolutionSearchResponse(solutions=final_solutions)