Spaces:
Sleeping
Sleeping
Update ki_gen/data_processor.py
Browse files- ki_gen/data_processor.py +154 -43
ki_gen/data_processor.py
CHANGED
@@ -14,13 +14,11 @@ from langgraph.graph import StateGraph
|
|
14 |
from llmlingua import PromptCompressor
|
15 |
|
16 |
# Import get_model which now handles Gemini
|
17 |
-
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
|
18 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
19 |
|
20 |
|
21 |
-
# ... (
|
22 |
-
|
23 |
-
|
24 |
# Requires ~2GB of RAM
|
25 |
def get_llm_lingua(compress_method:str = "llm_lingua2"):
|
26 |
|
@@ -32,7 +30,7 @@ def get_llm_lingua(compress_method:str = "llm_lingua2"):
|
|
32 |
device_map="cpu"
|
33 |
)
|
34 |
return llm_lingua2
|
35 |
-
|
36 |
# Requires ~8GB memory
|
37 |
elif compress_method == "llm_lingua":
|
38 |
llm_lingua = PromptCompressor(
|
@@ -43,7 +41,6 @@ def get_llm_lingua(compress_method:str = "llm_lingua2"):
|
|
43 |
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
|
44 |
|
45 |
|
46 |
-
|
47 |
def compress(state: DocProcessorState, config: ConfigSchema):
|
48 |
"""
|
49 |
This node compresses last processing result for each doc using llm_lingua
|
@@ -52,13 +49,21 @@ def compress(state: DocProcessorState, config: ConfigSchema):
|
|
52 |
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
|
53 |
for doc_process_history in doc_process_histories:
|
54 |
doc_process_history.append(llm_lingua.compress_prompt(
|
|
|
55 |
doc = str(doc_process_history[-1]),
|
56 |
rate=config["configurable"].get("compress_rate") or 0.33,
|
57 |
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
|
58 |
)["compressed_prompt"]
|
59 |
)
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# Update default model
|
64 |
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
|
@@ -75,16 +80,24 @@ Document:
|
|
75 |
("system", prompt)
|
76 |
])
|
77 |
# Update default model name
|
78 |
-
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
|
79 |
doc_process_histories = state["docs_in_processing"]
|
80 |
# Use get_model to handle instantiation
|
81 |
-
llm_summarize = get_model(model)
|
82 |
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
|
83 |
|
84 |
for doc_process_history in doc_process_histories:
|
|
|
85 |
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
|
86 |
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Update default model
|
90 |
def custom_process(state: DocProcessorState):
|
@@ -94,13 +107,45 @@ def custom_process(state: DocProcessorState):
|
|
94 |
context : the previous processing results to send as context to the LLM
|
95 |
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
|
96 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
processing_params = state["process_steps"][state["current_process_step"]]
|
99 |
# Update default model name
|
100 |
-
model = processing_params.get("processing_model") or "gemini-2.0-flash"
|
101 |
-
user_prompt = processing_params
|
102 |
-
context = processing_params.get("context")
|
103 |
-
doc_process_histories = state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
if not isinstance(context, list):
|
105 |
context = [context]
|
106 |
|
@@ -110,49 +155,112 @@ def custom_process(state: DocProcessorState):
|
|
110 |
for doc_process_history in doc_process_histories:
|
111 |
context_str = ""
|
112 |
for i, context_element in enumerate(context):
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
def final(state: DocProcessorState):
|
121 |
"""
|
122 |
-
A node to store the final results of processing in the 'valid_docs' field
|
123 |
"""
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
# TODO : remove this node and use conditional entry point instead
|
127 |
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
|
128 |
"""
|
129 |
-
|
|
|
130 |
"""
|
131 |
-
#
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
|
136 |
def next_processor_step(state: DocProcessorState):
|
137 |
"""
|
138 |
Conditional edge function to go to next processing step
|
139 |
"""
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
step = "custom"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
else:
|
|
|
146 |
step = "final"
|
147 |
|
|
|
148 |
return step
|
149 |
|
150 |
|
|
|
151 |
def build_data_processor_graph(memory):
|
152 |
"""
|
153 |
Builds the data processor graph
|
154 |
"""
|
155 |
-
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
156 |
|
157 |
graph_builder_doc_processor = StateGraph(DocProcessorState)
|
158 |
|
@@ -163,27 +271,30 @@ def build_data_processor_graph(memory):
|
|
163 |
graph_builder_doc_processor.add_node("final", final)
|
164 |
|
165 |
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
|
|
|
|
|
166 |
graph_builder_doc_processor.add_conditional_edges(
|
167 |
-
"get_process_steps",
|
168 |
-
next_processor_step,
|
|
|
169 |
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
|
170 |
)
|
171 |
graph_builder_doc_processor.add_conditional_edges(
|
172 |
-
"summarize",
|
173 |
next_processor_step,
|
174 |
-
{"compress" : "compress", "final": "final", "custom" : "custom"}
|
175 |
)
|
176 |
graph_builder_doc_processor.add_conditional_edges(
|
177 |
-
"compress",
|
178 |
next_processor_step,
|
179 |
-
{"summarize" : "summarize", "final": "final", "custom" : "custom"}
|
180 |
)
|
181 |
graph_builder_doc_processor.add_conditional_edges(
|
182 |
-
"custom",
|
183 |
next_processor_step,
|
184 |
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
|
185 |
)
|
186 |
graph_builder_doc_processor.add_edge("final", "__end__")
|
187 |
-
|
188 |
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
|
189 |
return graph_doc_processor
|
|
|
14 |
from llmlingua import PromptCompressor
|
15 |
|
16 |
# Import get_model which now handles Gemini
|
17 |
+
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
|
18 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
19 |
|
20 |
|
21 |
+
# ... (get_llm_lingua remains the same) ...
|
|
|
|
|
22 |
# Requires ~2GB of RAM
|
23 |
def get_llm_lingua(compress_method:str = "llm_lingua2"):
|
24 |
|
|
|
30 |
device_map="cpu"
|
31 |
)
|
32 |
return llm_lingua2
|
33 |
+
|
34 |
# Requires ~8GB memory
|
35 |
elif compress_method == "llm_lingua":
|
36 |
llm_lingua = PromptCompressor(
|
|
|
41 |
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
|
42 |
|
43 |
|
|
|
44 |
def compress(state: DocProcessorState, config: ConfigSchema):
|
45 |
"""
|
46 |
This node compresses last processing result for each doc using llm_lingua
|
|
|
49 |
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
|
50 |
for doc_process_history in doc_process_histories:
|
51 |
doc_process_history.append(llm_lingua.compress_prompt(
|
52 |
+
# Use str() to ensure the input is string, handle potential non-string data
|
53 |
doc = str(doc_process_history[-1]),
|
54 |
rate=config["configurable"].get("compress_rate") or 0.33,
|
55 |
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
|
56 |
)["compressed_prompt"]
|
57 |
)
|
58 |
+
|
59 |
+
# --- MODIFICATION START ---
|
60 |
+
# Ensure 'process_steps' persists in the state
|
61 |
+
return {
|
62 |
+
"docs_in_processing": doc_process_histories,
|
63 |
+
"current_process_step" : state["current_process_step"] + 1,
|
64 |
+
"process_steps": state.get("process_steps", []) # Pass existing steps along
|
65 |
+
}
|
66 |
+
# --- MODIFICATION END ---
|
67 |
|
68 |
# Update default model
|
69 |
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
|
|
|
80 |
("system", prompt)
|
81 |
])
|
82 |
# Update default model name
|
83 |
+
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
|
84 |
doc_process_histories = state["docs_in_processing"]
|
85 |
# Use get_model to handle instantiation
|
86 |
+
llm_summarize = get_model(model)
|
87 |
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
|
88 |
|
89 |
for doc_process_history in doc_process_histories:
|
90 |
+
# Use str() to ensure the input is string
|
91 |
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
|
92 |
|
93 |
+
# --- MODIFICATION START ---
|
94 |
+
# Ensure 'process_steps' persists in the state
|
95 |
+
return {
|
96 |
+
"docs_in_processing": doc_process_histories,
|
97 |
+
"current_process_step": state["current_process_step"] + 1,
|
98 |
+
"process_steps": state.get("process_steps", []) # Pass existing steps along
|
99 |
+
}
|
100 |
+
# --- MODIFICATION END ---
|
101 |
|
102 |
# Update default model
|
103 |
def custom_process(state: DocProcessorState):
|
|
|
107 |
context : the previous processing results to send as context to the LLM
|
108 |
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
|
109 |
"""
|
110 |
+
# Use .get() for safer access in case state mapping fails earlier
|
111 |
+
process_steps_list = state.get("process_steps", [])
|
112 |
+
current_step_index = state.get("current_process_step", 0)
|
113 |
+
|
114 |
+
if not process_steps_list or current_step_index >= len(process_steps_list):
|
115 |
+
print(f"Error: Invalid current_process_step ({current_step_index}) or empty process_steps in custom_process.")
|
116 |
+
# Return state, potentially adding an error indicator if needed
|
117 |
+
return state # Or modify state to indicate error
|
118 |
+
|
119 |
+
processing_params = process_steps_list[current_step_index]
|
120 |
+
|
121 |
+
# Ensure processing_params is a dict before accessing keys
|
122 |
+
if not isinstance(processing_params, dict):
|
123 |
+
print(f"Error: Expected dictionary for process step {current_step_index}, but got {type(processing_params)}. Step details: {processing_params}")
|
124 |
+
# Decide how to handle: skip step, return error state?
|
125 |
+
# For now, let's skip this step by incrementing the counter and returning
|
126 |
+
return {
|
127 |
+
"docs_in_processing": state.get("docs_in_processing", []),
|
128 |
+
"current_process_step": current_step_index + 1,
|
129 |
+
"process_steps": process_steps_list
|
130 |
+
}
|
131 |
+
|
132 |
|
|
|
133 |
# Update default model name
|
134 |
+
model = processing_params.get("processing_model") or "gemini-2.0-flash"
|
135 |
+
user_prompt = processing_params.get("prompt", "") # Default to empty string if missing
|
136 |
+
context = processing_params.get("context", [0]) # Default to [0]
|
137 |
+
doc_process_histories = state.get("docs_in_processing", []) # Default to empty list
|
138 |
+
|
139 |
+
if not doc_process_histories:
|
140 |
+
print("Warning: docs_in_processing is empty in custom_process.")
|
141 |
+
# No docs to process, just increment step counter
|
142 |
+
return {
|
143 |
+
"docs_in_processing": [],
|
144 |
+
"current_process_step": current_step_index + 1,
|
145 |
+
"process_steps": process_steps_list
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
if not isinstance(context, list):
|
150 |
context = [context]
|
151 |
|
|
|
155 |
for doc_process_history in doc_process_histories:
|
156 |
context_str = ""
|
157 |
for i, context_element in enumerate(context):
|
158 |
+
# Check if index is valid and history is long enough
|
159 |
+
if isinstance(context_element, int) and 0 <= context_element < len(doc_process_history):
|
160 |
+
# Use str() to ensure context element is string
|
161 |
+
context_str += f"### TECHNICAL INFORMATION {i+1} \n {str(doc_process_history[context_element])}\n\n"
|
162 |
+
else:
|
163 |
+
print(f"Warning: Invalid context index {context_element} for doc_process_history length {len(doc_process_history)}")
|
164 |
+
# Use str() ensure input is string
|
165 |
+
doc_process_history.append(processing_chain.invoke(str(context_str + user_prompt)))
|
166 |
+
|
167 |
+
# --- MODIFICATION START ---
|
168 |
+
# Ensure 'process_steps' persists in the state
|
169 |
+
return {
|
170 |
+
"docs_in_processing" : doc_process_histories,
|
171 |
+
"current_process_step" : current_step_index + 1,
|
172 |
+
"process_steps": process_steps_list # Pass existing steps along
|
173 |
+
}
|
174 |
+
# --- MODIFICATION END ---
|
175 |
+
|
176 |
+
|
177 |
+
# ... (final node remains the same) ...
|
178 |
def final(state: DocProcessorState):
|
179 |
"""
|
180 |
+
A node to store the final results of processing in the 'valid_docs' field
|
181 |
"""
|
182 |
+
# Ensure docs_in_processing exists and is a list of lists
|
183 |
+
docs_in_processing = state.get("docs_in_processing", [])
|
184 |
+
if not isinstance(docs_in_processing, list):
|
185 |
+
docs_in_processing = []
|
186 |
+
|
187 |
+
# Safely get the last item from each inner list, default to None if empty
|
188 |
+
final_docs = [
|
189 |
+
doc_history[-1] if isinstance(doc_history, list) and doc_history else None
|
190 |
+
for doc_history in docs_in_processing
|
191 |
+
]
|
192 |
+
# Filter out any None values that might have resulted from empty histories
|
193 |
+
valid_final_docs = [doc for doc in final_docs if doc is not None]
|
194 |
+
|
195 |
+
return {"valid_docs" : valid_final_docs}
|
196 |
+
|
197 |
|
|
|
198 |
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
|
199 |
"""
|
200 |
+
Initializes the processing state within the subgraph.
|
201 |
+
It receives the 'valid_docs' and potentially 'process_steps' from the parent graph.
|
202 |
"""
|
203 |
+
# Initialize docs_in_processing based on valid_docs received from parent state
|
204 |
+
valid_docs = state.get("valid_docs", [])
|
205 |
+
docs_in_processing_init = [[format_doc(doc)] for doc in valid_docs if doc] # Ensure doc is not None
|
206 |
+
|
207 |
+
# Explicitly return process_steps, getting it from the input state or defaulting
|
208 |
+
return {
|
209 |
+
"current_process_step": 0,
|
210 |
+
"docs_in_processing": docs_in_processing_init,
|
211 |
+
"process_steps": state.get("process_steps", []) # Ensure process_steps is set here
|
212 |
+
}
|
213 |
|
214 |
|
215 |
def next_processor_step(state: DocProcessorState):
|
216 |
"""
|
217 |
Conditional edge function to go to next processing step
|
218 |
"""
|
219 |
+
# --- MODIFICATION START ---
|
220 |
+
# Use .get() for safer access to 'process_steps' and 'current_process_step'
|
221 |
+
process_steps = state.get("process_steps", [])
|
222 |
+
current_step_index = state.get("current_process_step", 0)
|
223 |
+
# --- MODIFICATION END ---
|
224 |
+
|
225 |
+
step = "final" # Default to final step
|
226 |
+
|
227 |
+
if not isinstance(process_steps, list):
|
228 |
+
print(f"Warning: 'process_steps' is not a list ({type(process_steps)}). Proceeding to final.")
|
229 |
+
process_steps = [] # Treat as empty list
|
230 |
+
|
231 |
+
if current_step_index < len(process_steps):
|
232 |
+
step_definition = process_steps[current_step_index]
|
233 |
+
if isinstance(step_definition, dict):
|
234 |
+
# Assuming a dict means a 'custom' step based on original logic
|
235 |
step = "custom"
|
236 |
+
elif isinstance(step_definition, str):
|
237 |
+
# Map string to node name if it's a known processing type
|
238 |
+
step_lower = step_definition.lower()
|
239 |
+
if step_lower in ["summarize", "compress"]:
|
240 |
+
step = step_lower
|
241 |
+
else:
|
242 |
+
print(f"Warning: Unknown process step type string '{step_definition}'. Defaulting to 'custom'.")
|
243 |
+
# Or default to 'final' if unknown steps shouldn't run custom logic
|
244 |
+
# step = "final"
|
245 |
+
# Let's assume unknown strings map to custom for flexibility, adjust if needed
|
246 |
+
step = "custom"
|
247 |
+
else:
|
248 |
+
print(f"Warning: Invalid type for process step definition: {type(step_definition)}. Proceeding to final.")
|
249 |
+
step = "final" # Go to final if step definition is unexpected type
|
250 |
else:
|
251 |
+
# If current_step_index is out of bounds, we should go to the final step
|
252 |
step = "final"
|
253 |
|
254 |
+
print(f"Next processor step determined: {step}") # Debugging print
|
255 |
return step
|
256 |
|
257 |
|
258 |
+
|
259 |
def build_data_processor_graph(memory):
|
260 |
"""
|
261 |
Builds the data processor graph
|
262 |
"""
|
263 |
+
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
264 |
|
265 |
graph_builder_doc_processor = StateGraph(DocProcessorState)
|
266 |
|
|
|
271 |
graph_builder_doc_processor.add_node("final", final)
|
272 |
|
273 |
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
|
274 |
+
|
275 |
+
# Conditional edges route FROM the node that just finished TO the next one
|
276 |
graph_builder_doc_processor.add_conditional_edges(
|
277 |
+
"get_process_steps", # Source node
|
278 |
+
next_processor_step, # Function to decide where to go next
|
279 |
+
# Map returned string from next_processor_step to actual node names
|
280 |
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
|
281 |
)
|
282 |
graph_builder_doc_processor.add_conditional_edges(
|
283 |
+
"summarize", # Source node
|
284 |
next_processor_step,
|
285 |
+
{"compress" : "compress", "final": "final", "custom" : "custom", "summarize": "summarize"} # Added summarize for loops
|
286 |
)
|
287 |
graph_builder_doc_processor.add_conditional_edges(
|
288 |
+
"compress", # Source node
|
289 |
next_processor_step,
|
290 |
+
{"summarize" : "summarize", "final": "final", "custom" : "custom", "compress": "compress"} # Added compress for loops
|
291 |
)
|
292 |
graph_builder_doc_processor.add_conditional_edges(
|
293 |
+
"custom", # Source node
|
294 |
next_processor_step,
|
295 |
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
|
296 |
)
|
297 |
graph_builder_doc_processor.add_edge("final", "__end__")
|
298 |
+
|
299 |
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
|
300 |
return graph_doc_processor
|