File size: 2,684 Bytes
96f6720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
import os

from dotenv import load_dotenv
from duckduckgo_search import DDGS
from langchain_core.messages.tool import BaseMessage, ToolMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from langgraph.graph import END, MessageGraph
from langgraph.prebuilt import ToolNode
from typing import TypedDict

from llm import get_text_llm
from log_util import logger
from time_it import time_it
from util import load_prompt

load_dotenv()

MAX_IMAGE_SEARCH_RESULTS = int(os.getenv('MAX_IMAGE_SEARCH_RESULTS', '3'))

class ImageSearchResult(TypedDict):
    title: str
    url: str

@time_it
def search_meal_image(meal: str) -> str:
    prompt = load_prompt('validate_is_meal.prompt.txt')

    llm = get_text_llm()
    tools = [search_meal_images]

    def is_meal_router(messages: list[BaseMessage]) -> str:
        if messages[-1].content.lower() == 'yes':
            return 'is_meal'
        return END

    graph = MessageGraph()
    graph.add_node('validate_is_meal', llm)
    graph.add_conditional_edges('validate_is_meal', is_meal_router)
    graph.add_node('is_meal', llm.bind_tools(tools))
    graph.add_edge('is_meal', 'call_tools')
    graph.add_node('call_tools', ToolNode(tools))
    graph.add_edge('call_tools', END)
    graph.set_entry_point('validate_is_meal')

    prompt_template = PromptTemplate.from_template(prompt)
    prompt = prompt_template.format(phrase=meal)

    workflow = graph.compile()
    messages: list = workflow.invoke(prompt)
    tool_messages = [message for message in messages if isinstance(message, ToolMessage)]
    if tool_messages and tool_messages[0].content:
        meal_images: list[ImageSearchResult] = json.loads(tool_messages[0].content)
        if meal_images:
            meal_image_url = meal_images[0]['url']
            logger.info(f'{meal_image_url=}')
            return meal_image_url
    return None

@tool
def search_meal_images(meal: str) -> list[ImageSearchResult]:
    '''Searches for images of the given meal.'''
    return search_images(meal)

@time_it
def search_images(keywords: str, max_results: int | None=MAX_IMAGE_SEARCH_RESULTS) -> list[ImageSearchResult]:
    results = DDGS().images(
        keywords=keywords,
        region='wt-wt',
        safesearch='on',
        size=None,
        color='color',
        type_image='photo',
        layout=None,
        license_image=None,
        max_results=max_results,
    )
    logger.info(f'{keywords=}: {results=}')
    results = [ImageSearchResult(title=result['title'], url=result['image']) for result in results]
    return results