|
from logging import Logger |
|
import os |
|
from threading import Event, Thread |
|
from time import perf_counter, sleep |
|
from typing import Optional |
|
from benchmarks_entrypoint import MetricsRecorder |
|
import gpustat |
|
import psutil |
|
import psycopg2 |
|
import torch |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache |
|
|
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "1" |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder): |
|
p = psutil.Process(os.getpid()) |
|
while not continue_metric_collection.is_set(): |
|
with p.oneshot(): |
|
cpu_util = p.cpu_percent() |
|
mem_megabytes = p.memory_info().rss / (1024 * 1024) |
|
gpu_stats = gpustat.GPUStatCollection.new_query() |
|
gpu_util = gpu_stats[0]["utilization.gpu"] |
|
gpu_mem_megabytes = gpu_stats[0]["memory.used"] |
|
metrics_recorder.collect_device_measurements( |
|
benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes |
|
) |
|
sleep(0.01) |
|
|
|
|
|
def run_benchmark( |
|
logger: Logger, repository: str, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100 |
|
): |
|
continue_metric_collection = Event() |
|
metrics_thread = None |
|
model_id = "meta-llama/Llama-2-7b-hf" |
|
metrics_recorder = MetricsRecorder( |
|
psycopg2.connect("dbname=metrics"), logger, repository, branch, commit_id, commit_msg |
|
) |
|
try: |
|
gpu_stats = gpustat.GPUStatCollection.new_query() |
|
gpu_name = gpu_stats[0]["name"] |
|
benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id}) |
|
logger.info(f"running benchmark #{benchmark_id} on {gpu_name} for {model_id}") |
|
metrics_thread = Thread( |
|
target=collect_metrics, |
|
args=[benchmark_id, continue_metric_collection, metrics_recorder], |
|
) |
|
metrics_thread.start() |
|
logger.info("started background thread to fetch device metrics") |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
device = "cuda" |
|
|
|
logger.info("downloading weights") |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16) |
|
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1) |
|
logger.info("loading model") |
|
start = perf_counter() |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, torch_dtype=torch.float16, generation_config=gen_config |
|
).eval() |
|
model.to(device) |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
model_load_time = end - start |
|
logger.info(f"loaded model in: {model_load_time}s") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
prompt = "Why dogs are so cute?" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
|
seq_length = inputs["input_ids"].shape[1] |
|
model.generation_config.max_length = seq_length + num_tokens_to_generate |
|
batch_size = inputs["input_ids"].shape[0] |
|
|
|
|
|
def multinomial_sample_one_no_sync(probs_sort): |
|
q = torch.empty_like(probs_sort).exponential_(1) |
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) |
|
|
|
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): |
|
logits = logits / max(temperature, 1e-5) |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
pivot = v.select(-1, -1).unsqueeze(-1) |
|
logits = torch.where(logits < pivot, -float("Inf"), logits) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
return probs |
|
|
|
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): |
|
probs = logits_to_probs(logits[:, -1], temperature, top_k) |
|
idx_next = multinomial_sample_one_no_sync(probs) |
|
return idx_next, probs |
|
|
|
def decode_one_token(model, cur_token, cache_position, past_key_values): |
|
logits = model( |
|
cur_token, |
|
cache_position=cache_position, |
|
past_key_values=past_key_values, |
|
return_dict=False, |
|
use_cache=True, |
|
)[0] |
|
new_token = sample(logits, temperature=0.6, top_k=5)[0] |
|
return new_token |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + num_tokens_to_generate, |
|
) |
|
cache_position = torch.arange(seq_length, device=device) |
|
start = perf_counter() |
|
model( |
|
**inputs, |
|
cache_position=cache_position, |
|
past_key_values=past_key_values, |
|
return_dict=False, |
|
use_cache=True, |
|
) |
|
end = perf_counter() |
|
first_eager_fwd_pass_time = end - start |
|
logger.info(f"completed first eager fwd pass in: {first_eager_fwd_pass_time}s") |
|
start = perf_counter() |
|
output = model.generate(**inputs, do_sample=False) |
|
end = perf_counter() |
|
first_eager_generate_time = end - start |
|
logger.info(f"completed first eager generation in: {first_eager_generate_time}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") |
|
|
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + num_tokens_to_generate, |
|
) |
|
cache_position = torch.arange(seq_length, device=device) |
|
start = perf_counter() |
|
model( |
|
**inputs, |
|
cache_position=cache_position, |
|
past_key_values=past_key_values, |
|
return_dict=False, |
|
use_cache=True, |
|
) |
|
end = perf_counter() |
|
second_eager_fwd_pass_time = end - start |
|
logger.info(f"completed second eager fwd pass in: {second_eager_fwd_pass_time}s") |
|
start = perf_counter() |
|
model.generate(**inputs, do_sample=False) |
|
end = perf_counter() |
|
second_eager_generate_time = end - start |
|
logger.info(f"completed second eager generation in: {second_eager_generate_time}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") |
|
|
|
torch.compiler.reset() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generated_ids = torch.zeros( |
|
(batch_size, num_tokens_to_generate + seq_length), dtype=torch.int, device=device |
|
) |
|
|
|
generated_ids[:, :seq_length] = inputs["input_ids"] |
|
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) |
|
|
|
|
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + num_tokens_to_generate + 10, |
|
) |
|
cache_position = torch.arange(seq_length, device=device) |
|
all_generated_tokens = [] |
|
|
|
start = perf_counter() |
|
next_token = decode_one_token( |
|
model, inputs["input_ids"], cache_position=cache_position, past_key_values=past_key_values |
|
) |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
time_to_first_token = end - start |
|
logger.info(f"completed first compile generation in: {time_to_first_token}s") |
|
cache_position += 1 |
|
all_generated_tokens += next_token.tolist() |
|
|
|
cache_position = torch.tensor([seq_length], device=device) |
|
|
|
start = perf_counter() |
|
next_token = decode_one_token( |
|
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values |
|
) |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
time_to_second_token = end - start |
|
logger.info(f"completed second compile generation in: {time_to_second_token}s") |
|
cache_position += 1 |
|
all_generated_tokens += next_token.tolist() |
|
|
|
|
|
start = perf_counter() |
|
next_token = decode_one_token( |
|
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values |
|
) |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
time_to_third_token = end - start |
|
logger.info(f"completed third compile forward in: {time_to_third_token}s") |
|
cache_position += 1 |
|
all_generated_tokens += next_token.tolist() |
|
|
|
|
|
|
|
start = perf_counter() |
|
for _ in range(1, num_tokens_to_generate): |
|
all_generated_tokens += next_token.tolist() |
|
next_token = decode_one_token( |
|
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values |
|
) |
|
cache_position += 1 |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
mean_time_to_next_token = (end - start) / num_tokens_to_generate |
|
logger.info(f"completed next compile generation in: {mean_time_to_next_token}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(all_generated_tokens)}") |
|
|
|
|
|
|
|
|
|
torch.compiler.reset() |
|
|
|
|
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + 128, |
|
) |
|
|
|
|
|
start = perf_counter() |
|
output = model.generate(**inputs, past_key_values=past_key_values) |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
first_compile_generate_time = end - start |
|
logger.info(f"completed first compile generation in: {first_compile_generate_time}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") |
|
|
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + 128, |
|
) |
|
|
|
start = perf_counter() |
|
output = model.generate(**inputs, past_key_values=past_key_values) |
|
torch.cuda.synchronize() |
|
end = perf_counter() |
|
second_compile_generate_time = end - start |
|
logger.info(f"completed second compile generation in: {second_compile_generate_time}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") |
|
|
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + 128, |
|
) |
|
|
|
|
|
start = perf_counter() |
|
output = model.generate(**inputs, past_key_values=past_key_values) |
|
end = perf_counter() |
|
third_compile_generate_time = end - start |
|
logger.info(f"completed third compile generation in: {third_compile_generate_time}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") |
|
|
|
past_key_values = StaticCache( |
|
model.config, |
|
max_batch_size=batch_size, |
|
device=device, |
|
dtype=torch.float16, |
|
max_cache_len=seq_length + 128, |
|
) |
|
|
|
start = perf_counter() |
|
output = model.generate(**inputs, past_key_values=past_key_values) |
|
end = perf_counter() |
|
fourth_compile_generate_time = end - start |
|
logger.info(f"completed fourth compile generation in: {fourth_compile_generate_time}s") |
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") |
|
|
|
metrics_recorder.collect_model_measurements( |
|
benchmark_id, |
|
{ |
|
"model_load_time": model_load_time, |
|
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time, |
|
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time, |
|
"first_eager_generate_time_secs": first_eager_generate_time, |
|
"second_eager_generate_time_secs": second_eager_generate_time, |
|
"time_to_first_token_secs": time_to_first_token, |
|
"time_to_second_token_secs": time_to_second_token, |
|
"time_to_third_token_secs": time_to_third_token, |
|
"time_to_next_token_mean_secs": mean_time_to_next_token, |
|
"first_compile_generate_time_secs": first_compile_generate_time, |
|
"second_compile_generate_time_secs": second_compile_generate_time, |
|
"third_compile_generate_time_secs": third_compile_generate_time, |
|
"fourth_compile_generate_time_secs": fourth_compile_generate_time, |
|
}, |
|
) |
|
except Exception as e: |
|
logger.error(f"Caught exception: {e}") |
|
continue_metric_collection.set() |
|
if metrics_thread is not None: |
|
metrics_thread.join() |
|
metrics_recorder.close() |
|
|