|
|
|
import functools |
|
import logging |
|
import os |
|
import sys |
|
import tempfile |
|
from typing import Any, Callable, Optional, TypeVar |
|
from typing_extensions import ParamSpec |
|
|
|
import torch |
|
from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler |
|
|
|
|
|
_T = TypeVar("_T") |
|
_P = ParamSpec("_P") |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): |
|
import shutil |
|
|
|
if not shutil.which("strobeclient"): |
|
log.info( |
|
"TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine." |
|
) |
|
else: |
|
log.info("Strobelight profiler is enabled via environment variable") |
|
StrobelightCompileTimeProfiler.enable() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch._running_with_deploy(): |
|
|
|
|
|
|
|
torch_parent = "" |
|
else: |
|
if os.path.basename(os.path.dirname(__file__)) == "shared": |
|
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
|
else: |
|
torch_parent = os.path.dirname(os.path.dirname(__file__)) |
|
|
|
|
|
def get_file_path(*path_components: str) -> str: |
|
return os.path.join(torch_parent, *path_components) |
|
|
|
|
|
def get_file_path_2(*path_components: str) -> str: |
|
return os.path.join(*path_components) |
|
|
|
|
|
def get_writable_path(path: str) -> str: |
|
if os.access(path, os.W_OK): |
|
return path |
|
return tempfile.mkdtemp(suffix=os.path.basename(path)) |
|
|
|
|
|
def prepare_multiprocessing_environment(path: str) -> None: |
|
pass |
|
|
|
|
|
def resolve_library_path(path: str) -> str: |
|
return os.path.realpath(path) |
|
|
|
|
|
def throw_abstract_impl_not_imported_error(opname, module, context): |
|
if module in sys.modules: |
|
raise NotImplementedError( |
|
f"{opname}: We could not find the fake impl for this operator. " |
|
) |
|
else: |
|
raise NotImplementedError( |
|
f"{opname}: We could not find the fake impl for this operator. " |
|
f"The operator specified that you may need to import the '{module}' " |
|
f"Python module to load the fake impl. {context}" |
|
) |
|
|
|
|
|
|
|
def compile_time_strobelight_meta( |
|
phase_name: str, |
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: |
|
def compile_time_strobelight_meta_inner( |
|
function: Callable[_P, _T], |
|
) -> Callable[_P, _T]: |
|
@functools.wraps(function) |
|
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T: |
|
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): |
|
kwargs["skip"] = skip + 1 |
|
|
|
|
|
|
|
if not StrobelightCompileTimeProfiler.enabled: |
|
return function(*args, **kwargs) |
|
|
|
return StrobelightCompileTimeProfiler.profile_compile_time( |
|
function, phase_name, *args, **kwargs |
|
) |
|
|
|
return wrapper_function |
|
|
|
return compile_time_strobelight_meta_inner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def signpost_event(category: str, name: str, parameters: dict[str, Any]): |
|
log.info("%s %s: %r", category, name, parameters) |
|
|
|
|
|
def log_compilation_event(metrics): |
|
log.info("%s", metrics) |
|
|
|
|
|
def upload_graph(graph): |
|
pass |
|
|
|
|
|
def set_pytorch_distributed_envs_from_justknobs(): |
|
pass |
|
|
|
|
|
def log_export_usage(**kwargs): |
|
pass |
|
|
|
|
|
def log_trace_structured_event(*args, **kwargs) -> None: |
|
pass |
|
|
|
|
|
def log_cache_bypass(*args, **kwargs) -> None: |
|
pass |
|
|
|
|
|
def log_torchscript_usage(api: str, **kwargs): |
|
_ = api |
|
return |
|
|
|
|
|
def check_if_torch_exportable(): |
|
return False |
|
|
|
|
|
def export_training_ir_rollout_check() -> bool: |
|
return True |
|
|
|
|
|
def log_torch_jit_trace_exportability( |
|
api: str, |
|
type_of_export: str, |
|
export_outcome: str, |
|
result: str, |
|
): |
|
_, _, _, _ = api, type_of_export, export_outcome, result |
|
return |
|
|
|
|
|
def justknobs_check(name: str, default: bool = True) -> bool: |
|
""" |
|
This function can be used to killswitch functionality in FB prod, |
|
where you can toggle this value to False in JK without having to |
|
do a code push. In OSS, we always have everything turned on all |
|
the time, because downstream users can simply choose to not update |
|
PyTorch. (If more fine-grained enable/disable is needed, we could |
|
potentially have a map we lookup name in to toggle behavior. But |
|
the point is that it's all tied to source code in OSS, since there's |
|
no live server to query.) |
|
|
|
This is the bare minimum functionality I needed to do some killswitches. |
|
We have a more detailed plan at |
|
https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit |
|
In particular, in some circumstances it may be necessary to read in |
|
a knob once at process start, and then use it consistently for the |
|
rest of the process. Future functionality will codify these patterns |
|
into a better high level API. |
|
|
|
WARNING: Do NOT call this function at module import time, JK is not |
|
fork safe and you will break anyone who forks the process and then |
|
hits JK again. |
|
""" |
|
return default |
|
|
|
|
|
def justknobs_getval_int(name: str) -> int: |
|
""" |
|
Read warning on justknobs_check |
|
""" |
|
return 0 |
|
|
|
|
|
def is_fb_unit_test() -> bool: |
|
return False |
|
|
|
|
|
@functools.lru_cache(None) |
|
def max_clock_rate(): |
|
if not torch.version.hip: |
|
from triton.testing import nvsmi |
|
|
|
return nvsmi(["clocks.max.sm"])[0] |
|
else: |
|
|
|
|
|
|
|
gcn_arch = str(torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0]) |
|
if "gfx94" in gcn_arch: |
|
return 1700 |
|
elif "gfx90a" in gcn_arch: |
|
return 1700 |
|
elif "gfx908" in gcn_arch: |
|
return 1502 |
|
elif "gfx12" in gcn_arch: |
|
return 1700 |
|
elif "gfx11" in gcn_arch: |
|
return 1700 |
|
elif "gfx103" in gcn_arch: |
|
return 1967 |
|
elif "gfx101" in gcn_arch: |
|
return 1144 |
|
elif "gfx95" in gcn_arch: |
|
return 1700 |
|
else: |
|
return 1100 |
|
|
|
|
|
def get_mast_job_name_version() -> Optional[tuple[str, int]]: |
|
return None |
|
|
|
|
|
TEST_MASTER_ADDR = "127.0.0.1" |
|
TEST_MASTER_PORT = 29500 |
|
|
|
|
|
USE_GLOBAL_DEPS = True |
|
|
|
|
|
USE_RTLD_GLOBAL_WITH_LIBTORCH = False |
|
|
|
|
|
|
|
|
|
REQUIRES_SET_PYTHON_MODULE = False |
|
|
|
|
|
def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: |
|
print("Uploading profile stats (fb-only otherwise no-op)") |
|
return None |
|
|
|
|
|
def log_chromium_event_internal( |
|
event: dict[str, Any], |
|
stack: list[str], |
|
logger_uuid: str, |
|
start_time_ns: int, |
|
): |
|
return None |
|
|
|
|
|
def record_chromium_event_internal( |
|
event: dict[str, Any], |
|
): |
|
return None |
|
|