File size: 3,494 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import copy

import torch

from deepspeed.accelerator import get_accelerator
from .passes import zero1_compile, zero3_compile
from .backend import make_backend, launch_compile_passes, init_schedule
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor

WARMUP = 5


def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None):

    optimizer = engine.optimizer
    optimizer.contiguous_gradients = False  # Avoid creating unnecessary buffer
    for hook in optimizer._grad_acc_hooks:
        hook.remove()
    optimizer._grad_acc_hooks.clear()

    dc = get_deepcompile_handle()
    dc.init(engine.data_parallel_group,
            engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
            is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, False,
            False)

    grad_buffer = {}

    for i, group in enumerate(optimizer.bit16_groups):

        grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i],
                                                      optimizer.first_offset[i],
                                                      optimizer.partition_size[i],
                                                      dtype=optimizer.gradient_accumulation_dtype,
                                                      device=get_accelerator().current_device_name(),
                                                      return_tensor_list=True)
        grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]]  # Maybe not necessary

        index_in_partition = 0
        first_in_partition = True
        for p in group:
            param_id = optimizer.get_param_id(p)
            p.param_id = param_id
            in_partition = optimizer.is_param_in_current_partition[param_id]

            if in_partition:
                buf = grad_buffer[i][index_in_partition]
                offset = optimizer.first_offset[i] if first_in_partition else 0
                # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}")
                dc.register_z1_param(p.param_id, p.shape, p, buf, int(offset))
                index_in_partition += 1
                first_in_partition = False
            else:
                # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None")
                dc.register_z1_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0)

    def set_grad_buffer():
        optimizer.averaged_gradients = copy.copy(grad_buffer)

    add_pre_backward_hook(set_grad_buffer)

    if schedule is None:
        schedule = []
        schedule.append((0, [zero1_compile.add_z1_reduce]))
    else:
        for opt in schedule:
            # avoid typical misconfiguration
            if zero3_compile.add_z3_gather_release in opt[1]:
                raise ValueError("A pass for ZeRO3 is not specified though ZeRO1 is enabled")

    init_schedule(schedule)

    engine.launch_compile_passes = launch_compile_passes
    return make_backend(backend,
                        compile_kwargs=compile_kwargs,
                        free_activation=False,
                        debug_log=compile_config.debug_log)