|
""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch |
|
|
|
This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch. |
|
|
|
99% of the implementation was done from papers, however last minute some adjustments were made |
|
based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit |
|
|
|
There are multiple sets of models defined for both architectures. Typically, names with a |
|
`_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit. |
|
These configs work well and appear to be a bit faster / lower resource than the paper. |
|
|
|
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to |
|
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. |
|
|
|
# FIXME / WARNING |
|
This impl remains a WIP, some configs and models may vanish or change... |
|
|
|
Papers: |
|
|
|
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 |
|
@article{tu2022maxvit, |
|
title={MaxViT: Multi-Axis Vision Transformer}, |
|
author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao}, |
|
journal={ECCV}, |
|
year={2022}, |
|
} |
|
|
|
CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803 |
|
@article{DBLP:journals/corr/abs-2106-04803, |
|
author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan}, |
|
title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes}, |
|
journal = {CoRR}, |
|
volume = {abs/2106.04803}, |
|
year = {2021} |
|
} |
|
|
|
Hacked together by / Copyright 2022, Ross Wightman |
|
""" |
|
|
|
import math |
|
from collections import OrderedDict |
|
from dataclasses import dataclass, replace, field |
|
from functools import partial |
|
from typing import Callable, Optional, Union, Tuple, List |
|
|
|
import torch |
|
from torch import nn |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from .helpers import build_model_with_cfg, checkpoint_seq, named_apply |
|
from .fx_features import register_notrace_function |
|
from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm |
|
from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d |
|
from .layers import to_2tuple, extend_tuple, make_divisible, _assert |
|
from .registry import register_model |
|
from .vision_transformer_relpos import RelPosMlp, RelPosBias |
|
|
|
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] |
|
|
|
|
|
def _cfg(url='', **kwargs): |
|
return { |
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), |
|
'crop_pct': 0.95, 'interpolation': 'bicubic', |
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), |
|
'first_conv': 'stem.conv1', 'classifier': 'head.fc', |
|
'fixed_input_size': True, |
|
**kwargs |
|
} |
|
|
|
|
|
default_cfgs = { |
|
|
|
'coatnet_pico_rw_224': _cfg(url=''), |
|
'coatnet_nano_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', |
|
crop_pct=0.9), |
|
'coatnet_0_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), |
|
'coatnet_1_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' |
|
), |
|
'coatnet_2_rw_224': _cfg(url=''), |
|
'coatnet_3_rw_224': _cfg(url=''), |
|
|
|
|
|
'coatnet_bn_0_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', |
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, |
|
crop_pct=0.95), |
|
'coatnet_rmlp_nano_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', |
|
crop_pct=0.9), |
|
'coatnet_rmlp_0_rw_224': _cfg(url=''), |
|
'coatnet_rmlp_1_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), |
|
'coatnet_rmlp_2_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), |
|
'coatnet_rmlp_3_rw_224': _cfg(url=''), |
|
'coatnet_nano_cc_224': _cfg(url=''), |
|
'coatnext_nano_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', |
|
crop_pct=0.9), |
|
|
|
|
|
'coatnet_0_224': _cfg(url=''), |
|
'coatnet_1_224': _cfg(url=''), |
|
'coatnet_2_224': _cfg(url=''), |
|
'coatnet_3_224': _cfg(url=''), |
|
'coatnet_4_224': _cfg(url=''), |
|
'coatnet_5_224': _cfg(url=''), |
|
|
|
|
|
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxvit_nano_rw_256': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxvit_tiny_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), |
|
'maxvit_tiny_rw_256': _cfg( |
|
url='', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxvit_rmlp_pico_rw_256': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxvit_rmlp_nano_rw_256': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxvit_rmlp_tiny_rw_256': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxvit_rmlp_small_rw_224': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', |
|
crop_pct=0.9, |
|
), |
|
'maxvit_rmlp_small_rw_256': _cfg( |
|
url='', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
|
|
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), |
|
|
|
'maxxvit_rmlp_nano_rw_256': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), |
|
'maxxvit_rmlp_small_rw_256': _cfg( |
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', |
|
input_size=(3, 256, 256), pool_size=(8, 8)), |
|
|
|
|
|
'maxvit_tiny_224': _cfg(url=''), |
|
'maxvit_small_224': _cfg(url=''), |
|
'maxvit_base_224': _cfg(url=''), |
|
'maxvit_large_224': _cfg(url=''), |
|
'maxvit_xlarge_224': _cfg(url=''), |
|
} |
|
|
|
|
|
@dataclass |
|
class MaxxVitTransformerCfg: |
|
dim_head: int = 32 |
|
expand_ratio: float = 4.0 |
|
expand_first: bool = True |
|
shortcut_bias: bool = True |
|
attn_bias: bool = True |
|
attn_drop: float = 0. |
|
proj_drop: float = 0. |
|
pool_type: str = 'avg2' |
|
rel_pos_type: str = 'bias' |
|
rel_pos_dim: int = 512 |
|
partition_ratio: int = 32 |
|
window_size: Optional[Tuple[int, int]] = None |
|
grid_size: Optional[Tuple[int, int]] = None |
|
init_values: Optional[float] = None |
|
act_layer: str = 'gelu' |
|
norm_layer: str = 'layernorm2d' |
|
norm_layer_cl: str = 'layernorm' |
|
norm_eps: float = 1e-6 |
|
|
|
def __post_init__(self): |
|
if self.grid_size is not None: |
|
self.grid_size = to_2tuple(self.grid_size) |
|
if self.window_size is not None: |
|
self.window_size = to_2tuple(self.window_size) |
|
if self.grid_size is None: |
|
self.grid_size = self.window_size |
|
|
|
|
|
@dataclass |
|
class MaxxVitConvCfg: |
|
block_type: str = 'mbconv' |
|
expand_ratio: float = 4.0 |
|
expand_output: bool = True |
|
kernel_size: int = 3 |
|
group_size: int = 1 |
|
pre_norm_act: bool = False |
|
output_bias: bool = True |
|
stride_mode: str = 'dw' |
|
pool_type: str = 'avg2' |
|
downsample_pool_type: str = 'avg2' |
|
attn_early: bool = False |
|
attn_layer: str = 'se' |
|
attn_act_layer: str = 'silu' |
|
attn_ratio: float = 0.25 |
|
init_values: Optional[float] = 1e-6 |
|
act_layer: str = 'gelu' |
|
norm_layer: str = '' |
|
norm_layer_cl: str = '' |
|
norm_eps: Optional[float] = None |
|
|
|
def __post_init__(self): |
|
|
|
assert self.block_type in ('mbconv', 'convnext') |
|
use_mbconv = self.block_type == 'mbconv' |
|
if not self.norm_layer: |
|
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' |
|
if not self.norm_layer_cl and not use_mbconv: |
|
self.norm_layer_cl = 'layernorm' |
|
if self.norm_eps is None: |
|
self.norm_eps = 1e-5 if use_mbconv else 1e-6 |
|
self.downsample_pool_type = self.downsample_pool_type or self.pool_type |
|
|
|
|
|
@dataclass |
|
class MaxxVitCfg: |
|
embed_dim: Tuple[int, ...] = (96, 192, 384, 768) |
|
depths: Tuple[int, ...] = (2, 3, 5, 2) |
|
block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') |
|
stem_width: Union[int, Tuple[int, int]] = 64 |
|
stem_bias: bool = True |
|
conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg) |
|
transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg) |
|
weight_init: str = 'vit_eff' |
|
|
|
|
|
def _rw_coat_cfg( |
|
stride_mode='pool', |
|
pool_type='avg2', |
|
conv_output_bias=False, |
|
conv_attn_early=False, |
|
conv_attn_act_layer='relu', |
|
conv_norm_layer='', |
|
transformer_shortcut_bias=True, |
|
transformer_norm_layer='layernorm2d', |
|
transformer_norm_layer_cl='layernorm', |
|
init_values=None, |
|
rel_pos_type='bias', |
|
rel_pos_dim=512, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return dict( |
|
conv_cfg=MaxxVitConvCfg( |
|
stride_mode=stride_mode, |
|
pool_type=pool_type, |
|
pre_norm_act=True, |
|
expand_output=False, |
|
output_bias=conv_output_bias, |
|
attn_early=conv_attn_early, |
|
attn_act_layer=conv_attn_act_layer, |
|
act_layer='silu', |
|
norm_layer=conv_norm_layer, |
|
), |
|
transformer_cfg=MaxxVitTransformerCfg( |
|
expand_first=False, |
|
shortcut_bias=transformer_shortcut_bias, |
|
pool_type=pool_type, |
|
init_values=init_values, |
|
norm_layer=transformer_norm_layer, |
|
norm_layer_cl=transformer_norm_layer_cl, |
|
rel_pos_type=rel_pos_type, |
|
rel_pos_dim=rel_pos_dim, |
|
), |
|
) |
|
|
|
|
|
def _rw_max_cfg( |
|
stride_mode='dw', |
|
pool_type='avg2', |
|
conv_output_bias=False, |
|
conv_attn_ratio=1 / 16, |
|
conv_norm_layer='', |
|
transformer_norm_layer='layernorm2d', |
|
transformer_norm_layer_cl='layernorm', |
|
window_size=None, |
|
dim_head=32, |
|
init_values=None, |
|
rel_pos_type='bias', |
|
rel_pos_dim=512, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
return dict( |
|
conv_cfg=MaxxVitConvCfg( |
|
stride_mode=stride_mode, |
|
pool_type=pool_type, |
|
expand_output=False, |
|
output_bias=conv_output_bias, |
|
attn_ratio=conv_attn_ratio, |
|
act_layer='silu', |
|
norm_layer=conv_norm_layer, |
|
), |
|
transformer_cfg=MaxxVitTransformerCfg( |
|
expand_first=False, |
|
pool_type=pool_type, |
|
dim_head=dim_head, |
|
window_size=window_size, |
|
init_values=init_values, |
|
norm_layer=transformer_norm_layer, |
|
norm_layer_cl=transformer_norm_layer_cl, |
|
rel_pos_type=rel_pos_type, |
|
rel_pos_dim=rel_pos_dim, |
|
), |
|
) |
|
|
|
|
|
def _next_cfg( |
|
stride_mode='dw', |
|
pool_type='avg2', |
|
conv_norm_layer='layernorm2d', |
|
conv_norm_layer_cl='layernorm', |
|
transformer_norm_layer='layernorm2d', |
|
transformer_norm_layer_cl='layernorm', |
|
window_size=None, |
|
init_values=1e-6, |
|
rel_pos_type='mlp', |
|
rel_pos_dim=512, |
|
): |
|
|
|
init_values = to_2tuple(init_values) |
|
return dict( |
|
conv_cfg=MaxxVitConvCfg( |
|
block_type='convnext', |
|
stride_mode=stride_mode, |
|
pool_type=pool_type, |
|
expand_output=False, |
|
init_values=init_values[0], |
|
norm_layer=conv_norm_layer, |
|
norm_layer_cl=conv_norm_layer_cl, |
|
), |
|
transformer_cfg=MaxxVitTransformerCfg( |
|
expand_first=False, |
|
pool_type=pool_type, |
|
window_size=window_size, |
|
init_values=init_values[1], |
|
norm_layer=transformer_norm_layer, |
|
norm_layer_cl=transformer_norm_layer_cl, |
|
rel_pos_type=rel_pos_type, |
|
rel_pos_dim=rel_pos_dim, |
|
), |
|
) |
|
|
|
|
|
model_cfgs = dict( |
|
|
|
coatnet_pico_rw_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 3, 5, 2), |
|
stem_width=(32, 64), |
|
**_rw_max_cfg( |
|
conv_output_bias=True, |
|
conv_attn_ratio=0.25, |
|
), |
|
), |
|
coatnet_nano_rw_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(3, 4, 6, 3), |
|
stem_width=(32, 64), |
|
**_rw_max_cfg( |
|
stride_mode='pool', |
|
conv_output_bias=True, |
|
conv_attn_ratio=0.25, |
|
), |
|
), |
|
coatnet_0_rw_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 3, 7, 2), |
|
stem_width=(32, 64), |
|
**_rw_coat_cfg( |
|
conv_attn_early=True, |
|
transformer_shortcut_bias=False, |
|
), |
|
), |
|
coatnet_1_rw_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 6, 14, 2), |
|
stem_width=(32, 64), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
conv_attn_early=True, |
|
transformer_shortcut_bias=False, |
|
) |
|
), |
|
coatnet_2_rw_224=MaxxVitCfg( |
|
embed_dim=(128, 256, 512, 1024), |
|
depths=(2, 6, 14, 2), |
|
stem_width=(64, 128), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
conv_attn_act_layer='silu', |
|
init_values=1e-6, |
|
), |
|
), |
|
coatnet_3_rw_224=MaxxVitCfg( |
|
embed_dim=(192, 384, 768, 1536), |
|
depths=(2, 6, 14, 2), |
|
stem_width=(96, 192), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
conv_attn_act_layer='silu', |
|
init_values=1e-6, |
|
), |
|
), |
|
|
|
|
|
coatnet_bn_0_rw_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 3, 7, 2), |
|
stem_width=(32, 64), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
conv_attn_early=True, |
|
transformer_shortcut_bias=False, |
|
transformer_norm_layer='batchnorm2d', |
|
) |
|
), |
|
coatnet_rmlp_nano_rw_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(3, 4, 6, 3), |
|
stem_width=(32, 64), |
|
**_rw_max_cfg( |
|
conv_output_bias=True, |
|
conv_attn_ratio=0.25, |
|
rel_pos_type='mlp', |
|
rel_pos_dim=384, |
|
), |
|
), |
|
coatnet_rmlp_0_rw_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 3, 7, 2), |
|
stem_width=(32, 64), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
rel_pos_type='mlp', |
|
), |
|
), |
|
coatnet_rmlp_1_rw_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 6, 14, 2), |
|
stem_width=(32, 64), |
|
**_rw_coat_cfg( |
|
pool_type='max', |
|
conv_attn_early=True, |
|
transformer_shortcut_bias=False, |
|
rel_pos_type='mlp', |
|
rel_pos_dim=384, |
|
), |
|
), |
|
coatnet_rmlp_2_rw_224=MaxxVitCfg( |
|
embed_dim=(128, 256, 512, 1024), |
|
depths=(2, 6, 14, 2), |
|
stem_width=(64, 128), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
conv_attn_act_layer='silu', |
|
init_values=1e-6, |
|
rel_pos_type='mlp' |
|
), |
|
), |
|
coatnet_rmlp_3_rw_224=MaxxVitCfg( |
|
embed_dim=(192, 384, 768, 1536), |
|
depths=(2, 6, 14, 2), |
|
stem_width=(96, 192), |
|
**_rw_coat_cfg( |
|
stride_mode='dw', |
|
conv_attn_act_layer='silu', |
|
init_values=1e-6, |
|
rel_pos_type='mlp' |
|
), |
|
), |
|
|
|
coatnet_nano_cc_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(3, 4, 6, 3), |
|
stem_width=(32, 64), |
|
block_type=('C', 'C', ('C', 'T'), ('C', 'T')), |
|
**_rw_coat_cfg(), |
|
), |
|
coatnext_nano_rw_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(3, 4, 6, 3), |
|
stem_width=(32, 64), |
|
weight_init='normal', |
|
**_next_cfg( |
|
rel_pos_type='bias', |
|
init_values=(1e-5, None) |
|
), |
|
), |
|
|
|
|
|
coatnet_0_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 3, 5, 2), |
|
stem_width=64, |
|
), |
|
coatnet_1_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 6, 14, 2), |
|
stem_width=64, |
|
), |
|
coatnet_2_224=MaxxVitCfg( |
|
embed_dim=(128, 256, 512, 1024), |
|
depths=(2, 6, 14, 2), |
|
stem_width=128, |
|
), |
|
coatnet_3_224=MaxxVitCfg( |
|
embed_dim=(192, 384, 768, 1536), |
|
depths=(2, 6, 14, 2), |
|
stem_width=192, |
|
), |
|
coatnet_4_224=MaxxVitCfg( |
|
embed_dim=(192, 384, 768, 1536), |
|
depths=(2, 12, 28, 2), |
|
stem_width=192, |
|
), |
|
coatnet_5_224=MaxxVitCfg( |
|
embed_dim=(256, 512, 1280, 2048), |
|
depths=(2, 12, 28, 2), |
|
stem_width=192, |
|
), |
|
|
|
|
|
maxvit_pico_rw_256=MaxxVitCfg( |
|
embed_dim=(32, 64, 128, 256), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(24, 32), |
|
**_rw_max_cfg(), |
|
), |
|
maxvit_nano_rw_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(1, 2, 3, 1), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg(), |
|
), |
|
maxvit_tiny_rw_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg(), |
|
), |
|
maxvit_tiny_rw_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg(), |
|
), |
|
|
|
maxvit_rmlp_pico_rw_256=MaxxVitCfg( |
|
embed_dim=(32, 64, 128, 256), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(24, 32), |
|
**_rw_max_cfg(rel_pos_type='mlp'), |
|
), |
|
maxvit_rmlp_nano_rw_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(1, 2, 3, 1), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg(rel_pos_type='mlp'), |
|
), |
|
maxvit_rmlp_tiny_rw_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg(rel_pos_type='mlp'), |
|
), |
|
maxvit_rmlp_small_rw_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg( |
|
rel_pos_type='mlp', |
|
init_values=1e-6, |
|
), |
|
), |
|
maxvit_rmlp_small_rw_256=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg( |
|
rel_pos_type='mlp', |
|
init_values=1e-6, |
|
), |
|
), |
|
|
|
maxvit_tiny_pm_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 2, 5, 2), |
|
block_type=('PM',) * 4, |
|
stem_width=(32, 64), |
|
**_rw_max_cfg(), |
|
), |
|
|
|
maxxvit_rmlp_nano_rw_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(1, 2, 3, 1), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
weight_init='normal', |
|
**_next_cfg(), |
|
), |
|
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(32, 64), |
|
**_next_cfg(), |
|
), |
|
maxxvit_rmlp_small_rw_256=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=(48, 96), |
|
**_next_cfg(), |
|
), |
|
|
|
|
|
maxvit_tiny_224=MaxxVitCfg( |
|
embed_dim=(64, 128, 256, 512), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=64, |
|
), |
|
maxvit_small_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 2, 5, 2), |
|
block_type=('M',) * 4, |
|
stem_width=64, |
|
), |
|
maxvit_base_224=MaxxVitCfg( |
|
embed_dim=(96, 192, 384, 768), |
|
depths=(2, 6, 14, 2), |
|
block_type=('M',) * 4, |
|
stem_width=64, |
|
), |
|
maxvit_large_224=MaxxVitCfg( |
|
embed_dim=(128, 256, 512, 1024), |
|
depths=(2, 6, 14, 2), |
|
block_type=('M',) * 4, |
|
stem_width=128, |
|
), |
|
maxvit_xlarge_224=MaxxVitCfg( |
|
embed_dim=(192, 384, 768, 1536), |
|
depths=(2, 6, 14, 2), |
|
block_type=('M',) * 4, |
|
stem_width=192, |
|
), |
|
|
|
) |
|
|
|
|
|
class Attention2d(nn.Module): |
|
""" multi-head attention for 2D NCHW tensors""" |
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: Optional[int] = None, |
|
dim_head: int = 32, |
|
bias: bool = True, |
|
expand_first: bool = True, |
|
rel_pos_cls: Callable = None, |
|
attn_drop: float = 0., |
|
proj_drop: float = 0. |
|
): |
|
super().__init__() |
|
dim_out = dim_out or dim |
|
dim_attn = dim_out if expand_first else dim |
|
self.num_heads = dim_attn // dim_head |
|
self.dim_head = dim_head |
|
self.scale = dim_head ** -0.5 |
|
|
|
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) |
|
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): |
|
B, C, H, W = x.shape |
|
|
|
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) |
|
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale |
|
if self.rel_pos is not None: |
|
attn = self.rel_pos(attn) |
|
elif shared_rel_pos is not None: |
|
attn = attn + shared_rel_pos |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
class AttentionCl(nn.Module): |
|
""" Channels-last multi-head attention (B, ..., C) """ |
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: Optional[int] = None, |
|
dim_head: int = 32, |
|
bias: bool = True, |
|
expand_first: bool = True, |
|
rel_pos_cls: Callable = None, |
|
attn_drop: float = 0., |
|
proj_drop: float = 0. |
|
): |
|
super().__init__() |
|
dim_out = dim_out or dim |
|
dim_attn = dim_out if expand_first and dim_out > dim else dim |
|
assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' |
|
self.num_heads = dim_attn // dim_head |
|
self.dim_head = dim_head |
|
self.scale = dim_head ** -0.5 |
|
|
|
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) |
|
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim_attn, dim_out, bias=bias) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): |
|
B = x.shape[0] |
|
restore_shape = x.shape[:-1] |
|
|
|
q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
if self.rel_pos is not None: |
|
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) |
|
elif shared_rel_pos is not None: |
|
attn = attn + shared_rel_pos |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
class LayerScale(nn.Module): |
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
gamma = self.gamma |
|
return x.mul_(gamma) if self.inplace else x * gamma |
|
|
|
|
|
class LayerScale2d(nn.Module): |
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
gamma = self.gamma.view(1, -1, 1, 1) |
|
return x.mul_(gamma) if self.inplace else x * gamma |
|
|
|
|
|
class Downsample2d(nn.Module): |
|
""" A downsample pooling module supporting several maxpool and avgpool modes |
|
* 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 |
|
* 'max2' - MaxPool2d w/ kernel_size = stride = 2 |
|
* 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 |
|
* 'avg2' - AvgPool2d w/ kernel_size = stride = 2 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: int, |
|
pool_type: str = 'avg2', |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
assert pool_type in ('max', 'max2', 'avg', 'avg2') |
|
if pool_type == 'max': |
|
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
elif pool_type == 'max2': |
|
self.pool = nn.MaxPool2d(2) |
|
elif pool_type == 'avg': |
|
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) |
|
else: |
|
self.pool = nn.AvgPool2d(2) |
|
|
|
if dim != dim_out: |
|
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) |
|
else: |
|
self.expand = nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.pool(x) |
|
x = self.expand(x) |
|
return x |
|
|
|
|
|
def _init_transformer(module, name, scheme=''): |
|
if isinstance(module, (nn.Conv2d, nn.Linear)): |
|
if scheme == 'normal': |
|
nn.init.normal_(module.weight, std=.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif scheme == 'trunc_normal': |
|
trunc_normal_tf_(module.weight, std=.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif scheme == 'xavier_normal': |
|
nn.init.xavier_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
else: |
|
|
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
if 'mlp' in name: |
|
nn.init.normal_(module.bias, std=1e-6) |
|
else: |
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
class TransformerBlock2d(nn.Module): |
|
""" Transformer block with 2D downsampling |
|
'2D' NCHW tensor layout |
|
|
|
Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW |
|
for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs. |
|
|
|
This impl was faster on TPU w/ PT XLA than the 1D experiment. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: int, |
|
stride: int = 1, |
|
rel_pos_cls: Callable = None, |
|
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
drop_path: float = 0., |
|
): |
|
super().__init__() |
|
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) |
|
act_layer = get_act_layer(cfg.act_layer) |
|
|
|
if stride == 2: |
|
self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) |
|
self.norm1 = nn.Sequential(OrderedDict([ |
|
('norm', norm_layer(dim)), |
|
('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), |
|
])) |
|
else: |
|
assert dim == dim_out |
|
self.shortcut = nn.Identity() |
|
self.norm1 = norm_layer(dim) |
|
|
|
self.attn = Attention2d( |
|
dim, |
|
dim_out, |
|
dim_head=cfg.dim_head, |
|
expand_first=cfg.expand_first, |
|
bias=cfg.attn_bias, |
|
rel_pos_cls=rel_pos_cls, |
|
attn_drop=cfg.attn_drop, |
|
proj_drop=cfg.proj_drop |
|
) |
|
self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim_out) |
|
self.mlp = ConvMlp( |
|
in_features=dim_out, |
|
hidden_features=int(dim_out * cfg.expand_ratio), |
|
act_layer=act_layer, |
|
drop=cfg.proj_drop) |
|
self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def init_weights(self, scheme=''): |
|
named_apply(partial(_init_transformer, scheme=scheme), self) |
|
|
|
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): |
|
x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) |
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
|
return x |
|
|
|
|
|
def _init_conv(module, name, scheme=''): |
|
if isinstance(module, nn.Conv2d): |
|
if scheme == 'normal': |
|
nn.init.normal_(module.weight, std=.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif scheme == 'trunc_normal': |
|
trunc_normal_tf_(module.weight, std=.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif scheme == 'xavier_normal': |
|
nn.init.xavier_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
else: |
|
|
|
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels |
|
fan_out //= module.groups |
|
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def num_groups(group_size, channels): |
|
if not group_size: |
|
return 1 |
|
else: |
|
|
|
assert channels % group_size == 0 |
|
return channels // group_size |
|
|
|
|
|
class MbConvBlock(nn.Module): |
|
""" Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) |
|
""" |
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
stride: int = 1, |
|
dilation: Tuple[int, int] = (1, 1), |
|
cfg: MaxxVitConvCfg = MaxxVitConvCfg(), |
|
drop_path: float = 0. |
|
): |
|
super(MbConvBlock, self).__init__() |
|
norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) |
|
mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) |
|
groups = num_groups(cfg.group_size, mid_chs) |
|
|
|
if stride == 2: |
|
self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias) |
|
else: |
|
self.shortcut = nn.Identity() |
|
|
|
assert cfg.stride_mode in ('pool', '1x1', 'dw') |
|
stride_pool, stride_1, stride_2 = 1, 1, 1 |
|
if cfg.stride_mode == 'pool': |
|
|
|
stride_pool, dilation_2 = stride, dilation[1] |
|
|
|
elif cfg.stride_mode == '1x1': |
|
|
|
stride_1, dilation_2 = stride, dilation[1] |
|
else: |
|
stride_2, dilation_2 = stride, dilation[0] |
|
|
|
self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) |
|
if stride_pool > 1: |
|
self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) |
|
else: |
|
self.down = nn.Identity() |
|
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) |
|
self.norm1 = norm_act_layer(mid_chs) |
|
|
|
self.conv2_kxk = create_conv2d( |
|
mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups) |
|
|
|
attn_kwargs = {} |
|
if isinstance(cfg.attn_layer, str): |
|
if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': |
|
attn_kwargs['act_layer'] = cfg.attn_act_layer |
|
attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) |
|
|
|
|
|
if cfg.attn_early: |
|
self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) |
|
self.norm2 = norm_act_layer(mid_chs) |
|
self.se = None |
|
else: |
|
self.se_early = None |
|
self.norm2 = norm_act_layer(mid_chs) |
|
self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) |
|
|
|
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def init_weights(self, scheme=''): |
|
named_apply(partial(_init_conv, scheme=scheme), self) |
|
|
|
def forward(self, x): |
|
shortcut = self.shortcut(x) |
|
x = self.pre_norm(x) |
|
x = self.down(x) |
|
|
|
|
|
x = self.conv1_1x1(x) |
|
x = self.norm1(x) |
|
|
|
|
|
x = self.conv2_kxk(x) |
|
if self.se_early is not None: |
|
x = self.se_early(x) |
|
x = self.norm2(x) |
|
if self.se is not None: |
|
x = self.se(x) |
|
|
|
|
|
x = self.conv3_1x1(x) |
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
""" ConvNeXt Block |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: Optional[int] = None, |
|
kernel_size: int = 7, |
|
stride: int = 1, |
|
dilation: Tuple[int, int] = (1, 1), |
|
cfg: MaxxVitConvCfg = MaxxVitConvCfg(), |
|
conv_mlp: bool = True, |
|
drop_path: float = 0. |
|
): |
|
super().__init__() |
|
out_chs = out_chs or in_chs |
|
act_layer = get_act_layer(cfg.act_layer) |
|
if conv_mlp: |
|
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) |
|
mlp_layer = ConvMlp |
|
else: |
|
assert 'layernorm' in cfg.norm_layer |
|
norm_layer = LayerNorm |
|
mlp_layer = Mlp |
|
self.use_conv_mlp = conv_mlp |
|
|
|
if stride == 2: |
|
self.shortcut = Downsample2d(in_chs, out_chs) |
|
elif in_chs != out_chs: |
|
self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) |
|
else: |
|
self.shortcut = nn.Identity() |
|
|
|
assert cfg.stride_mode in ('pool', 'dw') |
|
stride_pool, stride_dw = 1, 1 |
|
|
|
if cfg.stride_mode == 'pool': |
|
stride_pool = stride |
|
else: |
|
stride_dw = stride |
|
|
|
if stride_pool == 2: |
|
self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) |
|
else: |
|
self.down = nn.Identity() |
|
|
|
self.conv_dw = create_conv2d( |
|
in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], |
|
depthwise=True, bias=cfg.output_bias) |
|
self.norm = norm_layer(out_chs) |
|
self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) |
|
if conv_mlp: |
|
self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() |
|
else: |
|
self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
shortcut = self.shortcut(x) |
|
x = self.down(x) |
|
x = self.conv_dw(x) |
|
if self.use_conv_mlp: |
|
x = self.norm(x) |
|
x = self.mlp(x) |
|
x = self.ls(x) |
|
else: |
|
x = x.permute(0, 2, 3, 1) |
|
x = self.norm(x) |
|
x = self.mlp(x) |
|
x = self.ls(x) |
|
x = x.permute(0, 3, 1, 2) |
|
|
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|
|
|
|
def window_partition(x, window_size: List[int]): |
|
B, H, W, C = x.shape |
|
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') |
|
_assert(W % window_size[1] == 0, '') |
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) |
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) |
|
return windows |
|
|
|
|
|
@register_notrace_function |
|
def window_reverse(windows, window_size: List[int], img_size: List[int]): |
|
H, W = img_size |
|
C = windows.shape[-1] |
|
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) |
|
return x |
|
|
|
|
|
def grid_partition(x, grid_size: List[int]): |
|
B, H, W, C = x.shape |
|
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') |
|
_assert(W % grid_size[1] == 0, '') |
|
x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) |
|
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) |
|
return windows |
|
|
|
|
|
@register_notrace_function |
|
def grid_reverse(windows, grid_size: List[int], img_size: List[int]): |
|
H, W = img_size |
|
C = windows.shape[-1] |
|
x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) |
|
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) |
|
return x |
|
|
|
|
|
def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): |
|
rel_pos_cls = None |
|
if cfg.rel_pos_type == 'mlp': |
|
rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) |
|
elif cfg.rel_pos_type == 'bias': |
|
rel_pos_cls = partial(RelPosBias, window_size=window_size) |
|
return rel_pos_cls |
|
|
|
|
|
class PartitionAttentionCl(nn.Module): |
|
""" Grid or Block partition + Attn + FFN. |
|
NxC 'channels last' tensor layout. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
partition_type: str = 'block', |
|
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
drop_path: float = 0., |
|
): |
|
super().__init__() |
|
norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) |
|
act_layer = get_act_layer(cfg.act_layer) |
|
|
|
self.partition_block = partition_type == 'block' |
|
self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) |
|
rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn = AttentionCl( |
|
dim, |
|
dim, |
|
dim_head=cfg.dim_head, |
|
bias=cfg.attn_bias, |
|
rel_pos_cls=rel_pos_cls, |
|
attn_drop=cfg.attn_drop, |
|
proj_drop=cfg.proj_drop, |
|
) |
|
self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=int(dim * cfg.expand_ratio), |
|
act_layer=act_layer, |
|
drop=cfg.proj_drop) |
|
self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def _partition_attn(self, x): |
|
img_size = x.shape[1:3] |
|
if self.partition_block: |
|
partitioned = window_partition(x, self.partition_size) |
|
else: |
|
partitioned = grid_partition(x, self.partition_size) |
|
|
|
partitioned = self.attn(partitioned) |
|
|
|
if self.partition_block: |
|
x = window_reverse(partitioned, self.partition_size, img_size) |
|
else: |
|
x = grid_reverse(partitioned, self.partition_size, img_size) |
|
return x |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) |
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
|
return x |
|
|
|
|
|
class ParallelPartitionAttention(nn.Module): |
|
""" Experimental. Grid and Block partition + single FFN |
|
NxC tensor layout. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
drop_path: float = 0., |
|
): |
|
super().__init__() |
|
assert dim % 2 == 0 |
|
norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) |
|
act_layer = get_act_layer(cfg.act_layer) |
|
|
|
assert cfg.window_size == cfg.grid_size |
|
self.partition_size = to_2tuple(cfg.window_size) |
|
rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn_block = AttentionCl( |
|
dim, |
|
dim // 2, |
|
dim_head=cfg.dim_head, |
|
bias=cfg.attn_bias, |
|
rel_pos_cls=rel_pos_cls, |
|
attn_drop=cfg.attn_drop, |
|
proj_drop=cfg.proj_drop, |
|
) |
|
self.attn_grid = AttentionCl( |
|
dim, |
|
dim // 2, |
|
dim_head=cfg.dim_head, |
|
bias=cfg.attn_bias, |
|
rel_pos_cls=rel_pos_cls, |
|
attn_drop=cfg.attn_drop, |
|
proj_drop=cfg.proj_drop, |
|
) |
|
self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=int(dim * cfg.expand_ratio), |
|
out_features=dim, |
|
act_layer=act_layer, |
|
drop=cfg.proj_drop) |
|
self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def _partition_attn(self, x): |
|
img_size = x.shape[1:3] |
|
|
|
partitioned_block = window_partition(x, self.partition_size) |
|
partitioned_block = self.attn_block(partitioned_block) |
|
x_window = window_reverse(partitioned_block, self.partition_size, img_size) |
|
|
|
partitioned_grid = grid_partition(x, self.partition_size) |
|
partitioned_grid = self.attn_grid(partitioned_grid) |
|
x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) |
|
|
|
return torch.cat([x_window, x_grid], dim=-1) |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) |
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
|
return x |
|
|
|
|
|
def window_partition_nchw(x, window_size: List[int]): |
|
B, C, H, W = x.shape |
|
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') |
|
_assert(W % window_size[1] == 0, '') |
|
x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) |
|
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) |
|
return windows |
|
|
|
|
|
@register_notrace_function |
|
def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): |
|
H, W = img_size |
|
C = windows.shape[1] |
|
x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) |
|
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) |
|
return x |
|
|
|
|
|
def grid_partition_nchw(x, grid_size: List[int]): |
|
B, C, H, W = x.shape |
|
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') |
|
_assert(W % grid_size[1] == 0, '') |
|
x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) |
|
windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) |
|
return windows |
|
|
|
|
|
@register_notrace_function |
|
def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): |
|
H, W = img_size |
|
C = windows.shape[1] |
|
x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) |
|
x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) |
|
return x |
|
|
|
|
|
class PartitionAttention2d(nn.Module): |
|
""" Grid or Block partition + Attn + FFN |
|
|
|
'2D' NCHW tensor layout. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
partition_type: str = 'block', |
|
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
drop_path: float = 0., |
|
): |
|
super().__init__() |
|
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) |
|
act_layer = get_act_layer(cfg.act_layer) |
|
|
|
self.partition_block = partition_type == 'block' |
|
self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) |
|
rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn = Attention2d( |
|
dim, |
|
dim, |
|
dim_head=cfg.dim_head, |
|
bias=cfg.attn_bias, |
|
rel_pos_cls=rel_pos_cls, |
|
attn_drop=cfg.attn_drop, |
|
proj_drop=cfg.proj_drop, |
|
) |
|
self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = ConvMlp( |
|
in_features=dim, |
|
hidden_features=int(dim * cfg.expand_ratio), |
|
act_layer=act_layer, |
|
drop=cfg.proj_drop) |
|
self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def _partition_attn(self, x): |
|
img_size = x.shape[-2:] |
|
if self.partition_block: |
|
partitioned = window_partition_nchw(x, self.partition_size) |
|
else: |
|
partitioned = grid_partition_nchw(x, self.partition_size) |
|
|
|
partitioned = self.attn(partitioned) |
|
|
|
if self.partition_block: |
|
x = window_reverse_nchw(partitioned, self.partition_size, img_size) |
|
else: |
|
x = grid_reverse_nchw(partitioned, self.partition_size, img_size) |
|
return x |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) |
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
|
return x |
|
|
|
|
|
class MaxxVitBlock(nn.Module): |
|
""" MaxVit conv, window partition + FFN , grid partition + FFN |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: int, |
|
stride: int = 1, |
|
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), |
|
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
use_nchw_attn: bool = False, |
|
drop_path: float = 0., |
|
): |
|
super().__init__() |
|
|
|
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock |
|
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) |
|
|
|
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) |
|
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl |
|
self.nchw_attn = use_nchw_attn |
|
self.attn_block = partition_layer(**attn_kwargs) |
|
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) |
|
|
|
def init_weights(self, scheme=''): |
|
named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) |
|
named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) |
|
named_apply(partial(_init_conv, scheme=scheme), self.conv) |
|
|
|
def forward(self, x): |
|
|
|
x = self.conv(x) |
|
|
|
if not self.nchw_attn: |
|
x = x.permute(0, 2, 3, 1) |
|
x = self.attn_block(x) |
|
x = self.attn_grid(x) |
|
if not self.nchw_attn: |
|
x = x.permute(0, 3, 1, 2) |
|
return x |
|
|
|
|
|
class ParallelMaxxVitBlock(nn.Module): |
|
""" MaxVit block with parallel cat(window + grid), one FF |
|
Experimental timm block. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
dim_out, |
|
stride=1, |
|
num_conv=2, |
|
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), |
|
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
drop_path=0., |
|
): |
|
super().__init__() |
|
|
|
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock |
|
if num_conv > 1: |
|
convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] |
|
convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) |
|
self.conv = nn.Sequential(*convs) |
|
else: |
|
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) |
|
self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) |
|
|
|
def init_weights(self, scheme=''): |
|
named_apply(partial(_init_transformer, scheme=scheme), self.attn) |
|
named_apply(partial(_init_conv, scheme=scheme), self.conv) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = x.permute(0, 2, 3, 1) |
|
x = self.attn(x) |
|
x = x.permute(0, 3, 1, 2) |
|
return x |
|
|
|
|
|
class MaxxVitStage(nn.Module): |
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
stride: int = 2, |
|
depth: int = 4, |
|
feat_size: Tuple[int, int] = (14, 14), |
|
block_types: Union[str, Tuple[str]] = 'C', |
|
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), |
|
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), |
|
drop_path: Union[float, List[float]] = 0., |
|
): |
|
super().__init__() |
|
self.grad_checkpointing = False |
|
|
|
block_types = extend_tuple(block_types, depth) |
|
blocks = [] |
|
for i, t in enumerate(block_types): |
|
block_stride = stride if i == 0 else 1 |
|
assert t in ('C', 'T', 'M', 'PM') |
|
if t == 'C': |
|
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock |
|
blocks += [conv_cls( |
|
in_chs, |
|
out_chs, |
|
stride=block_stride, |
|
cfg=conv_cfg, |
|
drop_path=drop_path[i], |
|
)] |
|
elif t == 'T': |
|
rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) |
|
blocks += [TransformerBlock2d( |
|
in_chs, |
|
out_chs, |
|
stride=block_stride, |
|
rel_pos_cls=rel_pos_cls, |
|
cfg=transformer_cfg, |
|
drop_path=drop_path[i], |
|
)] |
|
elif t == 'M': |
|
blocks += [MaxxVitBlock( |
|
in_chs, |
|
out_chs, |
|
stride=block_stride, |
|
conv_cfg=conv_cfg, |
|
transformer_cfg=transformer_cfg, |
|
drop_path=drop_path[i], |
|
)] |
|
elif t == 'PM': |
|
blocks += [ParallelMaxxVitBlock( |
|
in_chs, |
|
out_chs, |
|
stride=block_stride, |
|
conv_cfg=conv_cfg, |
|
transformer_cfg=transformer_cfg, |
|
drop_path=drop_path[i], |
|
)] |
|
in_chs = out_chs |
|
self.blocks = nn.Sequential(*blocks) |
|
|
|
def forward(self, x): |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint_seq(self.blocks, x) |
|
else: |
|
x = self.blocks(x) |
|
return x |
|
|
|
|
|
class Stem(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
kernel_size: int = 3, |
|
act_layer: str = 'gelu', |
|
norm_layer: str = 'batchnorm2d', |
|
norm_eps: float = 1e-5, |
|
): |
|
super().__init__() |
|
if not isinstance(out_chs, (list, tuple)): |
|
out_chs = to_2tuple(out_chs) |
|
|
|
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) |
|
self.out_chs = out_chs[-1] |
|
self.stride = 2 |
|
|
|
self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2) |
|
self.norm1 = norm_act_layer(out_chs[0]) |
|
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) |
|
|
|
def init_weights(self, scheme=''): |
|
named_apply(partial(_init_conv, scheme=scheme), self) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.norm1(x) |
|
x = self.conv2(x) |
|
return x |
|
|
|
|
|
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): |
|
if cfg.window_size is not None: |
|
assert cfg.grid_size |
|
return cfg |
|
partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio |
|
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) |
|
return cfg |
|
|
|
|
|
class MaxxVit(nn.Module): |
|
""" CoaTNet + MaxVit base model. |
|
|
|
Highly configurable for different block compositions, tensor layouts, pooling types. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
cfg: MaxxVitCfg, |
|
img_size: Union[int, Tuple[int, int]] = 224, |
|
in_chans: int = 3, |
|
num_classes: int = 1000, |
|
global_pool: str = 'avg', |
|
drop_rate: float = 0., |
|
drop_path_rate: float = 0. |
|
): |
|
super().__init__() |
|
img_size = to_2tuple(img_size) |
|
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) |
|
self.num_classes = num_classes |
|
self.global_pool = global_pool |
|
self.num_features = cfg.embed_dim[-1] |
|
self.embed_dim = cfg.embed_dim |
|
self.drop_rate = drop_rate |
|
self.grad_checkpointing = False |
|
|
|
self.stem = Stem( |
|
in_chs=in_chans, |
|
out_chs=cfg.stem_width, |
|
act_layer=cfg.conv_cfg.act_layer, |
|
norm_layer=cfg.conv_cfg.norm_layer, |
|
norm_eps=cfg.conv_cfg.norm_eps, |
|
) |
|
|
|
stride = self.stem.stride |
|
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) |
|
|
|
num_stages = len(cfg.embed_dim) |
|
assert len(cfg.depths) == num_stages |
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] |
|
in_chs = self.stem.out_chs |
|
stages = [] |
|
for i in range(num_stages): |
|
stage_stride = 2 |
|
out_chs = cfg.embed_dim[i] |
|
feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) |
|
stages += [MaxxVitStage( |
|
in_chs, |
|
out_chs, |
|
depth=cfg.depths[i], |
|
block_types=cfg.block_type[i], |
|
conv_cfg=cfg.conv_cfg, |
|
transformer_cfg=transformer_cfg, |
|
feat_size=feat_size, |
|
drop_path=dpr[i], |
|
)] |
|
stride *= stage_stride |
|
in_chs = out_chs |
|
self.stages = nn.Sequential(*stages) |
|
|
|
final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer) |
|
self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps) |
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) |
|
|
|
|
|
assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') |
|
if cfg.weight_init: |
|
named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) |
|
|
|
def _init_weights(self, module, name, scheme=''): |
|
if hasattr(module, 'init_weights'): |
|
try: |
|
module.init_weights(scheme=scheme) |
|
except TypeError: |
|
module.init_weights() |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return { |
|
k for k, _ in self.named_parameters() |
|
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} |
|
|
|
@torch.jit.ignore |
|
def group_matcher(self, coarse=False): |
|
matcher = dict( |
|
stem=r'^stem', |
|
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] |
|
) |
|
return matcher |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
for s in self.stages: |
|
s.grad_checkpointing = enable |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self): |
|
return self.head.fc |
|
|
|
def reset_classifier(self, num_classes, global_pool=None): |
|
self.num_classes = num_classes |
|
if global_pool is None: |
|
global_pool = self.head.global_pool.pool_type |
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) |
|
|
|
def forward_features(self, x): |
|
x = self.stem(x) |
|
x = self.stages(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
def forward_head(self, x, pre_logits: bool = False): |
|
return self.head(x, pre_logits=pre_logits) |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
x = self.forward_head(x) |
|
return x |
|
|
|
|
|
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): |
|
return build_model_with_cfg( |
|
MaxxVit, variant, pretrained, |
|
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], |
|
feature_cfg=dict(flatten_sequential=True), |
|
**kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_pico_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_nano_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_0_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_1_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_2_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_3_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_bn_0_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_nano_cc_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnext_nano_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_0_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_1_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_2_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_3_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_4_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def coatnet_5_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_pico_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_nano_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_tiny_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_tiny_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_tiny_pm_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_tiny_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_small_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_small_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_base_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_base_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_large_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_large_224', pretrained=pretrained, **kwargs) |
|
|
|
|
|
@register_model |
|
def maxvit_xlarge_224(pretrained=False, **kwargs): |
|
return _create_maxxvit('maxvit_xlarge_224', pretrained=pretrained, **kwargs) |
|
|
|
|