File size: 12,830 Bytes
70d06c8
 
 
 
 
 
42c00ab
 
 
 
 
 
70d06c8
 
 
42c00ab
66a2ae2
70d06c8
 
 
66a2ae2
70d06c8
 
 
 
 
 
 
 
 
 
 
66a2ae2
70d06c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66a2ae2
70d06c8
 
 
 
 
66a2ae2
 
 
 
 
 
 
 
 
70d06c8
42c00ab
70d06c8
 
 
 
 
 
 
 
 
 
 
 
 
42c00ab
66a2ae2
70d06c8
42c00ab
66a2ae2
70d06c8
 
 
66a2ae2
70d06c8
 
66a2ae2
 
 
 
 
 
 
 
70d06c8
42c00ab
70d06c8
 
 
 
 
 
 
66a2ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d06c8
42c00ab
66a2ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d06c8
 
 
42c00ab
70d06c8
 
 
 
 
66a2ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d06c8
 
66a2ae2
70d06c8
66a2ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d06c8
 
 
66a2ae2
 
70d06c8
66a2ae2
 
 
 
 
 
 
 
 
 
70d06c8
 
 
 
 
 
66a2ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d06c8
66a2ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
70d06c8
66a2ae2
70d06c8
 
66a2ae2
70d06c8
 
 
66a2ae2
70d06c8
 
 
 
66a2ae2
70d06c8
 
 
 
 
 
 
 
 
 
66a2ae2
 
70d06c8
66a2ae2
 
 
70d06c8
 
 
66a2ae2
70d06c8
66a2ae2
70d06c8
 
66a2ae2
70d06c8
66a2ae2
70d06c8
 
66a2ae2
70d06c8
 
 
 
66a2ae2
70d06c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
#!/usr/bin/env python
# coding: utf-8

from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
# Remove ChatGroq import
# from langchain_groq import ChatGroq
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI
import os # Add os import for getenv

from langgraph.graph import StateGraph
from llmlingua import PromptCompressor

# Import get_model which now handles Gemini
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
from langgraph.checkpoint.sqlite import SqliteSaver


# ... (get_llm_lingua remains the same) ...
# Requires ~2GB of RAM
def get_llm_lingua(compress_method:str = "llm_lingua2"):

    # Requires ~2GB memory
    if compress_method == "llm_lingua2":
        llm_lingua2 = PromptCompressor(
            model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
            use_llmlingua2=True,
            device_map="cpu"
        )
        return llm_lingua2

    # Requires ~8GB memory
    elif compress_method == "llm_lingua":
        llm_lingua = PromptCompressor(
            model_name="microsoft/phi-2",
            device_map="cpu"
        )
        return llm_lingua
    raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")


def compress(state: DocProcessorState, config: ConfigSchema):
    """
    This node compresses last processing result for each doc using llm_lingua
    """
    doc_process_histories = state["docs_in_processing"]
    llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
    for doc_process_history in doc_process_histories:
        doc_process_history.append(llm_lingua.compress_prompt(
            # Use str() to ensure the input is string, handle potential non-string data
            doc = str(doc_process_history[-1]),
            rate=config["configurable"].get("compress_rate") or 0.33,
            force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
            )["compressed_prompt"]
        )

    # --- MODIFICATION START ---
    # Ensure 'process_steps' persists in the state
    return {
        "docs_in_processing": doc_process_histories,
        "current_process_step" : state["current_process_step"] + 1,
        "process_steps": state.get("process_steps", []) # Pass existing steps along
        }
    # --- MODIFICATION END ---

