adrienbrdne commited on
Commit
66a2ae2
·
verified ·
1 Parent(s): dbe6919

Update ki_gen/data_processor.py

Browse files
Files changed (1) hide show
  1. ki_gen/data_processor.py +154 -43
ki_gen/data_processor.py CHANGED
@@ -14,13 +14,11 @@ from langgraph.graph import StateGraph
14
  from llmlingua import PromptCompressor
15
 
16
  # Import get_model which now handles Gemini
17
- from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
18
  from langgraph.checkpoint.sqlite import SqliteSaver
19
 
20
 
21
- # ... (rest of the imports and llm_lingua functions remain the same)
22
-
23
-
24
  # Requires ~2GB of RAM
25
  def get_llm_lingua(compress_method:str = "llm_lingua2"):
26
 
@@ -32,7 +30,7 @@ def get_llm_lingua(compress_method:str = "llm_lingua2"):
32
  device_map="cpu"
33
  )
34
  return llm_lingua2
35
-
36
  # Requires ~8GB memory
37
  elif compress_method == "llm_lingua":
38
  llm_lingua = PromptCompressor(
@@ -43,7 +41,6 @@ def get_llm_lingua(compress_method:str = "llm_lingua2"):
43
  raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
44
 
45
 
46
-
47
  def compress(state: DocProcessorState, config: ConfigSchema):
48
  """
49
  This node compresses last processing result for each doc using llm_lingua
@@ -52,13 +49,21 @@ def compress(state: DocProcessorState, config: ConfigSchema):
52
  llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
53
  for doc_process_history in doc_process_histories:
54
  doc_process_history.append(llm_lingua.compress_prompt(
 
55
  doc = str(doc_process_history[-1]),
56
  rate=config["configurable"].get("compress_rate") or 0.33,
57
  force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
58
  )["compressed_prompt"]
59
  )
60
-
61
- return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
 
 
 
 
 
 
 
62
 
63
  # Update default model
64
  def summarize_docs(state: DocProcessorState, config: ConfigSchema):
@@ -75,16 +80,24 @@ Document:
75
  ("system", prompt)
76
  ])
77
  # Update default model name
78
- model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
79
  doc_process_histories = state["docs_in_processing"]
80
  # Use get_model to handle instantiation
81
- llm_summarize = get_model(model)
82
  summarize_chain = sysmsg | llm_summarize | StrOutputParser()
83
 
84
  for doc_process_history in doc_process_histories:
 
85
  doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
86
 
87
- return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
 
 
 
 
 
 
 
88
 
89
  # Update default model
90
  def custom_process(state: DocProcessorState):
@@ -94,13 +107,45 @@ def custom_process(state: DocProcessorState):
94
  context : the previous processing results to send as context to the LLM
95
  user_prompt : the prompt/task which will be appended to the context before sending to the LLM
96
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- processing_params = state["process_steps"][state["current_process_step"]]
99
  # Update default model name
100
- model = processing_params.get("processing_model") or "gemini-2.0-flash"
101
- user_prompt = processing_params["prompt"]
102
- context = processing_params.get("context") or [0]
103
- doc_process_histories = state["docs_in_processing"]
 
 
 
 
 
 
 
 
 
 
 
104
  if not isinstance(context, list):
105
  context = [context]
106
 
@@ -110,49 +155,112 @@ def custom_process(state: DocProcessorState):
110
  for doc_process_history in doc_process_histories:
111
  context_str = ""
112
  for i, context_element in enumerate(context):
113
- context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
114
- doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
115
-
116
- return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
117
-
118
- # ... (rest of the file remains the same)
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def final(state: DocProcessorState):
121
  """
122
- A node to store the final results of processing in the 'valid_docs' field
123
  """
124
- return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- # TODO : remove this node and use conditional entry point instead
127
  def get_process_steps(state: DocProcessorState, config: ConfigSchema):
128
  """
129
- Dummy node
 
130
  """
131
- # if not process_steps:
132
- # process_steps = eval(input("Enter processing steps: "))
133
- return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
 
 
 
 
 
 
 
134
 
135
 
136
  def next_processor_step(state: DocProcessorState):
137
  """
138
  Conditional edge function to go to next processing step
139
  """
140
- process_steps = state["process_steps"]
141
- if state["current_process_step"] < len(process_steps):
142
- step = process_steps[state["current_process_step"]]
143
- if isinstance(step, dict):
 
 
 
 
 
 
 
 
 
 
 
 
144
  step = "custom"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  else:
 
146
  step = "final"
147
 
 
148
  return step
149
 
150
 
 
151
  def build_data_processor_graph(memory):
152
  """
153
  Builds the data processor graph
154
  """
155
- #with SqliteSaver.from_conn_string(":memory:") as memory :
156
 
157
  graph_builder_doc_processor = StateGraph(DocProcessorState)
158
 
@@ -163,27 +271,30 @@ def build_data_processor_graph(memory):
163
  graph_builder_doc_processor.add_node("final", final)
164
 
165
  graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
 
 
166
  graph_builder_doc_processor.add_conditional_edges(
167
- "get_process_steps",
168
- next_processor_step,
 
169
  {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
170
  )
171
  graph_builder_doc_processor.add_conditional_edges(
172
- "summarize",
173
  next_processor_step,
174
- {"compress" : "compress", "final": "final", "custom" : "custom"}
175
  )
176
  graph_builder_doc_processor.add_conditional_edges(
177
- "compress",
178
  next_processor_step,
179
- {"summarize" : "summarize", "final": "final", "custom" : "custom"}
180
  )
181
  graph_builder_doc_processor.add_conditional_edges(
182
- "custom",
183
  next_processor_step,
184
  {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
185
  )
186
  graph_builder_doc_processor.add_edge("final", "__end__")
187
-
188
  graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
189
  return graph_doc_processor
 
14
  from llmlingua import PromptCompressor
15
 
16
  # Import get_model which now handles Gemini
17
+ from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
18
  from langgraph.checkpoint.sqlite import SqliteSaver
19
 
20
 
21
+ # ... (get_llm_lingua remains the same) ...
 
 
22
  # Requires ~2GB of RAM
23
  def get_llm_lingua(compress_method:str = "llm_lingua2"):
24
 
 
30
  device_map="cpu"
31
  )
32
  return llm_lingua2
33
+
34
  # Requires ~8GB memory
35
  elif compress_method == "llm_lingua":
36
  llm_lingua = PromptCompressor(
 
41
  raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
42
 
43
 
 
44
  def compress(state: DocProcessorState, config: ConfigSchema):
45
  """
46
  This node compresses last processing result for each doc using llm_lingua
 
49
  llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
50
  for doc_process_history in doc_process_histories:
51
  doc_process_history.append(llm_lingua.compress_prompt(
52
+ # Use str() to ensure the input is string, handle potential non-string data
53
  doc = str(doc_process_history[-1]),
54
  rate=config["configurable"].get("compress_rate") or 0.33,
55
  force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
56
  )["compressed_prompt"]
57
  )
58
+
59
+ # --- MODIFICATION START ---
60
+ # Ensure 'process_steps' persists in the state
61
+ return {
62
+ "docs_in_processing": doc_process_histories,
63
+ "current_process_step" : state["current_process_step"] + 1,
64
+ "process_steps": state.get("process_steps", []) # Pass existing steps along
65
+ }
66
+ # --- MODIFICATION END ---
67
 
68
  # Update default model
69
  def summarize_docs(state: DocProcessorState, config: ConfigSchema):
 
80
  ("system", prompt)
81
  ])
