ProximileAdmin commited on
Commit
d8fb53c
·
verified ·
1 Parent(s): c51bc1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -776
app.py CHANGED
@@ -1,801 +1,211 @@
1
- import torch
2
- import json
3
  import gradio as gr
4
- import torch.nn.functional as F
5
- from transformers import AutoTokenizer, AutoModel
6
  import time
7
  import os
8
- import re
9
- import threading
10
 
11
- # Global variables for the loaded model and tokenizer
12
- model = None
13
- tokenizer = None
14
- MASK_TOKEN = "[MASK]"
15
- MASK_ID = 126336 # The token ID of [MASK] in LLaDA
16
- is_model_loaded = False
17
 
18
- # Tool class definitions
19
- class ToolBase:
20
- def __init__(self,
21
- programmatic_name,
22
- natural_name,
23
- description,
24
- input_params,
25
- required_params=None,
26
- ):
27
- self.json_name = programmatic_name
28
- self.json_description = description
29
- self.schema = {
30
- "type": "function",
31
- "function": {
32
- "name": self.json_name,
33
- "description": self.json_description,
34
- "parameters": {
35
- "type": "object",
36
- "properties": input_params,
37
- "required": required_params or []
38
- }
39
- }
40
- }
41
-
42
- def actual_function(self, **kwargs):
43
- raise NotImplementedError("Subclasses must implement this method.")
44
-
45
- class WeatherAPITool(ToolBase):
46
- def __init__(self):
47
- super().__init__(
48
- programmatic_name="get_weather",
49
- natural_name="Weather Report Fetcher",
50
- description="Get the current weather in a given location",
51
- input_params={
52
- "location": {
53
- "type": "string",
54
- "description": "The city and state, e.g. San Francisco, CA"
55
- },
56
- "unit": {
57
- "type": "string",
58
- "enum": ["celsius", "fahrenheit"],
59
- "description": "The unit of temperature"
60
- }
61
- },
62
- required_params=["location", "unit"],
63
- )
64
 
65
- def actual_function(self, **kwargs):
66
- # This would normally call an API, but we'll return dummy data
67
- return {
68
- "location": kwargs["location"],
69
- "temperature": 72 if kwargs["unit"] == "fahrenheit" else 22,
70
- "unit": kwargs["unit"],
71
- "condition": "Partly Cloudy",
72
- "humidity": 65,
73
- "wind_speed": 8,
74
- "wind_direction": "NE"
75
- }
76
 
77
- # Create the tool
78
- weather_tool = WeatherAPITool()
 
 
79
 
80
- # Custom CSS
81
- css = '''
82
- .category-legend{display:none}
83
- button{height: 60px}
84
- .visualization-container {
85
- margin-top: 20px;
86
- padding: 10px;
87
- background-color: #f8f9fa;
88
- border-radius: 8px;
89
- }
90
- .loading-container {
91
- text-align: center;
92
- padding: 40px;
93
- max-width: 800px;
94
- margin: 0 auto;
95
- }
96
- .loading-text {
97
- font-size: 18px;
98
- margin: 20px 0;
99
- }
100
- .loading-spinner {
101
- display: inline-block;
102
- width: 50px;
103
- height: 50px;
104
- border: 5px solid rgba(0,0,0,.1);
105
- border-radius: 50%;
106
- border-top-color: #3498db;
107
- animation: spin 1s ease-in-out infinite;
108
- }
109
- @keyframes spin {
110
- to { transform: rotate(360deg); }
111
- }
112
- '''
113
-
114
- # Diffusion model generation functions
115
- def add_gumbel_noise(logits, temperature):
116
- '''
117
- The Gumbel max is a method for sampling categorical distributions.
118
- For MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
119
- '''
120
- if temperature <= 0:
121
- return logits
122
-
123
- logits = logits.to(torch.float64)
124
- noise = torch.rand_like(logits, dtype=torch.float64)
125
- gumbel_noise = (- torch.log(noise)) ** temperature
126
- return logits.exp() / gumbel_noise
127
-
128
- def get_num_transfer_tokens(mask_index, steps):
129
- '''
130
- In the reverse process, we precompute the number of tokens to transition at each step.
131
- '''
132
- mask_num = mask_index.sum(dim=1, keepdim=True)
133
 
134
- # Ensure we have at least one step
135
- if steps == 0:
136
- steps = 1
137
-
138
- base = mask_num // steps
139
- remainder = mask_num % steps
140
 
141
- num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
 
142
 