# Update default model
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
    """
    This node summarizes all docs in state["valid_docs"]
    """

    prompt = """You are a 3GPP standardization expert.
Summarize the provided document in simple technical English for other experts in the field.

Document:
{document}"""
    sysmsg = ChatPromptTemplate.from_messages([
        ("system", prompt)
    ])
    # Update default model name
    model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
    doc_process_histories = state["docs_in_processing"]
    # Use get_model to handle instantiation
    llm_summarize = get_model(model)
    summarize_chain = sysmsg | llm_summarize | StrOutputParser()

    for doc_process_history in doc_process_histories:
        # Use str() to ensure the input is string
        doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))

    # --- MODIFICATION START ---
    # Ensure 'process_steps' persists in the state
    return {
        "docs_in_processing": doc_process_histories,
        "current_process_step": state["current_process_step"] + 1,
        "process_steps": state.get("process_steps", []) # Pass existing steps along
        }
    # --- MODIFICATION END ---

# Update default model
def custom_process(state: DocProcessorState):
    """
    Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
    processing_model : the LLM which will perform the processing
    context : the previous processing results to send as context to the LLM
    user_prompt : the prompt/task which will be appended to the context before sending to the LLM
    """
    # Use .get() for safer access in case state mapping fails earlier
    process_steps_list = state.get("process_steps", [])
    current_step_index = state.get("current_process_step", 0)

    if not process_steps_list or current_step_index >= len(process_steps_list):
         print(f"Error: Invalid current_process_step ({current_step_index}) or empty process_steps in custom_process.")
         # Return state, potentially adding an error indicator if needed
         return state # Or modify state to indicate error

    processing_params = process_steps_list[current_step_index]

    # Ensure processing_params is a dict before accessing keys
    if not isinstance(processing_params, dict):
        print(f"Error: Expected dictionary for process step {current_step_index}, but got {type(processing_params)}. Step details: {processing_params}")
        # Decide how to handle: skip step, return error state?
        # For now, let's skip this step by incrementing the counter and returning
        return {
             "docs_in_processing": state.get("docs_in_processing", []),
             "current_process_step": current_step_index + 1,
             "process_steps": process_steps_list
             }


    # Update default model name
    model = processing_params.get("processing_model") or "gemini-2.0-flash"
    user_prompt = processing_params.get("prompt", "") # Default to empty string if missing
    context = processing_params.get("context", [0]) # Default to [0]
    doc_process_histories = state.get("docs_in_processing", []) # Default to empty list

    if not doc_process_histories:
         print("Warning: docs_in_processing is empty in custom_process.")
         # No docs to process, just increment step counter
         return {
             "docs_in_processing": [],
             "current_process_step": current_step_index + 1,
             "process_steps": process_steps_list
             }


    if not isinstance(context, list):
        context = [context]

    # Use get_model
    processing_chain = get_model(model=model) | StrOutputParser()

    for doc_process_history in doc_process_histories:
        context_str = ""
        for i, context_element in enumerate(context):
             # Check if index is valid and history is long enough
            if isinstance(context_element, int) and 0 <= context_element < len(doc_process_history):
                 # Use str() to ensure context element is string
                context_str += f"### TECHNICAL INFORMATION {i+1} \n {str(doc_process_history[context_element])}\n\n"
            else:
                 print(f"Warning: Invalid context index {context_element} for doc_process_history length {len(doc_process_history)}")
        # Use str() ensure input is string
        doc_process_history.append(processing_chain.invoke(str(context_str + user_prompt)))

    # --- MODIFICATION START ---
    # Ensure 'process_steps' persists in the state
    return {
        "docs_in_processing" : doc_process_histories,
        "current_process_step" : current_step_index + 1,
        "process_steps": process_steps_list # Pass existing steps along
        }
    # --- MODIFICATION END ---


# ... (final node remains the same) ...
def final(state: DocProcessorState):
    """
    A node to store the final results of processing in the 'valid_docs' field
    """
    # Ensure docs_in_processing exists and is a list of lists
    docs_in_processing = state.get("docs_in_processing", [])
    if not isinstance(docs_in_processing, list):
        docs_in_processing = []

    # Safely get the last item from each inner list, default to None if empty
    final_docs = [
        doc_history[-1] if isinstance(doc_history, list) and doc_history else None
        for doc_history in docs_in_processing
    ]
    # Filter out any None values that might have resulted from empty histories
    valid_final_docs = [doc for doc in final_docs if doc is not None]

    return {"valid_docs" : valid_final_docs}


