File size: 5,457 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "z1.h"
#include "deepcompile.h"
#define USE_C10D_NCCL
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace dc {
class Z1CustomOpExecutor : public CustomOpExecutor {
public:
Z1CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::shared_ptr<DSParamRegistry> param_registry,
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
std::vector<long> ds_ids,
ncclComm_t nccl_comm,
at::cuda::CUDAStream rs_stream,
at::cuda::CUDAStream copy_stream,
bool pre_div_reduce)
: CustomOpExecutor(process_group,
param_registry,
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
pre_div_reduce)
{
}
~Z1CustomOpExecutor() {}
void endBackward() override
{
if (param_updated_) {
for (auto& it : has_acc_grad_) { it.second = false; }
}
}
void flushReduceBucket(at::ScalarType scalar_type) override
{
int rank = process_group_->getRank();
if (!hasKey(reduce_tasks_, scalar_type)) { return; }
int64_t tmp_recv_numel = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
copy_done_event->block(rs_stream_);
}
ncclGroupStart();
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg;
if (pre_div_reduce_) {
at::cuda::CUDAStreamGuard guard(rs_stream_);
t.getSendBuf().div_(process_group_->getSize());
}
// inplace
ncclResult_t result = ncclAllReduce(t.getSendBuf().data_ptr(),
t.getSendBuf().data_ptr(),
t.getSendBuf().numel(),
get_nccl_data_type(scalar_type),
op,
nccl_comm_,
rs_stream_);
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllReduce failed"); }
}
ncclGroupEnd();
{
at::cuda::CUDAStreamGuard guard(rs_stream_);
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
bool acc_grad = has_acc_grad_.at(t.getDSId());
auto param = param_registry_->getParam(t.getDSId());
auto grad_buf = param.getGradBuffer().flatten();
if (grad_buf.numel() == 0) { continue; }
int64_t offset = param.getOffset();
auto recv_buf = t.getSendBuf().flatten().index(
{torch::indexing::Slice(offset, offset + grad_buf.numel())});
if (acc_grad) {
grad_buf.add_(recv_buf);
} else {
grad_buf.copy_(recv_buf);
}
has_acc_grad_[t.getDSId()] = true;
}
}
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_);
// Not very sure if this is necessary
// Want to prevent grad tensor from being released before the copy is done
auto comp_stream = at::cuda::getCurrentCUDAStream();
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
copy_done_event->block(comp_stream);
}
reduce_tasks_[scalar_type].clear();
}
};
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
void register_graph_z1(long graph_id, const std::vector<long>& ds_ids)
{
executors[graph_id] = std::make_shared<Z1CustomOpExecutor>(process_group,
param_registry,
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
pre_div_reduce);
}
void register_z1_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
int64_t offset)
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false);
}
} // namespace dc
|