ProximileAdmin commited on
Commit
b9c9c20
·
verified ·
1 Parent(s): 6811e03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +705 -189
app.py CHANGED
@@ -1,211 +1,727 @@
1
- import gradio as gr
2
- from gradio_client import Client
3
  import json
 
 
 
 
4
  import time
5
- import os
6
-
7
- # Set the API endpoint (can be customized)
8
- API_ENDPOINT = os.environ.get("LLADA_API_ENDPOINT", "http://127.0.0.1:7880")
9
-
10
- # Create the client
11
- client = Client(API_ENDPOINT)
12
-
13
- # Default generation parameters
14
- DEFAULT_GEN_LENGTH = 64
15
- DEFAULT_STEPS = 64
16
- DEFAULT_DELAY = 0.1
17
- DEFAULT_TEMPERATURE = 0.1
18
- DEFAULT_CFG_SCALE = 0
19
- DEFAULT_BLOCK_LENGTH = 32
20
- DEFAULT_REMASKING = "low_confidence"
21
-
22
- def clear_conversation():
23
- """Clear the conversation history"""
24
- result = client.predict(api_name="/clear_conversation")
25
- return [], "", None
26
-
27
- def send_user_message(message, history, waiting_for_tool, current_tool_call):
28
- """Send a user message to the API"""
29
- if not message.strip() or waiting_for_tool:
30
- return history, "", waiting_for_tool, current_tool_call
31
-
32
- # Add user message to history for display
33
- history = history + [[message, None]]
34
-
35
- # Send message to API
36
- result = client.predict(message, api_name="/user_message_submitted")
37
-
38
- # Process result - get bot response
39
- bot_result = client.predict(
40
- DEFAULT_GEN_LENGTH, DEFAULT_STEPS, DEFAULT_DELAY,
41
- DEFAULT_TEMPERATURE, DEFAULT_CFG_SCALE, DEFAULT_BLOCK_LENGTH,
42
- DEFAULT_REMASKING, api_name="/bot_response"
43
- )
44
-
45
- # Extract the response text and update history
46
- response_text = bot_result[2]
47
- history[-1][1] = response_text
48
-
49
- # Check if the response is a tool call
50
- is_tool_call = False
51
- tool_call = None
52
-
53
- try:
54
- # Check if response is a valid JSON tool call
55
- if (response_text.strip().startswith('{') and response_text.strip().endswith('}')) or \
56
- (response_text.strip().startswith('[') and response_text.strip().endswith(']')):
57
- json_data = json.loads(response_text)
58
-
59
- # Check if it has the tool call structure
60
- if isinstance(json_data, list):
61
- for item in json_data:
62
- if isinstance(item, dict) and "name" in item and "parameters" in item:
63
- is_tool_call = True
64
- tool_call = item
65
- break
66
- elif isinstance(json_data, dict) and "name" in json_data and "parameters" in json_data:
67
- is_tool_call = True
68
- tool_call = json_data
69
- except:
70
- # Not a valid JSON or tool call
71
- pass
72
-
73
- return history, "", is_tool_call, tool_call
74
 
75
- def generate_dummy_response(tool_call):
76
- """Generate a dummy response for a tool call"""
77
- if not tool_call:
78
- return ""
79
-
80
- result = client.predict(api_name="/generate_dummy_response")
81
- return result
 
82
 
