import os import gradio as gr import torch from bytelatent.data.file_util import get_fs from bytelatent.generate_patcher import patcher_nocache from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies from bytelatent.args import TrainArgs from download_blt_weights import main as ensure_present # --- Global Setup (Consider loading models outside if necessary) --- # Kept inside the function for simplicity as before. def process_text(prompt: str, model_name: str = "blt-1b"): """ Processes the input prompt using the ByteLatent model and returns decoded characters. Args: prompt: The input text string from the Gradio interface. model_name: The name of the model to use. Returns: A string containing the decoded characters after processing, or an error message. """ try: # --- Model and Tokenizer Loading --- consolidated_path = os.path.join("hf-weights", model_name) train_args_path = os.path.join(consolidated_path, "params.json") if not os.path.exists(train_args_path): raise FileNotFoundError(f"Training args not found at {train_args_path}. " f"Ensure model '{model_name}' is downloaded/available.") fs = get_fs(train_args_path) train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) tokenizer = train_args.data.tokenizer_args.build() assert isinstance(tokenizer, BltTokenizer) patcher_args = train_args.data.patcher_args.model_copy(deep=True) patcher_args.realtime_patching = True device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") patcher_args.patching_device = device patcher_args.device = device print("Loading entropy model and patcher...") entropy_model_dir = os.path.join(consolidated_path, "entropy_model") if not os.path.exists(entropy_model_dir): raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.") patcher_args.entropy_model_checkpoint_dir = entropy_model_dir patcher = patcher_args.build() # --- End Loading --- # --- Processing --- prompts = [prompt] print(f"Processing prompt: '{prompt}'") results = patcher_nocache( prompts, tokenizer=tokenizer, patcher=patcher ) if not results: print("Processing returned no results.") return "Processing completed, but no results were generated." # Return info message batch_patch_lengths, batch_scores, batch_tokens = results # Decode the first (and only) result in the batch decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens] fig = None if decoded_chars_list: decoded_output = decoded_chars_list[0] fig = plot_entropies( batch_patch_lengths[0], batch_scores[0], decoded_output, threshold=patcher.threshold ) print("Processing and decoding complete.") # --- End Processing --- return fig except FileNotFoundError as e: print(f"Error: {e}") # raise gr.Error(str(e)) # Display specific error in Gradio UI return f"Error: {str(e)}" # Return error as text output except Exception as e: print(f"An unexpected error occurred: {e}") import traceback traceback.print_exc() # raise gr.Error(f"An error occurred during processing: {e}") return f"An unexpected error occurred: {e}" # Return error as text output iface = gr.Interface( fn=process_text, inputs=gr.Textbox( label="Input Prompt", placeholder="Enter your text here..." ), outputs=gr.Plot(label="Entropy Plot"), title="ByteLatent Text Processor", description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.", allow_flagging="never", ) with gr.Blocks() as iface: gr.Markdown("# ByteLatent Entropy Visualizer") # Title gr.Markdown( "Process any prompt (limited to 512 bytes) with the 100M entropy patcher model " "and visualize the token entropies plot below.

" # Updated description "NOTE: this implementation differs slightly by excluding local attention so we limit " "the characters limit to 512 to avoid any deviation.", line_breaks=True ) with gr.Column(): prompt_input = gr.Textbox( label="Input Prompt", value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.", placeholder="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.", max_length=512 ) submit_button = gr.Button("Generate Plot") # Add button plot_output = gr.Plot(label="Entropy w Threshold") # Output component # Define the action when the button is clicked submit_button.click( fn=process_text, inputs=prompt_input, # Input component(s) outputs=plot_output # Output component(s) ) # --- Launch the Gradio App --- if __name__ == "__main__": ensure_present(["blt-1b"]) iface.launch()