kig_test / ki_gen /data_processor.py
adrienbrdne's picture
Update ki_gen/data_processor.py
66a2ae2 verified
#!/usr/bin/env python
# coding: utf-8
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
# Remove ChatGroq import
# from langchain_groq import ChatGroq
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI
import os # Add os import for getenv
from langgraph.graph import StateGraph
from llmlingua import PromptCompressor
# Import get_model which now handles Gemini
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
from langgraph.checkpoint.sqlite import SqliteSaver
# ... (get_llm_lingua remains the same) ...
# Requires ~2GB of RAM
def get_llm_lingua(compress_method:str = "llm_lingua2"):
# Requires ~2GB memory
if compress_method == "llm_lingua2":
llm_lingua2 = PromptCompressor(
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
use_llmlingua2=True,
device_map="cpu"
)
return llm_lingua2
# Requires ~8GB memory
elif compress_method == "llm_lingua":
llm_lingua = PromptCompressor(
model_name="microsoft/phi-2",
device_map="cpu"
)
return llm_lingua
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
def compress(state: DocProcessorState, config: ConfigSchema):
"""
This node compresses last processing result for each doc using llm_lingua
"""
doc_process_histories = state["docs_in_processing"]
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
for doc_process_history in doc_process_histories:
doc_process_history.append(llm_lingua.compress_prompt(
# Use str() to ensure the input is string, handle potential non-string data
doc = str(doc_process_history[-1]),
rate=config["configurable"].get("compress_rate") or 0.33,
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
)["compressed_prompt"]
)
# --- MODIFICATION START ---
# Ensure 'process_steps' persists in the state
return {
"docs_in_processing": doc_process_histories,
"current_process_step" : state["current_process_step"] + 1,
"process_steps": state.get("process_steps", []) # Pass existing steps along
}
# --- MODIFICATION END ---
# Update default model
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
"""
This node summarizes all docs in state["valid_docs"]
"""
prompt = """You are a 3GPP standardization expert.
Summarize the provided document in simple technical English for other experts in the field.
Document:
{document}"""
sysmsg = ChatPromptTemplate.from_messages([
("system", prompt)
])
# Update default model name
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
doc_process_histories = state["docs_in_processing"]
# Use get_model to handle instantiation
llm_summarize = get_model(model)
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
for doc_process_history in doc_process_histories:
# Use str() to ensure the input is string
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
# --- MODIFICATION START ---
# Ensure 'process_steps' persists in the state
return {
"docs_in_processing": doc_process_histories,
"current_process_step": state["current_process_step"] + 1,
"process_steps": state.get("process_steps", []) # Pass existing steps along
}
# --- MODIFICATION END ---
# Update default model
def custom_process(state: DocProcessorState):
"""
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
processing_model : the LLM which will perform the processing
context : the previous processing results to send as context to the LLM
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
"""
# Use .get() for safer access in case state mapping fails earlier
process_steps_list = state.get("process_steps", [])
current_step_index = state.get("current_process_step", 0)
if not process_steps_list or current_step_index >= len(process_steps_list):
print(f"Error: Invalid current_process_step ({current_step_index}) or empty process_steps in custom_process.")
# Return state, potentially adding an error indicator if needed
return state # Or modify state to indicate error
processing_params = process_steps_list[current_step_index]
# Ensure processing_params is a dict before accessing keys
if not isinstance(processing_params, dict):
print(f"Error: Expected dictionary for process step {current_step_index}, but got {type(processing_params)}. Step details: {processing_params}")
# Decide how to handle: skip step, return error state?
# For now, let's skip this step by incrementing the counter and returning
return {
"docs_in_processing": state.get("docs_in_processing", []),
"current_process_step": current_step_index + 1,
"process_steps": process_steps_list
}
# Update default model name
model = processing_params.get("processing_model") or "gemini-2.0-flash"
user_prompt = processing_params.get("prompt", "") # Default to empty string if missing
context = processing_params.get("context", [0]) # Default to [0]
doc_process_histories = state.get("docs_in_processing", []) # Default to empty list
if not doc_process_histories:
print("Warning: docs_in_processing is empty in custom_process.")
# No docs to process, just increment step counter
return {
"docs_in_processing": [],
"current_process_step": current_step_index + 1,
"process_steps": process_steps_list
}
if not isinstance(context, list):
context = [context]
# Use get_model
processing_chain = get_model(model=model) | StrOutputParser()
for doc_process_history in doc_process_histories:
context_str = ""
for i, context_element in enumerate(context):
# Check if index is valid and history is long enough
if isinstance(context_element, int) and 0 <= context_element < len(doc_process_history):
# Use str() to ensure context element is string
context_str += f"### TECHNICAL INFORMATION {i+1} \n {str(doc_process_history[context_element])}\n\n"
else:
print(f"Warning: Invalid context index {context_element} for doc_process_history length {len(doc_process_history)}")
# Use str() ensure input is string
doc_process_history.append(processing_chain.invoke(str(context_str + user_prompt)))
# --- MODIFICATION START ---
# Ensure 'process_steps' persists in the state
return {
"docs_in_processing" : doc_process_histories,
"current_process_step" : current_step_index + 1,
"process_steps": process_steps_list # Pass existing steps along
}
# --- MODIFICATION END ---
# ... (final node remains the same) ...
def final(state: DocProcessorState):
"""
A node to store the final results of processing in the 'valid_docs' field
"""
# Ensure docs_in_processing exists and is a list of lists
docs_in_processing = state.get("docs_in_processing", [])
if not isinstance(docs_in_processing, list):
docs_in_processing = []
# Safely get the last item from each inner list, default to None if empty
final_docs = [
doc_history[-1] if isinstance(doc_history, list) and doc_history else None
for doc_history in docs_in_processing
]
# Filter out any None values that might have resulted from empty histories
valid_final_docs = [doc for doc in final_docs if doc is not None]
return {"valid_docs" : valid_final_docs}
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
"""
Initializes the processing state within the subgraph.
It receives the 'valid_docs' and potentially 'process_steps' from the parent graph.
"""
# Initialize docs_in_processing based on valid_docs received from parent state
valid_docs = state.get("valid_docs", [])
docs_in_processing_init = [[format_doc(doc)] for doc in valid_docs if doc] # Ensure doc is not None
# Explicitly return process_steps, getting it from the input state or defaulting
return {
"current_process_step": 0,
"docs_in_processing": docs_in_processing_init,
"process_steps": state.get("process_steps", []) # Ensure process_steps is set here
}
def next_processor_step(state: DocProcessorState):
"""
Conditional edge function to go to next processing step
"""
# --- MODIFICATION START ---
# Use .get() for safer access to 'process_steps' and 'current_process_step'
process_steps = state.get("process_steps", [])
current_step_index = state.get("current_process_step", 0)
# --- MODIFICATION END ---
step = "final" # Default to final step
if not isinstance(process_steps, list):
print(f"Warning: 'process_steps' is not a list ({type(process_steps)}). Proceeding to final.")
process_steps = [] # Treat as empty list
if current_step_index < len(process_steps):
step_definition = process_steps[current_step_index]
if isinstance(step_definition, dict):
# Assuming a dict means a 'custom' step based on original logic
step = "custom"
elif isinstance(step_definition, str):
# Map string to node name if it's a known processing type
step_lower = step_definition.lower()
if step_lower in ["summarize", "compress"]:
step = step_lower
else:
print(f"Warning: Unknown process step type string '{step_definition}'. Defaulting to 'custom'.")
# Or default to 'final' if unknown steps shouldn't run custom logic
# step = "final"
# Let's assume unknown strings map to custom for flexibility, adjust if needed
step = "custom"
else:
print(f"Warning: Invalid type for process step definition: {type(step_definition)}. Proceeding to final.")
step = "final" # Go to final if step definition is unexpected type
else:
# If current_step_index is out of bounds, we should go to the final step
step = "final"
print(f"Next processor step determined: {step}") # Debugging print
return step
def build_data_processor_graph(memory):
"""
Builds the data processor graph
"""
#with SqliteSaver.from_conn_string(":memory:") as memory :
graph_builder_doc_processor = StateGraph(DocProcessorState)
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
graph_builder_doc_processor.add_node("summarize", summarize_docs)
graph_builder_doc_processor.add_node("compress", compress)
graph_builder_doc_processor.add_node("custom", custom_process)
graph_builder_doc_processor.add_node("final", final)
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
# Conditional edges route FROM the node that just finished TO the next one
graph_builder_doc_processor.add_conditional_edges(
"get_process_steps", # Source node
next_processor_step, # Function to decide where to go next
# Map returned string from next_processor_step to actual node names
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"summarize", # Source node
next_processor_step,
{"compress" : "compress", "final": "final", "custom" : "custom", "summarize": "summarize"} # Added summarize for loops
)
graph_builder_doc_processor.add_conditional_edges(
"compress", # Source node
next_processor_step,
{"summarize" : "summarize", "final": "final", "custom" : "custom", "compress": "compress"} # Added compress for loops
)
graph_builder_doc_processor.add_conditional_edges(
"custom", # Source node
next_processor_step,
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
)
graph_builder_doc_processor.add_edge("final", "__end__")
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
return graph_doc_processor