avfranco's picture
Update ea4all/src/shared/configuration.py
828e50a verified
"""Define the configurable parameters for the agent."""
from __future__ import annotations
import ast
from dataclasses import dataclass, field, fields
from typing import Annotated, Any, Optional, Type, TypeVar, Literal
from langchain_core.runnables import RunnableConfig, ensure_config
# This file contains sample APPLICATIONS to index
DEFAULT_APM_CATALOGUE = "APM-ea4all (test-split).xlsx"
# These files contains sample QUESTIONS
APM_MOCK_QNA = "apm_qna_mock.txt"
PMO_MOCK_QNA = "pmo_qna_mock.txt"
@dataclass(kw_only=True)
class BaseConfiguration:
"""Configuration class for all Agents.
This class defines the parameters needed for configuring the indexing and
retrieval processes, including embedding model selection, retriever provider choice, and search parameters.
"""
supervisor_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
default="gpt-4o-mini",
metadata={
"description": "The language model used for supervisor agents. Should be in the form: provider/model-name."
},
)
api_base_url: Annotated[str, {"__template_metadata__": {"kind": "hosting"}}] = field(
default="https://api-inference.huggingface.co/models/",
metadata={
"description": "The base url for models hosted on Hugging Face's model hub."
},
)
max_tokens: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field(
default=4096,
metadata={
"description": "The maximum number of tokens allowed for in general question and answer model."
},
)
temperature: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field(
default=0,
metadata={
"description": "The default tempature to infere the LLM."
},
)
streaming: Annotated[bool, {"__template_metadata__": {"kind": "llm"}}] = field(
default=True,
metadata={
"description": "Default streaming mode."
},
)
ea4all_images: str = field(
default="ea4all/images",
metadata={
"description": "Configuration for the EA4ALL images folder."
},
)
ea4all_store: Annotated[str, {"__template_metadata__": {"kind": "infra"}}] = field(
default="ea4all/ea4all_store",
metadata={
"description": "The EA4ALL folder for mock & demo content."
},
)
ea4all_ask_human: Annotated[str, {"__template_metadata__": {"kind": "integration"}}] = field(
default="interrupt", #"Frontend"
metadata={
"description": "Trigger EA4ALL ask human input via interruption or receive from external frontend."
},
)
ea4all_recursion_limit: Annotated[int, {"__template_metadata__": {"kind": "graph"}}] = field(
default=25,
metadata={
"description": "Maximum recursion allowed for EA4ALL graphs."
},
)
# models
embedding_model: Annotated[str, {"__template_metadata__": {"kind": "embeddings"}}] = field(
default="openai/text-embedding-3-small",
metadata={
"description": "Name of the embedding model to use. Must be a valid embedding model name."
},
)
retriever_provider: Annotated[
Literal["faiss"],
{"__template_metadata__": {"kind": "retriever"}},
] = field(
default="faiss",
metadata={
"description": "The vector store provider to use for retrieval. Options are 'FAISS' at moment only."
},
)
apm_faiss: Annotated[str, {"__template_metadata__": {"kind": "infra"}}] = field(
default="apm_faiss_index",
metadata={
"description": "The EA4ALL APM default Vectorstore index name."
},
)
apm_catalogue: str = field(
default=DEFAULT_APM_CATALOGUE,
metadata={
"description": "The EA4ALL APM default Vectorstore index name."
},
)
search_kwargs: Annotated[str, {"__template_metadata__": {"kind": "retriever"}}] = field(
#default="{'k': 50, 'score_threshold': 0.8, 'filter': {'namespace':'ea4all_agent'}}",
default="{'k':10, 'fetch_k':50}",
metadata={
"description": "Additional keyword arguments to pass to the search function of the retriever."
}
)
def __post_init__(self):
# Convert search_kwargs from string to dictionary
try:
if isinstance(self.search_kwargs, str):
self.search_kwargs = ast.literal_eval(self.search_kwargs)
except (SyntaxError, ValueError):
# Fallback to an empty dict or log an error
self.search_kwargs = {}
print("Error parsing search_kwargs")
@classmethod
def from_runnable_config(
cls: Type[T], config: Optional[RunnableConfig] = None
) -> T:
"""Create an IndexConfiguration instance from a RunnableConfig object.
Args:
cls (Type[T]): The class itself.
config (Optional[RunnableConfig]): The configuration object to use.
Returns:
T: An instance of IndexConfiguration with the specified configuration.
"""
config = ensure_config(config)
configurable = config.get("configurable") or {}
_fields = {f.name for f in fields(cls) if f.init}
# Special handling for search_kwargs
if 'search_kwargs' in configurable and isinstance(configurable['search_kwargs'], str):
try:
configurable['search_kwargs'] = ast.literal_eval(configurable['search_kwargs'])
except (SyntaxError, ValueError):
configurable['search_kwargs'] = {}
return cls(**{k: v for k, v in configurable.items() if k in _fields})
T = TypeVar("T", bound=BaseConfiguration)