82
  # Update default model name
83
+ model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
84
  doc_process_histories = state["docs_in_processing"]
85
  # Use get_model to handle instantiation
86
+ llm_summarize = get_model(model)
87
  summarize_chain = sysmsg | llm_summarize | StrOutputParser()
88
 
89
  for doc_process_history in doc_process_histories:
90
+ # Use str() to ensure the input is string
91
  doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
92
 
93
+ # --- MODIFICATION START ---
94
+ # Ensure 'process_steps' persists in the state
95
+ return {
96
+ "docs_in_processing": doc_process_histories,
97
+ "current_process_step": state["current_process_step"] + 1,
98
+ "process_steps": state.get("process_steps", []) # Pass existing steps along
99
+ }
100
+ # --- MODIFICATION END ---
101
 
102
  # Update default model
103
  def custom_process(state: DocProcessorState):
 
107
  context : the previous processing results to send as context to the LLM
108
  user_prompt : the prompt/task which will be appended to the context before sending to the LLM
109
  """
110
+ # Use .get() for safer access in case state mapping fails earlier
111
+ process_steps_list = state.get("process_steps", [])
112
+ current_step_index = state.get("current_process_step", 0)
113
+
114
+ if not process_steps_list or current_step_index >= len(process_steps_list):
115
+ print(f"Error: Invalid current_process_step ({current_step_index}) or empty process_steps in custom_process.")
116
+ # Return state, potentially adding an error indicator if needed
117
+ return state # Or modify state to indicate error
118
+
119
+ processing_params = process_steps_list[current_step_index]
120
+
121
+ # Ensure processing_params is a dict before accessing keys
122
+ if not isinstance(processing_params, dict):
123
+ print(f"Error: Expected dictionary for process step {current_step_index}, but got {type(processing_params)}. Step details: {processing_params}")
124
+ # Decide how to handle: skip step, return error state?
125
+ # For now, let's skip this step by incrementing the counter and returning
126
+ return {
127
+ "docs_in_processing": state.get("docs_in_processing", []),
128
+ "current_process_step": current_step_index + 1,
129
+ "process_steps": process_steps_list
130
+ }
131
+
132
 
 
133
  # Update default model name
134
+ model = processing_params.get("processing_model") or "gemini-2.0-flash"
135
+ user_prompt = processing_params.get("prompt", "") # Default to empty string if missing
136
+ context = processing_params.get("context", [0]) # Default to [0]
137
+ doc_process_histories = state.get("docs_in_processing", []) # Default to empty list
138
+
139
+ if not doc_process_histories:
140
+ print("Warning: docs_in_processing is empty in custom_process.")
141
+ # No docs to process, just increment step counter
142
+ return {
143
+ "docs_in_processing": [],
144
+ "current_process_step": current_step_index + 1,
145
+ "process_steps": process_steps_list
146
+ }
147
+
148
+
149
  if not isinstance(context, list):
150
  context = [context]
151
 
 
155
  for doc_process_history in doc_process_histories:
156
  context_str = ""
157
  for i, context_element in enumerate(context):
158
+ # Check if index is valid and history is long enough
159
+ if isinstance(context_element, int) and 0 <= context_element < len(doc_process_history):
160
+ # Use str() to ensure context element is string
161
+ context_str += f"### TECHNICAL INFORMATION {i+1} \n {str(doc_process_history[context_element])}\n\n"
162
+ else:
163
+ print(f"Warning: Invalid context index {context_element} for doc_process_history length {len(doc_process_history)}")
164
+ # Use str() ensure input is string
165
+ doc_process_history.append(processing_chain.invoke(str(context_str + user_prompt)))
166
+
167
+ # --- MODIFICATION START ---
168
+ # Ensure 'process_steps' persists in the state
169
+ return {
170
+ "docs_in_processing" : doc_process_histories,
171
+ "current_process_step" : current_step_index + 1,
172
+ "process_steps": process_steps_list # Pass existing steps along
173
+ }
174
+ # --- MODIFICATION END ---
175
+
176
+
177
+ # ... (final node remains the same) ...
178
  def final(state: DocProcessorState):
179
  """
