Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,531 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import Callable, List, Optional
import torch
from einops import rearrange
from torch import nn
from common.cache import Cache
from common.distributed.ops import slice_inputs
# (dim: int, emb_dim: int)
ada_layer_type = Callable[[int, int], nn.Module]
def get_ada_layer(ada_layer: str) -> ada_layer_type:
if ada_layer == "single":
return AdaSingle
raise NotImplementedError(f"{ada_layer} is not supported")
def expand_dims(x: torch.Tensor, dim: int, ndim: int):
"""
Expand tensor "x" to "ndim" by adding empty dims at "dim".
Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d).
"""
shape = x.shape
shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:]
return x.reshape(shape)
class AdaSingle(nn.Module):
def __init__(
self,
dim: int,
emb_dim: int,
layers: List[str],
modes: List[str] = ["in", "out"],
):
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
super().__init__()
self.dim = dim
self.emb_dim = emb_dim
self.layers = layers
for l in layers:
if "in" in modes:
self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5))
self.register_parameter(
f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)
)
if "out" in modes:
self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5))
def forward(
self,
hid: torch.FloatTensor, # b ... c
emb: torch.FloatTensor, # b d
layer: str,
mode: str,
cache: Cache = Cache(disable=True),
branch_tag: str = "",
hid_len: Optional[torch.LongTensor] = None, # b
) -> torch.FloatTensor:
idx = self.layers.index(layer)
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None:
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
dim=0,
),
)
shiftA, scaleA, gateA = emb.unbind(-1)
shiftB, scaleB, gateB = (
getattr(self, f"{layer}_shift", None),
getattr(self, f"{layer}_scale", None),
getattr(self, f"{layer}_gate", None),
)
if mode == "in":
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
if mode == "out":
return hid.mul_(gateA + gateB)
raise NotImplementedError
def extra_repr(self) -> str:
return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" |