83
- def submit_tool_response(tool_response, history, tool_call):
84
- """Submit a tool response to the API"""
85
- if not tool_response.strip() or not tool_call:
86
- return history, False, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Process the tool response
89
- result = client.predict(
90
- tool_response, DEFAULT_GEN_LENGTH, DEFAULT_STEPS, DEFAULT_DELAY,
91
- DEFAULT_TEMPERATURE, DEFAULT_CFG_SCALE, DEFAULT_BLOCK_LENGTH,
92
- DEFAULT_REMASKING, api_name="/process_tool_response"
93
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Extract the follow-up response
96
- response_text = result[2]
 
 
 
 
97
 
98
- # Add the response to history
99
- history = history + [["Tool response processed", response_text]]
100
 
101
- return history, False, None
 
 
 
 
102
 
103
- # Create the Gradio interface
104
- with gr.Blocks() as demo:
105
- gr.Markdown("# LLaDA Chat Client")
106
- gr.Markdown("This is a client application for the LLaDA Diffusion Model with Tool Calls.")
 
 
 
 
 
 
107
 
108
- # State variables
109
- chat_history = gr.State([])
110
- waiting_for_tool = gr.State(False)
111
- current_tool_call = gr.State(None)
112
 
113
- # Chat interface
114
- chatbot = gr.Chatbot(label="Conversation")
 
115
 
116
- # Main chat input
117
- with gr.Row():
118
- msg = gr.Textbox(
119
- show_label=False,
120
- placeholder="Enter your message here...",
121
- scale=7
122
- )
123
- submit_btn = gr.Button("Send", scale=1)
124
-
125
- # Add clickable examples
126
- gr.Examples(
127
- examples=["What is the weather like in NYC? Use Fahrenheit."],
128
- inputs=msg
129
- )
130
 
131
- # Tool response section (initially hidden)
132
- with gr.Group(visible=False) as tool_group:
133
- gr.Markdown("## Tool Call Detected")
134
- tool_name = gr.Textbox(label="Tool Name", interactive=False)
135
- tool_params = gr.JSON(label="Parameters")
136
- tool_response = gr.Textbox(
137
- label="Tool Response (JSON)",
138
- placeholder="Enter JSON response for the tool...",
139
- lines=5
140
- )
141
- with gr.Row():
142
- dummy_btn = gr.Button("Use Dummy Response")
143
- submit_tool_btn = gr.Button("Submit Tool Response")
144
 
145
- # Clear conversation button
146
- clear_btn = gr.Button("Clear Conversation")
147
 
148
- # Update UI based on tool call detection
149
- def update_tool_ui(is_tool_call, tool_call):
150
- if is_tool_call and tool_call:
151
- return (
152
- gr.update(visible=True),
153
- gr.update(value=tool_call.get("name", "")),
154
- gr.update(value=tool_call.get("parameters", {}))
155
- )
156
- else:
157
- return (
158
- gr.update(visible=False),
159
- gr.update(value=""),
160
- gr.update(value={})
161
- )
162
 
163
- # Connect components
164
- submit_btn.click(
165
- fn=send_user_message,
166
- inputs=[msg, chat_history, waiting_for_tool, current_tool_call],
167
- outputs=[chatbot, msg, waiting_for_tool, current_tool_call]
168
- ).then(
169
- fn=update_tool_ui,
170
- inputs=[waiting_for_tool, current_tool_call],
171
- outputs=[tool_group, tool_name, tool_params]
172
- )
173
 
174
- msg.submit(
175
- fn=send_user_message,
176
- inputs=[msg, chat_history, waiting_for_tool, current_tool_call],
177
- outputs=[chatbot, msg, waiting_for_tool, current_tool_call]
178
- ).then(
179
- fn=update_tool_ui,
180
- inputs=[waiting_for_tool, current_tool_call],
181
- outputs=[tool_group, tool_name, tool_params]
182
- )
183
 
184
- clear_btn.click(
185
- fn=clear_conversation,
186
- inputs=[],
187
- outputs=[chatbot, msg, current_tool_call]
188
- ).then(
189
- fn=lambda: gr.update(visible=False),
190
- inputs=[],
191
- outputs=[tool_group]
192
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- dummy_btn.click(
195
- fn=generate_dummy_response,
196
- inputs=[current_tool_call],
197
- outputs=[tool_response]
198
- )
199
 
200
- submit_tool_btn.click(
201
- fn=submit_tool_response,
202
- inputs=[tool_response, chat_history, current_tool_call],
203
- outputs=[chatbot, waiting_for_tool, current_tool_call]
204
- ).then(
205
- fn=update_tool_ui,
206
- inputs=[waiting_for_tool, current_tool_call],
207
- outputs=[tool_group, tool_name, tool_params]
208
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
 
210
  if __name__ == "__main__":
211
- demo.launch()
 
 
1
+ import torch
 
2
  import json
3
+ import gradio as gr
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from peft import PeftModel
7
  import time
8
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Device setup
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ print(f"Using device: {device}")
13
+
14
+ # Load base model and tokenizer
15
+ tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
16
+ base_model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True,
17
+ torch_dtype=torch.bfloat16, load_in_8bit=True)
18
 
19
+ # Load the LoRA adapter for better tool calling
20
+ model = PeftModel.from_pretrained(base_model, "Proximile/LLaDA-8B-Tools-LoRA")
21
+ model.eval()
22
+
23
+ # Constants
24
+ MASK_TOKEN = "[MASK]"
25
+ MASK_ID = 126336 # The token ID of [MASK] in LLaDA
26
+
27
+ # Tool class definitions
28
+ class ToolBase:
29
+ def __init__(self,
30
+ programmatic_name,
31
+ natural_name,
32
+ description,
33
+ input_params,
34
+ required_params=None,
35
+ ):
36
+ self.json_name = programmatic_name
37
+ self.json_description = description
38
+ self.schema = {
39
+ "type": "function",
40
+ "function": {
41
+ "name": self.json_name,
42
+ "description": self.json_description,
43
+ "parameters": {
44
+ "type": "object",
45
+ "properties": input_params,
46
+ "required": required_params or []
47
+ }
48
+ }
49
+ }
50
 
51
+ def actual_function(self, **kwargs):
52
+ raise NotImplementedError("Subclasses must implement this method.")
53
+
54
+ class WeatherAPITool(ToolBase):
55
+ def __init__(self):
56
+ super().__init__(
57
+ programmatic_name="get_weather",
58
+ natural_name="Weather Report Fetcher",
59
+ description="Get the current weather in a given location",
60
+ input_params={
61
+ "location": {
62
+ "type": "string",
63
+ "description": "The city and state, e.g. San Francisco, CA"
64
+ },
65
+ "unit": {
66
+ "type": "string",
67
+ "enum": ["celsius", "fahrenheit"],
68
+ "description": "The unit of temperature"
69
+ }
70
+ },
71
+ required_params=["location", "unit"],
72
+ )
73
+
74
+ def actual_function(self, **kwargs):
75
+ # This would normally call an API, but we'll return dummy data
76
+ return {
77
+ "location": kwargs["location"],
78
+ "temperature": 72 if kwargs["unit"] == "fahrenheit" else 22,
79
+ "unit": kwargs["unit"],
80
+ "condition": "Partly Cloudy",
81
+ "humidity": 65,
82
+ "wind_speed": 8,
83
+ "wind_direction": "NE"
84
+ }
85
+
86
+ # Create the tool
87
+ weather_tool = WeatherAPITool()
88
+
89
+ # Diffusion model generation functions
90
+ def add_gumbel_noise(logits, temperature):
91
+ '''
92
+ The Gumbel max is a method for sampling categorical distributions.
93
+ For MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
94
+ '''
95
+ if temperature <= 0:
96
+ return logits
97
+
98
+ logits = logits.to(torch.float64)
99
+ noise = torch.rand_like(logits, dtype=torch.float64)
100
+ gumbel_noise = (- torch.log(noise)) ** temperature
101
+ return logits.exp() / gumbel_noise
102
+
103
+ def get_num_transfer_tokens(mask_index, steps):
104
+ '''
105
+ In the reverse process, we precompute the number of tokens to transition at each step.
106
+ '''
107
+ mask_num = mask_index.sum(dim=1, keepdim=True)
108
 
109
+ # Ensure we have at least one step
110
+ if steps == 0:
111
+ steps = 1
112
+
113
+ base = mask_num // steps
114
+ remainder = mask_num % steps
115
 
116
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
 
117
 
118
+ for i in range(mask_num.size(0)):
119
+ if remainder[i] > 0:
120
+ num_transfer_tokens[i, :remainder[i]] += 1
121
+
122
+ return num_transfer_tokens
123
 
124
+ def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=128, steps=128,
125
+ temperature=0.1, cfg_scale=0.0, block_length=32,
126
+ remasking='low_confidence'):
127
+ """
128
+ Generate text with LLaDA model with visualization
129
+ """
130
+ # Prepare the prompt using chat template
131
+ chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
132
+ input_ids = tokenizer(chat_input)['input_ids']
133
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
134
 
135
+ # For generation
136
+ prompt_length = input_ids.shape[1]
 
 
137
 
138
+ # Initialize the sequence with masks for the response part
139
+ x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
140
+ x[:, :prompt_length] = input_ids.clone()
141
 
142
+ # Initialize visualization states for the response part
143
+ visualization_states = []
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Add initial state (all masked)
146
+ initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
147
+ visualization_states.append(initial_state)
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # Mark prompt positions to exclude them from masking during classifier-free guidance
150
+ prompt_index = (x != MASK_ID)
151
 
152
+ # Ensure block_length is valid
153
+ if block_length > gen_length:
154
+ block_length = gen_length
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ # Calculate number of blocks
157
+ num_blocks = gen_length // block_length
158
+ if gen_length % block_length != 0:
159
+ num_blocks += 1
 
 
 
 
 
 
160
 
161
+ # Adjust steps per block
162
+ steps_per_block = steps // num_blocks
163
+ if steps_per_block < 1:
164
+ steps_per_block = 1
 
 
 
 
 
165
 
166
+ # Process each block
167
+ for num_block in range(num_blocks):
168
+ # Calculate the start and end indices for the current block
169
+ block_start = prompt_length + num_block * block_length
170
+ block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
171
+
172
+ # Get mask indices for the current block
173
+ block_mask_index = (x[:, block_start:block_end] == MASK_ID)
174
+
175
+ # Skip if no masks in this block
176
+ if not block_mask_index.any():
177
+ continue
178
+
179
+ # Calculate number of tokens to unmask at each step
180
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
181
+
182
+ # Process each step
183
+ for i in range(steps_per_block):
184
+ # Get all mask positions in the current sequence
185
+ mask_index = (x == MASK_ID)
186
+
187
+ # Skip if no masks
188
+ if not mask_index.any():
189
+ break
190
+
191
+ # Apply classifier-free guidance if enabled
192
+ if cfg_scale > 0.0:
193
+ un_x = x.clone()
194
+ un_x[prompt_index] = MASK_ID
195
+ x_ = torch.cat([x, un_x], dim=0)
196
+ logits = model(x_).logits
197
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
198
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
199
+ else:
200
+ logits = model(x).logits
201
+
202
+ # Apply Gumbel noise for sampling
203
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
204
+ x0 = torch.argmax(logits_with_noise, dim=-1)
205
+
206
+ # Calculate confidence scores for remasking
207
+ if remasking == 'low_confidence':
208
+ p = F.softmax(logits.to(torch.float64), dim=-1)
209
+ x0_p = torch.squeeze(
210
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
211
+ elif remasking == 'random':
212
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
213
+ else:
214
+ raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented")
215
+
216
+ # Don't consider positions beyond the current block
217
+ x0_p[:, block_end:] = -float('inf')
218
+
219
+ # Apply predictions where we have masks
220
+ old_x = x.clone()
221
+ x0 = torch.where(mask_index, x0, x)
222
+ confidence = torch.where(mask_index, x0_p, -float('inf'))
223
+
224
+ # Select tokens to unmask based on confidence
225
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
226
+ for j in range(confidence.shape[0]):
227
+ # Only consider positions within the current block for unmasking
228
+ block_confidence = confidence[j, block_start:block_end]
229
+ if i < steps_per_block - 1: # Not the last step
230
+ # Take top-k confidences
231
+ _, select_indices = torch.topk(block_confidence,
232
+ k=min(num_transfer_tokens[j, i].item(),
233
+ block_confidence.numel()))
234
+ # Adjust indices to global positions
235
+ select_indices = select_indices + block_start
236
+ transfer_index[j, select_indices] = True
237
+ else: # Last step - unmask everything remaining
238
+ transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end]
239
+
240
+ # Apply the selected tokens
241
+ x = torch.where(transfer_index, x0, x)
242
+
243
+ # Create visualization state only for the response part
244
+ current_state = []
245
+ for i in range(gen_length):
246
+ pos = prompt_length + i # Absolute position in the sequence
247
+
248
+ if x[0, pos] == MASK_ID:
249
+ # Still masked
250
+ current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
251
+
252
+ elif old_x[0, pos] == MASK_ID:
253
+ # Newly revealed in this step
254
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
255
+ # Color based on confidence
256
+ confidence = float(x0_p[0, pos].cpu())
257
+ if confidence < 0.3:
258
+ color = "#FF6666" # Light red
259
+ elif confidence < 0.7:
260
+ color = "#FFAA33" # Orange
261
+ else:
262
+ color = "#66CC66" # Light green
263
+
264
+ current_state.append((token, color))
265
+
266
+ else:
267
+ # Previously revealed
268
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
269
+ current_state.append((token, "#6699CC")) # Light blue
270
+
271
+ visualization_states.append(current_state)
272
 
273
+ # Extract final text (just the assistant's response)
274
+ response_tokens = x[0, prompt_length:]
275
+ final_text = tokenizer.decode(response_tokens,
276
+ skip_special_tokens=False,
277
+ clean_up_tokenization_spaces=True).split("<|")[0]
278
 
279
+ return visualization_states, final_text
280
+
281
+ # Tool handling functions
282
+ def is_tool_call(text):
283
+ """Check if the text looks like a JSON tool call"""
284
+ # Remove any whitespace at beginning and end
285
+ text = text.strip()
286
+ # Check if it starts with [ or { (common JSON indicators)
287
+ if (text.startswith('[') and text.endswith(']')) or (text.startswith('{') and text.endswith('}')):
288
+ try:
289
+ # Try to parse as JSON
290
+ data = json.loads(text)
291
+ # Check if it contains a tool call structure
292
+ if isinstance(data, list):
293
+ for item in data:
294
+ if isinstance(item, dict) and "name" in item and "parameters" in item:
295
+ return True
296
+ elif isinstance(data, dict) and "name" in data and "parameters" in data:
297
+ return True
298
+ except:
299
+ pass
300
+ return False
301
+
302
+ def extract_tool_call(text):
303
+ """Extract tool call data from text"""
304
+ try:
305
+ data = json.loads(text)
306
+ if isinstance(data, list) and len(data) > 0:
307
+ # Return the first valid tool call
308
+ for item in data:
309
+ if isinstance(item, dict) and "name" in item and "parameters" in item:
310
+ return item
311
+ elif isinstance(data, dict) and "name" in data and "parameters" in data:
312
+ return data
313
+ except:
314
+ pass
315
+ return None
316
+
317
+ def handle_tool_call(tool_call):
318
+ """Process a tool call and return the result"""
319
+ if tool_call["name"] == weather_tool.json_name:
320
+ return weather_tool.actual_function(**tool_call["parameters"])
321
+ return {"error": f"Tool {tool_call['name']} not found"}
322
+
323
+ # Custom CSS
324
+ css = '''
325
+ .category-legend{display:none}
326
+ button{height: 60px}
327
+ .visualization-container {
328
+ margin-top: 20px;
329
+ padding: 10px;
330
+ background-color: #f8f9fa;
331
+ border-radius: 8px;
332
+ }
333
+ '''
334
+
335
+ def create_chatbot_demo():
336
+ with gr.Blocks(css=css) as demo:
337
+ gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo")
338
+ gr.Markdown("This demo showcases the LLaDA diffusion model with the [Proximile/LLaDA-8B-Tools-LoRA](https://huggingface.co/Proximile/LLaDA-8B-Tools-LoRA) adapter for enhanced tool calling capabilities.")
339
+
340
+
341
+ # STATE MANAGEMENT
342
+ chat_history = gr.State([])
343
+ waiting_for_tool_response = gr.State(False)
344
+ current_tool_call = gr.State(None)
345
+
346
+ # UI COMPONENTS
347
+ with gr.Row():
348
+ with gr.Column(scale=3):
349
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500)
350
+
351
+ # Message input
352
+ with gr.Group():
353
+ with gr.Row():
354
+ user_input = gr.Textbox(
355
+ label="Your Message",
356
+ placeholder="Type your message here...",
357
+ show_label=False
358
+ )
359
+ send_btn = gr.Button("Send")
360
+
361
+ # Tool response input (initially hidden)
362
+ with gr.Group(visible=False) as tool_response_group:
363
+ gr.Markdown("## Tool Call Detected")
364
+ tool_name_display = gr.Textbox(label="Tool Name", interactive=False)
365
+ tool_params_display = gr.JSON(label="Parameters")
366
+ tool_response_input = gr.Textbox(
367
+ label="Tool Response (JSON)",
368
+ placeholder="Enter JSON response for the tool...",
369
+ lines=5
370
+ )
371
+ submit_tool_response = gr.Button("Submit Tool Response")
372
+
373
+ # Add a button for auto-filling dummy response
374
+ dummy_response_btn = gr.Button("Use Dummy Response")
375
+
376
+ with gr.Column(scale=2):
377
+ gr.Markdown("## Diffusion Process Visualization")
378
+ gr.Markdown("Watch tokens appear in real-time as the diffusion process progresses:")
379
+ output_vis = gr.HighlightedText(
380
+ label="Token Denoising",
381
+ combine_adjacent=False,
382
+ show_legend=True,
383
+ elem_classes="visualization-container"
384
+ )
385
+ gr.Markdown("**Color Legend:**")
386
+ gr.Markdown("- **Dark Gray** [MASK]: Not yet revealed")
387
+ gr.Markdown("- **Light Red**: Newly revealed with low confidence")
388
+ gr.Markdown("- **Orange**: Newly revealed with medium confidence")
389
+ gr.Markdown("- **Light Green**: Newly revealed with high confidence")
390
+ gr.Markdown("- **Light Blue**: Previously revealed tokens")
391
+
392
+ # Advanced generation settings
393
+ with gr.Accordion("Generation Settings", open=False):
394
+ with gr.Row():
395
+ gen_length = gr.Slider(
396
+ minimum=8, maximum=128, value=64, step=4,
397
+ label="Generation Length"
398
+ )
399
+ steps = gr.Slider(
400
+ minimum=8, maximum=128, value=64, step=4,
401
+ label="Denoising Steps"
402
+ )
403
+ with gr.Row():
404
+ temperature = gr.Slider(
405
+ minimum=0.0, maximum=1.0, value=0.1, step=0.1,
406
+ label="Temperature"
407
+ )
408
+ cfg_scale = gr.Slider(
409
+ minimum=0.0, maximum=2.0, value=0.0, step=0.1,
410
+ label="CFG Scale"
411
+ )
412
+ with gr.Row():
413
+ block_length = gr.Slider(
414
+ minimum=8, maximum=128, value=32, step=8,
415
+ label="Block Length"
416
+ )
417
+ remasking_strategy = gr.Radio(
418
+ choices=["low_confidence", "random"],
419
+ value="low_confidence",
420
+ label="Remasking Strategy"
421
+ )
422
+ with gr.Row():
423
+ visualization_delay = gr.Slider(
424
+ minimum=0.0, maximum=1.0, value=0.1, step=0.1,
425
+ label="Visualization Delay (seconds)"
426
+ )
427
+
428
+ # Current response text box (hidden)
429
+ current_response = gr.Textbox(
430
+ label="Current Response",
431
+ placeholder="The assistant's response will appear here...",
432
+ lines=3,
433
+ visible=False
434
+ )
435
+
436
+ # Clear button
437
+ clear_btn = gr.Button("Clear Conversation")
438
+
439
+ gr.Markdown("### Try asking about the weather to trigger a tool call!")
440
+ gr.Markdown("Examples: 'What's the weather like in New York?', 'How hot is it in Tokyo right now?'")
441
+
442
+ # System prompt for the model
443
+ system_prompt = f"""You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original user question.
444
+
445
+ If you choose to use one or more of the following tool functions, respond with a list of JSON function calls, each with the proper arguments that best answers the given prompt.
446
+
447
+ Each tool request within the list should be in the exact format {{"name": function name, "parameters": {{dictionary of argument names and values}}}}. Do not use variables. Just a list of two-key dictionaries, each starting with the function name, followed by a dictionary of parameters.
448
+
449
+ Here are the tool functions available to you:
450
+
451
+ {json.dumps([weather_tool.schema], indent=4)}
452
+
453
+ After receiving the results back from a function call, you have to formulate your response to the user. If the information needed is not found in the returned data, either attempt a new function call, or inform the user that you cannot answer based on your available knowledge. The user cannot see the function results. You have to interpret the data and provide a response based on it.
454
+
455
+ If the user request does not necessitate a function call, simply respond to the user's query directly."""
456
+
457
+ # HELPER FUNCTIONS
458
+ def add_message(history, message, response):
459
+ """Add a message pair to the history and return the updated history"""
460
+ history = history.copy()
461
+ history.append([message, response])
462
+ return history
463
+
464
+ def user_message_submitted(message, history, waiting_for_tool):
465
+ """Process a submitted user message"""
466
+ # Skip empty messages or if waiting for a tool response
467
+ if not message.strip() or waiting_for_tool:
468
+ # Return current state unchanged
469
+ history_for_display = history.copy()
470
+ return history, history_for_display, "", [], ""
471
+
472
+ # Add user message to history
473
+ history = add_message(history, message, None)
474
+
475
+ # Format for display - temporarily show user message with empty response
476
+ history_for_display = history.copy()
477
+
478
+ # Clear the input
479
+ message_out = ""
480
+
481
+ # Return immediately to update UI with user message
482
+ return history, history_for_display, message_out, [], ""
483
+
484
+ def bot_response(history, waiting_for_tool, current_tool,
485
+ gen_length, steps, delay, temperature,
486
+ cfg_scale, block_length, remasking):
487
+ """Generate bot response for the latest message"""
488
+ if not history or waiting_for_tool:
489
+ return history, [], "", waiting_for_tool, current_tool, gr.update(visible=False), gr.update(), gr.update()
490
+
491
+ # Get the last user message
492
+ last_user_message = history[-1][0]
493
+
494
+ try:
495
+ # Format the conversation for the model
496
+ messages = []
497
+
498
+ # Add system message first
499
+ messages.append({"role": "system", "content": system_prompt})
500
+
501
+ # Add conversation history
502
+ for h in history[:-1]:
503
+ messages.append({"role": "user", "content": h[0]})
504
+ if h[1]: # Only include assistant responses that exist
505
+ messages.append({"role": "assistant", "content": h[1]})
506
+
507
+ # Add the last user message
508
+ messages.append({"role": "user", "content": last_user_message})
509
+
510
+ # Generate response with visualization
511
+ vis_states, response_text = generate_response_with_visualization(
512
+ model, tokenizer, device,
513
+ messages,
514
+ gen_length=gen_length,
515
+ steps=steps,
516
+ temperature=temperature,
517
+ cfg_scale=cfg_scale,
518
+ block_length=block_length,
519
+ remasking=remasking
520
+ )
521
+
522
+ # Update history with the assistant's response
523
+ history[-1][1] = response_text
524
+
525
+ # Check if the response is a tool call
526
+ is_tool = is_tool_call(response_text)
527
+
528
+ if is_tool:
529
+ # Extract tool call information
530
+ tool_call = extract_tool_call(response_text)
531
+
532
+ # Return the initial state immediately
533
+ yield (history, vis_states[0], response_text,
534
+ True, tool_call,
535
+ gr.update(visible=True),
536
+ gr.update(value=tool_call["name"]),
537
+ gr.update(value=tool_call["parameters"]))
538
+
539
+ # Then animate through visualization states
540
+ for state in vis_states[1:]:
541
+ time.sleep(delay)
542
+ yield (history, state, response_text,
543
+ True, tool_call,
544
+ gr.update(visible=True),
545
+ gr.update(value=tool_call["name"]),
546
+ gr.update(value=tool_call["parameters"]))
547
+ else:
548
+ # Return the initial state immediately
549
+ yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
550
+
551
+ # Then animate through visualization states
552
+ for state in vis_states[1:]:
553
+ time.sleep(delay)
554
+ yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
555
+
556
+ except Exception as e:
557
+ error_msg = f"Error: {str(e)}"
558
+ print(error_msg)
559
+
560
+ # Show error in visualization
561
+ error_vis = [(error_msg, "red")]
562
+
563
+ # Don't update history with error
564
+ yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
565
+
566
+ def process_tool_response(tool_response, history, current_tool,
567
+ gen_length, steps, delay, temperature,
568
+ cfg_scale, block_length, remasking):
569
+ """Process tool response and generate a follow-up response"""
570
+ if not history or not current_tool:
571
+ return history, [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
572
+
573
+ try:
574
+ # Parse the tool response
575
+ response_data = json.loads(tool_response) if isinstance(tool_response, str) else tool_response
576
+
577
+ # Format the conversation for the model
578
+ messages = []
579
+
580
+ # Add system message first
581
+ messages.append({"role": "system", "content": system_prompt})
582
+
583
+ # Add conversation history
584
+ for h in history:
585
+ messages.append({"role": "user", "content": h[0]})
586
+ if h[1]: # Only include assistant responses that exist
587
+ messages.append({"role": "assistant", "content": h[1]})
588
+
589
+ # Add the tool response
590
+ messages.append({"role": "ipython", "content": json.dumps({
591
+ "name": current_tool["name"],
592
+ "return": response_data
593
+ })})
594
+
595
+ # Generate response with visualization
596
+ vis_states, response_text = generate_response_with_visualization(
597
+ model, tokenizer, device,
598
+ messages,
599
+ gen_length=gen_length,
600
+ steps=steps,
601
+ temperature=temperature,
602
+ cfg_scale=cfg_scale,
603
+ block_length=block_length,
604
+ remasking=remasking
605
+ )
606
+
607
+ # Add a new message pair for the tool-processed response
608
+ history = add_message(history, "Tool response processed", response_text)
609
+
610
+ # Return the initial state immediately
611
+ yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
612
+
613
+ # Then animate through visualization states
614
+ for state in vis_states[1:]:
615
+ time.sleep(delay)
616
+ yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
617
+
618
+ except Exception as e:
619
+ error_msg = f"Error processing tool response: {str(e)}"
620
+ print(error_msg)
621
+
622
+ # Show error in visualization
623
+ error_vis = [(error_msg, "red")]
624
+
625
+ # Don't update history with error
626
+ yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
627
+
628
+ def generate_dummy_response(current_tool):
629
+ """Generate a dummy response for a tool call"""
630
+ if not current_tool:
631
+ return ""
632
+
633
+ # Process based on tool name
634
+ if current_tool["name"] == weather_tool.json_name:
635
+ location = current_tool["parameters"].get("location", "Unknown")
636
+ unit = current_tool["parameters"].get("unit", "celsius")
637
+
638
+ dummy_data = {
639
+ "location": location,
640
+ "temperature": 72 if unit == "fahrenheit" else 22,
641
+ "unit": unit,
642
+ "condition": "Partly Cloudy",
643
+ "humidity": 65,
644
+ "wind_speed": 8,
645
+ "wind_direction": "NE"
646
+ }
647
+
648
+ return json.dumps(dummy_data, indent=2)
649
+
650
+ return "{}"
651
+
652
+ def clear_conversation():
653
+ """Clear the conversation history"""
654
+ return [], [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
655
+
656
+ # EVENT HANDLERS
657
+
658
+ # Clear button handler
659
+ clear_btn.click(
660
+ fn=clear_conversation,
661
+ inputs=[],
662
+ outputs=[chat_history, chatbot_ui, current_response, waiting_for_tool_response,
663
+ current_tool_call, tool_response_group, tool_name_display, tool_params_display]
664
+ )
665
+
666
+ # Dummy response button handler
667
+ dummy_response_btn.click(
668
+ fn=generate_dummy_response,
669
+ inputs=[current_tool_call],
670
+ outputs=[tool_response_input]
671
+ )
672
+
673
+ # User message submission flow
674
+ msg_submit = user_input.submit(
675
+ fn=user_message_submitted,
676
+ inputs=[user_input, chat_history, waiting_for_tool_response],
677
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
678
+ )
679
+
680
+ # Also connect the send button
681
+ send_click = send_btn.click(
682
+ fn=user_message_submitted,
683
+ inputs=[user_input, chat_history, waiting_for_tool_response],
684
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
685
+ )
686
+
687
+ # Generate bot response
688
+ msg_submit.then(
689
+ fn=bot_response,
690
+ inputs=[
691
+ chat_history, waiting_for_tool_response, current_tool_call,
692
+ gen_length, steps, visualization_delay, temperature,
693
+ cfg_scale, block_length, remasking_strategy
694
+ ],
695
+ outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
696
+ current_tool_call, tool_response_group, tool_name_display, tool_params_display]
697
+ )
698
+
699
+ send_click.then(
700
+ fn=bot_response,
701
+ inputs=[
702
+ chat_history, waiting_for_tool_response, current_tool_call,
703
+ gen_length, steps, visualization_delay, temperature,
704
+ cfg_scale, block_length, remasking_strategy
705
+ ],
706
+ outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
707
+ current_tool_call, tool_response_group, tool_name_display, tool_params_display]
708
+ )
709
+
710
+ # Tool response submission
711
+ submit_tool_response.click(
712
+ fn=process_tool_response,
713
+ inputs=[
714
+ tool_response_input, chat_history, current_tool_call,
715
+ gen_length, steps, visualization_delay, temperature,
716
+ cfg_scale, block_length, remasking_strategy
717
+ ],
718
+ outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
719
+ current_tool_call, tool_response_group, tool_name_display, tool_params_display]
720
+ )
721
+
722
+ return demo
723
 
724
+ # Launch the demo
725
  if __name__ == "__main__":
726
+ demo = create_chatbot_demo()
727
+ demo.queue().launch(share=True)