|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import functools |
|
import time |
|
from collections.abc import Generator |
|
|
|
from transformers import Trainer, is_wandb_available |
|
|
|
|
|
if is_wandb_available(): |
|
import wandb |
|
|
|
|
|
@contextlib.contextmanager |
|
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: |
|
""" |
|
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled. |
|
|
|
Args: |
|
trainer (`~transformers.Trainer`): |
|
Trainer object. |
|
name (`str`): |
|
Name of the block to be profiled. Used as a key in the logged dictionary. |
|
|
|
Example: |
|
```python |
|
from transformers import Trainer |
|
from trl.extras.profiling import profiling_context |
|
|
|
class MyTrainer(Trainer): |
|
def some_method(self): |
|
A = np.random.rand(1000, 1000) |
|
B = np.random.rand(1000, 1000) |
|
with profiling_context(self, "matrix_multiplication"): |
|
# Code to profile: simulate a computationally expensive operation |
|
result = A @ B # Matrix multiplication |
|
``` |
|
""" |
|
start_time = time.perf_counter() |
|
yield |
|
end_time = time.perf_counter() |
|
duration = end_time - start_time |
|
|
|
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: |
|
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}) |
|
|
|
|
|
def profiling_decorator(func: callable) -> callable: |
|
""" |
|
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. |
|
|
|
Args: |
|
func (`callable`): |
|
Function to be profiled. |
|
|
|
Example: |
|
```python |
|
from transformers import Trainer |
|
from trl.extras.profiling import profiling_decorator |
|
|
|
class MyTrainer(Trainer): |
|
@profiling_decorator |
|
def some_method(self): |
|
A = np.random.rand(1000, 1000) |
|
B = np.random.rand(1000, 1000) |
|
# Code to profile: simulate a computationally expensive operation |
|
result = A @ B |
|
``` |
|
""" |
|
|
|
@functools.wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
with profiling_context(self, func.__name__): |
|
return func(self, *args, **kwargs) |
|
|
|
return wrapper |
|
|