143
- for i in range(mask_num.size(0)):
144
- if remainder[i] > 0:
145
- num_transfer_tokens[i, :remainder[i]] += 1
146
-
147
- return num_transfer_tokens
148
-
149
- def generate_response_with_visualization(messages, gen_length=128, steps=128,
150
- temperature=0.1, cfg_scale=0.0, block_length=32,
151
- remasking='low_confidence'):
152
- """
153
- Generate text with LLaDA model with visualization
154
- """
155
- global model, tokenizer, MASK_TOKEN, MASK_ID
156
 
157
- # First make sure model is loaded
158
- if model is None or tokenizer is None:
159
- return [("Model not loaded yet. Please wait...", "red")], "Model not loaded yet. Please wait..."
160
 
161
- # Get device
162
- device = next(model.parameters()).device
 
163
 
164
- # Prepare the prompt using chat template
165
- chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
166
- input_ids = tokenizer(chat_input)['input_ids']
167
- input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- # For generation
170
- prompt_length = input_ids.shape[1]
 
 
 
 
171
 
172
- # Initialize the sequence with masks for the response part
173
- x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
174
- x[:, :prompt_length] = input_ids.clone()
 
 
 
 
175
 
176
- # Initialize visualization states for the response part
177
- visualization_states = []
 
 
 
 
178
 
179
- # Add initial state (all masked)
180
- initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
181
- visualization_states.append(initial_state)
182
 
183
- # Mark prompt positions to exclude them from masking during classifier-free guidance
184
- prompt_index = (x != MASK_ID)
185
 
186
- # Ensure block_length is valid
187
- if block_length > gen_length:
188
- block_length = gen_length
 
 
 
189
 
190
- # Calculate number of blocks
191
- num_blocks = gen_length // block_length
192
- if gen_length % block_length != 0:
193
- num_blocks += 1
194
 
195
- # Adjust steps per block
196
- steps_per_block = steps // num_blocks
197
- if steps_per_block < 1:
198
- steps_per_block = 1
199
 
