// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include "z3.h" #include "deepcompile.h" #define USE_C10D_NCCL #include #include #include #include #include #include #include namespace dc { const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000; class Z3CustomOpExecutor : public CustomOpExecutor { public: Z3CustomOpExecutor(c10::intrusive_ptr process_group, std::shared_ptr param_registry, std::shared_ptr reduce_buckets, std::vector ds_ids, ncclComm_t nccl_comm, at::cuda::CUDAStream ag_stream, at::cuda::CUDAStream rs_stream, at::cuda::CUDAStream copy_stream, at::cuda::CUDAStream offload_stream, at::cuda::CUDAStream reload_stream, bool pre_div_reduce) : CustomOpExecutor(process_group, param_registry, reduce_buckets, ds_ids, nccl_comm, rs_stream, copy_stream, pre_div_reduce), ag_stream_(ag_stream), offload_stream_(offload_stream), reload_stream_(reload_stream) { for (long ds_id : ds_ids_) { ag_comm_done_events_[ds_id] = std::make_shared(cudaEventDisableTiming); ag_comp_done_events_[ds_id] = std::make_shared(cudaEventDisableTiming); param_use_count_[ds_id] = 0; } } ~Z3CustomOpExecutor() {} void endBackward() override { if (param_updated_) { for (auto& it : has_acc_grad_) { it.second = false; param_registry_->setValid(it.first, false); } } for (auto& it : reload_buffers_) { it.second.record_stream(at::cuda::getCurrentCUDAStream()); } reload_buffers_.clear(); } void launchAllGather(at::Tensor output_buf, long ds_id, c10::intrusive_ptr symm_mem) { const DSParam& param = param_registry_->getParam(ds_id); const at::Tensor& ds_tensor = param.getDSTensor(); if (symm_mem == nullptr) { ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(), output_buf.data_ptr(), ds_tensor.numel(), get_nccl_data_type(ds_tensor.scalar_type()), nccl_comm_, ag_stream_); if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); } } else { at::cuda::CUDAStreamGuard guard(ag_stream_); int world_size = process_group_->getSize(); int rank = process_group_->getRank(); at::Tensor local_buf = symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); local_buf.copy_(ds_tensor, true); symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); auto chunks = output_buf.flatten().chunk(world_size); for (int step = 0; step < world_size; step++) { int remote_rank = (rank - step + world_size) % world_size; auto src_buf = symm_mem->get_buffer( remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); chunks[remote_rank].copy_(src_buf.flatten(), true); } symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); } param_registry_->registerGatheredParam(ds_id, output_buf); param_registry_->setValid(ds_id, true); } at::Tensor allgatherParam(long ds_id, c10::intrusive_ptr symm_mem) { if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); } const DSParam& param = param_registry_->getParam(ds_id); const at::Tensor& ds_tensor = param.getDSTensor(); at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id) ? param_registry_->getGatheredParam(ds_id) : torch::empty(param.getShape(), ds_tensor.options()); assert(hasKey(ag_comp_done_events_, ds_id)); ag_comp_done_events_[ds_id]->record(); ag_comp_done_events_[ds_id]->block(ag_stream_); launchAllGather(output_buf, ds_id, symm_mem); ag_comm_done_events_[ds_id]->record(ag_stream_); return output_buf; } void prefetchParamsFused(std::vector ds_ids, c10::intrusive_ptr symm_mem) { std::vector invalid_ds_ids; for (const auto& ds_id : ds_ids) { if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); } } std::unordered_map output_bufs; for (long ds_id : invalid_ds_ids) { const DSParam& param = param_registry_->getParam(ds_id); if (param_registry_->hasGatheredParam(ds_id)) { output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id); } else { output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options()); } } for (long ds_id : invalid_ds_ids) { ag_comp_done_events_[ds_id]->record(); ag_comp_done_events_[ds_id]->block(ag_stream_); } ncclGroupStart(); for (long ds_id : invalid_ds_ids) { assert(hasKey(output_bufs, ds_id)); launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem); } ncclGroupEnd(); for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); } } void releaseParam(long ds_id, long n_users) { const DSParam& param = param_registry_->getParam(ds_id); assert(hasKey(param_use_count_, ds_id)); if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; } param_use_count_[ds_id]--; if (param_use_count_[ds_id] == 0 && !param.isPersistent()) { at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id); if (gathered_param.defined()) { // gathered param is undefined while profiling const auto options = gathered_param.options(); at::Tensor empty_buffer = torch::empty({0}, options); gathered_param.set_data(empty_buffer); } param_registry_->unregisterGatheredParam(ds_id); } } at::Tensor waitAllgather(at::Tensor v, long ds_id) { assert(hasKey(ag_comm_done_events_, ds_id)); ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream()); return v; } void flushReduceBucket(at::ScalarType scalar_type) override { if (!hasKey(reduce_tasks_, scalar_type)) { return; } int64_t tmp_recv_numel = 0; for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); copy_done_event->block(rs_stream_); if (has_acc_grad_.at(t.getDSId())) { tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel(); } } at::Tensor tmp_recv_buf = at::Tensor(); if (tmp_recv_numel > 0) { at::cuda::CUDAStreamGuard guard(rs_stream_); tmp_recv_buf = torch::empty({tmp_recv_numel}, at::TensorOptions().dtype(scalar_type).device(at::kCUDA)); } ncclGroupStart(); int64_t offset = 0; for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); bool acc_grad = has_acc_grad_.at(t.getDSId()); if (acc_grad) { recv_buf = tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())}); } ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg; if (pre_div_reduce_) { at::cuda::CUDAStreamGuard guard(rs_stream_); t.getSendBuf().div_(process_group_->getSize()); } ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(), recv_buf.data_ptr(), recv_buf.numel(), get_nccl_data_type(scalar_type), op, nccl_comm_, rs_stream_); if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); } if (acc_grad) { offset += recv_buf.numel(); } } ncclGroupEnd(); { at::cuda::CUDAStreamGuard guard(rs_stream_); int64_t offset = 0; for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { bool acc_grad = has_acc_grad_.at(t.getDSId()); if (acc_grad) { auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); recv_buf.add_(tmp_recv_buf.index( {torch::indexing::Slice(offset, offset + recv_buf.numel())})); offset += recv_buf.numel(); } has_acc_grad_[t.getDSId()] = true; } } reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_); // Not very sure if this is necessary // Want to prevent grad tensor from being released before the copy is done auto comp_stream = at::cuda::getCurrentCUDAStream(); for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); copy_done_event->block(comp_stream); } reduce_tasks_[scalar_type].clear(); if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); } } at::Tensor offloadTensor(at::Tensor tensor, long id) { if (!hasKey(offload_events_, id)) { offload_events_[id] = std::make_shared(cudaEventDisableTiming); offload_comp_done_events_[id] = std::make_shared(cudaEventDisableTiming); const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU); offload_buffers_[id] = at::empty_like(tensor, options); } offload_comp_done_events_[id]->record(); offload_comp_done_events_[id]->block(offload_stream_); { at::cuda::CUDAStreamGuard guard(offload_stream_); offload_buffers_.at(id).copy_(tensor, true); } tensor.record_stream(offload_stream_); offload_events_[id]->record(offload_stream_); assert(hasKey(offload_buffers_, id)); return offload_buffers_.at(id); } at::Tensor reloadTensor(at::Tensor tensor, long id) { if (!hasKey(reload_events_, id)) { reload_events_[id] = std::make_shared(cudaEventDisableTiming); } assert(hasKey(offload_buffers_, id)); offload_events_[id]->block(reload_stream_); at::Tensor ten; { at::cuda::CUDAStreamGuard guard(reload_stream_); assert(hasKey(offload_buffers_, id)); at::Tensor buf = offload_buffers_.at(id); const auto options = at::TensorOptions().device(torch::kCUDA); ten = at::empty_like(buf, options); ten.copy_(buf, true); reload_buffers_[id] = ten; } reload_events_[id]->record(reload_stream_); return ten; } at::Tensor waitOffload(at::Tensor tensor, long id) { assert(hasKey(offload_events_, id)); offload_events_[id]->block(at::cuda::getCurrentCUDAStream()); assert(hasKey(offload_buffers_, id)); return offload_buffers_.at(id); } at::Tensor waitReload(at::Tensor tensor, long id) { assert(hasKey(reload_events_, id)); reload_events_[id]->block(at::cuda::getCurrentCUDAStream()); assert(hasKey(reload_buffers_, id)); auto ten = reload_buffers_.at(id); // We can't release here because the tensor is still being used // We will need "freeReloadedTensor" after the last user of the tensor to call // ".record_stream". As it is a bit complicated, we clear the buffer and do at the end of // the backward pass for now. reload_buffers_.erase(id); return ten; } void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); } void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); } bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); } bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); } private: at::cuda::CUDAStream ag_stream_; at::cuda::CUDAStream offload_stream_; at::cuda::CUDAStream reload_stream_; std::unordered_map> ag_comp_done_events_; std::unordered_map> ag_comm_done_events_; std::unordered_map> offload_events_; std::unordered_map> offload_comp_done_events_; std::unordered_map> reload_events_; std::unordered_map offload_buffers_; std::unordered_map reload_buffers_; std::unordered_map param_use_count_; }; static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); void register_graph_z3(long graph_id, const std::vector& ds_ids) { executors[graph_id] = std::make_shared(process_group, param_registry, reduce_buckets, ds_ids, nccl_comm, ag_stream, rs_stream, copy_stream, offload_stream, reload_stream, pre_div_reduce); } void register_z3_param(long ds_id, const std::vector& ds_shape, at::Tensor ds_tensor, at::Tensor grad_buffer, bool persistent) { param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent); if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); } } at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id) { auto executor = getExecutor(graph_id, executors); if (sync_before_allgather) { c10::cuda::device_synchronize(); } auto ret = executor->allgatherParam(ds_id, symm_mem); if (sync_after_allgather) { c10::cuda::device_synchronize(); } return ret; } void set_persistent(long ds_id) { param_registry->setPersistent(ds_id, true); // Allocate buffer here // Memory fragmentation will be more severe if we allocate in forward/backward for (auto& it : executors) { if (it.second->hasParam(ds_id)) { auto executor = getExecutor(it.first, executors); executor->allgatherParam(ds_id, symm_mem); } } } void prefetch_params_fused(long graph_id, const std::vector params, const std::vector& ds_ids) { auto executor = getExecutor(graph_id, executors); executor->prefetchParamsFused(ds_ids, symm_mem); } void prefetch_params_fused_meta(long graph_id, const std::vector params, const std::vector& ds_ids) { } // for profiling void invalidate_gathered_param(long ds_id) { const DSParam& param = param_registry->getParam(ds_id); if (param.isPersistent()) { return; } param_registry->unregisterGatheredParam(ds_id); param_registry->registerGatheredParam(ds_id, at::Tensor()); } void clear_all_gathered_params() { for (const auto& it : param_registry->getParams()) { long ds_id = it.first; const DSParam& param = param_registry->getParam(ds_id); if (param.isPersistent()) { continue; } if (param_registry->hasGatheredParam(ds_id)) { param_registry->unregisterGatheredParam(ds_id); } } } at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id) { const DSParam& param = param_registry->getParam(ds_id); auto options = param.getDSTensor().options().device(c10::kMeta); at::Tensor output_buf = torch::empty(param.getShape(), options); return output_buf; } at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users) { auto executor = getExecutor(graph_id, executors); executor->releaseParam(ds_id, n_users); if (clone_custom_op_output) { return dummy.clone(); } return dummy; } at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users) { return dummy; } at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id) { auto executor = getExecutor(graph_id, executors); executor->waitAllgather(v, ds_id); return v; } at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; } at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id) { auto executor = getExecutor(graph_id, executors); return executor->offloadTensor(tensor, id); } at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) { auto executor = getExecutor(graph_id, executors); return executor->reloadTensor(tensor, id); } at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id) { auto executor = getExecutor(graph_id, executors); return executor->waitOffload(tensor, id); } at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id) { auto executor = getExecutor(graph_id, executors); if (profile && !executor->hasReloadBuffer(id)) { return tensor; } return executor->waitReload(tensor, id); } at::Tensor test_call(at::Tensor a) { std::cout << "test_call" << std::endl; return a; } void reload_parameter(at::Tensor tensor, long graph_id, long ds_id) { auto executor = getExecutor(graph_id, executors); executor->reloadParameter(tensor, ds_id); } void offload_parameter(at::Tensor tensor, long graph_id, long ds_id) { auto executor = getExecutor(graph_id, executors); executor->offloadParameter(tensor, ds_id); } void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {} void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {} } // namespace dc