// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include "deepcompile.h" #define USE_C10D_NCCL namespace dc { std::shared_ptr param_registry; std::unordered_map> executors; std::shared_ptr reduce_buckets = nullptr; c10::intrusive_ptr process_group = nullptr; c10::intrusive_ptr symm_mem = nullptr; ncclComm_t nccl_comm; bool use_symm_mem; bool clone_custom_op_output; bool profile = false; bool pre_div_reduce = true; bool sync_before_reduce; // for debugging bool sync_after_reduce; // for debugging bool sync_before_allgather; // for debugging bool sync_after_allgather; // for debugging std::vector sizes_to_int_vector(at::IntArrayRef sizes) { std::vector result; for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); } return result; } void enable_profiling(bool enable) { profile = enable; } bool is_profiling() { return profile; } c10::intrusive_ptr getSymmMemWorkspace(int64_t size) { c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device()); std::vector sizes = {size}; std::vector strides = {1}; at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p( {size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt); return c10d::symmetric_memory::rendezvous(sym_mem_ws); } void lazy_init_symm_memory() { if (use_symm_mem && !symm_mem) { int64_t max_param_size = 0; for (const auto& it : param_registry->getParams()) { int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size(); if (size > max_param_size) { max_param_size = size; } } symm_mem = getSymmMemWorkspace(max_param_size); } } ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type) { switch (scalar_type) { case at::kFloat: return ncclFloat; case at::kHalf: return ncclHalf; case at::kDouble: return ncclDouble; case at::kBFloat16: return ncclBfloat16; case at::kLong: return ncclInt64; case at::kInt: return ncclInt; case at::kChar: return ncclInt8; default: throw std::runtime_error("Unsupported scalar type"); } } void reset() { executors.clear(); // We keep the buckets for memory estimation // reduce_buckets->clear(); } void cleanup() { reset(); ncclCommDestroy(nccl_comm); process_group = nullptr; symm_mem = nullptr; } at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id) { if (sync_before_reduce) { c10::cuda::device_synchronize(); } assert(hasKey(executors, graph_id)); if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); } if (sync_after_reduce) { c10::cuda::device_synchronize(); } return at::Tensor(); } at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id) { return at::Tensor(); } void free_tensors(std::vector tensors) { int64_t THRESHOLD = 10 * 1024 * 1024; if (!profile) { for (auto& tensor : tensors) { if (tensor.is_cuda() && tensor.numel() > THRESHOLD) { tensor.record_stream(at::cuda::getCurrentCUDAStream()); tensor.set_data(torch::empty({0}, tensor.options())); } } } } void free_tensors_meta(std::vector tensors) {} void init(c10::intrusive_ptr pg, int64_t initial_reduce_bucket_size, bool enable_double_buffer, bool _use_symm_mem, bool _clone_custom_op_output, bool _sync_before_reduce, bool _sync_after_reduce, bool _sync_before_allgather, bool _sync_after_allgather) { process_group = pg; ncclUniqueId ncclID; ncclGetUniqueId(&ncclID); // ProcessGroup doesn't have an API to get the CUDA stream for comm calls. // So we create a NCCL communicator and call NCCL APIs directly. auto vec = std::vector(reinterpret_cast(&ncclID), reinterpret_cast(&ncclID) + NCCL_UNIQUE_ID_BYTES); auto device = torch::Device(torch::kCUDA); at::Tensor tensor = torch::from_blob(vec.data(), {static_cast(vec.size())}, torch::kUInt8) .to(torch::Device(torch::kCUDA)); std::vector bcast_input = {tensor}; process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait(); // create a new nccl communicator std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES); ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank()); param_registry = std::make_shared(); reduce_buckets = std::make_shared(initial_reduce_bucket_size, enable_double_buffer); use_symm_mem = _use_symm_mem; clone_custom_op_output = _clone_custom_op_output; sync_before_reduce = _sync_before_reduce; sync_after_reduce = _sync_after_reduce; sync_before_allgather = _sync_before_allgather; sync_after_allgather = _sync_after_allgather; } void start_forward() { lazy_init_symm_memory(); for (auto& it : executors) { it.second->startForward(); } } void end_forward() { for (auto& it : executors) { it.second->endForward(); } } void start_backward(bool update) { for (auto& it : executors) { it.second->startBackward(update); } } // We don't call this // void end_backward(bool update) // { // } } // namespace dc