Safetensors
custom_code
File size: 1,109 Bytes
be4e27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Optional

import torch.distributed as dist


def get_rank(group: Optional[dist.ProcessGroup] = None):
    return dist.get_rank(group) if dist.is_initialized() else 0


def get_world_size(group: Optional[dist.ProcessGroup] = None):
    return dist.get_world_size(group) if dist.is_initialized() else 1


def barrier(group: Optional[dist.ProcessGroup] = None):
    if dist.is_initialized():
        dist.barrier(group)


class rank_gate:
    '''
    Execute the function on rank 0 first, followed by all other ranks. Useful when caches may need to be populated in a distributed environment.
    '''
    def __init__(self, func = None):
        self.func = func

    def __call__(self, *args, **kwargs):
        rank = get_rank()
        if rank == 0:
            result = self.func(*args, **kwargs)
        barrier()
        if rank > 0:
            result = self.func(*args, **kwargs)
        return result

    def __enter__(self, *args, **kwargs):
        if get_rank() > 0:
            barrier()

    def __exit__(self, *args, **kwargs):
        if get_rank() == 0:
            barrier()