|
#pragma once |
|
|
|
#include <c10/core/InferenceMode.h> |
|
#include <c10/core/impl/LocalDispatchKeySet.h> |
|
#include <c10/util/Exception.h> |
|
#include <c10/util/ThreadLocalDebugInfo.h> |
|
|
|
#include <ATen/FuncTorchTLS.h> |
|
#include <ATen/PythonTorchFunctionTLS.h> |
|
#include <ATen/SavedTensorHooks.h> |
|
#include <ATen/ThreadLocalPythonObjects.h> |
|
#include <ATen/record_function.h> |
|
#include <c10/core/impl/PythonDispatcherTLS.h> |
|
#include <c10/core/impl/TorchDispatchModeTLS.h> |
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
class TORCH_API ThreadLocalState { |
|
public: |
|
|
|
|
|
ThreadLocalState(); |
|
|
|
|
|
|
|
|
|
void set_grad_mode(bool enabled); |
|
|
|
|
|
|
|
|
|
|
|
void set_multithreading_enabled(bool enabled); |
|
|
|
|
|
|
|
static void setThreadLocalState(const ThreadLocalState& state); |
|
|
|
private: |
|
c10::impl::LocalDispatchKeySet dispatch_key_; |
|
|
|
|
|
|
|
std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_; |
|
|
|
|
|
RecordFunctionTLS rf_tls_; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_; |
|
|
|
|
|
AutogradState autograd_tls_; |
|
|
|
|
|
c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; |
|
|
|
|
|
c10::impl::PyInterpreter* python_dispatcher_state_; |
|
|
|
|
|
at::impl::PythonTorchFunctionTLS python_torch_function_state_; |
|
|
|
|
|
at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; |
|
|
|
bool functionalization_reapply_views_state_; |
|
|
|
|
|
at::impl::ThreadLocalPythonObjects saved_objects_; |
|
|
|
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ |
|
!defined(BUILD_LITE_INTERPRETER) |
|
|
|
std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES> |
|
autocast_dtypes_{}; |
|
#endif |
|
|
|
friend class ThreadLocalStateGuard; |
|
}; |
|
|
|
|
|
class TORCH_API ThreadLocalStateGuard { |
|
public: |
|
explicit ThreadLocalStateGuard(const ThreadLocalState& state) |
|
: prev_state_(ThreadLocalState()) { |
|
|
|
ThreadLocalState::setThreadLocalState(state); |
|
} |
|
ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete; |
|
ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete; |
|
ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete; |
|
ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete; |
|
|
|
~ThreadLocalStateGuard() { |
|
|
|
ThreadLocalState::setThreadLocalState(prev_state_); |
|
} |
|
|
|
private: |
|
|
|
const ThreadLocalState prev_state_; |
|
}; |
|
|
|
template <typename T> |
|
auto wrapPropagateTLSState(T callback) { |
|
return [tls_state = ThreadLocalState(), |
|
callback = std::move(callback)](auto&&... args) { |
|
ThreadLocalStateGuard g(tls_state); |
|
|
|
return callback(std::forward<decltype(args)>(args)...); |
|
}; |
|
} |
|
|
|
} |
|
|