200
- # Process each block
201
- for num_block in range(num_blocks):
202
- # Calculate the start and end indices for the current block
203
- block_start = prompt_length + num_block * block_length
204
- block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
205
-
206
- # Get mask indices for the current block
207
- block_mask_index = (x[:, block_start:block_end] == MASK_ID)
208
-
209
- # Skip if no masks in this block
210
- if not block_mask_index.any():
211
- continue
212
-
213
- # Calculate number of tokens to unmask at each step
214
- num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
215
-
216
- # Process each step
217
- for i in range(steps_per_block):
218
- # Get all mask positions in the current sequence
219
- mask_index = (x == MASK_ID)
220
-
221
- # Skip if no masks
222
- if not mask_index.any():
223
- break
224
-
225
- # Apply classifier-free guidance if enabled
226
- if cfg_scale > 0.0:
227
- un_x = x.clone()
228
- un_x[prompt_index] = MASK_ID
229
- x_ = torch.cat([x, un_x], dim=0)
230
- logits = model(x_).logits
231
- logits, un_logits = torch.chunk(logits, 2, dim=0)
232
- logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
233
- else:
234
- logits = model(x).logits
235
-
236
- # Apply Gumbel noise for sampling
237
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
238
- x0 = torch.argmax(logits_with_noise, dim=-1)
239
-
240
- # Calculate confidence scores for remasking
241
- if remasking == 'low_confidence':
242
- p = F.softmax(logits.to(torch.float64), dim=-1)
243
- x0_p = torch.squeeze(
244
- torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
245
- elif remasking == 'random':
246
- x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
247
- else:
248
- raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented")
249
-
250
- # Don't consider positions beyond the current block
251
- x0_p[:, block_end:] = -float('inf')
252
-
253
- # Apply predictions where we have masks
254
- old_x = x.clone()
255
- x0 = torch.where(mask_index, x0, x)
256
- confidence = torch.where(mask_index, x0_p, -float('inf'))
257
-
258
- # Select tokens to unmask based on confidence
259
- transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
260
- for j in range(confidence.shape[0]):
261
- # Only consider positions within the current block for unmasking
262
- block_confidence = confidence[j, block_start:block_end]
263
- if i < steps_per_block - 1: # Not the last step
264
- # Take top-k confidences
265
- _, select_indices = torch.topk(block_confidence,
266
- k=min(num_transfer_tokens[j, i].item(),
267
- block_confidence.numel()))
268
- # Adjust indices to global positions
269
- select_indices = select_indices + block_start
270
- transfer_index[j, select_indices] = True
271
- else: # Last step - unmask everything remaining
272
- transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end]
273
-
274
- # Apply the selected tokens
275
- x = torch.where(transfer_index, x0, x)
276
-
277
- # Create visualization state only for the response part
278
- current_state = []
279
- for i in range(gen_length):
280
- pos = prompt_length + i # Absolute position in the sequence
281
-
282
- if x[0, pos] == MASK_ID:
283
- # Still masked
284
- current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
285
-
286
- elif old_x[0, pos] == MASK_ID:
287
- # Newly revealed in this step
288
- token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
289
- # Color based on confidence
290
- confidence = float(x0_p[0, pos].cpu())
291
- if confidence < 0.3:
292
- color = "#FF6666" # Light red
293
- elif confidence < 0.7:
294
- color = "#FFAA33" # Orange
295
- else:
296
- color = "#66CC66" # Light green
297
-
298
- current_state.append((token, color))
299
-
300
- else:
301
- # Previously revealed
302
- token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
303
- current_state.append((token, "#6699CC")) # Light blue
304
-
305
- visualization_states.append(current_state)
306
-
307
- # Extract final text (just the assistant's response)
308
- response_tokens = x[0, prompt_length:]
309
- final_text = tokenizer.decode(response_tokens,
310
- skip_special_tokens=False,
311
- clean_up_tokenization_spaces=True).split("<|")[0]
312
-
313
- return visualization_states, final_text
314
-
315
- # Tool handling functions
316
- def is_tool_call(text):
317
- """Check if the text looks like a JSON tool call"""
318
- # Remove any whitespace at beginning and end
319
- text = text.strip()
320
- # Check if it starts with [ or { (common JSON indicators)
321
- if (text.startswith('[') and text.endswith(']')) or (text.startswith('{') and text.endswith('}')):
322
- try:
323
- # Try to parse as JSON
324
- data = json.loads(text)
325
- # Check if it contains a tool call structure
326
- if isinstance(data, list):
327
- for item in data:
328
- if isinstance(item, dict) and "name" in item and "parameters" in item:
329
- return True
330
- elif isinstance(data, dict) and "name" in data and "parameters" in data:
331
- return True
332
- except:
333
- pass
334
- return False
335
-
336
- def extract_tool_call(text):
337
- """Extract tool call data from text"""
338
- try:
339
- data = json.loads(text)
340
- if isinstance(data, list) and len(data) > 0:
341
- # Return the first valid tool call
342
- for item in data:
343
- if isinstance(item, dict) and "name" in item and "parameters" in item:
344
- return item
345
- elif isinstance(data, dict) and "name" in data and "parameters" in data:
346
- return data
347
- except:
348
- pass
349
- return None
350
-
351
- def handle_tool_call(tool_call):
352
- """Process a tool call and return the result"""
353
- if tool_call["name"] == weather_tool.json_name:
354
- return weather_tool.actual_function(**tool_call["parameters"])
355
- return {"error": f"Tool {tool_call['name']} not found"}
356
-
357
- def load_model():
358
- """Load the model and tokenizer"""
359
- global model, tokenizer, MASK_TOKEN, MASK_ID, is_model_loaded
360
-
361
- try:
362
- # Device setup
363
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
364
- print(f"Using device: {device}")
365
-
366
- # Load model and tokenizer
367
- tokenizer = AutoTokenizer.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True)
368
- model = AutoModel.from_pretrained(
369
- "Proximile/LLaDA-8B-Tools",
370
- trust_remote_code=True,
371
- torch_dtype=torch.bfloat16,
372
- load_in_4bit=True
373
- )
374
- model.eval()
375
-
376
- # Set constants
377
- MASK_TOKEN = "[MASK]"
378
- MASK_ID = 126336 # The token ID of [MASK] in LLaDA
379
-
380
- is_model_loaded = True
381
- return True
382
- except Exception as e:
383
- print(f"Error loading model: {str(e)}")
384
- return False
385
-
386
- def loading_app():
387
- with gr.Blocks(css=css) as loading_demo:
388
- with gr.Column(elem_classes="loading-container"):
389
- gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo")
390
- gr.Markdown("### Loading Model and Tokenizer...")
391
- gr.HTML('<div class="loading-spinner"></div>')
392
- gr.Markdown("Please wait while the model is loading. This may take several minutes.")
393
-
394
- return loading_demo
395
-
396
- def chat_app():
397
- with gr.Blocks(css=css) as chat_demo:
398
- gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo")
399
- gr.Markdown("This demo showcases the LLaDA diffusion model with the [Proximile/LLaDA-8B-Tools-LoRA](https://huggingface.co/Proximile/LLaDA-8B-Tools-LoRA) adapter for enhanced tool calling capabilities.")
400
-
401
- # STATE MANAGEMENT
402
- chat_history = gr.State([])
403
- waiting_for_tool_response = gr.State(False)
404
- current_tool_call = gr.State(None)
405
-
406
- # UI COMPONENTS
407
- with gr.Row():
408
- with gr.Column(scale=3):
409
- chatbot_ui = gr.Chatbot(
410
- label="Conversation",
411
- height=500,
412
- type="messages" # Fix the warning about type parameter
413
- )
414
-
415
- # Message input
416
- with gr.Group():
417
- with gr.Row():
418
- user_input = gr.Textbox(
419
- label="Your Message",
420
- placeholder="Type your message here...",
421
- show_label=False
422
- )
423
- send_btn = gr.Button("Send")
424
-
425
- # Tool response input (initially hidden)
426
- with gr.Group(visible=False) as tool_response_group:
427
- gr.Markdown("## Tool Call Detected")
428
- tool_name_display = gr.Textbox(label="Tool Name", interactive=False)
429
- tool_params_display = gr.JSON(label="Parameters")
430
- tool_response_input = gr.Textbox(
431
- label="Tool Response (JSON)",
432
- placeholder="Enter JSON response for the tool...",
433
- lines=5
434
- )
435
- submit_tool_response = gr.Button("Submit Tool Response")
436
-
437
- # Add a button for auto-filling dummy response
438
- dummy_response_btn = gr.Button("Use Dummy Response")
439
-
440
- with gr.Column(scale=2):
441
- gr.Markdown("## Diffusion Process Visualization")
442
- gr.Markdown("Watch tokens appear in real-time as the diffusion process progresses:")
443
- output_vis = gr.HighlightedText(
444
- label="Token Denoising",
445
- combine_adjacent=False,
446
- show_legend=True,
447
- elem_classes="visualization-container"
448
- )
449
- gr.Markdown("**Color Legend:**")
450
- gr.Markdown("- **Dark Gray** [MASK]: Not yet revealed")
451
- gr.Markdown("- **Light Red**: Newly revealed with low confidence")
452
- gr.Markdown("- **Orange**: Newly revealed with medium confidence")
453
- gr.Markdown("- **Light Green**: Newly revealed with high confidence")
454
- gr.Markdown("- **Light Blue**: Previously revealed tokens")
455
-
456
- # Advanced generation settings
457
- with gr.Accordion("Generation Settings", open=False):
458
- with gr.Row():
459
- gen_length = gr.Slider(
460
- minimum=8, maximum=128, value=64, step=4,
461
- label="Generation Length"
462
- )
463
- steps = gr.Slider(
464
- minimum=8, maximum=128, value=64, step=4,
465
- label="Denoising Steps"
466
- )
467
- with gr.Row():
468
- temperature = gr.Slider(
469
- minimum=0.0, maximum=1.0, value=0.1, step=0.1,
470
- label="Temperature"
471
- )
472
- cfg_scale = gr.Slider(
473
- minimum=0.0, maximum=2.0, value=0.0, step=0.1,
474
- label="CFG Scale"
475
- )
476
- with gr.Row():
477
- block_length = gr.Slider(
478
- minimum=8, maximum=128, value=32, step=8,
479
- label="Block Length"
480
- )
481
- remasking_strategy = gr.Radio(
482
- choices=["low_confidence", "random"],
483
- value="low_confidence",
484
- label="Remasking Strategy"
485
- )
486
- with gr.Row():
487
- visualization_delay = gr.Slider(
488
- minimum=0.0, maximum=1.0, value=0.1, step=0.1,
489
- label="Visualization Delay (seconds)"
490
- )
491
-
492
- # Current response text box (hidden)
493
- current_response = gr.Textbox(
494
- label="Current Response",
495
- placeholder="The assistant's response will appear here...",
496
- lines=3,
497
- visible=False
498
- )
499
-
500
- # Clear button
501
- clear_btn = gr.Button("Clear Conversation")
502
-
503
- gr.Markdown("### Try asking about the weather to trigger a tool call!")
504
- gr.Markdown("Examples: 'What's the weather like in New York?', 'How hot is it in Tokyo right now?'")
505
-
506
- # System prompt for the model
507
- system_prompt = f"""You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original user question.
508
-
509
- If you choose to use one or more of the following tool functions, respond with a list of JSON function calls, each with the proper arguments that best answers the given prompt.
510
-
511
- Each tool request within the list should be in the exact format {{"name": function name, "parameters": {{dictionary of argument names and values}}}}. Do not use variables. Just a list of two-key dictionaries, each starting with the function name, followed by a dictionary of parameters.
512
-
513
- Here are the tool functions available to you:
514
-
515
- {json.dumps([weather_tool.schema], indent=4)}
516
-
517
- After receiving the results back from a function call, you have to formulate your response to the user. If the information needed is not found in the returned data, either attempt a new function call, or inform the user that you cannot answer based on your available knowledge. The user cannot see the function results. You have to interpret the data and provide a response based on it.
518
-
519
- If the user request does not necessitate a function call, simply respond to the user's query directly."""
520
-
521
- # HELPER FUNCTIONS
522
- def add_message(history, message, response):
523
- """Add a message pair to the history and return the updated history"""
524
- history = history.copy()
525
- history.append([message, response])
526
- return history
527
-
528
- def user_message_submitted(message, history, waiting_for_tool):
529
- """Process a submitted user message"""
530
- # Skip empty messages or if waiting for a tool response
531
- if not message.strip() or waiting_for_tool:
532
- # Return current state unchanged
533
- history_for_display = history.copy()
534
- return history, history_for_display, "", [], ""
535
-
536
- # Add user message to history
537
- history = add_message(history, message, None)
538
-
539
- # Format for display - temporarily show user message with empty response
540
- history_for_display = history.copy()
541
-
542
- # Clear the input
543
- message_out = ""
544
-
545
- # Return immediately to update UI with user message
546
- return history, history_for_display, message_out, [], ""
547
-
548
- def bot_response(history, waiting_for_tool, current_tool,
549
- gen_length, steps, delay, temperature,
550
- cfg_scale, block_length, remasking):
551
- """Generate bot response for the latest message"""
552
- if not history or waiting_for_tool:
553
- return history, [], "", waiting_for_tool, current_tool, gr.update(visible=False), gr.update(), gr.update()
554
-
555
- # Get the last user message
556
- last_user_message = history[-1][0]
557
-
558
- try:
559
- # Format the conversation for the model
560
- messages = []
561
-
562
- # Add system message first
563
- messages.append({"role": "system", "content": system_prompt})
564
-
565
- # Add conversation history
566
- for h in history[:-1]:
567
- messages.append({"role": "user", "content": h[0]})
568
- if h[1]: # Only include assistant responses that exist
569
- messages.append({"role": "assistant", "content": h[1]})
570
-
571
- # Add the last user message
572
- messages.append({"role": "user", "content": last_user_message})
573
-
574
- # Generate response with visualization
575
- vis_states, response_text = generate_response_with_visualization(
576
- messages,
577
- gen_length=gen_length,
578
- steps=steps,
579
- temperature=temperature,
580
- cfg_scale=cfg_scale,
581
- block_length=block_length,
582
- remasking=remasking
583
- )
584
-
585
- # Update history with the assistant's response
586
- history[-1][1] = response_text
587
-
588
- # Check if the response is a tool call
589
- is_tool = is_tool_call(response_text)
590
-
591
- if is_tool:
592
- # Extract tool call information
593
- tool_call = extract_tool_call(response_text)
594
-
595
- # Return the initial state immediately
596
- yield (history, vis_states[0], response_text,
597
- True, tool_call,
598
- gr.update(visible=True),
599
- gr.update(value=tool_call["name"]),
600
- gr.update(value=tool_call["parameters"]))
601
-
602
- # Then animate through visualization states
603
- for state in vis_states[1:]:
604
- time.sleep(delay)
605
- yield (history, state, response_text,
606
- True, tool_call,
607
- gr.update(visible=True),
608
- gr.update(value=tool_call["name"]),
609
- gr.update(value=tool_call["parameters"]))
610
- else:
611
- # Return the initial state immediately
612
- yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
613
-
614
- # Then animate through visualization states
615
- for state in vis_states[1:]:
616
- time.sleep(delay)
617
- yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
618
-
619
- except Exception as e:
620
- error_msg = f"Error: {str(e)}"
621
- print(error_msg)
622
-
623
- # Show error in visualization
624
- error_vis = [(error_msg, "red")]
625
-
626
- # Don't update history with error
627
- yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
628
-
629
- def process_tool_response(tool_response, history, current_tool,
630
- gen_length, steps, delay, temperature,
631
- cfg_scale, block_length, remasking):
632
- """Process tool response and generate a follow-up response"""
633
- if not history or not current_tool:
634
- return history, [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
635
-
636
- try:
637
- # Parse the tool response
638
- response_data = json.loads(tool_response) if isinstance(tool_response, str) else tool_response
639
-
640
- # Format the conversation for the model
641
- messages = []
642
-
643
- # Add system message first
644
- messages.append({"role": "system", "content": system_prompt})
645
-
646
- # Add conversation history
647
- for h in history:
648
- messages.append({"role": "user", "content": h[0]})
649
- if h[1]: # Only include assistant responses that exist
650
- messages.append({"role": "assistant", "content": h[1]})
651
-
652
- # Add the tool response
653
- messages.append({"role": "ipython", "content": json.dumps({
654
- "name": current_tool["name"],
655
- "return": response_data
656
- })})
657
-
658
- # Generate response with visualization
659
- vis_states, response_text = generate_response_with_visualization(
660
- messages,
661
- gen_length=gen_length,
662
- steps=steps,
663
- temperature=temperature,
664
- cfg_scale=cfg_scale,
665
- block_length=block_length,
666
- remasking=remasking
667
- )
668
-
669
- # Add a new message pair for the tool-processed response
670
- history = add_message(history, "Tool response processed", response_text)
671
-
672
- # Return the initial state immediately
673
- yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
674
-
675
- # Then animate through visualization states
676
- for state in vis_states[1:]:
677
- time.sleep(delay)
678
- yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
679
-
680
- except Exception as e:
681
- error_msg = f"Error processing tool response: {str(e)}"
682
- print(error_msg)
683
-
684
- # Show error in visualization
685
- error_vis = [(error_msg, "red")]
686
-
687
- # Don't update history with error
688
- yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
689
-
690
- def generate_dummy_response(current_tool):
691
- """Generate a dummy response for a tool call"""
692
- if not current_tool:
693
- return ""
694
-
695
- # Process based on tool name
696
- if current_tool["name"] == weather_tool.json_name:
697
- location = current_tool["parameters"].get("location", "Unknown")
698
- unit = current_tool["parameters"].get("unit", "celsius")
699
-
700
- dummy_data = {
701
- "location": location,
702
- "temperature": 72 if unit == "fahrenheit" else 22,
703
- "unit": unit,
704
- "condition": "Partly Cloudy",
705
- "humidity": 65,
706
- "wind_speed": 8,
707
- "wind_direction": "NE"
708
- }
709
-
710
- return json.dumps(dummy_data, indent=2)
711
-
712
- return "{}"
713
-
714
- def clear_conversation():
715
- """Clear the conversation history"""
716
- return [], [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
717
-
718
- # Connect event handlers for chat interface
719
- clear_btn.click(
720
- fn=clear_conversation,
721
- inputs=[],
722
- outputs=[chat_history, chatbot_ui, current_response, waiting_for_tool_response,
723
- current_tool_call, tool_response_group, tool_name_display, tool_params_display]
724
- )
725
-
726
- dummy_response_btn.click(
727
- fn=generate_dummy_response,
728
- inputs=[current_tool_call],
729
- outputs=[tool_response_input]
730
- )
731
-
732
- msg_submit = user_input.submit(
733
- fn=user_message_submitted,
734
- inputs=[user_input, chat_history, waiting_for_tool_response],
735
- outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
736
  )
737
-
738
- send_click = send_btn.click(
739
- fn=user_message_submitted,
740
- inputs=[user_input, chat_history, waiting_for_tool_response],
741
- outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
 
 
 
 
 
 
 
 
 
 
 
 
742
  )
743
-
744
- msg_submit.then(
745
- fn=bot_response,
746
- inputs=[
747
- chat_history, waiting_for_tool_response, current_tool_call,
748
- gen_length, steps, visualization_delay, temperature,
749
- cfg_scale, block_length, remasking_strategy
750
- ],
751
- outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
752
- current_tool_call, tool_response_group, tool_name_display, tool_params_display]
753
- )
754
-
755
- send_click.then(
756
- fn=bot_response,
757
- inputs=[
758
- chat_history, waiting_for_tool_response, current_tool_call,
759
- gen_length, steps, visualization_delay, temperature,
760
- cfg_scale, block_length, remasking_strategy
761
- ],
762
- outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
763
- current_tool_call, tool_response_group, tool_name_display, tool_params_display]
764
- )
765
-
766
- submit_tool_response.click(
767
- fn=process_tool_response,
768
- inputs=[
769
- tool_response_input, chat_history, current_tool_call,
770
- gen_length, steps, visualization_delay, temperature,
771
- cfg_scale, block_length, remasking_strategy
772
- ],
773
- outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
774
- current_tool_call, tool_response_group, tool_name_display, tool_params_display]
775
- )
776
-
777
- return chat_demo
778
-
779
- # Start loading the model in a separate thread immediately
780
- loading_thread = threading.Thread(target=load_model)
781
- loading_thread.daemon = True
782
- loading_thread.start()
783
-
784
- # Choose which app to launch based on model loading state
785
- if is_model_loaded:
786
- # Model is already loaded (unlikely but possible)
787
- demo = chat_app()
788
- demo.queue().launch(share=True)
789
- else:
790
- # First show the loading app
791
- loading = loading_app()
792
- loading.queue().launch(share=True)
793
-
794
- # Wait for model to load
795
- while not is_model_loaded:
796
- time.sleep(1)
797
-
798
- # Close the loading app and show the chat app
799
- loading.close()
800
- demo = chat_app()
801
- demo.queue().launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio_client import Client
3
+ import json
4
  import time
