Spaces:
Sleeping
Sleeping
File size: 31,014 Bytes
b9c9c20 d8fb53c b9c9c20 30ec073 b9c9c20 d8fb53c b9c9c20 4a09143 d8fb53c b9c9c20 30ec073 b9c9c20 30ec073 b9c9c20 30ec073 b9c9c20 30ec073 b9c9c20 d8fb53c b9c9c20 30ec073 b9c9c20 30ec073 b9c9c20 30ec073 b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c b9c9c20 d8fb53c 4a09143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 |
import torch
import json
import gradio as gr
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import time
import re
# Device setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load base model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True)
model = AutoModel.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True, torch_dtype=torch.bfloat16, load_in_4bit=True)
model.eval()
# Constants
MASK_TOKEN = "[MASK]"
MASK_ID = 126336 # The token ID of [MASK] in LLaDA
# Tool class definitions
class ToolBase:
def __init__(self,
programmatic_name,
natural_name,
description,
input_params,
required_params=None,
):
self.json_name = programmatic_name
self.json_description = description
self.schema = {
"type": "function",
"function": {
"name": self.json_name,
"description": self.json_description,
"parameters": {
"type": "object",
"properties": input_params,
"required": required_params or []
}
}
}
def actual_function(self, **kwargs):
raise NotImplementedError("Subclasses must implement this method.")
class WeatherAPITool(ToolBase):
def __init__(self):
super().__init__(
programmatic_name="get_weather",
natural_name="Weather Report Fetcher",
description="Get the current weather in a given location",
input_params={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature"
}
},
required_params=["location", "unit"],
)
def actual_function(self, **kwargs):
# This would normally call an API, but we'll return dummy data
return {
"location": kwargs["location"],
"temperature": 72 if kwargs["unit"] == "fahrenheit" else 22,
"unit": kwargs["unit"],
"condition": "Partly Cloudy",
"humidity": 65,
"wind_speed": 8,
"wind_direction": "NE"
}
# Create the tool
weather_tool = WeatherAPITool()
# Diffusion model generation functions
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
For MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
'''
if temperature <= 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(mask_index, steps):
'''
In the reverse process, we precompute the number of tokens to transition at each step.
'''
mask_num = mask_index.sum(dim=1, keepdim=True)
# Ensure we have at least one step
if steps == 0:
steps = 1
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
if remainder[i] > 0:
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=128, steps=128,
temperature=0.1, cfg_scale=0.0, block_length=32,
remasking='low_confidence'):
"""
Generate text with LLaDA model with visualization
"""
# Prepare the prompt using chat template
chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(chat_input)['input_ids']
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
# For generation
prompt_length = input_ids.shape[1]
# Initialize the sequence with masks for the response part
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
x[:, :prompt_length] = input_ids.clone()
# Initialize visualization states for the response part
visualization_states = []
# Add initial state (all masked)
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
visualization_states.append(initial_state)
# Mark prompt positions to exclude them from masking during classifier-free guidance
prompt_index = (x != MASK_ID)
# Ensure block_length is valid
if block_length > gen_length:
block_length = gen_length
# Calculate number of blocks
num_blocks = gen_length // block_length
if gen_length % block_length != 0:
num_blocks += 1
# Adjust steps per block
steps_per_block = steps // num_blocks
if steps_per_block < 1:
steps_per_block = 1
# Process each block
for num_block in range(num_blocks):
# Calculate the start and end indices for the current block
block_start = prompt_length + num_block * block_length
block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
# Get mask indices for the current block
block_mask_index = (x[:, block_start:block_end] == MASK_ID)
# Skip if no masks in this block
if not block_mask_index.any():
continue
# Calculate number of tokens to unmask at each step
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
# Process each step
for i in range(steps_per_block):
# Get all mask positions in the current sequence
mask_index = (x == MASK_ID)
# Skip if no masks
if not mask_index.any():
break
# Apply classifier-free guidance if enabled
if cfg_scale > 0.0:
un_x = x.clone()
un_x[prompt_index] = MASK_ID
x_ = torch.cat([x, un_x], dim=0)
logits = model(x_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x).logits
# Apply Gumbel noise for sampling
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
# Calculate confidence scores for remasking
if remasking == 'low_confidence':
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented")
# Don't consider positions beyond the current block
x0_p[:, block_end:] = -float('inf')
# Apply predictions where we have masks
old_x = x.clone()
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -float('inf'))
# Select tokens to unmask based on confidence
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
for j in range(confidence.shape[0]):
# Only consider positions within the current block for unmasking
block_confidence = confidence[j, block_start:block_end]
if i < steps_per_block - 1: # Not the last step
# Take top-k confidences
_, select_indices = torch.topk(block_confidence,
k=min(num_transfer_tokens[j, i].item(),
block_confidence.numel()))
# Adjust indices to global positions
select_indices = select_indices + block_start
transfer_index[j, select_indices] = True
else: # Last step - unmask everything remaining
transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end]
# Apply the selected tokens
x = torch.where(transfer_index, x0, x)
# Create visualization state only for the response part
current_state = []
for i in range(gen_length):
pos = prompt_length + i # Absolute position in the sequence
if x[0, pos] == MASK_ID:
# Still masked
current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
elif old_x[0, pos] == MASK_ID:
# Newly revealed in this step
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
# Color based on confidence
confidence = float(x0_p[0, pos].cpu())
if confidence < 0.3:
color = "#FF6666" # Light red
elif confidence < 0.7:
color = "#FFAA33" # Orange
else:
color = "#66CC66" # Light green
current_state.append((token, color))
else:
# Previously revealed
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
current_state.append((token, "#6699CC")) # Light blue
visualization_states.append(current_state)
# Extract final text (just the assistant's response)
response_tokens = x[0, prompt_length:]
final_text = tokenizer.decode(response_tokens,
skip_special_tokens=False,
clean_up_tokenization_spaces=True).split("<|")[0]
return visualization_states, final_text
# Tool handling functions
def is_tool_call(text):
"""Check if the text looks like a JSON tool call"""
# Remove any whitespace at beginning and end
text = text.strip()
# Check if it starts with [ or { (common JSON indicators)
if (text.startswith('[') and text.endswith(']')) or (text.startswith('{') and text.endswith('}')):
try:
# Try to parse as JSON
data = json.loads(text)
# Check if it contains a tool call structure
if isinstance(data, list):
for item in data:
if isinstance(item, dict) and "name" in item and "parameters" in item:
return True
elif isinstance(data, dict) and "name" in data and "parameters" in data:
return True
except:
pass
return False
def extract_tool_call(text):
"""Extract tool call data from text"""
try:
data = json.loads(text)
if isinstance(data, list) and len(data) > 0:
# Return the first valid tool call
for item in data:
if isinstance(item, dict) and "name" in item and "parameters" in item:
return item
elif isinstance(data, dict) and "name" in data and "parameters" in data:
return data
except:
pass
return None
def handle_tool_call(tool_call):
"""Process a tool call and return the result"""
if tool_call["name"] == weather_tool.json_name:
return weather_tool.actual_function(**tool_call["parameters"])
return {"error": f"Tool {tool_call['name']} not found"}
# Custom CSS
css = '''
.category-legend{display:none}
button{height: 60px}
.visualization-container {
margin-top: 20px;
padding: 10px;
background-color: #f8f9fa;
border-radius: 8px;
}
'''
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo")
gr.Markdown("This demo showcases the LLaDA diffusion model with the [Proximile/LLaDA-8B-Tools-LoRA](https://huggingface.co/Proximile/LLaDA-8B-Tools-LoRA) adapter for enhanced tool calling capabilities.")
# STATE MANAGEMENT
chat_history = gr.State([])
waiting_for_tool_response = gr.State(False)
current_tool_call = gr.State(None)
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(label="Conversation", height=500)
# Message input
with gr.Group():
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
show_label=False
)
send_btn = gr.Button("Send")
# Tool response input (initially hidden)
with gr.Group(visible=False) as tool_response_group:
gr.Markdown("## Tool Call Detected")
tool_name_display = gr.Textbox(label="Tool Name", interactive=False)
tool_params_display = gr.JSON(label="Parameters")
tool_response_input = gr.Textbox(
label="Tool Response (JSON)",
placeholder="Enter JSON response for the tool...",
lines=5
)
submit_tool_response = gr.Button("Submit Tool Response")
# Add a button for auto-filling dummy response
dummy_response_btn = gr.Button("Use Dummy Response")
with gr.Column(scale=2):
gr.Markdown("## Diffusion Process Visualization")
gr.Markdown("Watch tokens appear in real-time as the diffusion process progresses:")
output_vis = gr.HighlightedText(
label="Token Denoising",
combine_adjacent=False,
show_legend=True,
elem_classes="visualization-container"
)
gr.Markdown("**Color Legend:**")
gr.Markdown("- **Dark Gray** [MASK]: Not yet revealed")
gr.Markdown("- **Light Red**: Newly revealed with low confidence")
gr.Markdown("- **Orange**: Newly revealed with medium confidence")
gr.Markdown("- **Light Green**: Newly revealed with high confidence")
gr.Markdown("- **Light Blue**: Previously revealed tokens")
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
gen_length = gr.Slider(
minimum=8, maximum=128, value=64, step=4,
label="Generation Length"
)
steps = gr.Slider(
minimum=8, maximum=128, value=64, step=4,
label="Denoising Steps"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.1,
label="Temperature"
)
cfg_scale = gr.Slider(
minimum=0.0, maximum=2.0, value=0.0, step=0.1,
label="CFG Scale"
)
with gr.Row():
block_length = gr.Slider(
minimum=8, maximum=128, value=32, step=8,
label="Block Length"
)
remasking_strategy = gr.Radio(
choices=["low_confidence", "random"],
value="low_confidence",
label="Remasking Strategy"
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.1,
label="Visualization Delay (seconds)"
)
# Current response text box (hidden)
current_response = gr.Textbox(
label="Current Response",
placeholder="The assistant's response will appear here...",
lines=3,
visible=False
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
gr.Markdown("### Try asking about the weather to trigger a tool call!")
gr.Markdown("Examples: 'What's the weather like in New York?', 'How hot is it in Tokyo right now?'")
# System prompt for the model
system_prompt = f"""You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original user question.
If you choose to use one or more of the following tool functions, respond with a list of JSON function calls, each with the proper arguments that best answers the given prompt.
Each tool request within the list should be in the exact format {{"name": function name, "parameters": {{dictionary of argument names and values}}}}. Do not use variables. Just a list of two-key dictionaries, each starting with the function name, followed by a dictionary of parameters.
Here are the tool functions available to you:
{json.dumps([weather_tool.schema], indent=4)}
After receiving the results back from a function call, you have to formulate your response to the user. If the information needed is not found in the returned data, either attempt a new function call, or inform the user that you cannot answer based on your available knowledge. The user cannot see the function results. You have to interpret the data and provide a response based on it.
If the user request does not necessitate a function call, simply respond to the user's query directly."""
# HELPER FUNCTIONS
def add_message(history, message, response):
"""Add a message pair to the history and return the updated history"""
history = history.copy()
history.append([message, response])
return history
def user_message_submitted(message, history, waiting_for_tool):
"""Process a submitted user message"""
# Skip empty messages or if waiting for a tool response
if not message.strip() or waiting_for_tool:
# Return current state unchanged
history_for_display = history.copy()
return history, history_for_display, "", [], ""
# Add user message to history
history = add_message(history, message, None)
# Format for display - temporarily show user message with empty response
history_for_display = history.copy()
# Clear the input
message_out = ""
# Return immediately to update UI with user message
return history, history_for_display, message_out, [], ""
def bot_response(history, waiting_for_tool, current_tool,
gen_length, steps, delay, temperature,
cfg_scale, block_length, remasking):
"""Generate bot response for the latest message"""
if not history or waiting_for_tool:
return history, [], "", waiting_for_tool, current_tool, gr.update(visible=False), gr.update(), gr.update()
# Get the last user message
last_user_message = history[-1][0]
try:
# Format the conversation for the model
messages = []
# Add system message first
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for h in history[:-1]:
messages.append({"role": "user", "content": h[0]})
if h[1]: # Only include assistant responses that exist
messages.append({"role": "assistant", "content": h[1]})
# Add the last user message
messages.append({"role": "user", "content": last_user_message})
# Generate response with visualization
vis_states, response_text = generate_response_with_visualization(
model, tokenizer, device,
messages,
gen_length=gen_length,
steps=steps,
temperature=temperature,
cfg_scale=cfg_scale,
block_length=block_length,
remasking=remasking
)
# Update history with the assistant's response
history[-1][1] = response_text
# Check if the response is a tool call
is_tool = is_tool_call(response_text)
if is_tool:
# Extract tool call information
tool_call = extract_tool_call(response_text)
# Return the initial state immediately
yield (history, vis_states[0], response_text,
True, tool_call,
gr.update(visible=True),
gr.update(value=tool_call["name"]),
gr.update(value=tool_call["parameters"]))
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield (history, state, response_text,
True, tool_call,
gr.update(visible=True),
gr.update(value=tool_call["name"]),
gr.update(value=tool_call["parameters"]))
else:
# Return the initial state immediately
yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
# Show error in visualization
error_vis = [(error_msg, "red")]
# Don't update history with error
yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
def process_tool_response(tool_response, history, current_tool,
gen_length, steps, delay, temperature,
cfg_scale, block_length, remasking):
"""Process tool response and generate a follow-up response"""
if not history or not current_tool:
return history, [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
try:
# Parse the tool response
response_data = json.loads(tool_response) if isinstance(tool_response, str) else tool_response
# Format the conversation for the model
messages = []
# Add system message first
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for h in history:
messages.append({"role": "user", "content": h[0]})
if h[1]: # Only include assistant responses that exist
messages.append({"role": "assistant", "content": h[1]})
# Add the tool response
messages.append({"role": "ipython", "content": json.dumps({
"name": current_tool["name"],
"return": response_data
})})
# Generate response with visualization
vis_states, response_text = generate_response_with_visualization(
model, tokenizer, device,
messages,
gen_length=gen_length,
steps=steps,
temperature=temperature,
cfg_scale=cfg_scale,
block_length=block_length,
remasking=remasking
)
# Add a new message pair for the tool-processed response
history = add_message(history, "Tool response processed", response_text)
# Return the initial state immediately
yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
except Exception as e:
error_msg = f"Error processing tool response: {str(e)}"
print(error_msg)
# Show error in visualization
error_vis = [(error_msg, "red")]
# Don't update history with error
yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
def generate_dummy_response(current_tool):
"""Generate a dummy response for a tool call"""
if not current_tool:
return ""
# Process based on tool name
if current_tool["name"] == weather_tool.json_name:
location = current_tool["parameters"].get("location", "Unknown")
unit = current_tool["parameters"].get("unit", "celsius")
dummy_data = {
"location": location,
"temperature": 72 if unit == "fahrenheit" else 22,
"unit": unit,
"condition": "Partly Cloudy",
"humidity": 65,
"wind_speed": 8,
"wind_direction": "NE"
}
return json.dumps(dummy_data, indent=2)
return "{}"
def clear_conversation():
"""Clear the conversation history"""
return [], [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
# EVENT HANDLERS
# Clear button handler
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
# Dummy response button handler
dummy_response_btn.click(
fn=generate_dummy_response,
inputs=[current_tool_call],
outputs=[tool_response_input]
)
# User message submission flow
msg_submit = user_input.submit(
fn=user_message_submitted,
inputs=[user_input, chat_history, waiting_for_tool_response],
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
)
# Also connect the send button
send_click = send_btn.click(
fn=user_message_submitted,
inputs=[user_input, chat_history, waiting_for_tool_response],
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
)
# Generate bot response
msg_submit.then(
fn=bot_response,
inputs=[
chat_history, waiting_for_tool_response, current_tool_call,
gen_length, steps, visualization_delay, temperature,
cfg_scale, block_length, remasking_strategy
],
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
send_click.then(
fn=bot_response,
inputs=[
chat_history, waiting_for_tool_response, current_tool_call,
gen_length, steps, visualization_delay, temperature,
cfg_scale, block_length, remasking_strategy
],
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
# Tool response submission
submit_tool_response.click(
fn=process_tool_response,
inputs=[
tool_response_input, chat_history, current_tool_call,
gen_length, steps, visualization_delay, temperature,
cfg_scale, block_length, remasking_strategy
],
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
return demo
demo = create_chatbot_demo()
demo.queue().launch(share=True) |