avfranco commited on
Commit
5d3f8a0
·
1 Parent(s): 162ed68

ea4all-gradio-agents-mcp-hackathon-tools-refactoring-vqa

Browse files
ea4all/ea4all_mcp.py CHANGED
@@ -13,6 +13,9 @@ import ea4all.src.ea4all_vqa.graph as e4v
13
  import ea4all.src.ea4all_gra.graph as e4t
14
  import ea4all.src.shared.utils as e4u
15
  from ea4all.src.graph import super_graph
 
 
 
16
  #from ea4all.src.pmo_crew.crew_runner import run_pmo_crew
17
 
18
  from typing import AsyncGenerator
@@ -66,100 +69,11 @@ async def run_qna_agentic_system(question: str) -> AsyncGenerator[list, None]:
66
  if not question:
67
  format_response = "Hi, how are you today? To start using the EA4ALL MCP Tool, provide the required Inputs!"
68
  chat_memory.append(ChatMessage(role="assistant", content=format_response))
69
- yield chat_memory
70
-
71
- if not chat_memory:
72
- chat_memory.append(ChatMessage(role="user", content=question))
73
- yield chat_memory
74
-
75
- if question:
76
- #capture user ip
77
- #ea4all_user = e4u.get_user_identification(request)
78
-
79
- ##Initialise APM Graph
80
- #apm_graph = e4a.apm_graph
81
- #inputs = {"question": question, "chat_memory":chat_memory}
82
- inputs = {"messages": [{"role": "user", "content": question}]}
83
-
84
- #add question to memory
85
- chat_memory.append(ChatMessage(role="user", content=question))
86
 
87
- partial_message = ""
88
- async for event in super_graph.astream_events(input=inputs, config=config, version="v2"):
89
- #async for event in super_graph.astream(input=inputs, config=config, subgraphs=True):
90
- # chat_memory.append(ChatMessage(role="assistant", content=str(event)))
91
- # yield chat_memory
92
-
93
- kind = event["event"]
94
- tags = event.get("tags", [])
95
- name = event['name']
96
-
97
- #chat_memory.append(ChatMessage(role="assistant", content=f"Running: {name}"))
98
- #yield chat_memory
99
-
100
- if name == "safety_check":
101
- #if kind == "on_chain_start":
102
- # chat_memory.append(ChatMessage(role="assistant", content=f"- `{name}`"))
103
- # yield chat_memory
104
- if kind == "on_chain_stream":
105
- chunk = event['data'].get('chunk')
106
- if chunk and 'safety_status' in chunk and len(chunk['safety_status']) > 0:
107
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{name}`: {chunk['safety_status'][0]}"))
108
- if chunk['safety_status'][0] == 'no' and len(chunk['safety_status']) > 1:
109
- chat_memory.append(ChatMessage(role="assistant", content=f"Safety-status: {chunk['safety_status'][1]}"))
110
- yield chat_memory
111
- if kind == "on_chain_end" and name == "route_question":
112
- output = event['data'].get('output')
113
- if output and 'source' in output:
114
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{name}:` {output['source']}"))
115
- else:
116
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{name}:` (no source available)"))
117
- yield chat_memory
118
- if kind == "on_chain_start" and name == "retrieve":
119
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{name}` RAG\n\n"))
120
- yield chat_memory
121
- if kind == "on_chain_start" and name in ("generate_web_search", "websearch", "stream_generation"):
122
- chat_memory.append(ChatMessage(role="assistant", content= f"\n\n- `{name}`\n\n"))
123
- yield chat_memory
124
- if kind == "on_chain_stream" and name == "stream_generation":
125
- data = event["data"]
126
- # Accumulate the chunk of data
127
- partial_message += data.get('chunk', '')
128
- chat_memory[-1].content = partial_message
129
- time.sleep(0.05)
130
- yield chat_memory
131
- if name == "grade_generation_v_documents_and_question":
132
- if kind == "on_chain_start":
133
- chat_memory.append(ChatMessage(role="assistant", content=f"\n\n- `{name}`: "))
134
- yield chat_memory
135
- if kind == "on_chain_end":
136
- input_data = event['data'].get('input')
137
- if input_data and hasattr(input_data, 'source'):
138
- output_value = event['data'].get('output', '')
139
- chat_memory.append(ChatMessage(role="assistant", content=f"`{input_data.source}:` {output_value}"))
140
- else:
141
- chat_memory.append(ChatMessage(role="assistant", content=f"`{event['data'].get('output', '')}`"))
142
- yield chat_memory
143
- if "stream_hallucination" in tags and kind == "on_chain_start":
144
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{tags[-1]}`"))
145
- yield chat_memory
146
- if "stream_grade_answer" in tags and kind == "on_chain_start":
147
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{tags[-1]}`"))
148
- yield chat_memory
149
- if name == "supervisor":
150
- if kind == "on_chain_start":
151
- chat_memory.append(ChatMessage(role="assistant", content=f"- `{name}` "))
152
- yield chat_memory
153
- if kind == "on_chain_stream":
154
- chunk = event['data'].get('chunk')
155
- if chunk is not None:
156
- chat_memory.append(ChatMessage(role="assistant", content=f"{chunk}"))
157
- yield chat_memory
158
-
159
- # Set environment variable only when 'event' is defined
160
- #os.environ["EA4ALL_" + ea4all_user.replace(".", "_")] = str(event['run_id'])
161
-
162
- wait_for_all_tracers()
163
 