def get_process_steps(state: DocProcessorState, config: ConfigSchema):
    """
    Initializes the processing state within the subgraph.
    It receives the 'valid_docs' and potentially 'process_steps' from the parent graph.
    """
    # Initialize docs_in_processing based on valid_docs received from parent state
    valid_docs = state.get("valid_docs", [])
    docs_in_processing_init = [[format_doc(doc)] for doc in valid_docs if doc] # Ensure doc is not None

    # Explicitly return process_steps, getting it from the input state or defaulting
    return {
        "current_process_step": 0,
        "docs_in_processing": docs_in_processing_init,
        "process_steps": state.get("process_steps", []) # Ensure process_steps is set here
    }


def next_processor_step(state: DocProcessorState):
    """
    Conditional edge function to go to next processing step
    """
    # --- MODIFICATION START ---
    # Use .get() for safer access to 'process_steps' and 'current_process_step'
    process_steps = state.get("process_steps", [])
    current_step_index = state.get("current_process_step", 0)
    # --- MODIFICATION END ---

    step = "final" # Default to final step

    if not isinstance(process_steps, list):
        print(f"Warning: 'process_steps' is not a list ({type(process_steps)}). Proceeding to final.")
        process_steps = [] # Treat as empty list

    if current_step_index < len(process_steps):
        step_definition = process_steps[current_step_index]
        if isinstance(step_definition, dict):
            # Assuming a dict means a 'custom' step based on original logic
            step = "custom"
        elif isinstance(step_definition, str):
            # Map string to node name if it's a known processing type
            step_lower = step_definition.lower()
            if step_lower in ["summarize", "compress"]:
                step = step_lower
            else:
                 print(f"Warning: Unknown process step type string '{step_definition}'. Defaulting to 'custom'.")
                 # Or default to 'final' if unknown steps shouldn't run custom logic
                 # step = "final"
                 # Let's assume unknown strings map to custom for flexibility, adjust if needed
                 step = "custom"
        else:
            print(f"Warning: Invalid type for process step definition: {type(step_definition)}. Proceeding to final.")
            step = "final" # Go to final if step definition is unexpected type
    else:
        # If current_step_index is out of bounds, we should go to the final step
        step = "final"

    print(f"Next processor step determined: {step}") # Debugging print
    return step



def build_data_processor_graph(memory):
    """
    Builds the data processor graph
    """
    #with SqliteSaver.from_conn_string(":memory:") as memory :

    graph_builder_doc_processor = StateGraph(DocProcessorState)

    graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
    graph_builder_doc_processor.add_node("summarize", summarize_docs)
    graph_builder_doc_processor.add_node("compress", compress)
    graph_builder_doc_processor.add_node("custom", custom_process)
    graph_builder_doc_processor.add_node("final", final)

    graph_builder_doc_processor.add_edge("__start__", "get_process_steps")

    # Conditional edges route FROM the node that just finished TO the next one
    graph_builder_doc_processor.add_conditional_edges(
        "get_process_steps", # Source node
        next_processor_step, # Function to decide where to go next
        # Map returned string from next_processor_step to actual node names
        {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
    )
    graph_builder_doc_processor.add_conditional_edges(
        "summarize", # Source node
        next_processor_step,
        {"compress" : "compress", "final": "final", "custom" : "custom", "summarize": "summarize"} # Added summarize for loops
    )
    graph_builder_doc_processor.add_conditional_edges(
        "compress", # Source node
        next_processor_step,
        {"summarize" : "summarize", "final": "final", "custom" : "custom", "compress": "compress"} # Added compress for loops
    )
    graph_builder_doc_processor.add_conditional_edges(
        "custom", # Source node
        next_processor_step,
        {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
    )
    graph_builder_doc_processor.add_edge("final", "__end__")

    graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
    return graph_doc_processor