File size: 5,508 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
#pragma once
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
// NOTE [Tracing Mode Switches]
//
// Historically, tracing function was controlled by two switches:
//
// - `AutoDispatchBelowADInplaceOrView` guard
//
// Tracing function used to be script-generated inside `VariableType_*.cpp`
// kernels, sharing the same `Autograd` dispatch key with autograd function.
// Therefore, before tracing function was moved out of VariableType,
// `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
// side effect of disabling `Autograd` dispatching.
//
// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
//
// It stores tracing data in a `TracingState` object in TLS. If the
// `TracingState` object in TLS is `null`, then tracing is paused.
//
// The `TracingState` object is created in `tracer::trace()` - the main
// entrance of tracing function. It's temporarily set to `null` inside
// generated VariableType (now TraceType) to bypass tracing for intermediate
// ops (ops being called by other ops). After the intermediate op call
// finishes it's set back to the original `TracingState` object.
//
// The `TracingState` obect in TLS can also be read/written via its Python
// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
// which are also exposed as `TORCH_API`.
//
// Two new switches were introduced since tracing function was moved out of
// VariableType:
//
// - `tracer::impl::set_dispatch_enabled()` API
//
// Unlike the special `Autograd` dispatch key which is included in dispatch
// key set by default, `Tracer` dispatch key is off by default. The
// dispatching switch can be toggled via this new API.
//
// - `tracer::impl::NoTracerDispatchMode` guard
//
// It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
// after tracing was moved out of VariableType.
//
// Before tracing function was moved out of VariableType, tracing was enabled
// when the following conditions are satisfied:
//
// 1) `TracingState` object in TLS != null;
// - Either inside the execution scope of `tracer::trace()`, or
// - Eagerly called `setTracingState()` with non-null object.
// 2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
//
// After:
//
// 1) `TracingState` object in TLS != null;
// 2) Has called `tracer::impl::set_dispatch_enabled(true)`;
// 3) Not inside `tracer::impl::NonDispatchGuard` scope;
//
// [TODOs]
//
// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
//
// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
// to keep the semantics exactly the same as before - it's confusing to keep
// both switches, though. We should consider simplifying/limiting the exposed
// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
// these two can be unified.
//
// - `AutoDispatchBelowADInplaceOrView` v.s.
// `tracer::impl::NoTracerDispatchMode`
//
// We don't need to always set both guards together to keep semantics
// unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
// we don't need set the new tracer guard:
//
// * Script-generated VariableType kernels. The guard is not necessary as
// tracing is already disabled explicitly by `setTracingState(null)` in
// generated TraceType kernels - we could keep it as is or use the new guard
// instead.
//
// * Custom ops. Will be handled by fallback kernel for `Tracer`.
//
// * Functions that are not likely to be called in tracing context (no python
// binding / not an operator), e.g.: all mobile forward() wrappers, test
// binaries, and etc.
//
// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
// It's not necessary as tracing is off by default.
//
// For the rest of cases we might need have both:
//
// * Functions that might be reachable from eager mode python (especially
// factory methods), e.g.:
// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
// Without the new guard it will add `aten::empty` to the traced graph.
//
// * Some manually maintained functions, e.g.:
// `torch/csrc/autograd/VariableTypeManual.cpp`.
// Set the new guard if it's not obvious whether `setTracingState(null)`
// has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
// guard.
//
// We might need tweak the usage of the new guard to optimize/fix things.
// It should only affect the correctness of tracing function, because the
// guard is essentially no-op when the master `setTracingState()` switch is
// off.
// TODO: move this from `at::` to `jit::torch::` after
// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
namespace at::tracer::impl {
inline bool is_dispatch_enabled() {
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
}
inline void set_dispatch_enabled(bool enabled) {
TORCH_INTERNAL_ASSERT(
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
"Cannot enable tracing within the scope of NoTracerDispatchMode!");
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
}
struct NoTracerDispatchMode {
c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
};
} // namespace at::tracer::impl
|