Merlin-AI-Coach / components /stage_mapping.py
naishwarya's picture
temp fix hf space
860db05
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import VectorStoreIndex, Document
# from llama_index.llms.openllm import OpenLLM
from llama_index.llms.nebius import NebiusLLM
import requests
import os
# Load environment variables from .env if present
from dotenv import load_dotenv
load_dotenv()
# Read provider, keys, and model names from environment
LLM_PROVIDER = os.environ.get("LLM_PROVIDER", "openllm").lower()
LLM_API_URL = os.environ.get("LLM_API_URL")
LLM_API_KEY = os.environ.get("LLM_API_KEY")
NEBIUS_API_KEY = os.environ.get("NEBIUS_API_KEY", "")
OPENLLM_MODEL = os.environ.get("OPENLLM_MODEL", "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16")
NEBIUS_MODEL = os.environ.get("NEBIUS_MODEL", "meta-llama/Llama-3.3-70B-Instruct")
# Choose LLM provider
if LLM_PROVIDER == "nebius":
llm = NebiusLLM(
api_key=NEBIUS_API_KEY,
model=NEBIUS_MODEL
)
else:
pass
# llm = OpenLLM(
# model=OPENLLM_MODEL,
# api_base=LLM_API_URL,
# api_key=LLM_API_KEY
# )
# Example: Define your stages and their descriptions here
STAGE_DOCS = [
Document(text="Goal setting: Define what you want to achieve."),
Document(text="Research: Gather information and resources."),
Document(text="Planning: Break down your goal into actionable steps."),
Document(text="Execution: Start working on your plan."),
Document(text="Review: Reflect on your progress and adjust as needed."),
]
# Stage-specific instructions for each stage
STAGE_INSTRUCTIONS = {
"Goal setting": (
"After trying to understand the goal, before moving to the next phase, "
"write down key objectives that the user is interested in."
),
"Research": (
"Before suggesting something to the user, think deeply about what scientific approach you are using to suggest something or ask a question. "
"Before moving to a new phase, summarize in a detailed format the key findings of research and intuition."
),
"Planning": (
"Provide a detailed actionable plan with a proper timeline. "
"Try to create tasks in 3 types: Important and have a deadline, Important but do not have a timeline, Not important and has a deadline."
),
"Execution": (
"Focus on helping the user execute the plan step by step. Offer encouragement and practical advice."
),
"Review": (
"Help the user reflect on progress, identify what worked, and suggest adjustments for future improvement."
),
}
def get_stage_instruction(stage_name):
"""
Returns the instruction string for a given stage name, or an empty string if not found.
"""
return STAGE_INSTRUCTIONS.get(stage_name, "")
def build_index():
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Always build the index from the documents, so text is present
return VectorStoreIndex.from_documents(STAGE_DOCS, embed_model=embed_model)
# Build the index once (reuse for all queries)
index = build_index()
def map_stage(user_input):
# Use your custom LLM for generative responses if needed
query_engine = index.as_query_engine(similarity_top_k=1, llm=llm)
response = query_engine.query(user_input)
# Return the most relevant stage and its details
return {
"stage": response.source_nodes[0].node.text,
"details": response.response
}
def get_stage_and_details(user_input):
"""
Helper to get stage and details for a given user input.
"""
query_engine = index.as_query_engine(similarity_top_k=1, llm=llm)
response = query_engine.query(user_input)
stage = response.source_nodes[0].node.text
details = response.response
return stage, details
def clear_vector_store():
if os.path.exists(VECTOR_STORE_PATH):
os.remove(VECTOR_STORE_PATH)
def get_stage_list():
"""
Returns the ordered list of stage names.
"""
return [
"Goal setting",
"Research",
"Planning",
"Execution",
"Review"
]
def get_next_stage(current_stage):
"""
Given the current stage name, returns the next stage name or None if at the end.
"""
stages = get_stage_list()
try:
idx = stages.index(current_stage)
if idx + 1 < len(stages):
return stages[idx + 1]
except ValueError:
pass
return None
def get_stage_index(stage_name):
"""
Returns the index of the given stage name in the ordered list, or -1 if not found.
"""
try:
return get_stage_list().index(stage_name)
except ValueError:
return -1