|
|
|
|
|
|
|
|
|
|
|
#include "z3.h" |
|
#include "deepcompile.h" |
|
|
|
#define USE_C10D_NCCL |
|
|
|
#include <ATen/cuda/CUDAEvent.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
#include <c10/cuda/CUDAStream.h> |
|
#include <torch/csrc/cuda/nccl.h> |
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
|
|
|
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp> |
|
|
|
namespace dc { |
|
|
|
const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000; |
|
|
|
class Z3CustomOpExecutor : public CustomOpExecutor { |
|
public: |
|
Z3CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group, |
|
std::shared_ptr<DSParamRegistry> param_registry, |
|
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets, |
|
std::vector<long> ds_ids, |
|
ncclComm_t nccl_comm, |
|
at::cuda::CUDAStream ag_stream, |
|
at::cuda::CUDAStream rs_stream, |
|
at::cuda::CUDAStream copy_stream, |
|
at::cuda::CUDAStream offload_stream, |
|
at::cuda::CUDAStream reload_stream, |
|
bool pre_div_reduce) |
|
: CustomOpExecutor(process_group, |
|
param_registry, |
|
reduce_buckets, |
|
ds_ids, |
|
nccl_comm, |
|
rs_stream, |
|
copy_stream, |
|
pre_div_reduce), |
|
ag_stream_(ag_stream), |
|
offload_stream_(offload_stream), |
|
reload_stream_(reload_stream) |
|
{ |
|
for (long ds_id : ds_ids_) { |
|
ag_comm_done_events_[ds_id] = |
|
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming); |
|
ag_comp_done_events_[ds_id] = |
|
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming); |
|
|
|
param_use_count_[ds_id] = 0; |
|
} |
|
} |
|
~Z3CustomOpExecutor() {} |
|
|
|
void endBackward() override |
|
{ |
|
if (param_updated_) { |
|
for (auto& it : has_acc_grad_) { |
|
it.second = false; |
|
param_registry_->setValid(it.first, false); |
|
} |
|
} |
|
|
|
for (auto& it : reload_buffers_) { |
|
it.second.record_stream(at::cuda::getCurrentCUDAStream()); |
|
} |
|
reload_buffers_.clear(); |
|
} |
|
|
|
void launchAllGather(at::Tensor output_buf, |
|
long ds_id, |
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem) |
|
{ |
|
const DSParam& param = param_registry_->getParam(ds_id); |
|
const at::Tensor& ds_tensor = param.getDSTensor(); |
|
|
|
if (symm_mem == nullptr) { |
|
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(), |
|
output_buf.data_ptr(), |
|
ds_tensor.numel(), |
|
get_nccl_data_type(ds_tensor.scalar_type()), |
|
nccl_comm_, |
|
ag_stream_); |
|
|
|
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); } |
|
} else { |
|
at::cuda::CUDAStreamGuard guard(ag_stream_); |
|
int world_size = process_group_->getSize(); |
|
int rank = process_group_->getRank(); |
|
|
|
at::Tensor local_buf = |
|
symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); |
|
local_buf.copy_(ds_tensor, true); |
|
|
|
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); |
|
auto chunks = output_buf.flatten().chunk(world_size); |
|
for (int step = 0; step < world_size; step++) { |
|
int remote_rank = (rank - step + world_size) % world_size; |
|
auto src_buf = symm_mem->get_buffer( |
|
remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); |
|
chunks[remote_rank].copy_(src_buf.flatten(), true); |
|
} |
|
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); |
|
} |
|
|
|
param_registry_->registerGatheredParam(ds_id, output_buf); |
|
param_registry_->setValid(ds_id, true); |
|
} |
|
|
|
at::Tensor allgatherParam(long ds_id, |
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem) |
|
{ |
|
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); } |
|
|
|
const DSParam& param = param_registry_->getParam(ds_id); |
|
const at::Tensor& ds_tensor = param.getDSTensor(); |
|
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id) |
|
? param_registry_->getGatheredParam(ds_id) |
|
: torch::empty(param.getShape(), ds_tensor.options()); |
|
|
|
assert(hasKey(ag_comp_done_events_, ds_id)); |
|
ag_comp_done_events_[ds_id]->record(); |
|
ag_comp_done_events_[ds_id]->block(ag_stream_); |
|
|
|
launchAllGather(output_buf, ds_id, symm_mem); |
|
|
|
ag_comm_done_events_[ds_id]->record(ag_stream_); |
|
return output_buf; |
|
} |
|
|
|
void prefetchParamsFused(std::vector<int64_t> ds_ids, |
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem) |
|
{ |
|
std::vector<int64_t> invalid_ds_ids; |
|
for (const auto& ds_id : ds_ids) { |
|
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); } |
|
} |
|
|
|
std::unordered_map<long, at::Tensor> output_bufs; |
|
for (long ds_id : invalid_ds_ids) { |
|
const DSParam& param = param_registry_->getParam(ds_id); |
|
if (param_registry_->hasGatheredParam(ds_id)) { |
|
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id); |
|
} else { |
|
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options()); |
|
} |
|
} |
|
|
|
for (long ds_id : invalid_ds_ids) { |
|
ag_comp_done_events_[ds_id]->record(); |
|
ag_comp_done_events_[ds_id]->block(ag_stream_); |
|
} |
|
|
|
ncclGroupStart(); |
|
for (long ds_id : invalid_ds_ids) { |
|
assert(hasKey(output_bufs, ds_id)); |
|
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem); |
|
} |
|
ncclGroupEnd(); |
|
|
|
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); } |
|
} |
|
|
|
void releaseParam(long ds_id, long n_users) |
|
{ |
|
const DSParam& param = param_registry_->getParam(ds_id); |
|
|
|
assert(hasKey(param_use_count_, ds_id)); |
|
if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; } |
|
param_use_count_[ds_id]--; |
|
|
|
if (param_use_count_[ds_id] == 0 && !param.isPersistent()) { |
|
at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id); |
|
|
|
if (gathered_param.defined()) { |
|
const auto options = gathered_param.options(); |
|
at::Tensor empty_buffer = torch::empty({0}, options); |
|
gathered_param.set_data(empty_buffer); |
|
} |
|
|
|
param_registry_->unregisterGatheredParam(ds_id); |
|
} |
|
} |
|
|
|
at::Tensor waitAllgather(at::Tensor v, long ds_id) |
|
{ |
|
assert(hasKey(ag_comm_done_events_, ds_id)); |
|
ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream()); |
|
return v; |
|
} |
|
|
|
void flushReduceBucket(at::ScalarType scalar_type) override |
|
{ |
|
if (!hasKey(reduce_tasks_, scalar_type)) { return; } |
|
|
|
int64_t tmp_recv_numel = 0; |
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { |
|
auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); |
|
copy_done_event->block(rs_stream_); |
|
|
|
if (has_acc_grad_.at(t.getDSId())) { |
|
tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel(); |
|
} |
|
} |
|
|
|
at::Tensor tmp_recv_buf = at::Tensor(); |
|
if (tmp_recv_numel > 0) { |
|
at::cuda::CUDAStreamGuard guard(rs_stream_); |
|
tmp_recv_buf = torch::empty({tmp_recv_numel}, |
|
at::TensorOptions().dtype(scalar_type).device(at::kCUDA)); |
|
} |
|
|
|
ncclGroupStart(); |
|
int64_t offset = 0; |
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { |
|
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); |
|
|
|
bool acc_grad = has_acc_grad_.at(t.getDSId()); |
|
|
|
if (acc_grad) { |
|
recv_buf = |
|
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())}); |
|
} |
|
|
|
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg; |
|
if (pre_div_reduce_) { |
|
at::cuda::CUDAStreamGuard guard(rs_stream_); |
|
t.getSendBuf().div_(process_group_->getSize()); |
|
} |
|
ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(), |
|
recv_buf.data_ptr(), |
|
recv_buf.numel(), |
|
get_nccl_data_type(scalar_type), |
|
op, |
|
nccl_comm_, |
|
rs_stream_); |
|
if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); } |
|
|
|
if (acc_grad) { offset += recv_buf.numel(); } |
|
} |
|
ncclGroupEnd(); |
|
|
|
{ |
|
at::cuda::CUDAStreamGuard guard(rs_stream_); |
|
int64_t offset = 0; |
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { |
|
bool acc_grad = has_acc_grad_.at(t.getDSId()); |
|
|
|
if (acc_grad) { |
|
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); |
|
recv_buf.add_(tmp_recv_buf.index( |
|
{torch::indexing::Slice(offset, offset + recv_buf.numel())})); |
|
offset += recv_buf.numel(); |
|
} |
|
has_acc_grad_[t.getDSId()] = true; |
|
} |
|
} |
|
|
|
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_); |
|
|
|
|
|
|
|
auto comp_stream = at::cuda::getCurrentCUDAStream(); |
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { |
|
auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); |
|
copy_done_event->block(comp_stream); |
|
} |
|
reduce_tasks_[scalar_type].clear(); |
|
|
|
if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); } |
|
} |
|
|
|
at::Tensor offloadTensor(at::Tensor tensor, long id) |
|
{ |
|
if (!hasKey(offload_events_, id)) { |
|
offload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming); |
|
offload_comp_done_events_[id] = |
|
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming); |
|
|
|
const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU); |
|
offload_buffers_[id] = at::empty_like(tensor, options); |
|
} |
|
|
|
offload_comp_done_events_[id]->record(); |
|
offload_comp_done_events_[id]->block(offload_stream_); |
|
{ |
|
at::cuda::CUDAStreamGuard guard(offload_stream_); |
|
offload_buffers_.at(id).copy_(tensor, true); |
|
} |
|
|
|
tensor.record_stream(offload_stream_); |
|
|
|
offload_events_[id]->record(offload_stream_); |
|
assert(hasKey(offload_buffers_, id)); |
|
return offload_buffers_.at(id); |
|
} |
|
|
|
at::Tensor reloadTensor(at::Tensor tensor, long id) |
|
{ |
|
if (!hasKey(reload_events_, id)) { |
|
reload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming); |
|
} |
|
|
|
assert(hasKey(offload_buffers_, id)); |
|
offload_events_[id]->block(reload_stream_); |
|
|
|
at::Tensor ten; |
|
{ |
|
at::cuda::CUDAStreamGuard guard(reload_stream_); |
|
|
|
assert(hasKey(offload_buffers_, id)); |
|
at::Tensor buf = offload_buffers_.at(id); |
|
const auto options = at::TensorOptions().device(torch::kCUDA); |
|
ten = at::empty_like(buf, options); |
|
ten.copy_(buf, true); |
|
|
|
reload_buffers_[id] = ten; |
|
} |
|
|
|
reload_events_[id]->record(reload_stream_); |
|
return ten; |
|
} |
|
|
|
at::Tensor waitOffload(at::Tensor tensor, long id) |
|
{ |
|
assert(hasKey(offload_events_, id)); |
|
offload_events_[id]->block(at::cuda::getCurrentCUDAStream()); |
|
|
|
assert(hasKey(offload_buffers_, id)); |
|
return offload_buffers_.at(id); |
|
} |
|
|
|
at::Tensor waitReload(at::Tensor tensor, long id) |
|
{ |
|
assert(hasKey(reload_events_, id)); |
|
reload_events_[id]->block(at::cuda::getCurrentCUDAStream()); |
|
|
|
assert(hasKey(reload_buffers_, id)); |
|
auto ten = reload_buffers_.at(id); |
|
|
|
|
|
|
|
|
|
|
|
return ten; |
|
} |
|
|
|
void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); } |
|
void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); } |
|
|
|
bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); } |
|
|
|
bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); } |
|
|
|
private: |
|
at::cuda::CUDAStream ag_stream_; |
|
at::cuda::CUDAStream offload_stream_; |
|
at::cuda::CUDAStream reload_stream_; |
|
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comp_done_events_; |
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comm_done_events_; |
|
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_events_; |
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_comp_done_events_; |
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> reload_events_; |
|
std::unordered_map<long, at::Tensor> offload_buffers_; |
|
std::unordered_map<long, at::Tensor> reload_buffers_; |
|
|
|
std::unordered_map<long, long> param_use_count_; |
|
}; |
|
|
|
static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); |
|
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); |
|
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); |
|
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); |
|
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); |
|
|
|
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids) |
|
{ |
|
executors[graph_id] = std::make_shared<Z3CustomOpExecutor>(process_group, |
|
param_registry, |
|
reduce_buckets, |
|
ds_ids, |
|
nccl_comm, |
|
ag_stream, |
|
rs_stream, |
|
copy_stream, |
|
offload_stream, |
|
reload_stream, |
|
pre_div_reduce); |
|
} |
|
|
|
void register_z3_param(long ds_id, |
|
const std::vector<int64_t>& ds_shape, |
|
at::Tensor ds_tensor, |
|
at::Tensor grad_buffer, |
|
bool persistent) |
|
{ |
|
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent); |
|
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); } |
|
} |
|
|
|
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
|
|
if (sync_before_allgather) { c10::cuda::device_synchronize(); } |
|
auto ret = executor->allgatherParam(ds_id, symm_mem); |
|
if (sync_after_allgather) { c10::cuda::device_synchronize(); } |
|
return ret; |
|
} |
|
|
|
void set_persistent(long ds_id) |
|
{ |
|
param_registry->setPersistent(ds_id, true); |
|
|
|
|
|
|
|
for (auto& it : executors) { |
|
if (it.second->hasParam(ds_id)) { |
|
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors); |
|
executor->allgatherParam(ds_id, symm_mem); |
|
} |
|
} |
|
} |
|
|
|
void prefetch_params_fused(long graph_id, |
|
const std::vector<at::Tensor> params, |
|
const std::vector<long>& ds_ids) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
executor->prefetchParamsFused(ds_ids, symm_mem); |
|
} |
|
|
|
void prefetch_params_fused_meta(long graph_id, |
|
const std::vector<at::Tensor> params, |
|
const std::vector<long>& ds_ids) |
|
{ |
|
} |
|
|
|
|
|
void invalidate_gathered_param(long ds_id) |
|
{ |
|
const DSParam& param = param_registry->getParam(ds_id); |
|
if (param.isPersistent()) { return; } |
|
|
|
param_registry->unregisterGatheredParam(ds_id); |
|
param_registry->registerGatheredParam(ds_id, at::Tensor()); |
|
} |
|
|
|
void clear_all_gathered_params() |
|
{ |
|
for (const auto& it : param_registry->getParams()) { |
|
long ds_id = it.first; |
|
const DSParam& param = param_registry->getParam(ds_id); |
|
if (param.isPersistent()) { continue; } |
|
if (param_registry->hasGatheredParam(ds_id)) { |
|
param_registry->unregisterGatheredParam(ds_id); |
|
} |
|
} |
|
} |
|
|
|
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id) |
|
{ |
|
const DSParam& param = param_registry->getParam(ds_id); |
|
auto options = param.getDSTensor().options().device(c10::kMeta); |
|
at::Tensor output_buf = torch::empty(param.getShape(), options); |
|
return output_buf; |
|
} |
|
|
|
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
executor->releaseParam(ds_id, n_users); |
|
|
|
if (clone_custom_op_output) { return dummy.clone(); } |
|
return dummy; |
|
} |
|
|
|
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users) |
|
{ |
|
return dummy; |
|
} |
|
|
|
at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
executor->waitAllgather(v, ds_id); |
|
return v; |
|
} |
|
|
|
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; } |
|
|
|
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
return executor->offloadTensor(tensor, id); |
|
} |
|
|
|
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
return executor->reloadTensor(tensor, id); |
|
} |
|
|
|
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
return executor->waitOffload(tensor, id); |
|
} |
|
|
|
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
if (profile && !executor->hasReloadBuffer(id)) { return tensor; } |
|
return executor->waitReload(tensor, id); |
|
} |
|
|
|
at::Tensor test_call(at::Tensor a) |
|
{ |
|
std::cout << "test_call" << std::endl; |
|
return a; |
|
} |
|
|
|
void reload_parameter(at::Tensor tensor, long graph_id, long ds_id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
executor->reloadParameter(tensor, ds_id); |
|
} |
|
|
|
void offload_parameter(at::Tensor tensor, long graph_id, long ds_id) |
|
{ |
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors); |
|
executor->offloadParameter(tensor, ds_id); |
|
} |
|
void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {} |
|
void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {} |
|
|
|
} |
|
|