ProximileAdmin's picture
Update app.py
4a09143 verified
import torch
import json
import gradio as gr
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import time
import re
# Device setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load base model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True)
model = AutoModel.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True, torch_dtype=torch.bfloat16, load_in_4bit=True)
model.eval()
# Constants
MASK_TOKEN = "[MASK]"
MASK_ID = 126336 # The token ID of [MASK] in LLaDA
# Tool class definitions
class ToolBase:
def __init__(self,
programmatic_name,
natural_name,
description,
input_params,
required_params=None,
):
self.json_name = programmatic_name
self.json_description = description
self.schema = {
"type": "function",
"function": {
"name": self.json_name,
"description": self.json_description,
"parameters": {
"type": "object",
"properties": input_params,
"required": required_params or []
}
}
}
def actual_function(self, **kwargs):
raise NotImplementedError("Subclasses must implement this method.")
class WeatherAPITool(ToolBase):
def __init__(self):
super().__init__(
programmatic_name="get_weather",
natural_name="Weather Report Fetcher",
description="Get the current weather in a given location",
input_params={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature"
}
},
required_params=["location", "unit"],
)
def actual_function(self, **kwargs):
# This would normally call an API, but we'll return dummy data
return {
"location": kwargs["location"],
"temperature": 72 if kwargs["unit"] == "fahrenheit" else 22,
"unit": kwargs["unit"],
"condition": "Partly Cloudy",
"humidity": 65,
"wind_speed": 8,
"wind_direction": "NE"
}
# Create the tool
weather_tool = WeatherAPITool()
# Diffusion model generation functions
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
For MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
'''
if temperature <= 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(mask_index, steps):
'''
In the reverse process, we precompute the number of tokens to transition at each step.
'''
mask_num = mask_index.sum(dim=1, keepdim=True)
# Ensure we have at least one step
if steps == 0:
steps = 1
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
if remainder[i] > 0:
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=128, steps=128,
temperature=0.1, cfg_scale=0.0, block_length=32,
remasking='low_confidence'):
"""
Generate text with LLaDA model with visualization
"""
# Prepare the prompt using chat template
chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(chat_input)['input_ids']
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
# For generation
prompt_length = input_ids.shape[1]
# Initialize the sequence with masks for the response part
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
x[:, :prompt_length] = input_ids.clone()
# Initialize visualization states for the response part
visualization_states = []
# Add initial state (all masked)
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
visualization_states.append(initial_state)
# Mark prompt positions to exclude them from masking during classifier-free guidance
prompt_index = (x != MASK_ID)
# Ensure block_length is valid
if block_length > gen_length:
block_length = gen_length
# Calculate number of blocks
num_blocks = gen_length // block_length
if gen_length % block_length != 0:
num_blocks += 1
# Adjust steps per block
steps_per_block = steps // num_blocks
if steps_per_block < 1:
steps_per_block = 1
# Process each block
for num_block in range(num_blocks):
# Calculate the start and end indices for the current block
block_start = prompt_length + num_block * block_length
block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
# Get mask indices for the current block
block_mask_index = (x[:, block_start:block_end] == MASK_ID)
# Skip if no masks in this block
if not block_mask_index.any():
continue
# Calculate number of tokens to unmask at each step
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
# Process each step
for i in range(steps_per_block):
# Get all mask positions in the current sequence
mask_index = (x == MASK_ID)
# Skip if no masks
if not mask_index.any():
break
# Apply classifier-free guidance if enabled
if cfg_scale > 0.0:
un_x = x.clone()
un_x[prompt_index] = MASK_ID
x_ = torch.cat([x, un_x], dim=0)
logits = model(x_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x).logits
# Apply Gumbel noise for sampling
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
# Calculate confidence scores for remasking
if remasking == 'low_confidence':
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented")
# Don't consider positions beyond the current block
x0_p[:, block_end:] = -float('inf')
# Apply predictions where we have masks
old_x = x.clone()
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -float('inf'))
# Select tokens to unmask based on confidence
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
for j in range(confidence.shape[0]):
# Only consider positions within the current block for unmasking
block_confidence = confidence[j, block_start:block_end]
if i < steps_per_block - 1: # Not the last step
# Take top-k confidences
_, select_indices = torch.topk(block_confidence,
k=min(num_transfer_tokens[j, i].item(),
block_confidence.numel()))
# Adjust indices to global positions
select_indices = select_indices + block_start
transfer_index[j, select_indices] = True
else: # Last step - unmask everything remaining
transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end]
# Apply the selected tokens
x = torch.where(transfer_index, x0, x)
# Create visualization state only for the response part
current_state = []
for i in range(gen_length):
pos = prompt_length + i # Absolute position in the sequence
if x[0, pos] == MASK_ID:
# Still masked
current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
elif old_x[0, pos] == MASK_ID:
# Newly revealed in this step
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
# Color based on confidence
confidence = float(x0_p[0, pos].cpu())
if confidence < 0.3:
color = "#FF6666" # Light red
elif confidence < 0.7:
color = "#FFAA33" # Orange
else:
color = "#66CC66" # Light green
current_state.append((token, color))
else:
# Previously revealed
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
current_state.append((token, "#6699CC")) # Light blue
visualization_states.append(current_state)
# Extract final text (just the assistant's response)
response_tokens = x[0, prompt_length:]
final_text = tokenizer.decode(response_tokens,
skip_special_tokens=False,
clean_up_tokenization_spaces=True).split("<|")[0]
return visualization_states, final_text
# Tool handling functions
def is_tool_call(text):
"""Check if the text looks like a JSON tool call"""
# Remove any whitespace at beginning and end
text = text.strip()
# Check if it starts with [ or { (common JSON indicators)
if (text.startswith('[') and text.endswith(']')) or (text.startswith('{') and text.endswith('}')):
try:
# Try to parse as JSON
data = json.loads(text)
# Check if it contains a tool call structure
if isinstance(data, list):
for item in data:
if isinstance(item, dict) and "name" in item and "parameters" in item:
return True
elif isinstance(data, dict) and "name" in data and "parameters" in data:
return True
except:
pass
return False
def extract_tool_call(text):
"""Extract tool call data from text"""
try:
data = json.loads(text)
if isinstance(data, list) and len(data) > 0:
# Return the first valid tool call
for item in data:
if isinstance(item, dict) and "name" in item and "parameters" in item:
return item
elif isinstance(data, dict) and "name" in data and "parameters" in data:
return data
except:
pass
return None
def handle_tool_call(tool_call):
"""Process a tool call and return the result"""
if tool_call["name"] == weather_tool.json_name:
return weather_tool.actual_function(**tool_call["parameters"])
return {"error": f"Tool {tool_call['name']} not found"}
# Custom CSS
css = '''
.category-legend{display:none}
button{height: 60px}
.visualization-container {
margin-top: 20px;
padding: 10px;
background-color: #f8f9fa;
border-radius: 8px;
}
'''
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo")
gr.Markdown("This demo showcases the LLaDA diffusion model with the [Proximile/LLaDA-8B-Tools-LoRA](https://huggingface.co/Proximile/LLaDA-8B-Tools-LoRA) adapter for enhanced tool calling capabilities.")
# STATE MANAGEMENT
chat_history = gr.State([])
waiting_for_tool_response = gr.State(False)
current_tool_call = gr.State(None)
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(label="Conversation", height=500)
# Message input
with gr.Group():
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
show_label=False
)
send_btn = gr.Button("Send")
# Tool response input (initially hidden)
with gr.Group(visible=False) as tool_response_group:
gr.Markdown("## Tool Call Detected")
tool_name_display = gr.Textbox(label="Tool Name", interactive=False)
tool_params_display = gr.JSON(label="Parameters")
tool_response_input = gr.Textbox(
label="Tool Response (JSON)",
placeholder="Enter JSON response for the tool...",
lines=5
)
submit_tool_response = gr.Button("Submit Tool Response")
# Add a button for auto-filling dummy response
dummy_response_btn = gr.Button("Use Dummy Response")
with gr.Column(scale=2):
gr.Markdown("## Diffusion Process Visualization")
gr.Markdown("Watch tokens appear in real-time as the diffusion process progresses:")
output_vis = gr.HighlightedText(
label="Token Denoising",
combine_adjacent=False,
show_legend=True,
elem_classes="visualization-container"
)
gr.Markdown("**Color Legend:**")
gr.Markdown("- **Dark Gray** [MASK]: Not yet revealed")
gr.Markdown("- **Light Red**: Newly revealed with low confidence")
gr.Markdown("- **Orange**: Newly revealed with medium confidence")
gr.Markdown("- **Light Green**: Newly revealed with high confidence")
gr.Markdown("- **Light Blue**: Previously revealed tokens")
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
gen_length = gr.Slider(
minimum=8, maximum=128, value=64, step=4,
label="Generation Length"
)
steps = gr.Slider(
minimum=8, maximum=128, value=64, step=4,
label="Denoising Steps"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.1,
label="Temperature"
)
cfg_scale = gr.Slider(
minimum=0.0, maximum=2.0, value=0.0, step=0.1,
label="CFG Scale"
)
with gr.Row():
block_length = gr.Slider(
minimum=8, maximum=128, value=32, step=8,
label="Block Length"
)
remasking_strategy = gr.Radio(
choices=["low_confidence", "random"],
value="low_confidence",
label="Remasking Strategy"
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.1,
label="Visualization Delay (seconds)"
)
# Current response text box (hidden)
current_response = gr.Textbox(
label="Current Response",
placeholder="The assistant's response will appear here...",
lines=3,
visible=False
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
gr.Markdown("### Try asking about the weather to trigger a tool call!")
gr.Markdown("Examples: 'What's the weather like in New York?', 'How hot is it in Tokyo right now?'")
# System prompt for the model
system_prompt = f"""You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original user question.
If you choose to use one or more of the following tool functions, respond with a list of JSON function calls, each with the proper arguments that best answers the given prompt.
Each tool request within the list should be in the exact format {{"name": function name, "parameters": {{dictionary of argument names and values}}}}. Do not use variables. Just a list of two-key dictionaries, each starting with the function name, followed by a dictionary of parameters.
Here are the tool functions available to you:
{json.dumps([weather_tool.schema], indent=4)}
After receiving the results back from a function call, you have to formulate your response to the user. If the information needed is not found in the returned data, either attempt a new function call, or inform the user that you cannot answer based on your available knowledge. The user cannot see the function results. You have to interpret the data and provide a response based on it.
If the user request does not necessitate a function call, simply respond to the user's query directly."""
# HELPER FUNCTIONS
def add_message(history, message, response):
"""Add a message pair to the history and return the updated history"""
history = history.copy()
history.append([message, response])
return history
def user_message_submitted(message, history, waiting_for_tool):
"""Process a submitted user message"""
# Skip empty messages or if waiting for a tool response
if not message.strip() or waiting_for_tool:
# Return current state unchanged
history_for_display = history.copy()
return history, history_for_display, "", [], ""
# Add user message to history
history = add_message(history, message, None)
# Format for display - temporarily show user message with empty response
history_for_display = history.copy()
# Clear the input
message_out = ""
# Return immediately to update UI with user message
return history, history_for_display, message_out, [], ""
def bot_response(history, waiting_for_tool, current_tool,
gen_length, steps, delay, temperature,
cfg_scale, block_length, remasking):
"""Generate bot response for the latest message"""
if not history or waiting_for_tool:
return history, [], "", waiting_for_tool, current_tool, gr.update(visible=False), gr.update(), gr.update()
# Get the last user message
last_user_message = history[-1][0]
try:
# Format the conversation for the model
messages = []
# Add system message first
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for h in history[:-1]:
messages.append({"role": "user", "content": h[0]})
if h[1]: # Only include assistant responses that exist
messages.append({"role": "assistant", "content": h[1]})
# Add the last user message
messages.append({"role": "user", "content": last_user_message})
# Generate response with visualization
vis_states, response_text = generate_response_with_visualization(
model, tokenizer, device,
messages,
gen_length=gen_length,
steps=steps,
temperature=temperature,
cfg_scale=cfg_scale,
block_length=block_length,
remasking=remasking
)
# Update history with the assistant's response
history[-1][1] = response_text
# Check if the response is a tool call
is_tool = is_tool_call(response_text)
if is_tool:
# Extract tool call information
tool_call = extract_tool_call(response_text)
# Return the initial state immediately
yield (history, vis_states[0], response_text,
True, tool_call,
gr.update(visible=True),
gr.update(value=tool_call["name"]),
gr.update(value=tool_call["parameters"]))
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield (history, state, response_text,
True, tool_call,
gr.update(visible=True),
gr.update(value=tool_call["name"]),
gr.update(value=tool_call["parameters"]))
else:
# Return the initial state immediately
yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
# Show error in visualization
error_vis = [(error_msg, "red")]
# Don't update history with error
yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
def process_tool_response(tool_response, history, current_tool,
gen_length, steps, delay, temperature,
cfg_scale, block_length, remasking):
"""Process tool response and generate a follow-up response"""
if not history or not current_tool:
return history, [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
try:
# Parse the tool response
response_data = json.loads(tool_response) if isinstance(tool_response, str) else tool_response
# Format the conversation for the model
messages = []
# Add system message first
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for h in history:
messages.append({"role": "user", "content": h[0]})
if h[1]: # Only include assistant responses that exist
messages.append({"role": "assistant", "content": h[1]})
# Add the tool response
messages.append({"role": "ipython", "content": json.dumps({
"name": current_tool["name"],
"return": response_data
})})
# Generate response with visualization
vis_states, response_text = generate_response_with_visualization(
model, tokenizer, device,
messages,
gen_length=gen_length,
steps=steps,
temperature=temperature,
cfg_scale=cfg_scale,
block_length=block_length,
remasking=remasking
)
# Add a new message pair for the tool-processed response
history = add_message(history, "Tool response processed", response_text)
# Return the initial state immediately
yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update()
except Exception as e:
error_msg = f"Error processing tool response: {str(e)}"
print(error_msg)
# Show error in visualization
error_vis = [(error_msg, "red")]
# Don't update history with error
yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update()
def generate_dummy_response(current_tool):
"""Generate a dummy response for a tool call"""
if not current_tool:
return ""
# Process based on tool name
if current_tool["name"] == weather_tool.json_name:
location = current_tool["parameters"].get("location", "Unknown")
unit = current_tool["parameters"].get("unit", "celsius")
dummy_data = {
"location": location,
"temperature": 72 if unit == "fahrenheit" else 22,
"unit": unit,
"condition": "Partly Cloudy",
"humidity": 65,
"wind_speed": 8,
"wind_direction": "NE"
}
return json.dumps(dummy_data, indent=2)
return "{}"
def clear_conversation():
"""Clear the conversation history"""
return [], [], "", False, None, gr.update(visible=False), gr.update(), gr.update()
# EVENT HANDLERS
# Clear button handler
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
# Dummy response button handler
dummy_response_btn.click(
fn=generate_dummy_response,
inputs=[current_tool_call],
outputs=[tool_response_input]
)
# User message submission flow
msg_submit = user_input.submit(
fn=user_message_submitted,
inputs=[user_input, chat_history, waiting_for_tool_response],
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
)
# Also connect the send button
send_click = send_btn.click(
fn=user_message_submitted,
inputs=[user_input, chat_history, waiting_for_tool_response],
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
)
# Generate bot response
msg_submit.then(
fn=bot_response,
inputs=[
chat_history, waiting_for_tool_response, current_tool_call,
gen_length, steps, visualization_delay, temperature,
cfg_scale, block_length, remasking_strategy
],
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
send_click.then(
fn=bot_response,
inputs=[
chat_history, waiting_for_tool_response, current_tool_call,
gen_length, steps, visualization_delay, temperature,
cfg_scale, block_length, remasking_strategy
],
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
# Tool response submission
submit_tool_response.click(
fn=process_tool_response,
inputs=[
tool_response_input, chat_history, current_tool_call,
gen_length, steps, visualization_delay, temperature,
cfg_scale, block_length, remasking_strategy
],
outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response,
current_tool_call, tool_response_group, tool_name_display, tool_params_display]
)
return demo
demo = create_chatbot_demo()
demo.queue().launch(share=True)