|
import time |
|
|
|
import datasets |
|
import torch |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
|
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
model_id = "meta-llama/Llama-3.2-3b-Instruct" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" |
|
).eval() |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") |
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=512, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
use_cache=False, |
|
num_blocks=2048, |
|
block_size=128, |
|
do_sample=True, |
|
max_batch_tokens=1024, |
|
scheduler="prefill_first", |
|
) |
|
|
|
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") |
|
|
|
|
|
print("--- Running CB Generation Example ---") |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["question"]) |
|
|
|
|
|
tokenized_datasets = train_dataset.map(tokenize_function, batched=True) |
|
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] |
|
|
|
start_time_simple = time.time() |
|
|
|
batch_outputs = model.generate_batch( |
|
inputs=simple_batch_inputs, |
|
generation_config=generation_config, |
|
) |
|
end_time_simple = time.time() |
|
|
|
for request in batch_outputs: |
|
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) |
|
try: |
|
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) |
|
except Exception as e: |
|
print(f"Decoding failed for request {request}: {e}") |
|
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) |
|
if len(output_text) > 0: |
|
print("-" * 20) |
|
print(f"{request} Input: {input_text}") |
|
print(f"{request} Output: {output_text}") |
|
else: |
|
print("", end="\r\r\r\r") |
|
print("-" * 20) |
|
print("--- Finished CB Generation Example ---\n\n") |
|
|
|
|
|
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|