164
  #Trigger Solution Architecture Diagram QnA
165
  async def run_vqa_agentic_system(question: str, diagram: str, request: gr.Request) -> AsyncGenerator[list, None]:
@@ -202,39 +116,13 @@ async def run_vqa_agentic_system(question: str, diagram: str, request: gr.Reques
202
  with Image.open(diagram) as diagram_:
203
  if diagram_.format not in allowed_file_types:
204
  chat_memory.append(ChatMessage(role="assistant", content="Invalid file type. Allowed file types are JPEG and PNG."))
205
- yield chat_memory
206
  else:
207
  #'vqa_image = e4u.get_raw_image(diagram) #MOVED into Graph
208
  vqa_image = diagram
209
-
210
- #Setup Quality Assurance Agentic System
211
- #graph = e4v.ea4all_graph(config['configurable']['vqa_model'])
212
-
213
- #Setup enter graph
214
- diagram_graph = e4v.diagram_graph
215
-
216
- partial_message = ""
217
- chat_memory.append(ChatMessage(role="assistant", content="Hi, I am working on your question..."))
218
- async for event in diagram_graph.astream_events(
219
- {"question":msg, "image": vqa_image}, config, version="v2"
220
- ):
221
- if (
222
- event["event"] == "on_chat_model_stream"
223
- and "vqa_stream" in event.get('tags', [])
224
- #and event["metadata"].get("langgraph_node") == "tools"
225
- ):
226
- chunk = event["data"].get("chunk")
227
- if chunk is not None and hasattr(chunk, "content"):
228
- partial_message += chunk.content
229
- chat_memory[-1].content = partial_message
230
- time.sleep(e4u.CFG.STREAM_SLEEP)
231
- yield chat_memory #, message to update question
232
- elif not partial_message:
233
- yield chat_memory #, message
234
-
235
- #os.environ["EA4ALL_" + ea4all_user.replace(".", "_")] = str(event['run_id'])
236
-
237
- wait_for_all_tracers()
238
 
239
  except Exception as e:
240
  yield (e.args[-1])
 
13
  import ea4all.src.ea4all_gra.graph as e4t
14
  import ea4all.src.shared.utils as e4u
15
  from ea4all.src.graph import super_graph
16
+ from ea4all.src.ea4all_apm.graph import apm_graph
17
+ from ea4all.src.ea4all_vqa.graph import diagram_graph
18
+
19
  #from ea4all.src.pmo_crew.crew_runner import run_pmo_crew
20
 
21
  from typing import AsyncGenerator
 
69
  if not question:
70
  format_response = "Hi, how are you today? To start using the EA4ALL MCP Tool, provide the required Inputs!"
71
  chat_memory.append(ChatMessage(role="assistant", content=format_response))
72
+ else:
73
+ response = await apm_graph.ainvoke({"question": question}, config=config)
74
+ chat_memory.append(ChatMessage(role="assistant", content=response['generation']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ yield chat_memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  #Trigger Solution Architecture Diagram QnA
79
  async def run_vqa_agentic_system(question: str, diagram: str, request: gr.Request) -> AsyncGenerator[list, None]:
 
116
  with Image.open(diagram) as diagram_:
117
  if diagram_.format not in allowed_file_types:
118
  chat_memory.append(ChatMessage(role="assistant", content="Invalid file type. Allowed file types are JPEG and PNG."))
 
119
  else:
120
  #'vqa_image = e4u.get_raw_image(diagram) #MOVED into Graph
121
  vqa_image = diagram
122
+ response = await diagram_graph.ainvoke({"question":msg, "image": vqa_image}, config)
123
+ chat_memory.append(ChatMessage(role="assistant", content=response['messages'][-1].content))
124
+
125
+ yield chat_memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  except Exception as e:
128
  yield (e.args[-1])
ea4all/src/ea4all_vqa/graph.py CHANGED
@@ -20,7 +20,8 @@ from langchain_core.language_models.chat_models import BaseChatModel
20
  from langchain_core.messages import (
21
  AIMessage,
22
  HumanMessage,
23
- ToolMessage
 
24
  )
25
 
26
  #pydantic
@@ -73,7 +74,7 @@ class DiagramV2S(BaseModel):
73
  isSafe: bool = Field(...,description="Should be True if image or question are safe to be processed, False otherwise")
74
  description: str = Field(description="Should be a string describing the image title.")
75
 
76
- @tool("vqa_diagram")
77
  @spaces.GPU
78
  async def vqa_diagram(next:str, state: Annotated[OverallState, InjectedState], config: RunnableConfig):
79
  """Diagram Vision Question Answering"""
@@ -100,9 +101,8 @@ async def vqa_diagram(next:str, state: Annotated[OverallState, InjectedState], c
100
  },
101
  ],
102
  )
103
-
104
  prompt = ChatPromptTemplate.from_messages([user_message])
105
- values = {"question:":question}
106
 
107
  llm.max_tokens = set_max_new_tokens(get_predicted_num_tokens_from_prompt(llm, prompt, values))
108
  chain = prompt | llm
@@ -113,7 +113,8 @@ async def vqa_diagram(next:str, state: Annotated[OverallState, InjectedState], c
113
 
114
  response = await chain.ainvoke(input=values, config={"tags": ["vqa_stream"]}, kwargs={"max_tokens": configuration.vqa_max_tokens})
115
 
116
- return response
 
117
 
118
  ##Supervisor Agent Function custom parse with tool calling response support
119
  def parse(output: ToolMessage) -> dict | AgentFinish:
 
20
  from langchain_core.messages import (
21
  AIMessage,
22
  HumanMessage,
23
+ ToolMessage,
24
+ BaseMessage
25
  )
26
 
27
  #pydantic
 
74
  isSafe: bool = Field(...,description="Should be True if image or question are safe to be processed, False otherwise")
75
  description: str = Field(description="Should be a string describing the image title.")
76
 
77
+ @tool("vqa_diagram", response_format="content")
78
  @spaces.GPU
79
  async def vqa_diagram(next:str, state: Annotated[OverallState, InjectedState], config: RunnableConfig):
80
  """Diagram Vision Question Answering"""
 
101
  },
102
  ],
