File size: 7,171 Bytes
7042c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from langchain_core.runnables import RunnableConfig
from langchain.docstore.document import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStoreRetriever

import ast
import numpy as np
import pandas as pd
from contextlib import contextmanager
from typing import Generator

from ea4all.src.shared.utils import _join_paths
from ea4all.src.shared.configuration import BaseConfiguration

global _vectorstore
_vectorstore = None

def make_text_encoder(model: str) -> Embeddings:
    """Connect to the configured text encoder."""
    provider, model = model.split("/", maxsplit=1)
    match provider:
        case "openai":
            from langchain_openai import OpenAIEmbeddings

            return OpenAIEmbeddings(model=model)
        case _:
            raise ValueError(f"Unsupported embedding provider: {provider}")

@contextmanager
def make_faiss_retriever(
    configuration: BaseConfiguration, embeddings: Embeddings
) -> Generator[VectorStoreRetriever, None, None]:
    """Configure this agent to connect to a FAISS index & namespaces."""
    from langchain_community.docstore.in_memory import InMemoryDocstore
    from langchain_community.vectorstores import FAISS
    import faiss

    global _vectorstore

    if _vectorstore is None:
        try:
            _vectorstore = FAISS.load_local(
                folder_path=configuration.ea4all_store,
                embeddings=embeddings,
                index_name=configuration.apm_faiss,
                allow_dangerous_deserialization=True)
            
        except Exception as e:
            # Create an empty index
            index = faiss.IndexFlatL2(len(embeddings.embed_query("")))

            #Initialize an empty FAISS vectorstore
            _vectorstore = FAISS(
                embedding_function=embeddings,
                index=index,
                docstore=InMemoryDocstore(),
                index_to_docstore_id={},
            )
            #apm_docs = get_apm_excel_content(configuration)
            #_vectorstore = FAISS.from_documents(apm_docs, embeddings)
            #_vectorstore.save_local(folder_path=configuration.ea4all_store, index_name=configuration.apm_faiss,)

    search_kwargs  = configuration.search_kwargs

    yield _vectorstore.as_retriever(search_type="similarity", search_kwargs=search_kwargs)

@contextmanager
def make_retriever(
    config: RunnableConfig,
) -> Generator[VectorStoreRetriever, None, None]:
    """Create a retriever for the agent, based on the current configuration."""
    configuration = BaseConfiguration.from_runnable_config(config)
    embeddings = make_text_encoder(configuration.embedding_model)
    match configuration.retriever_provider:
        case "faiss":
            with make_faiss_retriever(configuration, embeddings) as retriever:
                yield retriever

        case _:
            raise ValueError(
                "Unrecognized retriever_provider in configuration. "
                f"Expected one of: {', '.join(BaseConfiguration.__annotations__['retriever_provider'].__args__)}\n"
                f"Got: {configuration.retriever_provider}"
            )

#convert dataframe to langchain document structure, added user_ip
def panda_to_langchain_document(dataframe,user_ip):
    # create an empty list to store the documents
    apm_documents = []
    # iterate over the rows of the dataframe
    for index, row in dataframe.iterrows():
        # create a document object from the row values for all df columns
        page_content = ""
        application = ""
        capability = ""
        description = ""
        fit = ""
        roadmap = ""
        for column in dataframe.columns:
            column = ' '.join(column.split())
            page_content += f" {column}:{row[column]}"
            if 'application' in column.lower(): application = row[column]
            elif 'capabilit' in column.lower(): capability = row[column]
            elif 'desc' in column.lower(): description = row[column]
            elif 'business fit' in column.lower(): fit = row[column]
            elif 'roadmap' in column.lower(): roadmap = row[column]
        doc = Document(
            page_content=page_content, 
            metadata={
                "source": application, 
                "capability": capability,
                "description": description,
                "business fit": fit,
                "roadmap": roadmap,
                "row_number": index, "namespace": user_ip}
            )
        # append the document object to the list
        apm_documents.append(doc)
    return(apm_documents)

#local landscape data (excel file)
def apm_dataframe_loader(file):
    pd.set_option('display.max_colwidth', None)
    df = pd.read_excel(file)
    df = df.dropna(axis=0, how='all')
    df = df.dropna(axis=1, how='all')
    df.fillna('NaN')

    return df

##New APM Excel loader
#Removed df from return
def get_apm_excel_content(config:RunnableConfig, file=None, user_ip="ea4all_agent"):

    if file is None:
        file = _join_paths(
            getattr(config, "ea4all_store", BaseConfiguration.ea4all_store),
            getattr(config, "apm_catalogue", BaseConfiguration.apm_catalogue)
        )

    #load file into dataframe
    df = apm_dataframe_loader(file)
    #add user_id into df
    df['namespace'] = user_ip

    apm_docs = panda_to_langchain_document(df, user_ip)
    return apm_docs

def remove_user_apm_faiss(config, db, ea4all_user):
    #apm_vectorstore.docstore.__dict__["_dict"][apm_vectorstore.index_to_docstore_id[0]].metadata

    #check if user's uploaded any apm before
    byod = ea4all_user in str(db.docstore._dict.values())

    #if yes
    if byod:
        removed_ids = []
        for id, doc in db.docstore._dict.items():
            if doc.metadata['namespace'] == ea4all_user:
                removed_ids.append(id)
        
        ##save updated index
        if removed_ids:
            index_ids = [
                i_id
                for i_id, d_id in db.index_to_docstore_id.items()
                if d_id in removed_ids
                ]
            #Remove ids from docstore
            db.delete(ids=removed_ids) 
            #Remove the corresponding embeddings from the FAISS index
            db.index.remove_ids(np.array(index_ids,dtype=np.int64))
            #Reorg embeddings
            db.index_to_docstore_id = {
                i: d_id
                for i, d_id in enumerate(db.index_to_docstore_id.values())
                }
            #save updated index
            db.save_local(folder_path=config.ea4all_store, index_name=config.apm_faiss)

#Get faiss index as a retriever
def retriever_faiss(db, user_ip="ea4all_agent"):
    ##size: len(retriever.vectorstore.index_to_docstore_id), retriever.vectorstore.index.ntotal

    #check if user's BYOData
    byod = user_ip in str(db.docstore._dict.values())
    
    if byod==False: 
        namespace="ea4all_agent"
    else:
        namespace = user_ip

    retriever = db.as_retriever(search_type="similarity", 
        search_kwargs={'k': 50, 'score_threshold': 0.8, 'filter': {'namespace':namespace}})

    return retriever