namespace at::internal { | |
template <typename F> | |
inline void invoke_parallel( | |
int64_t begin, | |
int64_t end, | |
int64_t grain_size, | |
const F& f) { | |
std::atomic_flag err_flag = ATOMIC_FLAG_INIT; | |
std::exception_ptr eptr; | |
{ | |
// choose number of tasks based on grain size and number of threads | |
// can't use num_threads clause due to bugs in GOMP's thread pool (See | |
// #32008) | |
int64_t num_threads = omp_get_num_threads(); | |
if (grain_size > 0) { | |
num_threads = std::min(num_threads, divup((end - begin), grain_size)); | |
} | |
int64_t tid = omp_get_thread_num(); | |
int64_t chunk_size = divup((end - begin), num_threads); | |
int64_t begin_tid = begin + tid * chunk_size; | |
if (begin_tid < end) { | |
try { | |
internal::ThreadIdGuard tid_guard(tid); | |
f(begin_tid, std::min(end, chunk_size + begin_tid)); | |
} catch (...) { | |
if (!err_flag.test_and_set()) { | |
eptr = std::current_exception(); | |
} | |
} | |
} | |
} | |
if (eptr) { | |
std::rethrow_exception(eptr); | |
} | |
} | |
} // namespace at::internal | |