5
  import os
 
 
6
 
7
+ # Set the API endpoint (can be customized)
8
+ API_ENDPOINT = os.environ.get("LLADA_API_ENDPOINT", "http://127.0.0.1:7880")
 
 
 
 
9
 
10
+ # Create the client
11
+ client = Client(API_ENDPOINT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Default generation parameters
14
+ DEFAULT_GEN_LENGTH = 64
15
+ DEFAULT_STEPS = 64
16
+ DEFAULT_DELAY = 0.1
17
+ DEFAULT_TEMPERATURE = 0.1
18
+ DEFAULT_CFG_SCALE = 0
19
+ DEFAULT_BLOCK_LENGTH = 32
20
+ DEFAULT_REMASKING = "low_confidence"
 
 
 
21
 
22
+ def clear_conversation():
23
+ """Clear the conversation history"""
24
+ result = client.predict(api_name="/clear_conversation")
25
+ return [], "", None
26
 
27
+ def send_user_message(message, history, waiting_for_tool, current_tool_call):
28
+ """Send a user message to the API"""
29
+ if not message.strip() or waiting_for_tool:
30
+ return history, "", waiting_for_tool, current_tool_call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Add user message to history for display
33
+ history = history + [[message, None]]
 
 
 
 
34
 
35
+ # Send message to API
36
+ result = client.predict(message, api_name="/user_message_submitted")
37
 
38
+ # Process result - get bot response
39
+ bot_result = client.predict(
40
+ DEFAULT_GEN_LENGTH, DEFAULT_STEPS, DEFAULT_DELAY,
41
+ DEFAULT_TEMPERATURE, DEFAULT_CFG_SCALE, DEFAULT_BLOCK_LENGTH,
42
+ DEFAULT_REMASKING, api_name="/bot_response"
43
+ )
 
 
 
 
 
 
 
44
 
45
+ # Extract the response text and update history
46
+ response_text = bot_result[2]
47
+ history[-1][1] = response_text
48
 
49
+ # Check if the response is a tool call
50
+ is_tool_call = False
51
+ tool_call = None
52
 
53
+ try:
54
+ # Check if response is a valid JSON tool call
55
+ if (response_text.strip().startswith('{') and response_text.strip().endswith('}')) or \
56
+ (response_text.strip().startswith('[') and response_text.strip().endswith(']')):
57
+ json_data = json.loads(response_text)
58
+
59
+ # Check if it has the tool call structure
60
+ if isinstance(json_data, list):
61
+ for item in json_data:
62
+ if isinstance(item, dict) and "name" in item and "parameters" in item:
63
+ is_tool_call = True
64
+ tool_call = item
65
+ break
66
+ elif isinstance(json_data, dict) and "name" in json_data and "parameters" in json_data:
67
+ is_tool_call = True
68
+ tool_call = json_data
69
+ except:
70
+ # Not a valid JSON or tool call
71
+ pass
72
 
73
+ return history, "", is_tool_call, tool_call
74
+
75
+ def generate_dummy_response(tool_call):
76
+ """Generate a dummy response for a tool call"""
77
+ if not tool_call:
78
+ return ""
79
 
80
+ result = client.predict(api_name="/generate_dummy_response")
81
+ return result
82
+
83
+ def submit_tool_response(tool_response, history, tool_call):
84
+ """Submit a tool response to the API"""
85
+ if not tool_response.strip() or not tool_call:
86
+ return history, False, None
87
 
88
+ # Process the tool response
89
+ result = client.predict(
90
+ tool_response, DEFAULT_GEN_LENGTH, DEFAULT_STEPS, DEFAULT_DELAY,
91
+ DEFAULT_TEMPERATURE, DEFAULT_CFG_SCALE, DEFAULT_BLOCK_LENGTH,
92
+ DEFAULT_REMASKING, api_name="/process_tool_response"
93
+ )
94
 
95
+ # Extract the follow-up response
96
+ response_text = result[2]
 
97
 
98
+ # Add the response to history
99
+ history = history + [["Tool response processed", response_text]]
100
 
101
+ return history, False, None
102
+
103
+ # Create the Gradio interface
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# LLaDA Chat Client")
106
+ gr.Markdown("This is a client application for the LLaDA Diffusion Model with Tool Calls.")
107
 
108
+ # State variables
109
+ chat_history = gr.State([])
110
+ waiting_for_tool = gr.State(False)
111
+ current_tool_call = gr.State(None)
112
 
113
+ # Chat interface
114
+ chatbot = gr.Chatbot(label="Conversation")
 
 
115
 
116
+ # Main chat input
117
+ with gr.Row():
118
+ msg = gr.Textbox(
119
+ show_label=False,
120
+ placeholder="Enter your message here...",
121
+ scale=7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
+ submit_btn = gr.Button("Send", scale=1)
124
+
125
+ # Add clickable examples
126
+ gr.Examples(
127
+ examples=["What is the weather like in NYC? Use Fahrenheit."],
128
+ inputs=msg
129
+ )
130
+
131
+ # Tool response section (initially hidden)
132
+ with gr.Group(visible=False) as tool_group:
133
+ gr.Markdown("## Tool Call Detected")
134
+ tool_name = gr.Textbox(label="Tool Name", interactive=False)
135
+ tool_params = gr.JSON(label="Parameters")
136
+ tool_response = gr.Textbox(
137
+ label="Tool Response (JSON)",
138
+ placeholder="Enter JSON response for the tool...",
139
+ lines=5
140
  )
141
+ with gr.Row():
142
+ dummy_btn = gr.Button("Use Dummy Response")
143
+ submit_tool_btn = gr.Button("Submit Tool Response")
144
+
145
+ # Clear conversation button
146
+ clear_btn = gr.Button("Clear Conversation")
147
+
148
+ # Update UI based on tool call detection
149
+ def update_tool_ui(is_tool_call, tool_call):
150
+ if is_tool_call and tool_call:
151
+ return (
152
+ gr.update(visible=True),
153
+ gr.update(value=tool_call.get("name", "")),
154
+ gr.update(value=tool_call.get("parameters", {}))
155
+ )
156
+ else:
157
+ return (
158
+ gr.update(visible=False),
159
+ gr.update(value=""),
160
+ gr.update(value={})
161
+ )
162
+
163
+ # Connect components
164
+ submit_btn.click(
165
+ fn=send_user_message,
166
+ inputs=[msg, chat_history, waiting_for_tool, current_tool_call],
167
+ outputs=[chatbot, msg, waiting_for_tool, current_tool_call]
168
+ ).then(
169
+ fn=update_tool_ui,
170
+ inputs=[waiting_for_tool, current_tool_call],
171
+ outputs=[tool_group, tool_name, tool_params]
172
+ )
173
+
174
+ msg.submit(
175
+ fn=send_user_message,
176
+ inputs=[msg, chat_history, waiting_for_tool, current_tool_call],
177
+ outputs=[chatbot, msg, waiting_for_tool, current_tool_call]
178
+ ).then(
179
+ fn=update_tool_ui,
180
+ inputs=[waiting_for_tool, current_tool_call],
181
+ outputs=[tool_group, tool_name, tool_params]
182
+ )
183
+
184
+ clear_btn.click(
185
+ fn=clear_conversation,
186
+ inputs=[],
187
+ outputs=[chatbot, msg, current_tool_call]
188
+ ).then(
189
+ fn=lambda: gr.update(visible=False),
190
+ inputs=[],
191
+ outputs=[tool_group]
192
+ )
193
+
194
+ dummy_btn.click(
195
+ fn=generate_dummy_response,
196
+ inputs=[current_tool_call],
197
+ outputs=[tool_response]
198
+ )
199
+
200
+ submit_tool_btn.click(
201
+ fn=submit_tool_response,
202
+ inputs=[tool_response, chat_history, current_tool_call],
203
+ outputs=[chatbot, waiting_for_tool, current_tool_call]
204
+ ).then(
205
+ fn=update_tool_ui,
206
+ inputs=[waiting_for_tool, current_tool_call],
207
+ outputs=[tool_group, tool_name, tool_params]
208
+ )
209
+
210
+ if __name__ == "__main__":
211
+ demo.launch()