Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import typer | |
| from bytelatent.distributed import DistributedArgs, setup_torch_distributed | |
| from bytelatent.generate import load_consolidated_model_and_tokenizer | |
| from bytelatent.generate_blt import generate_nocache | |
| from bytelatent.model.blt import ByteLatentTransformer | |
| from bytelatent.tokenizers.blt_tokenizer import BltTokenizer | |
| def main(prompt: str, model_name: str = "blt-1b"): | |
| distributed_args = DistributedArgs() | |
| distributed_args.configure_world() | |
| if not torch.distributed.is_initialized(): | |
| setup_torch_distributed(distributed_args) | |
| checkpoint_path = os.path.join("hf-weights", model_name) | |
| print(f"Loading BLT model: {model_name}") | |
| model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( | |
| checkpoint_path, | |
| ) | |
| assert isinstance(model, ByteLatentTransformer) | |
| assert isinstance(tokenizer, BltTokenizer) | |
| patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) | |
| patcher_args.realtime_patching = True | |
| print("Loading entropy model and patcher") | |
| patcher_args.entropy_model_checkpoint_dir = os.path.join( | |
| checkpoint_path, "entropy_model" | |
| ) | |
| patcher = patcher_args.build() | |
| prompts = [prompt] | |
| outputs = generate_nocache( | |
| prompts, model=model, tokenizer=tokenizer, patcher=patcher | |
| ) | |
| text_outputs = [tokenizer.decode(t) for t in outputs] | |
| for p, t in zip(prompts, text_outputs): | |
| print(f'Prompt: "{p}" Completion: "{t}"') | |
| print() | |
| if __name__ == "__main__": | |
| typer.run(main) | |