180
+ A node to store the final results of processing in the 'valid_docs' field
181
  """
182
+ # Ensure docs_in_processing exists and is a list of lists
183
+ docs_in_processing = state.get("docs_in_processing", [])
184
+ if not isinstance(docs_in_processing, list):
185
+ docs_in_processing = []
186
+
187
+ # Safely get the last item from each inner list, default to None if empty
188
+ final_docs = [
189
+ doc_history[-1] if isinstance(doc_history, list) and doc_history else None
190
+ for doc_history in docs_in_processing
191
+ ]
192
+ # Filter out any None values that might have resulted from empty histories
193
+ valid_final_docs = [doc for doc in final_docs if doc is not None]
194
+
195
+ return {"valid_docs" : valid_final_docs}
196
+
197
 
 
198
  def get_process_steps(state: DocProcessorState, config: ConfigSchema):
199
  """
200
+ Initializes the processing state within the subgraph.
201
+ It receives the 'valid_docs' and potentially 'process_steps' from the parent graph.
202
  """
203
+ # Initialize docs_in_processing based on valid_docs received from parent state
204
+ valid_docs = state.get("valid_docs", [])
205
+ docs_in_processing_init = [[format_doc(doc)] for doc in valid_docs if doc] # Ensure doc is not None
206
+
207
+ # Explicitly return process_steps, getting it from the input state or defaulting
208
+ return {
209
+ "current_process_step": 0,
210
+ "docs_in_processing": docs_in_processing_init,
211
+ "process_steps": state.get("process_steps", []) # Ensure process_steps is set here
212
+ }
213
 
214
 
215
  def next_processor_step(state: DocProcessorState):
216
  """
