|
"""Define the configurable parameters for the agent.""" |
|
|
|
from __future__ import annotations |
|
|
|
import ast |
|
from dataclasses import dataclass, field, fields |
|
from typing import Annotated, Any, Optional, Type, TypeVar, Literal |
|
|
|
from langchain_core.runnables import RunnableConfig, ensure_config |
|
|
|
|
|
DEFAULT_APM_CATALOGUE = "APM-ea4all (test-split).xlsx" |
|
|
|
|
|
APM_MOCK_QNA = "apm_qna_mock.txt" |
|
PMO_MOCK_QNA = "pmo_qna_mock.txt" |
|
|
|
@dataclass(kw_only=True) |
|
class BaseConfiguration: |
|
"""Configuration class for all Agents. |
|
|
|
This class defines the parameters needed for configuring the indexing and |
|
retrieval processes, including embedding model selection, retriever provider choice, and search parameters. |
|
""" |
|
|
|
supervisor_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field( |
|
default="gpt-4o-mini", |
|
metadata={ |
|
"description": "The language model used for supervisor agents. Should be in the form: provider/model-name." |
|
}, |
|
) |
|
|
|
api_base_url: Annotated[str, {"__template_metadata__": {"kind": "hosting"}}] = field( |
|
default="https://api-inference.huggingface.co/models/", |
|
metadata={ |
|
"description": "The base url for models hosted on Hugging Face's model hub." |
|
}, |
|
) |
|
|
|
max_tokens: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field( |
|
default=4096, |
|
metadata={ |
|
"description": "The maximum number of tokens allowed for in general question and answer model." |
|
}, |
|
) |
|
|
|
temperature: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field( |
|
default=0, |
|
metadata={ |
|
"description": "The default tempature to infere the LLM." |
|
}, |
|
) |
|
|
|
streaming: Annotated[bool, {"__template_metadata__": {"kind": "llm"}}] = field( |
|
default=True, |
|
metadata={ |
|
"description": "Default streaming mode." |
|
}, |
|
) |
|
|
|
ea4all_images: str = field( |
|
default="ea4all/images", |
|
metadata={ |
|
"description": "Configuration for the EA4ALL images folder." |
|
}, |
|
) |
|
|
|
ea4all_store: Annotated[str, {"__template_metadata__": {"kind": "infra"}}] = field( |
|
default="ea4all/ea4all_store", |
|
metadata={ |
|
"description": "The EA4ALL folder for mock & demo content." |
|
}, |
|
) |
|
|
|
ea4all_ask_human: Annotated[str, {"__template_metadata__": {"kind": "integration"}}] = field( |
|
default="interrupt", |
|
metadata={ |
|
"description": "Trigger EA4ALL ask human input via interruption or receive from external frontend." |
|
}, |
|
) |
|
|
|
ea4all_recursion_limit: Annotated[int, {"__template_metadata__": {"kind": "graph"}}] = field( |
|
default=25, |
|
metadata={ |
|
"description": "Maximum recursion allowed for EA4ALL graphs." |
|
}, |
|
) |
|
|
|
|
|
embedding_model: Annotated[str, {"__template_metadata__": {"kind": "embeddings"}}] = field( |
|
default="openai/text-embedding-3-small", |
|
metadata={ |
|
"description": "Name of the embedding model to use. Must be a valid embedding model name." |
|
}, |
|
) |
|
|
|
retriever_provider: Annotated[ |
|
Literal["faiss"], |
|
{"__template_metadata__": {"kind": "retriever"}}, |
|
] = field( |
|
default="faiss", |
|
metadata={ |
|
"description": "The vector store provider to use for retrieval. Options are 'FAISS' at moment only." |
|
}, |
|
) |
|
|
|
apm_faiss: Annotated[str, {"__template_metadata__": {"kind": "infra"}}] = field( |
|
default="apm_faiss_index", |
|
metadata={ |
|
"description": "The EA4ALL APM default Vectorstore index name." |
|
}, |
|
) |
|
|
|
apm_catalogue: str = field( |
|
default=DEFAULT_APM_CATALOGUE, |
|
metadata={ |
|
"description": "The EA4ALL APM default Vectorstore index name." |
|
}, |
|
) |
|
|
|
search_kwargs: Annotated[str, {"__template_metadata__": {"kind": "retriever"}}] = field( |
|
|
|
default="{'k':10, 'fetch_k':50}", |
|
metadata={ |
|
"description": "Additional keyword arguments to pass to the search function of the retriever." |
|
} |
|
) |
|
|
|
def __post_init__(self): |
|
|
|
try: |
|
if isinstance(self.search_kwargs, str): |
|
self.search_kwargs = ast.literal_eval(self.search_kwargs) |
|
except (SyntaxError, ValueError): |
|
|
|
self.search_kwargs = {} |
|
print("Error parsing search_kwargs") |
|
|
|
@classmethod |
|
def from_runnable_config( |
|
cls: Type[T], config: Optional[RunnableConfig] = None |
|
) -> T: |
|
"""Create an IndexConfiguration instance from a RunnableConfig object. |
|
|
|
Args: |
|
cls (Type[T]): The class itself. |
|
config (Optional[RunnableConfig]): The configuration object to use. |
|
|
|
Returns: |
|
T: An instance of IndexConfiguration with the specified configuration. |
|
""" |
|
config = ensure_config(config) |
|
configurable = config.get("configurable") or {} |
|
_fields = {f.name for f in fields(cls) if f.init} |
|
|
|
|
|
if 'search_kwargs' in configurable and isinstance(configurable['search_kwargs'], str): |
|
try: |
|
configurable['search_kwargs'] = ast.literal_eval(configurable['search_kwargs']) |
|
except (SyntaxError, ValueError): |
|
configurable['search_kwargs'] = {} |
|
|
|
return cls(**{k: v for k, v in configurable.items() if k in _fields}) |
|
|
|
T = TypeVar("T", bound=BaseConfiguration) |
|
|