Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import typer | |
| from bytelatent.data.file_util import get_fs | |
| from bytelatent.distributed import DistributedArgs, setup_torch_distributed | |
| from bytelatent.generate_patcher import patcher_nocache | |
| from bytelatent.tokenizers.blt_tokenizer import BltTokenizer | |
| from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies | |
| def main(prompt: str, model_name: str = "blt-1b"): | |
| from bytelatent.args import TrainArgs | |
| consolidated_path = os.path.join("hf-weights", model_name) | |
| train_args_path = os.path.join(consolidated_path, "params.json") | |
| fs = get_fs(train_args_path) | |
| train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) | |
| tokenizer = train_args.data.tokenizer_args.build() | |
| assert isinstance(tokenizer, BltTokenizer) | |
| patcher_args = train_args.data.patcher_args.model_copy(deep=True) | |
| patcher_args.realtime_patching = True | |
| # NOTE: CPU currently unsupported due to reliance of xformers | |
| patcher_args.patching_device = "cpu" | |
| patcher_args.device = "cpu" | |
| print("Loading entropy model and patcher") | |
| patcher_args.entropy_model_checkpoint_dir = os.path.join( | |
| consolidated_path, "entropy_model" | |
| ) | |
| patcher = patcher_args.build() | |
| prompts = [prompt] | |
| results = patcher_nocache( | |
| prompts, tokenizer=tokenizer, patcher=patcher | |
| ) | |
| if not results: | |
| raise Exception("Ruh roh") | |
| batch_patch_lengths, batch_scores, batch_tokens = results | |
| decoded_chars = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens] | |
| plot_entropies( | |
| batch_patch_lengths[0], | |
| batch_scores[0], | |
| decoded_chars[0], | |
| threshold=patcher.threshold | |
| ) | |
| if __name__ == "__main__": | |
| typer.run(main) | |