File size: 31,397 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from deepspeed import comm as dist
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.runtime.zero.utils import is_zero_param
from abc import ABC, abstractmethod
from typing import Iterable, Any, Optional, List, Tuple
from .fusedqkv_utils import shard_value_with_share_qk, shard_chunk_mlp, prepare_tp_fused_qkvw
from deepspeed.runtime.tensor_parallel import AUTOTP_MODE
from copy import deepcopy
from typing import Union

__all__ = [
    "TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
    "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
]

DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'


def get_auto_tp_mode():
    global DEEPSPEED_AUTOTP_MODE
    return DEEPSPEED_AUTOTP_MODE


def is_autotp_training_mode():
    global DEEPSPEED_AUTOTP_MODE
    return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING


def set_autotp_mode(training=False):
    """
    Set the DEEPSPEED_AUTOTP_MODE based on the training flag
    """
    global DEEPSPEED_AUTOTP_MODE
    if training:
        DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.TRAINING
    else:
        DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE


def add_bias(input, bias):
    if bias is None:
        return input
    if is_autotp_training_mode():
        # Training mode - avoid inplace to ensure correct autograd
        input = input + bias
        return input
    else:
        input += bias
        return input


class RowParallel(torch.autograd.Function):
    """
    A custom autograd function for performing row-wise parallelism.
    """

    @staticmethod
    def symbolic(graph, input):
        """Symbolic function for tracing."""
        return input

    @staticmethod
    def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, is_inference_mode: bool) -> torch.Tensor:
        """
        Forward pass.
        """
        ctx.group = group
        if group == None:
            return input
        if is_inference_mode:
            dist.inference_all_reduce(input, group=group)
        else:
            dist.all_reduce(input.contiguous(), group=group)
        return input

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None]:
        """
        Backward pass.
        """
        return None, grad_output, None


class AsyncColumnParallel(torch.autograd.Function):

    @staticmethod
    def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor:
        """
        Forward pass.
        """
        ctx.use_bias = bias is not None
        ctx.group = group
        output = torch.matmul(input, weight.transpose(-1, -2))
        if bias is not None:
            output = add_bias(output, bias)

        ctx.save_for_backward(input, weight)

        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:

        input, weight = ctx.saved_tensors
        grad_input = grad_output.matmul(weight)
        handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
        grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
        grad_bias = grad_output.sum(0) if ctx.use_bias else None
        handle.wait()
        return None, grad_input, grad_weight, grad_bias


