# # Exceptions and other classes # class ExceededContextLengthException(Exception): """Exception raised when an input exceeds a model's context length""" class _LlamaStopwatch: """Track elapsed time for prompt processing and text generation""" # # Q: why don't you use llama_perf_context? # # A: comments in llama.h state to only use that in llama.cpp examples, # and to do your own performance measurements instead. # # trying to use llama_perf_context leads to output with # "0.00 ms per token" and "inf tokens per second" # def __init__(self): self.pp_start_time = None self.tg_start_time = None self.wall_start_time = None self.generic_start_time = None self.pp_elapsed_time = 0 self.tg_elapsed_time = 0 self.wall_elapsed_time = 0 self.generic_elapsed_time = 0 self.n_pp_tokens = 0 self.n_tg_tokens = 0 def start_pp(self): """Start prompt processing stopwatch""" self.pp_start_time = time.time_ns() def stop_pp(self): """Stop prompt processing stopwatch""" if self.pp_start_time is not None: self.pp_elapsed_time += time.time_ns() - self.pp_start_time self.pp_start_time = None def start_tg(self): """Start text generation stopwatch""" self.tg_start_time = time.time_ns() def stop_tg(self): """Stop text generation stopwatch""" if self.tg_start_time is not None: self.tg_elapsed_time += time.time_ns() - self.tg_start_time self.tg_start_time = None def start_wall_time(self): """Start wall-time stopwatch""" self.wall_start_time = time.time_ns() def stop_wall_time(self): """Stop wall-time stopwatch""" if self.wall_start_time is not None: self.wall_elapsed_time += time.time_ns() - self.wall_start_time self.wall_start_time = None def start_generic(self): """Start generic stopwatch (not shown in print_stats)""" self.generic_start_time = time.time_ns() def stop_generic(self): """Stop generic stopwatch""" if self.generic_start_time is not None: self.generic_elapsed_time += time.time_ns() - self.generic_start_time self.generic_start_time = None def get_elapsed_time_pp(self) -> int: """Total nanoseconds elapsed during prompt processing""" return self.pp_elapsed_time def get_elapsed_time_tg(self) -> int: """Total nanoseconds elapsed during text generation""" return self.tg_elapsed_time def get_elapsed_wall_time(self) -> int: """Total wall-time nanoseconds elapsed""" return self.wall_elapsed_time def get_elapsed_time_generic(self) -> int: """Total generic nanoseconds elapsed""" return self.generic_elapsed_time def increment_pp_tokens(self, n: int): if n < 0: raise ValueError('negative increments are not allowed') self.n_pp_tokens += n def increment_tg_tokens(self, n: int): if n < 0: raise ValueError('negative increments are not allowed') self.n_tg_tokens += n def reset(self): """Reset the stopwatch to its original state""" self.pp_start_time = None self.tg_start_time = None self.wall_start_time = None self.generic_start_time = None self.pp_elapsed_time = 0 self.tg_elapsed_time = 0 self.wall_elapsed_time = 0 self.generic_elapsed_time = 0 self.n_pp_tokens = 0 self.n_tg_tokens = 0 def print_stats(self): """Print performance statistics using current stopwatch state #### NOTE: The `n_tg_tokens` value will be equal to the number of calls to llama_decode which have a batch size of 1, which is technically not always equal to the number of tokens generated - it may be off by one.""" print(f"\n", end='', file=sys.stderr, flush=True) if self.n_pp_tokens + self.n_tg_tokens == 0: print_stopwatch(f'print_stats was called but no tokens were processed or generated') if self.n_pp_tokens > 0: pp_elapsed_ns = self.get_elapsed_time_pp() pp_elapsed_ms = pp_elapsed_ns / 1e6 pp_elapsed_s = pp_elapsed_ns / 1e9 pp_tps = self.n_pp_tokens / pp_elapsed_s print_stopwatch( f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms ' f'({pp_tps:>10.2f} tok/s)' ) if self.n_tg_tokens > 0: tg_elapsed_ns = self.get_elapsed_time_tg() tg_elapsed_ms = tg_elapsed_ns / 1e6 tg_elapsed_s = tg_elapsed_ns / 1e9 tg_tps = self.n_tg_tokens / tg_elapsed_s print_stopwatch( f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms ' f'({tg_tps:>10.2f} tok/s)' ) wall_elapsed_ns = self.get_elapsed_wall_time() wall_elapsed_ms = wall_elapsed_ns / 1e6 print_stopwatch(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms")