Spaces:
Sleeping
Sleeping
File size: 12,830 Bytes
70d06c8 42c00ab 70d06c8 42c00ab 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 42c00ab 70d06c8 42c00ab 66a2ae2 70d06c8 42c00ab 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 42c00ab 70d06c8 66a2ae2 70d06c8 42c00ab 66a2ae2 70d06c8 42c00ab 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 66a2ae2 70d06c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
#!/usr/bin/env python
# coding: utf-8
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
# Remove ChatGroq import
# from langchain_groq import ChatGroq
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI
import os # Add os import for getenv
from langgraph.graph import StateGraph
from llmlingua import PromptCompressor
# Import get_model which now handles Gemini
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
from langgraph.checkpoint.sqlite import SqliteSaver
# ... (get_llm_lingua remains the same) ...
# Requires ~2GB of RAM
def get_llm_lingua(compress_method:str = "llm_lingua2"):
# Requires ~2GB memory
if compress_method == "llm_lingua2":
llm_lingua2 = PromptCompressor(
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
use_llmlingua2=True,
device_map="cpu"
)
return llm_lingua2
# Requires ~8GB memory
elif compress_method == "llm_lingua":
llm_lingua = PromptCompressor(
model_name="microsoft/phi-2",
device_map="cpu"
)
return llm_lingua
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
def compress(state: DocProcessorState, config: ConfigSchema):
"""
This node compresses last processing result for each doc using llm_lingua
"""
doc_process_histories = state["docs_in_processing"]
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
for doc_process_history in doc_process_histories:
doc_process_history.append(llm_lingua.compress_prompt(
# Use str() to ensure the input is string, handle potential non-string data
doc = str(doc_process_history[-1]),
rate=config["configurable"].get("compress_rate") or 0.33,
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
)["compressed_prompt"]
)
# --- MODIFICATION START ---
# Ensure 'process_steps' persists in the state
return {
"docs_in_processing": doc_process_histories,
"current_process_step" : state["current_process_step"] + 1,
"process_steps": state.get("process_steps", []) # Pass existing steps along
}
# --- MODIFICATION END ---
# Update default model
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
"""
This node summarizes all docs in state["valid_docs"]
"""
prompt = """You are a 3GPP standardization expert.
Summarize the provided document in simple technical English for other experts in the field.
Document:
{document}"""
sysmsg = ChatPromptTemplate.from_messages([
("system", prompt)
])
# Update default model name
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
doc_process_histories = state["docs_in_processing"]
# Use get_model to handle instantiation
llm_summarize = get_model(model)
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
for doc_process_history in doc_process_histories:
# Use str() to ensure the input is string
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
# --- MODIFICATION START ---
# Ensure 'process_steps' persists in the state
return {
"docs_in_processing": doc_process_histories,
"current_process_step": state["current_process_step"] + 1,
"process_steps": state.get("process_steps", []) # Pass existing steps along
}
# --- MODIFICATION END ---
# Update default model
def custom_process(state: DocProcessorState):
"""
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
processing_model : the LLM which will perform the processing
context : the previous processing results to send as context to the LLM
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
"""
# Use .get() for safer access in case state mapping fails earlier
process_steps_list = state.get("process_steps", [])
current_step_index = state.get("current_process_step", 0)
if not process_steps_list or current_step_index >= len(process_steps_list):
print(f"Error: Invalid current_process_step ({current_step_index}) or empty process_steps in custom_process.")
# Return state, potentially adding an error indicator if needed
return state # Or modify state to indicate error
processing_params = process_steps_list[current_step_index]
# Ensure processing_params is a dict before accessing keys
if not isinstance(processing_params, dict):
print(f"Error: Expected dictionary for process step {current_step_index}, but got {type(processing_params)}. Step details: {processing_params}")
# Decide how to handle: skip step, return error state?
# For now, let's skip this step by incrementing the counter and returning
return {
"docs_in_processing": state.get("docs_in_processing", []),
"current_process_step": current_step_index + 1,
"process_steps": process_steps_list
}
# Update default model name
model = processing_params.get("processing_model") or "gemini-2.0-flash"
user_prompt = processing_params.get("prompt", "") # Default to empty string if missing
context = processing_params.get("context", [0]) # Default to [0]
doc_process_histories = state.get("docs_in_processing", []) # Default to empty list
if not doc_process_histories:
print("Warning: docs_in_processing is empty in custom_process.")
# No docs to process, just increment step counter
return {
"docs_in_processing": [],
"current_process_step": current_step_index + 1,
"process_steps": process_steps_list
}
if not isinstance(context, list):
context = [context]
# Use get_model
processing_chain = get_model(model=model) | StrOutputParser()
for doc_process_history in doc_process_histories:
context_str = ""
for i, context_element in enumerate(context):
# Check if index is valid and history is long enough
if isinstance(context_element, int) and 0 <= context_element < len(doc_process_history):
# Use str() to ensure context element is string
context_str += f"### TECHNICAL INFORMATION {i+1} \n {str(doc_process_history[context_element])}\n\n"
else:
print(f"Warning: Invalid context index {context_element} for doc_process_history length {len(doc_process_history)}")
# Use str() ensure input is string
doc_process_history.append(processing_chain.invoke(str(context_str + user_prompt)))
# --- MODIFICATION START ---
# Ensure 'process_steps' persists in the state
return {
"docs_in_processing" : doc_process_histories,
"current_process_step" : current_step_index + 1,
"process_steps": process_steps_list # Pass existing steps along
}
# --- MODIFICATION END ---
# ... (final node remains the same) ...
def final(state: DocProcessorState):
"""
A node to store the final results of processing in the 'valid_docs' field
"""
# Ensure docs_in_processing exists and is a list of lists
docs_in_processing = state.get("docs_in_processing", [])
if not isinstance(docs_in_processing, list):
docs_in_processing = []
# Safely get the last item from each inner list, default to None if empty
final_docs = [
doc_history[-1] if isinstance(doc_history, list) and doc_history else None
for doc_history in docs_in_processing
]
# Filter out any None values that might have resulted from empty histories
valid_final_docs = [doc for doc in final_docs if doc is not None]
return {"valid_docs" : valid_final_docs}
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
"""
Initializes the processing state within the subgraph.
It receives the 'valid_docs' and potentially 'process_steps' from the parent graph.
"""
# Initialize docs_in_processing based on valid_docs received from parent state
valid_docs = state.get("valid_docs", [])
docs_in_processing_init = [[format_doc(doc)] for doc in valid_docs if doc] # Ensure doc is not None
# Explicitly return process_steps, getting it from the input state or defaulting
return {
"current_process_step": 0,
"docs_in_processing": docs_in_processing_init,
"process_steps": state.get("process_steps", []) # Ensure process_steps is set here
}
def next_processor_step(state: DocProcessorState):
"""
Conditional edge function to go to next processing step
"""
# --- MODIFICATION START ---
# Use .get() for safer access to 'process_steps' and 'current_process_step'
process_steps = state.get("process_steps", [])
current_step_index = state.get("current_process_step", 0)
# --- MODIFICATION END ---
step = "final" # Default to final step
if not isinstance(process_steps, list):
print(f"Warning: 'process_steps' is not a list ({type(process_steps)}). Proceeding to final.")
process_steps = [] # Treat as empty list
if current_step_index < len(process_steps):
step_definition = process_steps[current_step_index]
if isinstance(step_definition, dict):
# Assuming a dict means a 'custom' step based on original logic
step = "custom"
elif isinstance(step_definition, str):
# Map string to node name if it's a known processing type
step_lower = step_definition.lower()
if step_lower in ["summarize", "compress"]:
step = step_lower
else:
print(f"Warning: Unknown process step type string '{step_definition}'. Defaulting to 'custom'.")
# Or default to 'final' if unknown steps shouldn't run custom logic
# step = "final"
# Let's assume unknown strings map to custom for flexibility, adjust if needed
step = "custom"
else:
print(f"Warning: Invalid type for process step definition: {type(step_definition)}. Proceeding to final.")
step = "final" # Go to final if step definition is unexpected type
else:
# If current_step_index is out of bounds, we should go to the final step
step = "final"
print(f"Next processor step determined: {step}") # Debugging print
return step
def build_data_processor_graph(memory):
"""
Builds the data processor graph
"""
#with SqliteSaver.from_conn_string(":memory:") as memory :
graph_builder_doc_processor = StateGraph(DocProcessorState)
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
graph_builder_doc_processor.add_node("summarize", summarize_docs)
graph_builder_doc_processor.add_node("compress", compress)
graph_builder_doc_processor.add_node("custom", custom_process)
graph_builder_doc_processor.add_node("final", final)
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
# Conditional edges route FROM the node that just finished TO the next one
graph_builder_doc_processor.add_conditional_edges(
"get_process_steps", # Source node
next_processor_step, # Function to decide where to go next
# Map returned string from next_processor_step to actual node names
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"summarize", # Source node
next_processor_step,
{"compress" : "compress", "final": "final", "custom" : "custom", "summarize": "summarize"} # Added summarize for loops
)
graph_builder_doc_processor.add_conditional_edges(
"compress", # Source node
next_processor_step,
{"summarize" : "summarize", "final": "final", "custom" : "custom", "compress": "compress"} # Added compress for loops
)
graph_builder_doc_processor.add_conditional_edges(
"custom", # Source node
next_processor_step,
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
)
graph_builder_doc_processor.add_edge("final", "__end__")
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
return graph_doc_processor |