class ColumnParallel(torch.autograd.Function):
    """
    Custom autograd function for column-wise parallelism.
    """

    @staticmethod
    def symbolic(graph, input):
        """Symbolic function for tracing."""
        return dist.all_reduce(input.contiguous(), dist.get_tensor_model_parallel_group())

    @staticmethod
    def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        """
        ctx.group = group
        return input

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
        """
        Backward pass.
        """
        if ctx.group == None:
            return None, grad_output

        dist.all_reduce(grad_output.contiguous(), group=ctx.group)
        return None, grad_output


class TensorParallel_Layer(nn.Module, ABC):
    """
    A base class for model layers with  tensor parallelism support.
    This class is designed to be extended by specific layers that require distributed
    operations and parameter gather/partitioning during inference or training.

    Attributes:
        mode (str): The mode of operation[INFERENCE or TRAINING], default is "INFERENCE".
        mp_group (Optional[dist.ProcessGroup]): The process group used for model parallelism.
        tp_world_size (int): The world size of tensor parallelism, i.e., the number of parallel workers.
        tp_index (int): The rank (ID) of the current worker in tensor parallelism.
        support_training (bool): Flag indicating whether the layer supports training (default: False).
        name (Optional[str]): The name of the layer, if provided.
    """
    ##### Initialize Parameter List #####

    # keep_module_on_host determines whether to keep the module on the host.
    # Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
    # so an additional copy is unnecessary.
    keep_module_on_host: bool = False

    ##### Runtime Parameter List #####
    tp_overlap_comm: bool = False
    """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """

    def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
        """
        Initializes the TensorParallel_Layer with optional model parallelism group and layer name.

        Args:
            mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism.
                                                    If None, no model parallelism is set.
        """
        super().__init__()
        self.support_training: bool = False
        if mp_group is not None:
            self.mp_group = mp_group
            self.tp_world_size: int = dist.get_world_size(self.mp_group)
            self.tp_index: int = dist.get_rank(mp_group)

            # backward compatibility
            self.world_size = self.tp_world_size
            self.rank = self.tp_index

        self.name = getattr(self, 'name', None)
        if kwargs.get('name') is not None:
            self.name = kwargs.get('name')  # Set the layer name if provided.

    @classmethod
    def set_keep_module_on_host(cls, value: bool):
        """
        Set the static variable keep_module_on_host.

        Args:
            value (bool): The new value for keep_module_on_host.
        """
        cls.keep_module_on_host = value

    @abstractmethod
    def forward(self, input):
        """
        Forward pass method. Must be implemented by subclasses to define layer-specific operations.
        """
        pass

    @abstractmethod
    def gather_params(self, params_list):
        """
        Gathers parameters across devices for distributed training. Must be implemented by subclasses in "TRAINING" mode.
        """
        pass

    @abstractmethod
    def _tp_partition(self, params_list: List[torch.Tensor]):
        """
        Partitions the parameters for tensor parallelism.
        It is necessary to ensure that this function only involves the logic of params partitioning.
        """
        pass

    def config_requires_grad(self, weight):
        if weight is not None:
            if self.is_training_mode():
                if weight.requires_grad is None:
                    weight.requires_grad = True
            else:
                weight.requires_grad = False

    def config_tp_params(self, weight):
        """
        Configures the weight tensor for training with tensor parallelism. This includes enabling gradients
        and associating necessary methods for parameter gathering and partitioning.

        Args:
            weight (Optional[torch.Tensor]): The weight tensor to configure for tensor parallelism.
                                              If None, no action is taken.
        """
        # # The RNG states have already been synchronized in init_inference.
        if self.is_training_mode():
            assert self.support_training, "No implementation of backward."
        if weight is not None:
            self.config_requires_grad(weight)
            weight.gather_params = self.gather_params
            weight._tp_partition = self._tp_partition
            setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
            setattr(weight, DS_IS_REPLACED_MODULE, True)

    def is_training_mode(self):
        global DEEPSPEED_AUTOTP_MODE
        return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING

    def __deepcopy__(self, memo):
        # This function is designed for
        # 'mp_group' (a 'ProcessGroup') cannot be pickled during deepcopy in some usage.
        cls = self.__class__
        new_obj = cls.__new__(cls)

        for key, value in vars(self).items():
            if key == 'mp_group':
                new_obj.mp_group = self.mp_group
            else:
                setattr(new_obj, key, deepcopy(value, memo))

        memo[id(self)] = new_obj
        return new_obj

    def extra_repr(self):
        out_features, in_features = None, None
        if self.weight is not None:
            out_features, in_features = self.weight.ds_shape[-2:] if is_zero_param(
                self.weight) else self.weight.shape[-2:]
        dtype = self.weight.dtype if self.weight is not None else None
        return "in_features={}, out_features={}, bias={}, dtype={}".format(in_features, out_features, self.bias
                                                                           is not None, dtype)

    def move(self, tensor):
        # TODO: consider the timing of deletion
        # to save host resources when DP > 1。

        # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
        # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
        if tensor.is_meta:
            # Keep tensor in meta device if tensor is meta.
            return tensor
        else:
            device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name()
            return_new_copy = not self.__class__.keep_module_on_host

            # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
            # Using copy=True instead of clone() will help in case of cpu --> cpu.
            # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
            cloned_tensor = tensor.to(device, copy=return_new_copy)

            if return_new_copy:
                # free the memory of the original tensor to reduce memory peak
                # Equivalent to directly deleting the tensor reference outside the function.
                # see https://github.com/microsoft/DeepSpeed/pull/4353
                tensor.data = torch.empty(0, device=tensor.device)
            return cloned_tensor


def configure_tensor_parallel_runtime(config):
    runtime_keys = ['tp_overlap_comm']
    for key in runtime_keys:
        if hasattr(config, key):
            setattr(TensorParallel_Layer, key, getattr(config, key))


class GatherReplacedLayerParams:
    """
    A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
    based on the configuration of the model.
    """

    def __init__(self,
                 params: Union[Iterable[torch.Tensor], torch.Tensor],
                 module: torch.nn.Module,
                 enabled: bool = True):
        """
        Initialize the context manager to handle parameter gathering and partitioning for a replaced layer.

        Args:
            params (Iterable or torch.Tensor): A collection or single parameter to manage.
            module (torch.nn.Module): The module that these parameters belong to.
            enabled (bool): Flag indicating whether the parameter management is enabled (default: True).
        """
        self.enabled = enabled
        self.module = module
        if not enabled:
            return

        # Ensure params is a list, whether it's a single param or iterable (e.g., model.parameters())
        if isinstance(params, Iterable) and not isinstance(params, torch.Tensor):
            self.params: List[torch.Tensor] = list(params)  # Convert generators to a list for multiple iterations
        else:
            self.params: List[torch.Tensor] = [params]  # Wrap single parameter in a list for uniform processing

        # Check if the parameters belong to a replaced layer (indicated by a specific attribute)
        if not any(self._is_replaced_module_weight(p) for p in params):
            self.enabled = False
            return

    def _is_replaced_module_weight(self, param: torch.Tensor) -> bool:
        """
        Helper function to determine if a parameter belongs to a replaced module.

        Args:
            param (torch.Tensor): The parameter to check.

        Returns:
            bool: True if the parameter belongs to a replaced module, False otherwise.
        """
        return getattr(param, DS_IS_REPLACED_MODULE, False)

    def __enter__(self) -> None:
        """
        Enter the context manager. If enabled, gather parameters for the replaced module.
        """
        if self.enabled:
            self.params[0].gather_params(self.params)

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        """
        Exit the context manager. If enabled, partition the parameters for the replaced module.
        """
        #TODO : Check whether there are any missing attributes.
        if self.enabled:
            self.params[0]._tp_partition(self.params)


class LinearAllreduce(TensorParallel_Layer):

    def __init__(self, module, mp_group, **kwargs):
        super(LinearAllreduce, self).__init__(mp_group, **kwargs)
        self.weight = module.weight
        self.bias = module.bias

        self._tp_partition([self.weight, self.bias])
        self.support_training = True
        self.config_tp_params(self.weight)
        if self.bias is not None:
            # bias here is not tp params
            self.config_requires_grad(self.bias)

    def forward(self, input):
        output = torch.matmul(input, self.weight.transpose(-1, -2))
        output = RowParallel.apply(self.mp_group, output, not self.is_training_mode())
        if self.bias is not None:
            output = add_bias(output, self.bias)
        return output

    @torch.no_grad()
    def gather_params(self, params_list):

        for idx, param in enumerate(params_list):
            if param is None or idx > 0:
                # don't gather bias
                return
            params_list[idx].data_partition = param.data
            param = param.transpose(0, 1).contiguous()

            output_param = torch.empty(self.tp_world_size * param.shape[0],
                                       param.shape[1],
                                       dtype=param.dtype,
                                       device=param.device)
            dist.all_gather_into_tensor(output_param, param, group=self.mp_group)
            params_list[idx].data = output_param.transpose(0, 1).contiguous()
        return

    @torch.no_grad()
    def _tp_partition(self, params_list):

        if not self.is_training_mode():
            self.uneven_partition(params_list)
            return

        else:
            for idx, param in enumerate(params_list):
                if param is None:
                    # don't slipt bias
                    return
                if idx > 0:  # move bias to device at initialization
                    _partition = self.move(param).detach()
                    params_list[idx].data = _partition
                    return

                _partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]

                _partition = self.move(_partition).detach()

                params_list[idx].data = _partition

    def uneven_partition(self, params_list):
        for idx, param in enumerate(params_list):
            if param is None or idx > 0:
                # don't slipt bias
                return
            assert self.name is not None, "The module name must be provided in the initialization."
            _partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[1], self.tp_world_size,
                                                                    self.name),
                                                dim=1)[self.tp_index]

            _partition = self.move(_partition).detach()
            params_list[idx].data = _partition


#remove kwargs from partition.
class LinearLayer(TensorParallel_Layer):

    def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
        super(LinearLayer, self).__init__(mp_group, **kwargs)
        self.weight = module.weight
        self.bias = module.bias
        if not skip_partition:
            self._tp_partition([self.weight, self.bias])
        self.support_training = True
        self.config_tp_params(self.weight)
        if self.bias is not None:
            self.config_tp_params(self.bias)

    def forward(self, input):
        if not self.__class__.tp_overlap_comm:
            if getattr(self, 'mp_group', None) is not None:
                input = ColumnParallel.apply(self.mp_group, input)
            output = torch.matmul(input, self.weight.transpose(-1, -2))
            if self.bias is not None:
                output = add_bias(output, self.bias)
        else:
            output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)

        return output

    @torch.no_grad()
    def gather_params(self, params_list):
        #  Does not support uneven shard.
        for idx, param in enumerate(params_list):

            params_list[idx].data_partition = param.data
            output_param = torch.empty((self.tp_world_size * param.shape[0], *param.shape[1:]),
                                       dtype=param.dtype,
                                       device=param.device)
            dist.all_gather_into_tensor(output_param, param, group=self.mp_group)
            params_list[idx].data = output_param.contiguous()

    @torch.no_grad()
    def _tp_partition(self, params_list):

        if not self.is_training_mode():
            self.uneven_partition(params_list)
            return
        for idx, param in enumerate(params_list):
            if param is None:
                return
            #split bias if provide
            _partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index]

            _partition = self.move(_partition).detach()

            params_list[idx].data = _partition

    def uneven_partition(self, params_list):

        for idx, param in enumerate(params_list):
            if param is None:
                #split bias if provide
                return
            assert self.name is not None, "The module name must be provided in the initialization."
            _partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[0], self.tp_world_size,
                                                                    self.name),
                                                dim=0)[self.tp_index]

            _partition = self.move(_partition).detach()

            params_list[idx].data = _partition

    # for bwc
    @classmethod
    def from_weights(cls, weight_shape=None, dtype=torch.half, weight=None, bias=None):
        if weight is not None:
            in_features = weight.shape[1]
            out_features = weight.shape[0]
            linear = nn.Linear(in_features, out_features, bias=(bias is not None))
            linear.weight.data = weight
            if bias is not None:
                linear.bias.data = bias
        else:
            in_features = weight_shape[1]
            out_features = weight_shape[0]
            linear = nn.Linear(in_features, out_features, bias=(bias is not None))
        return cls(linear, skip_partition=True)


class FusedModuleWrapper:

    def __init__(self, fused_module: nn.Module):
        self.fused_module = fused_module

    def __getattr__(self, module):
        return self.fused_module


class fused_LinearLayer(LinearLayer):

    def __init__(self, module, mp_group, skip_partition=False, **kwargs):
        assert kwargs.get('fused_module') is not None, "'fused_module' is required but not provided"
        # Use the warp class to avoid module circular references.
        self.fused_module = FusedModuleWrapper(kwargs.get('fused_module'))
        super().__init__(module, mp_group, skip_partition, **kwargs)

    @torch.no_grad()
    def _tp_partition(self, params_list):
        for idx, param in enumerate(params_list):
            if param is None:
                return

            _partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index)

            _partition = self.move(_partition).detach()

            params_list[idx].data = _partition


class conv_LinearLayer(LinearLayer):

    @torch.no_grad()
    def _tp_partition(self, params_list):
        weight = None
        bias = None
        if len(params_list) == 1:
            weight = params_list[0]
        elif len(params_list) == 2:
            weight, bias = params_list[0], params_list[1]
        _partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name),
                                       dim=1)[self.tp_index]
        _partition = self.move(_partition).detach()
        weight.data = _partition

        if bias is not None:
            _partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name),
                                         dim=0)[self.tp_index]
            _partition = self.move(_partition).detach()

            bias.data = _partition


#override the subclasses related to weight splitting.
class Yuan_LinearAllreduce(LinearAllreduce):

    #Yuan2
    @torch.no_grad()
    def _tp_partition(self, params_list):
        weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
                                                 self.tp_world_size, False)
        params_list[0].data = weight
        if bias is not None:
            params_list[1].data = bias


class Yuan_LinearLayer(LinearLayer):
    #Yuan2
    @torch.no_grad()
    def _tp_partition(self, params_list):
        weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
                                                 self.tp_world_size, True)
        params_list[0].data = self.move(weight).detach()
        if bias is not None:
            params_list[1].data = self.move(bias).detach()


class GateUpPack_LinearLayer(LinearLayer):
    # chatGLM2, chatGLM2
    @torch.no_grad()
    def _tp_partition(self, params_list):
        weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
        params_list[0].data = self.move(weight).detach()
        if bias is not None:
            params_list[1].data = self.move(bias).detach()


class Conv_LinearALlreduce(LinearAllreduce):

    @torch.no_grad()
    def _tp_partition(self, params_list):
        for idx, param in enumerate(params_list):
            if param is None:
                return
            param.data = param.data.transpose(-1, -2).contiguous()

            _partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name),
                                     dim=1)[self.tp_index]

            _partition = self.move(_partition).detach()

            params_list[idx].data = _partition


#override the subclasses related to fwd/bwd.
class LmHeadLinearAllreduce(LinearAllreduce):

    def __init__(self, module, mp_group, **kwargs):
        # set the fixed name before partition
        self.name = "lm_head"

        # In some tied_embedding cases, only the lm head is sharded, while the word embedding is not.
        # Reinitialization is used to decouple them and prevent the word embedding from being sharded.
        # This should also be effective for cases where both are sharded in tied_embedding scenarios.

        # TODO: Training scenario-related tests, is it necessary to re-implement the vocab parallel module?
        module.weight = nn.Parameter(module.weight.clone().detach())
        if hasattr(module, 'bias') and module.bias is not None:
            module.bias = nn.Parameter(module.bias.clone().detach())
        super().__init__(module, mp_group, **kwargs)

    def forward(self, input):
        input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head")
        input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index])
        output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
                              self.weight.transpose(-1, -2))
        if self.mp_group is not None:
            dist.inference_all_reduce(output, group=self.mp_group)
        if self.bias is not None:
            output = add_bias(output, self.bias)
        return output


class TensorParallelConv2d(nn.Module):

    def __init__(self, conv, rank, world_size, shard_by_oc):
        super().__init__()
        self.rank = rank
        self.world_size = world_size
        self.shard_by_oc = shard_by_oc
        self.shard_weights(conv)

    # Split along the input/output channel depending on whether it is the last conv layer.
    def shard_weights(self, conv):
        if self.shard_by_oc:
            total_size = conv.weight.shape[0]
        else:
            total_size = conv.weight.shape[1]
        bias_data = None
        cols_per_rank = [0]
        for i in range(self.world_size - 1, -1, -1):
            cols = total_size // self.world_size
            if i < total_size % self.world_size:
                cols += 1
            cols_per_rank.append(cols_per_rank[-1] + cols)
        weight_data = conv.weight.data
        if self.shard_by_oc:
            # not last conv layer, split output channel
            weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
            if conv.bias is not None:
                bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
        else:
            # last conv layer, split input channel
            weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
            if conv.bias is not None:
                bias_data = conv.bias.data / float(self.world_size)
        self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding,
                              conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode)
        self.conv.weight = torch.nn.Parameter(weight_data)
        if conv.bias is not None:
            self.conv.bias = torch.nn.Parameter(bias_data)
        del conv

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.conv(input)


class TensorParallelOcShardConv2d(TensorParallelConv2d):

    def __init__(self, conv, rank, world_size):
        super().__init__(conv, rank, world_size, True)


class TensorParallelIcShardConv2d(TensorParallelConv2d):

    def __init__(self, conv, rank, world_size):
        super().__init__(conv, rank, world_size, False)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = self.conv(input)
        if self.world_size > 1:
            dist.inference_all_reduce(out)
        return out


class Normalize(nn.Module):

    def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
        super(Normalize, self).__init__()
        if weight is not None:
            self.weight = weight
            self.bias = bias
        else:
            self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
            self.weight = self.norm.weight
            self.bias = self.norm.bias

        self.eps = eps

    def forward(self, input):
        return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)


class EmbeddingLayer(nn.Module):

    def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
        super(EmbeddingLayer, self).__init__()
        if weight is None:
            self.weight = Parameter(
                torch.empty(weight_shape[0],
                            weight_shape[1],
                            dtype=dtype,
                            device=get_accelerator().current_device_name()))
        else:
            self.weight = weight

    def forward(self, input):
        return F.embedding(input, self.weight)


class OPTEmbedding(EmbeddingLayer):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, weight_shape=None, weight=None, bias=None):
        # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        super().__init__(weight_shape, weight=weight)

    def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
        """`input_ids_shape` is expected to be [bsz x seqlen]."""
        attention_mask = attention_mask.long()

        # create positions depending on attention_mask
        positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1

        # cut positions if `past_key_values_length` is > 0
        positions = positions[:, past_key_values_length:]

        return super().forward(positions + self.offset)


class RMSNormalize(nn.Module):

    def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
        super(RMSNormalize, self).__init__()
        if weight is not None:
            self.weight = weight
        else:
            self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=get_accelerator().current_device_name()))

        self.eps = eps

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return hidden_states * self.weight