File size: 474 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "deepcompile.h"
#pragma once
namespace dc {
void register_graph_z1(long graph_id, const std::vector<long>& ds_ids);
void register_z1_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
int64_t offset);
} // namespace dc
|