Spaces:
Sleeping
Sleeping
import torch | |
import json | |
import gradio as gr | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
import time | |
import re | |
# Device setup | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
# Load base model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True) | |
model = AutoModel.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True, torch_dtype=torch.bfloat16, load_in_4bit=True) | |
model.eval() | |
# Constants | |
MASK_TOKEN = "[MASK]" | |
MASK_ID = 126336 # The token ID of [MASK] in LLaDA | |
# Tool class definitions | |
class ToolBase: | |
def __init__(self, | |
programmatic_name, | |
natural_name, | |
description, | |
input_params, | |
required_params=None, | |
): | |
self.json_name = programmatic_name | |
self.json_description = description | |
self.schema = { | |
"type": "function", | |
"function": { | |
"name": self.json_name, | |
"description": self.json_description, | |
"parameters": { | |
"type": "object", | |
"properties": input_params, | |
"required": required_params or [] | |
} | |
} | |
} | |
def actual_function(self, **kwargs): | |
raise NotImplementedError("Subclasses must implement this method.") | |
class WeatherAPITool(ToolBase): | |
def __init__(self): | |
super().__init__( | |
programmatic_name="get_weather", | |
natural_name="Weather Report Fetcher", | |
description="Get the current weather in a given location", | |
input_params={ | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g. San Francisco, CA" | |
}, | |
"unit": { | |
"type": "string", | |
"enum": ["celsius", "fahrenheit"], | |
"description": "The unit of temperature" | |
} | |
}, | |
required_params=["location", "unit"], | |
) | |
def actual_function(self, **kwargs): | |
# This would normally call an API, but we'll return dummy data | |
return { | |
"location": kwargs["location"], | |
"temperature": 72 if kwargs["unit"] == "fahrenheit" else 22, | |
"unit": kwargs["unit"], | |
"condition": "Partly Cloudy", | |
"humidity": 65, | |
"wind_speed": 8, | |
"wind_direction": "NE" | |
} | |
# Create the tool | |
weather_tool = WeatherAPITool() | |
# Diffusion model generation functions | |
def add_gumbel_noise(logits, temperature): | |
''' | |
The Gumbel max is a method for sampling categorical distributions. | |
For MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. | |
''' | |
if temperature <= 0: | |
return logits | |
logits = logits.to(torch.float64) | |
noise = torch.rand_like(logits, dtype=torch.float64) | |
gumbel_noise = (- torch.log(noise)) ** temperature | |
return logits.exp() / gumbel_noise | |
def get_num_transfer_tokens(mask_index, steps): | |
''' | |
In the reverse process, we precompute the number of tokens to transition at each step. | |
''' | |
mask_num = mask_index.sum(dim=1, keepdim=True) | |
# Ensure we have at least one step | |
if steps == 0: | |
steps = 1 | |
base = mask_num // steps | |
remainder = mask_num % steps | |
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base | |
for i in range(mask_num.size(0)): | |
if remainder[i] > 0: | |
num_transfer_tokens[i, :remainder[i]] += 1 | |
return num_transfer_tokens | |
def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=128, steps=128, | |
temperature=0.1, cfg_scale=0.0, block_length=32, | |
remasking='low_confidence'): | |
""" | |
Generate text with LLaDA model with visualization | |
""" | |
# Prepare the prompt using chat template | |
chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
input_ids = tokenizer(chat_input)['input_ids'] | |
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) | |
# For generation | |
prompt_length = input_ids.shape[1] | |
# Initialize the sequence with masks for the response part | |
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device) | |
x[:, :prompt_length] = input_ids.clone() | |
# Initialize visualization states for the response part | |
visualization_states = [] | |
# Add initial state (all masked) | |
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] | |
visualization_states.append(initial_state) | |
# Mark prompt positions to exclude them from masking during classifier-free guidance | |
prompt_index = (x != MASK_ID) | |
# Ensure block_length is valid | |
if block_length > gen_length: | |
block_length = gen_length | |
# Calculate number of blocks | |
num_blocks = gen_length // block_length | |
if gen_length % block_length != 0: | |
num_blocks += 1 | |
# Adjust steps per block | |
steps_per_block = steps // num_blocks | |
if steps_per_block < 1: | |
steps_per_block = 1 | |
# Process each block | |
for num_block in range(num_blocks): | |
# Calculate the start and end indices for the current block | |
block_start = prompt_length + num_block * block_length | |
block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1]) | |
# Get mask indices for the current block | |
block_mask_index = (x[:, block_start:block_end] == MASK_ID) | |
# Skip if no masks in this block | |
if not block_mask_index.any(): | |
continue | |
# Calculate number of tokens to unmask at each step | |
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) | |
# Process each step | |
for i in range(steps_per_block): | |
# Get all mask positions in the current sequence | |
mask_index = (x == MASK_ID) | |
# Skip if no masks | |
if not mask_index.any(): | |
break | |
# Apply classifier-free guidance if enabled | |
if cfg_scale > 0.0: | |
un_x = x.clone() | |
un_x[prompt_index] = MASK_ID | |
x_ = torch.cat([x, un_x], dim=0) | |
logits = model(x_).logits | |
logits, un_logits = torch.chunk(logits, 2, dim=0) | |
logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
else: | |
logits = model(x).logits | |
# Apply Gumbel noise for sampling | |
logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
x0 = torch.argmax(logits_with_noise, dim=-1) | |
# Calculate confidence scores for remasking | |
if remasking == 'low_confidence': | |
p = F.softmax(logits.to(torch.float64), dim=-1) | |
x0_p = torch.squeeze( | |
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l | |
elif remasking == 'random': | |
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
else: | |
raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented") | |
# Don't consider positions beyond the current block | |
x0_p[:, block_end:] = -float('inf') | |
# Apply predictions where we have masks | |
old_x = x.clone() | |
x0 = torch.where(mask_index, x0, x) | |
confidence = torch.where(mask_index, x0_p, -float('inf')) | |
# Select tokens to unmask based on confidence | |
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
for j in range(confidence.shape[0]): | |
# Only consider positions within the current block for unmasking | |
block_confidence = confidence[j, block_start:block_end] | |
if i < steps_per_block - 1: # Not the last step | |
# Take top-k confidences | |
_, select_indices = torch.topk(block_confidence, | |
k=min(num_transfer_tokens[j, i].item(), | |
block_confidence.numel())) | |
# Adjust indices to global positions | |
select_indices = select_indices + block_start | |
transfer_index[j, select_indices] = True | |
else: # Last step - unmask everything remaining | |
transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end] | |
# Apply the selected tokens | |
x = torch.where(transfer_index, x0, x) | |
# Create visualization state only for the response part | |
current_state = [] | |
for i in range(gen_length): | |
pos = prompt_length + i # Absolute position in the sequence | |
if x[0, pos] == MASK_ID: | |
# Still masked | |
current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks | |
elif old_x[0, pos] == MASK_ID: | |
# Newly revealed in this step | |
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) | |
# Color based on confidence | |
confidence = float(x0_p[0, pos].cpu()) | |
if confidence < 0.3: | |
color = "#FF6666" # Light red | |
elif confidence < 0.7: | |
color = "#FFAA33" # Orange | |
else: | |
color = "#66CC66" # Light green | |
current_state.append((token, color)) | |
else: | |
# Previously revealed | |
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) | |
current_state.append((token, "#6699CC")) # Light blue | |
visualization_states.append(current_state) | |
# Extract final text (just the assistant's response) | |
response_tokens = x[0, prompt_length:] | |
final_text = tokenizer.decode(response_tokens, | |
skip_special_tokens=False, | |
clean_up_tokenization_spaces=True).split("<|")[0] | |
return visualization_states, final_text | |
# Tool handling functions | |
def is_tool_call(text): | |
"""Check if the text looks like a JSON tool call""" | |
# Remove any whitespace at beginning and end | |
text = text.strip() | |
# Check if it starts with [ or { (common JSON indicators) | |
if (text.startswith('[') and text.endswith(']')) or (text.startswith('{') and text.endswith('}')): | |
try: | |
# Try to parse as JSON | |
data = json.loads(text) | |
# Check if it contains a tool call structure | |
if isinstance(data, list): | |
for item in data: | |
if isinstance(item, dict) and "name" in item and "parameters" in item: | |
return True | |
elif isinstance(data, dict) and "name" in data and "parameters" in data: | |
return True | |
except: | |
pass | |
return False | |
def extract_tool_call(text): | |
"""Extract tool call data from text""" | |
try: | |
data = json.loads(text) | |
if isinstance(data, list) and len(data) > 0: | |
# Return the first valid tool call | |
for item in data: | |
if isinstance(item, dict) and "name" in item and "parameters" in item: | |
return item | |
elif isinstance(data, dict) and "name" in data and "parameters" in data: | |
return data | |
except: | |
pass | |
return None | |
def handle_tool_call(tool_call): | |
"""Process a tool call and return the result""" | |
if tool_call["name"] == weather_tool.json_name: | |
return weather_tool.actual_function(**tool_call["parameters"]) | |
return {"error": f"Tool {tool_call['name']} not found"} | |
# Custom CSS | |
css = ''' | |
.category-legend{display:none} | |
button{height: 60px} | |
.visualization-container { | |
margin-top: 20px; | |
padding: 10px; | |
background-color: #f8f9fa; | |
border-radius: 8px; | |
} | |
''' | |
def create_chatbot_demo(): | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo") | |
gr.Markdown("This demo showcases the LLaDA diffusion model with the [Proximile/LLaDA-8B-Tools-LoRA](https://huggingface.co/Proximile/LLaDA-8B-Tools-LoRA) adapter for enhanced tool calling capabilities.") | |
# STATE MANAGEMENT | |
chat_history = gr.State([]) | |
waiting_for_tool_response = gr.State(False) | |
current_tool_call = gr.State(None) | |
# UI COMPONENTS | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot(label="Conversation", height=500) | |
# Message input | |
with gr.Group(): | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
show_label=False | |
) | |
send_btn = gr.Button("Send") | |
# Tool response input (initially hidden) | |
with gr.Group(visible=False) as tool_response_group: | |
gr.Markdown("## Tool Call Detected") | |
tool_name_display = gr.Textbox(label="Tool Name", interactive=False) | |
tool_params_display = gr.JSON(label="Parameters") | |
tool_response_input = gr.Textbox( | |
label="Tool Response (JSON)", | |
placeholder="Enter JSON response for the tool...", | |
lines=5 | |
) | |
submit_tool_response = gr.Button("Submit Tool Response") | |
# Add a button for auto-filling dummy response | |
dummy_response_btn = gr.Button("Use Dummy Response") | |
with gr.Column(scale=2): | |
gr.Markdown("## Diffusion Process Visualization") | |
gr.Markdown("Watch tokens appear in real-time as the diffusion process progresses:") | |
output_vis = gr.HighlightedText( | |
label="Token Denoising", | |
combine_adjacent=False, | |
show_legend=True, | |
elem_classes="visualization-container" | |
) | |
gr.Markdown("**Color Legend:**") | |
gr.Markdown("- **Dark Gray** [MASK]: Not yet revealed") | |
gr.Markdown("- **Light Red**: Newly revealed with low confidence") | |
gr.Markdown("- **Orange**: Newly revealed with medium confidence") | |
gr.Markdown("- **Light Green**: Newly revealed with high confidence") | |
gr.Markdown("- **Light Blue**: Previously revealed tokens") | |
# Advanced generation settings | |
with gr.Accordion("Generation Settings", open=False): | |
with gr.Row(): | |
gen_length = gr.Slider( | |
minimum=8, maximum=128, value=64, step=4, | |
label="Generation Length" | |
) | |
steps = gr.Slider( | |
minimum=8, maximum=128, value=64, step=4, | |
label="Denoising Steps" | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.1, step=0.1, | |
label="Temperature" | |
) | |
cfg_scale = gr.Slider( | |
minimum=0.0, maximum=2.0, value=0.0, step=0.1, | |
label="CFG Scale" | |
) | |
with gr.Row(): | |
block_length = gr.Slider( | |
minimum=8, maximum=128, value=32, step=8, | |
label="Block Length" | |
) | |
remasking_strategy = gr.Radio( | |
choices=["low_confidence", "random"], | |
value="low_confidence", | |
label="Remasking Strategy" | |
) | |
with gr.Row(): | |
visualization_delay = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.1, step=0.1, | |
label="Visualization Delay (seconds)" | |
) | |
# Current response text box (hidden) | |
current_response = gr.Textbox( | |
label="Current Response", | |
placeholder="The assistant's response will appear here...", | |
lines=3, | |
visible=False | |
) | |
# Clear button | |
clear_btn = gr.Button("Clear Conversation") | |
gr.Markdown("### Try asking about the weather to trigger a tool call!") | |
gr.Markdown("Examples: 'What's the weather like in New York?', 'How hot is it in Tokyo right now?'") | |
# System prompt for the model | |
system_prompt = f"""You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original user question. | |
If you choose to use one or more of the following tool functions, respond with a list of JSON function calls, each with the proper arguments that best answers the given prompt. | |
Each tool request within the list should be in the exact format {{"name": function name, "parameters": {{dictionary of argument names and values}}}}. Do not use variables. Just a list of two-key dictionaries, each starting with the function name, followed by a dictionary of parameters. | |
Here are the tool functions available to you: | |
{json.dumps([weather_tool.schema], indent=4)} | |
After receiving the results back from a function call, you have to formulate your response to the user. If the information needed is not found in the returned data, either attempt a new function call, or inform the user that you cannot answer based on your available knowledge. The user cannot see the function results. You have to interpret the data and provide a response based on it. | |
If the user request does not necessitate a function call, simply respond to the user's query directly.""" | |
# HELPER FUNCTIONS | |
def add_message(history, message, response): | |
"""Add a message pair to the history and return the updated history""" | |
history = history.copy() | |
history.append([message, response]) | |
return history | |
def user_message_submitted(message, history, waiting_for_tool): | |
"""Process a submitted user message""" | |
# Skip empty messages or if waiting for a tool response | |
if not message.strip() or waiting_for_tool: | |
# Return current state unchanged | |
history_for_display = history.copy() | |
return history, history_for_display, "", [], "" | |
# Add user message to history | |
history = add_message(history, message, None) | |
# Format for display - temporarily show user message with empty response | |
history_for_display = history.copy() | |
# Clear the input | |
message_out = "" | |
# Return immediately to update UI with user message | |
return history, history_for_display, message_out, [], "" | |
def bot_response(history, waiting_for_tool, current_tool, | |
gen_length, steps, delay, temperature, | |
cfg_scale, block_length, remasking): | |
"""Generate bot response for the latest message""" | |
if not history or waiting_for_tool: | |
return history, [], "", waiting_for_tool, current_tool, gr.update(visible=False), gr.update(), gr.update() | |
# Get the last user message | |
last_user_message = history[-1][0] | |
try: | |
# Format the conversation for the model | |
messages = [] | |
# Add system message first | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add conversation history | |
for h in history[:-1]: | |
messages.append({"role": "user", "content": h[0]}) | |
if h[1]: # Only include assistant responses that exist | |
messages.append({"role": "assistant", "content": h[1]}) | |
# Add the last user message | |
messages.append({"role": "user", "content": last_user_message}) | |
# Generate response with visualization | |
vis_states, response_text = generate_response_with_visualization( | |
model, tokenizer, device, | |
messages, | |
gen_length=gen_length, | |
steps=steps, | |
temperature=temperature, | |
cfg_scale=cfg_scale, | |
block_length=block_length, | |
remasking=remasking | |
) | |
# Update history with the assistant's response | |
history[-1][1] = response_text | |
# Check if the response is a tool call | |
is_tool = is_tool_call(response_text) | |
if is_tool: | |
# Extract tool call information | |
tool_call = extract_tool_call(response_text) | |
# Return the initial state immediately | |
yield (history, vis_states[0], response_text, | |
True, tool_call, | |
gr.update(visible=True), | |
gr.update(value=tool_call["name"]), | |
gr.update(value=tool_call["parameters"])) | |
# Then animate through visualization states | |
for state in vis_states[1:]: | |
time.sleep(delay) | |
yield (history, state, response_text, | |
True, tool_call, | |
gr.update(visible=True), | |
gr.update(value=tool_call["name"]), | |
gr.update(value=tool_call["parameters"])) | |
else: | |
# Return the initial state immediately | |
yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update() | |
# Then animate through visualization states | |
for state in vis_states[1:]: | |
time.sleep(delay) | |
yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update() | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
print(error_msg) | |
# Show error in visualization | |
error_vis = [(error_msg, "red")] | |
# Don't update history with error | |
yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update() | |
def process_tool_response(tool_response, history, current_tool, | |
gen_length, steps, delay, temperature, | |
cfg_scale, block_length, remasking): | |
"""Process tool response and generate a follow-up response""" | |
if not history or not current_tool: | |
return history, [], "", False, None, gr.update(visible=False), gr.update(), gr.update() | |
try: | |
# Parse the tool response | |
response_data = json.loads(tool_response) if isinstance(tool_response, str) else tool_response | |
# Format the conversation for the model | |
messages = [] | |
# Add system message first | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add conversation history | |
for h in history: | |
messages.append({"role": "user", "content": h[0]}) | |
if h[1]: # Only include assistant responses that exist | |
messages.append({"role": "assistant", "content": h[1]}) | |
# Add the tool response | |
messages.append({"role": "ipython", "content": json.dumps({ | |
"name": current_tool["name"], | |
"return": response_data | |
})}) | |
# Generate response with visualization | |
vis_states, response_text = generate_response_with_visualization( | |
model, tokenizer, device, | |
messages, | |
gen_length=gen_length, | |
steps=steps, | |
temperature=temperature, | |
cfg_scale=cfg_scale, | |
block_length=block_length, | |
remasking=remasking | |
) | |
# Add a new message pair for the tool-processed response | |
history = add_message(history, "Tool response processed", response_text) | |
# Return the initial state immediately | |
yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update() | |
# Then animate through visualization states | |
for state in vis_states[1:]: | |
time.sleep(delay) | |
yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update() | |
except Exception as e: | |
error_msg = f"Error processing tool response: {str(e)}" | |
print(error_msg) | |
# Show error in visualization | |
error_vis = [(error_msg, "red")] | |
# Don't update history with error | |
yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update() | |
def generate_dummy_response(current_tool): | |
"""Generate a dummy response for a tool call""" | |
if not current_tool: | |
return "" | |
# Process based on tool name | |
if current_tool["name"] == weather_tool.json_name: | |
location = current_tool["parameters"].get("location", "Unknown") | |
unit = current_tool["parameters"].get("unit", "celsius") | |
dummy_data = { | |
"location": location, | |
"temperature": 72 if unit == "fahrenheit" else 22, | |
"unit": unit, | |
"condition": "Partly Cloudy", | |
"humidity": 65, | |
"wind_speed": 8, | |
"wind_direction": "NE" | |
} | |
return json.dumps(dummy_data, indent=2) | |
return "{}" | |
def clear_conversation(): | |
"""Clear the conversation history""" | |
return [], [], "", False, None, gr.update(visible=False), gr.update(), gr.update() | |
# EVENT HANDLERS | |
# Clear button handler | |
clear_btn.click( | |
fn=clear_conversation, | |
inputs=[], | |
outputs=[chat_history, chatbot_ui, current_response, waiting_for_tool_response, | |
current_tool_call, tool_response_group, tool_name_display, tool_params_display] | |
) | |
# Dummy response button handler | |
dummy_response_btn.click( | |
fn=generate_dummy_response, | |
inputs=[current_tool_call], | |
outputs=[tool_response_input] | |
) | |
# User message submission flow | |
msg_submit = user_input.submit( | |
fn=user_message_submitted, | |
inputs=[user_input, chat_history, waiting_for_tool_response], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] | |
) | |
# Also connect the send button | |
send_click = send_btn.click( | |
fn=user_message_submitted, | |
inputs=[user_input, chat_history, waiting_for_tool_response], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] | |
) | |
# Generate bot response | |
msg_submit.then( | |
fn=bot_response, | |
inputs=[ | |
chat_history, waiting_for_tool_response, current_tool_call, | |
gen_length, steps, visualization_delay, temperature, | |
cfg_scale, block_length, remasking_strategy | |
], | |
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response, | |
current_tool_call, tool_response_group, tool_name_display, tool_params_display] | |
) | |
send_click.then( | |
fn=bot_response, | |
inputs=[ | |
chat_history, waiting_for_tool_response, current_tool_call, | |
gen_length, steps, visualization_delay, temperature, | |
cfg_scale, block_length, remasking_strategy | |
], | |
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response, | |
current_tool_call, tool_response_group, tool_name_display, tool_params_display] | |
) | |
# Tool response submission | |
submit_tool_response.click( | |
fn=process_tool_response, | |
inputs=[ | |
tool_response_input, chat_history, current_tool_call, | |
gen_length, steps, visualization_delay, temperature, | |
cfg_scale, block_length, remasking_strategy | |
], | |
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response, | |
current_tool_call, tool_response_group, tool_name_display, tool_params_display] | |
) | |
return demo | |
demo = create_chatbot_demo() | |
demo.queue().launch(share=True) |