import torch import json import gradio as gr import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import time import re # Device setup device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # Load base model and tokenizer tokenizer = AutoTokenizer.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True) model = AutoModel.from_pretrained("Proximile/LLaDA-8B-Tools", trust_remote_code=True, torch_dtype=torch.bfloat16, load_in_4bit=True) model.eval() # Constants MASK_TOKEN = "[MASK]" MASK_ID = 126336 # The token ID of [MASK] in LLaDA # Tool class definitions class ToolBase: def __init__(self, programmatic_name, natural_name, description, input_params, required_params=None, ): self.json_name = programmatic_name self.json_description = description self.schema = { "type": "function", "function": { "name": self.json_name, "description": self.json_description, "parameters": { "type": "object", "properties": input_params, "required": required_params or [] } } } def actual_function(self, **kwargs): raise NotImplementedError("Subclasses must implement this method.") class WeatherAPITool(ToolBase): def __init__(self): super().__init__( programmatic_name="get_weather", natural_name="Weather Report Fetcher", description="Get the current weather in a given location", input_params={ "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA" }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit of temperature" } }, required_params=["location", "unit"], ) def actual_function(self, **kwargs): # This would normally call an API, but we'll return dummy data return { "location": kwargs["location"], "temperature": 72 if kwargs["unit"] == "fahrenheit" else 22, "unit": kwargs["unit"], "condition": "Partly Cloudy", "humidity": 65, "wind_speed": 8, "wind_direction": "NE" } # Create the tool weather_tool = WeatherAPITool() # Diffusion model generation functions def add_gumbel_noise(logits, temperature): ''' The Gumbel max is a method for sampling categorical distributions. For MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. ''' if temperature <= 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise def get_num_transfer_tokens(mask_index, steps): ''' In the reverse process, we precompute the number of tokens to transition at each step. ''' mask_num = mask_index.sum(dim=1, keepdim=True) # Ensure we have at least one step if steps == 0: steps = 1 base = mask_num // steps remainder = mask_num % steps num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base for i in range(mask_num.size(0)): if remainder[i] > 0: num_transfer_tokens[i, :remainder[i]] += 1 return num_transfer_tokens def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=128, steps=128, temperature=0.1, cfg_scale=0.0, block_length=32, remasking='low_confidence'): """ Generate text with LLaDA model with visualization """ # Prepare the prompt using chat template chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) input_ids = tokenizer(chat_input)['input_ids'] input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) # For generation prompt_length = input_ids.shape[1] # Initialize the sequence with masks for the response part x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device) x[:, :prompt_length] = input_ids.clone() # Initialize visualization states for the response part visualization_states = [] # Add initial state (all masked) initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] visualization_states.append(initial_state) # Mark prompt positions to exclude them from masking during classifier-free guidance prompt_index = (x != MASK_ID) # Ensure block_length is valid if block_length > gen_length: block_length = gen_length # Calculate number of blocks num_blocks = gen_length // block_length if gen_length % block_length != 0: num_blocks += 1 # Adjust steps per block steps_per_block = steps // num_blocks if steps_per_block < 1: steps_per_block = 1 # Process each block for num_block in range(num_blocks): # Calculate the start and end indices for the current block block_start = prompt_length + num_block * block_length block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1]) # Get mask indices for the current block block_mask_index = (x[:, block_start:block_end] == MASK_ID) # Skip if no masks in this block if not block_mask_index.any(): continue # Calculate number of tokens to unmask at each step num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) # Process each step for i in range(steps_per_block): # Get all mask positions in the current sequence mask_index = (x == MASK_ID) # Skip if no masks if not mask_index.any(): break # Apply classifier-free guidance if enabled if cfg_scale > 0.0: un_x = x.clone() un_x[prompt_index] = MASK_ID x_ = torch.cat([x, un_x], dim=0) logits = model(x_).logits logits, un_logits = torch.chunk(logits, 2, dim=0) logits = un_logits + (cfg_scale + 1) * (logits - un_logits) else: logits = model(x).logits # Apply Gumbel noise for sampling logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) # Calculate confidence scores for remasking if remasking == 'low_confidence': p = F.softmax(logits.to(torch.float64), dim=-1) x0_p = torch.squeeze( torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented") # Don't consider positions beyond the current block x0_p[:, block_end:] = -float('inf') # Apply predictions where we have masks old_x = x.clone() x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, x0_p, -float('inf')) # Select tokens to unmask based on confidence transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) for j in range(confidence.shape[0]): # Only consider positions within the current block for unmasking block_confidence = confidence[j, block_start:block_end] if i < steps_per_block - 1: # Not the last step # Take top-k confidences _, select_indices = torch.topk(block_confidence, k=min(num_transfer_tokens[j, i].item(), block_confidence.numel())) # Adjust indices to global positions select_indices = select_indices + block_start transfer_index[j, select_indices] = True else: # Last step - unmask everything remaining transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end] # Apply the selected tokens x = torch.where(transfer_index, x0, x) # Create visualization state only for the response part current_state = [] for i in range(gen_length): pos = prompt_length + i # Absolute position in the sequence if x[0, pos] == MASK_ID: # Still masked current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks elif old_x[0, pos] == MASK_ID: # Newly revealed in this step token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) # Color based on confidence confidence = float(x0_p[0, pos].cpu()) if confidence < 0.3: color = "#FF6666" # Light red elif confidence < 0.7: color = "#FFAA33" # Orange else: color = "#66CC66" # Light green current_state.append((token, color)) else: # Previously revealed token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) current_state.append((token, "#6699CC")) # Light blue visualization_states.append(current_state) # Extract final text (just the assistant's response) response_tokens = x[0, prompt_length:] final_text = tokenizer.decode(response_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True).split("<|")[0] return visualization_states, final_text # Tool handling functions def is_tool_call(text): """Check if the text looks like a JSON tool call""" # Remove any whitespace at beginning and end text = text.strip() # Check if it starts with [ or { (common JSON indicators) if (text.startswith('[') and text.endswith(']')) or (text.startswith('{') and text.endswith('}')): try: # Try to parse as JSON data = json.loads(text) # Check if it contains a tool call structure if isinstance(data, list): for item in data: if isinstance(item, dict) and "name" in item and "parameters" in item: return True elif isinstance(data, dict) and "name" in data and "parameters" in data: return True except: pass return False def extract_tool_call(text): """Extract tool call data from text""" try: data = json.loads(text) if isinstance(data, list) and len(data) > 0: # Return the first valid tool call for item in data: if isinstance(item, dict) and "name" in item and "parameters" in item: return item elif isinstance(data, dict) and "name" in data and "parameters" in data: return data except: pass return None def handle_tool_call(tool_call): """Process a tool call and return the result""" if tool_call["name"] == weather_tool.json_name: return weather_tool.actual_function(**tool_call["parameters"]) return {"error": f"Tool {tool_call['name']} not found"} # Custom CSS css = ''' .category-legend{display:none} button{height: 60px} .visualization-container { margin-top: 20px; padding: 10px; background-color: #f8f9fa; border-radius: 8px; } ''' def create_chatbot_demo(): with gr.Blocks(css=css) as demo: gr.Markdown("# LLaDA - Diffusion Model with Tool Calls Demo") gr.Markdown("This demo showcases the LLaDA diffusion model with the [Proximile/LLaDA-8B-Tools-LoRA](https://huggingface.co/Proximile/LLaDA-8B-Tools-LoRA) adapter for enhanced tool calling capabilities.") # STATE MANAGEMENT chat_history = gr.State([]) waiting_for_tool_response = gr.State(False) current_tool_call = gr.State(None) # UI COMPONENTS with gr.Row(): with gr.Column(scale=3): chatbot_ui = gr.Chatbot(label="Conversation", height=500) # Message input with gr.Group(): with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", show_label=False ) send_btn = gr.Button("Send") # Tool response input (initially hidden) with gr.Group(visible=False) as tool_response_group: gr.Markdown("## Tool Call Detected") tool_name_display = gr.Textbox(label="Tool Name", interactive=False) tool_params_display = gr.JSON(label="Parameters") tool_response_input = gr.Textbox( label="Tool Response (JSON)", placeholder="Enter JSON response for the tool...", lines=5 ) submit_tool_response = gr.Button("Submit Tool Response") # Add a button for auto-filling dummy response dummy_response_btn = gr.Button("Use Dummy Response") with gr.Column(scale=2): gr.Markdown("## Diffusion Process Visualization") gr.Markdown("Watch tokens appear in real-time as the diffusion process progresses:") output_vis = gr.HighlightedText( label="Token Denoising", combine_adjacent=False, show_legend=True, elem_classes="visualization-container" ) gr.Markdown("**Color Legend:**") gr.Markdown("- **Dark Gray** [MASK]: Not yet revealed") gr.Markdown("- **Light Red**: Newly revealed with low confidence") gr.Markdown("- **Orange**: Newly revealed with medium confidence") gr.Markdown("- **Light Green**: Newly revealed with high confidence") gr.Markdown("- **Light Blue**: Previously revealed tokens") # Advanced generation settings with gr.Accordion("Generation Settings", open=False): with gr.Row(): gen_length = gr.Slider( minimum=8, maximum=128, value=64, step=4, label="Generation Length" ) steps = gr.Slider( minimum=8, maximum=128, value=64, step=4, label="Denoising Steps" ) with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature" ) cfg_scale = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale" ) with gr.Row(): block_length = gr.Slider( minimum=8, maximum=128, value=32, step=8, label="Block Length" ) remasking_strategy = gr.Radio( choices=["low_confidence", "random"], value="low_confidence", label="Remasking Strategy" ) with gr.Row(): visualization_delay = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Visualization Delay (seconds)" ) # Current response text box (hidden) current_response = gr.Textbox( label="Current Response", placeholder="The assistant's response will appear here...", lines=3, visible=False ) # Clear button clear_btn = gr.Button("Clear Conversation") gr.Markdown("### Try asking about the weather to trigger a tool call!") gr.Markdown("Examples: 'What's the weather like in New York?', 'How hot is it in Tokyo right now?'") # System prompt for the model system_prompt = f"""You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original user question. If you choose to use one or more of the following tool functions, respond with a list of JSON function calls, each with the proper arguments that best answers the given prompt. Each tool request within the list should be in the exact format {{"name": function name, "parameters": {{dictionary of argument names and values}}}}. Do not use variables. Just a list of two-key dictionaries, each starting with the function name, followed by a dictionary of parameters. Here are the tool functions available to you: {json.dumps([weather_tool.schema], indent=4)} After receiving the results back from a function call, you have to formulate your response to the user. If the information needed is not found in the returned data, either attempt a new function call, or inform the user that you cannot answer based on your available knowledge. The user cannot see the function results. You have to interpret the data and provide a response based on it. If the user request does not necessitate a function call, simply respond to the user's query directly.""" # HELPER FUNCTIONS def add_message(history, message, response): """Add a message pair to the history and return the updated history""" history = history.copy() history.append([message, response]) return history def user_message_submitted(message, history, waiting_for_tool): """Process a submitted user message""" # Skip empty messages or if waiting for a tool response if not message.strip() or waiting_for_tool: # Return current state unchanged history_for_display = history.copy() return history, history_for_display, "", [], "" # Add user message to history history = add_message(history, message, None) # Format for display - temporarily show user message with empty response history_for_display = history.copy() # Clear the input message_out = "" # Return immediately to update UI with user message return history, history_for_display, message_out, [], "" def bot_response(history, waiting_for_tool, current_tool, gen_length, steps, delay, temperature, cfg_scale, block_length, remasking): """Generate bot response for the latest message""" if not history or waiting_for_tool: return history, [], "", waiting_for_tool, current_tool, gr.update(visible=False), gr.update(), gr.update() # Get the last user message last_user_message = history[-1][0] try: # Format the conversation for the model messages = [] # Add system message first messages.append({"role": "system", "content": system_prompt}) # Add conversation history for h in history[:-1]: messages.append({"role": "user", "content": h[0]}) if h[1]: # Only include assistant responses that exist messages.append({"role": "assistant", "content": h[1]}) # Add the last user message messages.append({"role": "user", "content": last_user_message}) # Generate response with visualization vis_states, response_text = generate_response_with_visualization( model, tokenizer, device, messages, gen_length=gen_length, steps=steps, temperature=temperature, cfg_scale=cfg_scale, block_length=block_length, remasking=remasking ) # Update history with the assistant's response history[-1][1] = response_text # Check if the response is a tool call is_tool = is_tool_call(response_text) if is_tool: # Extract tool call information tool_call = extract_tool_call(response_text) # Return the initial state immediately yield (history, vis_states[0], response_text, True, tool_call, gr.update(visible=True), gr.update(value=tool_call["name"]), gr.update(value=tool_call["parameters"])) # Then animate through visualization states for state in vis_states[1:]: time.sleep(delay) yield (history, state, response_text, True, tool_call, gr.update(visible=True), gr.update(value=tool_call["name"]), gr.update(value=tool_call["parameters"])) else: # Return the initial state immediately yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update() # Then animate through visualization states for state in vis_states[1:]: time.sleep(delay) yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update() except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) # Show error in visualization error_vis = [(error_msg, "red")] # Don't update history with error yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update() def process_tool_response(tool_response, history, current_tool, gen_length, steps, delay, temperature, cfg_scale, block_length, remasking): """Process tool response and generate a follow-up response""" if not history or not current_tool: return history, [], "", False, None, gr.update(visible=False), gr.update(), gr.update() try: # Parse the tool response response_data = json.loads(tool_response) if isinstance(tool_response, str) else tool_response # Format the conversation for the model messages = [] # Add system message first messages.append({"role": "system", "content": system_prompt}) # Add conversation history for h in history: messages.append({"role": "user", "content": h[0]}) if h[1]: # Only include assistant responses that exist messages.append({"role": "assistant", "content": h[1]}) # Add the tool response messages.append({"role": "ipython", "content": json.dumps({ "name": current_tool["name"], "return": response_data })}) # Generate response with visualization vis_states, response_text = generate_response_with_visualization( model, tokenizer, device, messages, gen_length=gen_length, steps=steps, temperature=temperature, cfg_scale=cfg_scale, block_length=block_length, remasking=remasking ) # Add a new message pair for the tool-processed response history = add_message(history, "Tool response processed", response_text) # Return the initial state immediately yield history, vis_states[0], response_text, False, None, gr.update(visible=False), gr.update(), gr.update() # Then animate through visualization states for state in vis_states[1:]: time.sleep(delay) yield history, state, response_text, False, None, gr.update(visible=False), gr.update(), gr.update() except Exception as e: error_msg = f"Error processing tool response: {str(e)}" print(error_msg) # Show error in visualization error_vis = [(error_msg, "red")] # Don't update history with error yield history, error_vis, error_msg, False, None, gr.update(visible=False), gr.update(), gr.update() def generate_dummy_response(current_tool): """Generate a dummy response for a tool call""" if not current_tool: return "" # Process based on tool name if current_tool["name"] == weather_tool.json_name: location = current_tool["parameters"].get("location", "Unknown") unit = current_tool["parameters"].get("unit", "celsius") dummy_data = { "location": location, "temperature": 72 if unit == "fahrenheit" else 22, "unit": unit, "condition": "Partly Cloudy", "humidity": 65, "wind_speed": 8, "wind_direction": "NE" } return json.dumps(dummy_data, indent=2) return "{}" def clear_conversation(): """Clear the conversation history""" return [], [], "", False, None, gr.update(visible=False), gr.update(), gr.update() # EVENT HANDLERS # Clear button handler clear_btn.click( fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, current_response, waiting_for_tool_response, current_tool_call, tool_response_group, tool_name_display, tool_params_display] ) # Dummy response button handler dummy_response_btn.click( fn=generate_dummy_response, inputs=[current_tool_call], outputs=[tool_response_input] ) # User message submission flow msg_submit = user_input.submit( fn=user_message_submitted, inputs=[user_input, chat_history, waiting_for_tool_response], outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] ) # Also connect the send button send_click = send_btn.click( fn=user_message_submitted, inputs=[user_input, chat_history, waiting_for_tool_response], outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] ) # Generate bot response msg_submit.then( fn=bot_response, inputs=[ chat_history, waiting_for_tool_response, current_tool_call, gen_length, steps, visualization_delay, temperature, cfg_scale, block_length, remasking_strategy ], outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response, current_tool_call, tool_response_group, tool_name_display, tool_params_display] ) send_click.then( fn=bot_response, inputs=[ chat_history, waiting_for_tool_response, current_tool_call, gen_length, steps, visualization_delay, temperature, cfg_scale, block_length, remasking_strategy ], outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response, current_tool_call, tool_response_group, tool_name_display, tool_params_display] ) # Tool response submission submit_tool_response.click( fn=process_tool_response, inputs=[ tool_response_input, chat_history, current_tool_call, gen_length, steps, visualization_delay, temperature, cfg_scale, block_length, remasking_strategy ], outputs=[chatbot_ui, output_vis, current_response, waiting_for_tool_response, current_tool_call, tool_response_group, tool_name_display, tool_params_display] ) return demo demo = create_chatbot_demo() demo.queue().launch(share=True)