File size: 3,412 Bytes
c6cd0dd
 
 
 
 
 
cf966ea
c6cd0dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf966ea
c6cd0dd
8ddaa0b
c6cd0dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf966ea
 
c6cd0dd
cf966ea
c6cd0dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.tools.retriever import create_retriever_tool
from langchain_qdrant  import QdrantVectorStore
from qdrant_client import QdrantClient
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from tools import multiply,add,subtract,divide,modulus,wiki_search,duckduckgo_search,arvix_search


load_dotenv()

with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()

# System message
sys_msg = SystemMessage(content=system_prompt)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/static-similarity-mrl-multilingual-v1", model_kwargs={'device': 'cpu'})
qdrant = QdrantClient(
    url=os.environ.get("QDRANT_URL"),
    api_key=os.environ.get("QDRANT_SERVICE_KEY")
)

vector_store = QdrantVectorStore(
    client=qdrant,
    embeddings=embeddings,
    collection_name="documents",
)
create_retriever_tool = create_retriever_tool(
    retriever=vector_store.as_retriever(),
    name="Question Search",
    description="A tool to retrieve similar questions from a vector store.",
)
tools = [
    multiply,
    add,
    subtract,
    divide,
    modulus,
    wiki_search,
    duckduckgo_search,
    arvix_search,
]

def build_graph(provider: str = "groq"):
    """Build the graph"""
    # Load environment variables from .env file
    model=""
    if provider == "google":
        # Google Gemini
        model = os.environ.get("GEMINI_MODEL")
        llm = ChatGoogleGenerativeAI(model=model, temperature=0)
    elif provider == "groq":
        # Groq https://console.groq.com/docs/models
        model = os.environ.get("GROQ_MODEL")
        llm = ChatGroq(model=model, temperature=0)
    elif provider == "huggingface":
        model = os.environ.get("HUGGINGFACEHUB_URL")
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url=model,
                temperature=0,
            ),
        )
    else:
        raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    def assistant(state: MessagesState):
        """Assistant node"""
        return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
    def retriever(state: MessagesState):
        """Retriever node"""
        print(state["messages"][0])
        similar_question = vector_store.similarity_search(state["messages"][0].page_content)
        example_msg = HumanMessage(
            content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].content}",
        )
        return {"messages": [sys_msg] + state["messages"] + [example_msg]}

    builder = StateGraph(MessagesState)
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")

    return builder.compile()