File size: 3,574 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .optimized_linear import LoRAOptimizedLinear, OptimizedLinear

import torch

try:
    import transformers
except ImportError:
    transformers = None


def init_lora(model):
    model.requires_grad_(False)
    for m in model.modules():
        if isinstance(m, LoRAOptimizedLinear):
            m.init_lora()


class Init(object):
    """
    Init context wrapper similar in style to zero.Init. Allows for injecting OptimizedLinear during model
    construction which will shard base weights and reduce overall memory usage during model init. Primarily
    useful when initializing a model via transformers.AutoModelForCausalLM.

    Example usage:
        lora_config = deepspeed.linear.LoRAConfig(..)
        quant_config = deepspeed.linear.QuantizationConfig(..)
        with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config):
            model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B")

    """

    def __init__(self, lora_config=None, quant_config=None):
        self._orig_nn_linear = torch.nn.Linear
        self._orig_causallm_pretrained = None
        if transformers != None:
            self._orig_causallm_pretrained = transformers.AutoModelForCausalLM.from_pretrained
            self._orig_causallm_config = transformers.AutoModelForCausalLM.from_config
        self.lora_config = lora_config
        self.quant_config = quant_config
        self._post_init_complete = False

    def __enter__(self):

        class OptLinearWrapper:
            _orig_nn_linear = self._orig_nn_linear
            _lora_config = self.lora_config
            _quant_config = self.quant_config

            def __new__(self, *args, **kwargs):
                self._lora_config.delay_lora_init = True
                kwargs['lora_config'] = self._lora_config
                kwargs['quantization_config'] = self._quant_config
                kwargs['linear_cls'] = self._orig_nn_linear
                return OptimizedLinear(*args, **kwargs)

        def _model_init(model):
            if self.lora_config != None:
                init_lora(model)
            self._post_init_complete = True
            return model

        # ensures non-lora params are frozen and lora weights are initialized
        def from_pretrained(*args, **kwargs):
            model = self._orig_causallm_pretrained(*args, **kwargs)
            return _model_init(model)

        def from_config(*args, **kwargs):
            model = self._orig_causallm_config(*args, **kwargs)
            return _model_init(model)

        torch.nn.Linear = OptLinearWrapper
        if transformers != None:
            transformers.AutoModelForCausalLM.from_pretrained = from_pretrained
            transformers.AutoModelForCausalLM.from_config = from_config

    def __exit__(self, *args, **kwargs):
        torch.nn.Linear = self._orig_nn_linear
        if not self._post_init_complete:
            print('WARNING: For some reason LoRA modules are not initialized, this is usually done automatically '
                  'if using transformers via (AutoModelForCausalLM from_pretrained/from_config). '
                  'You must call `init_lora` on each module in order to use DeepSpeed LoRA, otherwise '
                  'you will error out during runtime.')
        else:
            transformers.AutoModelForCausalLM.from_pretrained = self._orig_causallm_pretrained
            transformers.AutoModelForCausalLM.from_config = self._orig_causallm_config