217
  Conditional edge function to go to next processing step
218
  """
219
+ # --- MODIFICATION START ---
220
+ # Use .get() for safer access to 'process_steps' and 'current_process_step'
221
+ process_steps = state.get("process_steps", [])
222
+ current_step_index = state.get("current_process_step", 0)
223
+ # --- MODIFICATION END ---
224
+
225
+ step = "final" # Default to final step
226
+
227
+ if not isinstance(process_steps, list):
228
+ print(f"Warning: 'process_steps' is not a list ({type(process_steps)}). Proceeding to final.")
229
+ process_steps = [] # Treat as empty list
230
+
231
+ if current_step_index < len(process_steps):
232
+ step_definition = process_steps[current_step_index]
233
+ if isinstance(step_definition, dict):
234
+ # Assuming a dict means a 'custom' step based on original logic
235
  step = "custom"
236
+ elif isinstance(step_definition, str):
237
+ # Map string to node name if it's a known processing type
238
+ step_lower = step_definition.lower()
239
+ if step_lower in ["summarize", "compress"]:
240
+ step = step_lower
241
+ else:
242
+ print(f"Warning: Unknown process step type string '{step_definition}'. Defaulting to 'custom'.")
243
+ # Or default to 'final' if unknown steps shouldn't run custom logic
244
+ # step = "final"
245
+ # Let's assume unknown strings map to custom for flexibility, adjust if needed
246
+ step = "custom"
247
+ else:
248
+ print(f"Warning: Invalid type for process step definition: {type(step_definition)}. Proceeding to final.")
249
+ step = "final" # Go to final if step definition is unexpected type
250
  else:
251
+ # If current_step_index is out of bounds, we should go to the final step
252
  step = "final"
253
 
254
+ print(f"Next processor step determined: {step}") # Debugging print
255
  return step
256
 
257
 
258
+
259
  def build_data_processor_graph(memory):
260
  """
261
  Builds the data processor graph
262
  """
263
+ #with SqliteSaver.from_conn_string(":memory:") as memory :
264
 
265
  graph_builder_doc_processor = StateGraph(DocProcessorState)
266
 
 
271
  graph_builder_doc_processor.add_node("final", final)
272
 
273
  graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
274
+
275
+ # Conditional edges route FROM the node that just finished TO the next one
276
  graph_builder_doc_processor.add_conditional_edges(
277
+ "get_process_steps", # Source node
278
+ next_processor_step, # Function to decide where to go next
279
+ # Map returned string from next_processor_step to actual node names
280
  {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
281
  )
282
  graph_builder_doc_processor.add_conditional_edges(
283
+ "summarize", # Source node
284
  next_processor_step,
285
+ {"compress" : "compress", "final": "final", "custom" : "custom", "summarize": "summarize"} # Added summarize for loops
286
  )
287
  graph_builder_doc_processor.add_conditional_edges(
288
+ "compress", # Source node
289
  next_processor_step,
290
+ {"summarize" : "summarize", "final": "final", "custom" : "custom", "compress": "compress"} # Added compress for loops
291
  )
292
  graph_builder_doc_processor.add_conditional_edges(
293
+ "custom", # Source node
294
  next_processor_step,
295
  {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
296
  )
297
  graph_builder_doc_processor.add_edge("final", "__end__")
298
+
299
  graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
300
  return graph_doc_processor