103
  )
 
104
  prompt = ChatPromptTemplate.from_messages([user_message])
105
+ values = {"question": question}
106
 
107
  llm.max_tokens = set_max_new_tokens(get_predicted_num_tokens_from_prompt(llm, prompt, values))
108
  chain = prompt | llm
 
113
 
114
  response = await chain.ainvoke(input=values, config={"tags": ["vqa_stream"]}, kwargs={"max_tokens": configuration.vqa_max_tokens})
115
 
116
+ ## When exposed as MCP tool, output schema should as simple as possible as output is serialized to a single string
117
+ return response.content
118
 
119
  ##Supervisor Agent Function custom parse with tool calling response support
120
  def parse(output: ToolMessage) -> dict | AgentFinish:
ea4all/src/ea4all_vqa/state.py CHANGED
@@ -15,7 +15,6 @@ from typing import (
15
  from langchain_core.messages import (
16
  BaseMessage,
17
  )
18
-
19
  from langgraph.graph import MessagesState
20
 
21
  # Optional, the InputState is a restricted version of the State that is used to
@@ -44,21 +43,22 @@ class OutputState:
44
  """Represents the output schema for the Diagram agent.
45
  """
46
 
47
- answer: str
48
- """Answer to user's question about the Architectural Diagram."""
 
 
 
 
 
49
 
50
  @dataclass(kw_only=True)
51
- class OverallState(InputState):
52
  """Represents the overall state of the Diagram graph."""
53
 
54
  """Attributes:
55
- messages: list of messages
56
- safety_status: safety status of the diagram provided by the user
57
  error: tool error
58
  next: next tool to be called
59
  """
60
 
61
- messages: Optional[Annotated[Sequence[BaseMessage], operator.add]] = None
62
- safety_status: Optional[bool] = None
63
  error: Optional[str] = None
64
  next: Optional[str] = None
 
15
  from langchain_core.messages import (
16
  BaseMessage,
17
  )
 
18
  from langgraph.graph import MessagesState
19
 
20
  # Optional, the InputState is a restricted version of the State that is used to
 
43
  """Represents the output schema for the Diagram agent.
44
  """
45
 
46
+ messages: Optional[Annotated[Sequence[MessagesState], operator.add]] = None
47
+ safety_status: Optional[bool] = None
48
+
49
+ """Attributes:
50
+ safety_status: safety status of the diagram provided by the user
51
+ Answer to user's question about the Architectural Diagram.
52
+ """
53
 
54
  @dataclass(kw_only=True)
55
+ class OverallState(InputState, OutputState):
56
  """Represents the overall state of the Diagram graph."""
57
 
58
  """Attributes:
 
 
59
  error: tool error
60
  next: next tool to be called
61
  """
62
 
 
 
63
  error: Optional[str] = None
64
  next: Optional[str] = None