File size: 4,136 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-layer profilers."""
import copy
import time
from typing import Generator, List, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
from ..microbatch import Batch
__all__: List[str] = []
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(
module: nn.Sequential,
device: torch.device,
) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given
module.
"""
for layer in module:
layer_copy = copy.deepcopy(layer)
layer_copy.to(device)
layer_copy.train()
yield layer_copy
def detach(batch: Batch) -> None:
"""Detaches from autograd graph."""
for i, x in enumerate(batch):
batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(
module: nn.Sequential,
sample: TensorOrTensors,
timeout: float,
device: torch.device,
) -> List[int]:
"""Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
_batch = Batch(sample, 0)
for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
time_bufs: List[List[float]] = [[] for _ in module]
begun_at = time.time()
while time.time() - begun_at < timeout:
batch = _batch
for i, layer in enumerate(layerwise_sandbox(module, device)):
detach(batch)
if device.type == "cuda":
torch.cuda.synchronize(device)
tick = time.time()
# Forward
batch = batch.call(layer)
# Backward
backward_tensors = tuple(y for y in batch if y.requires_grad)
if backward_tensors:
torch.autograd.backward(backward_tensors, backward_tensors)
if device.type == "cuda":
torch.cuda.synchronize(device)
tock = time.time()
time_bufs[i].append(tock - tick)
us = 1_000_000
return [sum(int(t * us) for t in buf) for buf in time_bufs]
def profile_sizes(
module: nn.Sequential,
input: TensorOrTensors,
chunks: int,
param_scale: float,
device: torch.device,
) -> List[int]:
"""Profiles CUDA memory usage per layer."""
if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device")
batch = Batch(input, 0)
sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks
for i, x in enumerate(batch):
batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad)
for layer in layerwise_sandbox(module, device):
detach(batch)
# Detect memory usage at forward.
memory_before = torch.cuda.memory_allocated(device)
batch = batch.call(layer)
memory_after = torch.cuda.memory_allocated(device)
latent_size = memory_after - memory_before
# Analyze size of parameters.
param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters())
# Combine size of parameters and activations with normalize scales.
size = latent_size * latent_scale + param_size * param_scale
sizes.append(int(size))
return sizes
|