|
|
|
|
|
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata |
|
|
|
|
|
__all__: list[str] = [] |
|
|
|
|
|
def _check_shard_metadata_pair_overlap( |
|
shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata |
|
): |
|
"""Check if two shards overlap.""" |
|
|
|
|
|
|
|
|
|
ndims = len(shard1.offsets) |
|
for i in range(ndims): |
|
if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: |
|
return False |
|
if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: |
|
return False |
|
|
|
return True |
|
|
|
|
|
def _shards_get_overlap_region_wrt_saved_tensor( |
|
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata |
|
) -> list[tuple[int, int, int, int]]: |
|
""" |
|
Return the overlapping region between saved_shard and current_shard. |
|
|
|
There returned list has the same number of elements as the tensor's dimension. |
|
For each element, we produce a tuple with the following contents: |
|
(dimension, `saved_shard` offset, `current_shard` offset, length) |
|
|
|
Offsets are relative to each shard. |
|
""" |
|
narrows = [] |
|
for dim, ( |
|
saved_shard_offset, |
|
current_shard_offset, |
|
saved_shard_size, |
|
current_shard_size, |
|
) in enumerate( |
|
zip( |
|
saved_shard.offsets, |
|
current_shard.offsets, |
|
saved_shard.sizes, |
|
current_shard.sizes, |
|
) |
|
): |
|
min_range_end = min( |
|
saved_shard_offset + saved_shard_size, |
|
current_shard_offset + current_shard_size, |
|
) |
|
|
|
length = min_range_end - max(current_shard_offset, saved_shard_offset) |
|
|
|
if saved_shard_offset > current_shard_offset: |
|
offset_for_saved_tensor = 0 |
|
offset_for_current_tensor = saved_shard_offset - current_shard_offset |
|
else: |
|
offset_for_saved_tensor = current_shard_offset - saved_shard_offset |
|
offset_for_current_tensor = 0 |
|
|
|
narrows.append( |
|
(dim, offset_for_saved_tensor, offset_for_current_tensor, length) |
|
) |
|
|
|
return narrows |
|
|