|
|
|
|
|
|
|
|
|
|
|
#include "deepcompile.h" |
|
|
|
#define USE_C10D_NCCL |
|
|
|
namespace dc { |
|
|
|
std::shared_ptr<DSParamRegistry> param_registry; |
|
std::unordered_map<long, std::shared_ptr<CustomOpExecutor>> executors; |
|
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets = nullptr; |
|
|
|
c10::intrusive_ptr<c10d::ProcessGroup> process_group = nullptr; |
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem = nullptr; |
|
ncclComm_t nccl_comm; |
|
bool use_symm_mem; |
|
bool clone_custom_op_output; |
|
bool profile = false; |
|
bool pre_div_reduce = true; |
|
|
|
bool sync_before_reduce; |
|
bool sync_after_reduce; |
|
bool sync_before_allgather; |
|
bool sync_after_allgather; |
|
|
|
std::vector<int64_t> sizes_to_int_vector(at::IntArrayRef sizes) |
|
{ |
|
std::vector<int64_t> result; |
|
for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); } |
|
return result; |
|
} |
|
|
|
void enable_profiling(bool enable) { profile = enable; } |
|
|
|
bool is_profiling() { return profile; } |
|
|
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> getSymmMemWorkspace(int64_t size) |
|
{ |
|
c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device()); |
|
std::vector<int64_t> sizes = {size}; |
|
std::vector<int64_t> strides = {1}; |
|
at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p( |
|
{size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt); |
|
return c10d::symmetric_memory::rendezvous(sym_mem_ws); |
|
} |
|
|
|
void lazy_init_symm_memory() |
|
{ |
|
if (use_symm_mem && !symm_mem) { |
|
int64_t max_param_size = 0; |
|
for (const auto& it : param_registry->getParams()) { |
|
int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size(); |
|
if (size > max_param_size) { max_param_size = size; } |
|
} |
|
symm_mem = getSymmMemWorkspace(max_param_size); |
|
} |
|
} |
|
|
|
ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type) |
|
{ |
|
switch (scalar_type) { |
|
case at::kFloat: return ncclFloat; |
|
case at::kHalf: return ncclHalf; |
|
case at::kDouble: return ncclDouble; |
|
case at::kBFloat16: return ncclBfloat16; |
|
case at::kLong: return ncclInt64; |
|
case at::kInt: return ncclInt; |
|
case at::kChar: return ncclInt8; |
|
default: throw std::runtime_error("Unsupported scalar type"); |
|
} |
|
} |
|
|
|
void reset() |
|
{ |
|
executors.clear(); |
|
|
|
|
|
} |
|
|
|
void cleanup() |
|
{ |
|
reset(); |
|
|
|
ncclCommDestroy(nccl_comm); |
|
process_group = nullptr; |
|
symm_mem = nullptr; |
|
} |
|
|
|
at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id) |
|
{ |
|
if (sync_before_reduce) { c10::cuda::device_synchronize(); } |
|
|
|
assert(hasKey(executors, graph_id)); |
|
if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); } |
|
|
|
if (sync_after_reduce) { c10::cuda::device_synchronize(); } |
|
|
|
return at::Tensor(); |
|
} |
|
|
|
at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id) |
|
{ |
|
return at::Tensor(); |
|
} |
|
|
|
void free_tensors(std::vector<at::Tensor> tensors) |
|
{ |
|
int64_t THRESHOLD = 10 * 1024 * 1024; |
|
|
|
if (!profile) { |
|
for (auto& tensor : tensors) { |
|
if (tensor.is_cuda() && tensor.numel() > THRESHOLD) { |
|
tensor.record_stream(at::cuda::getCurrentCUDAStream()); |
|
tensor.set_data(torch::empty({0}, tensor.options())); |
|
} |
|
} |
|
} |
|
} |
|
|
|
void free_tensors_meta(std::vector<at::Tensor> tensors) {} |
|
|
|
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg, |
|
int64_t initial_reduce_bucket_size, |
|
bool enable_double_buffer, |
|
bool _use_symm_mem, |
|
bool _clone_custom_op_output, |
|
bool _sync_before_reduce, |
|
bool _sync_after_reduce, |
|
bool _sync_before_allgather, |
|
bool _sync_after_allgather) |
|
{ |
|
process_group = pg; |
|
|
|
ncclUniqueId ncclID; |
|
ncclGetUniqueId(&ncclID); |
|
|
|
|
|
|
|
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(&ncclID), |
|
reinterpret_cast<uint8_t*>(&ncclID) + NCCL_UNIQUE_ID_BYTES); |
|
auto device = torch::Device(torch::kCUDA); |
|
at::Tensor tensor = torch::from_blob(vec.data(), {static_cast<long>(vec.size())}, torch::kUInt8) |
|
.to(torch::Device(torch::kCUDA)); |
|
std::vector<at::Tensor> bcast_input = {tensor}; |
|
|
|
process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait(); |
|
|
|
|
|
std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES); |
|
ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank()); |
|
|
|
param_registry = std::make_shared<DSParamRegistry>(); |
|
reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(initial_reduce_bucket_size, |
|
enable_double_buffer); |
|
use_symm_mem = _use_symm_mem; |
|
clone_custom_op_output = _clone_custom_op_output; |
|
|
|
sync_before_reduce = _sync_before_reduce; |
|
sync_after_reduce = _sync_after_reduce; |
|
sync_before_allgather = _sync_before_allgather; |
|
sync_after_allgather = _sync_after_allgather; |
|
} |
|
|
|
void start_forward() |
|
{ |
|
lazy_init_symm_memory(); |
|
for (auto& it : executors) { it.second->startForward(); } |
|
} |
|
|
|
void end_forward() |
|
{ |
|
for (auto& it : executors) { it.second->endForward(); } |
|
} |
|
|
|
void start_backward(bool update) |
|
{ |
|
for (auto& it : executors) { it.second->startBackward(update); } |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|