#!/usr/bin/env python # coding: utf-8 from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate # Remove ChatGroq import # from langchain_groq import ChatGroq # Add ChatGoogleGenerativeAI import from langchain_google_genai import ChatGoogleGenerativeAI import os # Add os import for getenv from langgraph.graph import StateGraph from llmlingua import PromptCompressor # Import get_model which now handles Gemini from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc from langgraph.checkpoint.sqlite import SqliteSaver # ... (get_llm_lingua remains the same) ... # Requires ~2GB of RAM def get_llm_lingua(compress_method:str = "llm_lingua2"): # Requires ~2GB memory if compress_method == "llm_lingua2": llm_lingua2 = PromptCompressor( model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, device_map="cpu" ) return llm_lingua2 # Requires ~8GB memory elif compress_method == "llm_lingua": llm_lingua = PromptCompressor( model_name="microsoft/phi-2", device_map="cpu" ) return llm_lingua raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'") def compress(state: DocProcessorState, config: ConfigSchema): """ This node compresses last processing result for each doc using llm_lingua """ doc_process_histories = state["docs_in_processing"] llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2") for doc_process_history in doc_process_histories: doc_process_history.append(llm_lingua.compress_prompt( # Use str() to ensure the input is string, handle potential non-string data doc = str(doc_process_history[-1]), rate=config["configurable"].get("compress_rate") or 0.33, force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ','] )["compressed_prompt"] ) # --- MODIFICATION START --- # Ensure 'process_steps' persists in the state return { "docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1, "process_steps": state.get("process_steps", []) # Pass existing steps along } # --- MODIFICATION END --- # Update default model def summarize_docs(state: DocProcessorState, config: ConfigSchema): """ This node summarizes all docs in state["valid_docs"] """ prompt = """You are a 3GPP standardization expert. Summarize the provided document in simple technical English for other experts in the field. Document: {document}""" sysmsg = ChatPromptTemplate.from_messages([ ("system", prompt) ]) # Update default model name model = config["configurable"].get("summarize_model") or "gemini-2.0-flash" doc_process_histories = state["docs_in_processing"] # Use get_model to handle instantiation llm_summarize = get_model(model) summarize_chain = sysmsg | llm_summarize | StrOutputParser() for doc_process_history in doc_process_histories: # Use str() to ensure the input is string doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])})) # --- MODIFICATION START --- # Ensure 'process_steps' persists in the state return { "docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1, "process_steps": state.get("process_steps", []) # Pass existing steps along } # --- MODIFICATION END --- # Update default model def custom_process(state: DocProcessorState): """ Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]] processing_model : the LLM which will perform the processing context : the previous processing results to send as context to the LLM user_prompt : the prompt/task which will be appended to the context before sending to the LLM """ # Use .get() for safer access in case state mapping fails earlier process_steps_list = state.get("process_steps", []) current_step_index = state.get("current_process_step", 0) if not process_steps_list or current_step_index >= len(process_steps_list): print(f"Error: Invalid current_process_step ({current_step_index}) or empty process_steps in custom_process.") # Return state, potentially adding an error indicator if needed return state # Or modify state to indicate error processing_params = process_steps_list[current_step_index] # Ensure processing_params is a dict before accessing keys if not isinstance(processing_params, dict): print(f"Error: Expected dictionary for process step {current_step_index}, but got {type(processing_params)}. Step details: {processing_params}") # Decide how to handle: skip step, return error state? # For now, let's skip this step by incrementing the counter and returning return { "docs_in_processing": state.get("docs_in_processing", []), "current_process_step": current_step_index + 1, "process_steps": process_steps_list } # Update default model name model = processing_params.get("processing_model") or "gemini-2.0-flash" user_prompt = processing_params.get("prompt", "") # Default to empty string if missing context = processing_params.get("context", [0]) # Default to [0] doc_process_histories = state.get("docs_in_processing", []) # Default to empty list if not doc_process_histories: print("Warning: docs_in_processing is empty in custom_process.") # No docs to process, just increment step counter return { "docs_in_processing": [], "current_process_step": current_step_index + 1, "process_steps": process_steps_list } if not isinstance(context, list): context = [context] # Use get_model processing_chain = get_model(model=model) | StrOutputParser() for doc_process_history in doc_process_histories: context_str = "" for i, context_element in enumerate(context): # Check if index is valid and history is long enough if isinstance(context_element, int) and 0 <= context_element < len(doc_process_history): # Use str() to ensure context element is string context_str += f"### TECHNICAL INFORMATION {i+1} \n {str(doc_process_history[context_element])}\n\n" else: print(f"Warning: Invalid context index {context_element} for doc_process_history length {len(doc_process_history)}") # Use str() ensure input is string doc_process_history.append(processing_chain.invoke(str(context_str + user_prompt))) # --- MODIFICATION START --- # Ensure 'process_steps' persists in the state return { "docs_in_processing" : doc_process_histories, "current_process_step" : current_step_index + 1, "process_steps": process_steps_list # Pass existing steps along } # --- MODIFICATION END --- # ... (final node remains the same) ... def final(state: DocProcessorState): """ A node to store the final results of processing in the 'valid_docs' field """ # Ensure docs_in_processing exists and is a list of lists docs_in_processing = state.get("docs_in_processing", []) if not isinstance(docs_in_processing, list): docs_in_processing = [] # Safely get the last item from each inner list, default to None if empty final_docs = [ doc_history[-1] if isinstance(doc_history, list) and doc_history else None for doc_history in docs_in_processing ] # Filter out any None values that might have resulted from empty histories valid_final_docs = [doc for doc in final_docs if doc is not None] return {"valid_docs" : valid_final_docs} def get_process_steps(state: DocProcessorState, config: ConfigSchema): """ Initializes the processing state within the subgraph. It receives the 'valid_docs' and potentially 'process_steps' from the parent graph. """ # Initialize docs_in_processing based on valid_docs received from parent state valid_docs = state.get("valid_docs", []) docs_in_processing_init = [[format_doc(doc)] for doc in valid_docs if doc] # Ensure doc is not None # Explicitly return process_steps, getting it from the input state or defaulting return { "current_process_step": 0, "docs_in_processing": docs_in_processing_init, "process_steps": state.get("process_steps", []) # Ensure process_steps is set here } def next_processor_step(state: DocProcessorState): """ Conditional edge function to go to next processing step """ # --- MODIFICATION START --- # Use .get() for safer access to 'process_steps' and 'current_process_step' process_steps = state.get("process_steps", []) current_step_index = state.get("current_process_step", 0) # --- MODIFICATION END --- step = "final" # Default to final step if not isinstance(process_steps, list): print(f"Warning: 'process_steps' is not a list ({type(process_steps)}). Proceeding to final.") process_steps = [] # Treat as empty list if current_step_index < len(process_steps): step_definition = process_steps[current_step_index] if isinstance(step_definition, dict): # Assuming a dict means a 'custom' step based on original logic step = "custom" elif isinstance(step_definition, str): # Map string to node name if it's a known processing type step_lower = step_definition.lower() if step_lower in ["summarize", "compress"]: step = step_lower else: print(f"Warning: Unknown process step type string '{step_definition}'. Defaulting to 'custom'.") # Or default to 'final' if unknown steps shouldn't run custom logic # step = "final" # Let's assume unknown strings map to custom for flexibility, adjust if needed step = "custom" else: print(f"Warning: Invalid type for process step definition: {type(step_definition)}. Proceeding to final.") step = "final" # Go to final if step definition is unexpected type else: # If current_step_index is out of bounds, we should go to the final step step = "final" print(f"Next processor step determined: {step}") # Debugging print return step def build_data_processor_graph(memory): """ Builds the data processor graph """ #with SqliteSaver.from_conn_string(":memory:") as memory : graph_builder_doc_processor = StateGraph(DocProcessorState) graph_builder_doc_processor.add_node("get_process_steps", get_process_steps) graph_builder_doc_processor.add_node("summarize", summarize_docs) graph_builder_doc_processor.add_node("compress", compress) graph_builder_doc_processor.add_node("custom", custom_process) graph_builder_doc_processor.add_node("final", final) graph_builder_doc_processor.add_edge("__start__", "get_process_steps") # Conditional edges route FROM the node that just finished TO the next one graph_builder_doc_processor.add_conditional_edges( "get_process_steps", # Source node next_processor_step, # Function to decide where to go next # Map returned string from next_processor_step to actual node names {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"} ) graph_builder_doc_processor.add_conditional_edges( "summarize", # Source node next_processor_step, {"compress" : "compress", "final": "final", "custom" : "custom", "summarize": "summarize"} # Added summarize for loops ) graph_builder_doc_processor.add_conditional_edges( "compress", # Source node next_processor_step, {"summarize" : "summarize", "final": "final", "custom" : "custom", "compress": "compress"} # Added compress for loops ) graph_builder_doc_processor.add_conditional_edges( "custom", # Source node next_processor_step, {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"} ) graph_builder_doc_processor.add_edge("final", "__end__") graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory) return graph_doc_processor