File size: 36,825 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 |
""" Cross-Covariance Image Transformer (XCiT) in PyTorch
Paper:
- https://arxiv.org/abs/2106.09681
Same as the official implementation, with some minor adaptations, original copyright below
- https://github.com/facebookresearch/xcit/blob/master/xcit.py
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import math
from functools import partial
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .vision_transformer import _cfg, Mlp
from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn
from .fx_features import register_notrace_module
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# Patch size 16
'xcit_nano_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'),
'xcit_nano_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'),
'xcit_nano_12_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)),
'xcit_tiny_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'),
'xcit_tiny_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'),
'xcit_tiny_12_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)),
'xcit_tiny_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'),
'xcit_tiny_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'),
'xcit_tiny_24_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)),
'xcit_small_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'),
'xcit_small_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'),
'xcit_small_12_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)),
'xcit_small_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'),
'xcit_small_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'),
'xcit_small_24_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)),
'xcit_medium_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'),
'xcit_medium_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'),
'xcit_medium_24_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)),
'xcit_large_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'),
'xcit_large_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'),
'xcit_large_24_p16_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)),
# Patch size 8
'xcit_nano_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'),
'xcit_nano_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'),
'xcit_nano_12_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)),
'xcit_tiny_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'),
'xcit_tiny_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'),
'xcit_tiny_12_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)),
'xcit_tiny_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'),
'xcit_tiny_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'),
'xcit_tiny_24_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)),
'xcit_small_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'),
'xcit_small_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'),
'xcit_small_12_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)),
'xcit_small_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'),
'xcit_small_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'),
'xcit_small_24_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)),
'xcit_medium_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'),
'xcit_medium_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'),
'xcit_medium_24_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)),
'xcit_large_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'),
'xcit_large_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'),
'xcit_large_24_p8_384_dist': _cfg(
url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)),
}
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
"""
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
Based on the official XCiT code
- https://github.com/facebookresearch/xcit/blob/master/xcit.py
"""
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
self.eps = 1e-6
def forward(self, B: int, H: int, W: int):
device = self.token_projection.weight.device
y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W)
x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1)
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution + batch norm"""
return torch.nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_planes)
)
class ConvPatchEmbed(nn.Module):
"""Image to Patch Embedding using multiple convolutional layers"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU):
super().__init__()
img_size = to_2tuple(img_size)
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
if patch_size == 16:
self.proj = torch.nn.Sequential(
conv3x3(in_chans, embed_dim // 8, 2),
act_layer(),
conv3x3(embed_dim // 8, embed_dim // 4, 2),
act_layer(),
conv3x3(embed_dim // 4, embed_dim // 2, 2),
act_layer(),
conv3x3(embed_dim // 2, embed_dim, 2),
)
elif patch_size == 8:
self.proj = torch.nn.Sequential(
conv3x3(in_chans, embed_dim // 4, 2),
act_layer(),
conv3x3(embed_dim // 4, embed_dim // 2, 2),
act_layer(),
conv3x3(embed_dim // 2, embed_dim, 2),
)
else:
raise('For convolutional projection, patch size has to be in [8, 16]')
def forward(self, x):
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # (B, N, C)
return x, (Hp, Wp)
class LPI(nn.Module):
"""
Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the
implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable
3x3 convolutions with GeLU and BatchNorm2d
"""
def __init__(self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3):
super().__init__()
out_features = out_features or in_features
padding = kernel_size // 2
self.conv1 = torch.nn.Conv2d(
in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features)
self.act = act_layer()
self.bn = nn.BatchNorm2d(in_features)
self.conv2 = torch.nn.Conv2d(
in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features)
def forward(self, x, H: int, W: int):
B, N, C = x.shape
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.conv1(x)
x = self.act(x)
x = self.bn(x)
x = self.conv2(x)
x = x.reshape(B, C, N).permute(0, 2, 1)
return x
class ClassAttentionBlock(nn.Module):
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ClassAttn(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
if eta is not None: # LayerScale Initialization (no layerscale when None)
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
self.gamma2 = nn.Parameter(eta * torch.ones(dim))
else:
self.gamma1, self.gamma2 = 1.0, 1.0
# See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
self.tokens_norm = tokens_norm
def forward(self, x):
x_norm1 = self.norm1(x)
x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
x = x + self.drop_path(self.gamma1 * x_attn)
if self.tokens_norm:
x = self.norm2(x)
else:
x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
x_res = x
cls_token = x[:, 0:1]
cls_token = self.gamma2 * self.mlp(cls_token)
x = torch.cat([cls_token, x[:, 1:]], dim=1)
x = x_res + self.drop_path(x)
return x
class XCA(nn.Module):
""" Cross-Covariance Attention (XCA)
Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax
normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h)
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# Paper section 3.2 l2-Normalization and temperature scaling
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, H, C', N), permute -> (B, N, H, C')
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature'}
class XCABlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm3 = norm_layer(dim)
self.local_mp = LPI(in_features=dim, act_layer=act_layer)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
self.gamma3 = nn.Parameter(eta * torch.ones(dim))
self.gamma2 = nn.Parameter(eta * torch.ones(dim))
def forward(self, x, H: int, W: int):
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
# NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights
# See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
class XCiT(nn.Module):
"""
Based on timm and DeiT code bases
https://github.com/rwightman/pytorch-image-models/tree/master/timm
https://github.com/facebookresearch/deit/
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
"""
Args:
img_size (int, tuple): input image size
patch_size (int): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate (constant across all layers)
norm_layer: (nn.Module): normalization layer
cls_attn_layers: (int) Depth of Class attention layers
use_pos_embed: (bool) whether to use positional encoding
eta: (float) layerscale initialization value
tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
Notes:
- Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch
interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
img_size = to_2tuple(img_size)
assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \
'`patch_size` should divide image dimensions evenly'
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.global_pool = global_pool
self.grad_checkpointing = False
self.patch_embed = ConvPatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.use_pos_embed = use_pos_embed
if use_pos_embed:
self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
self.blocks = nn.ModuleList([
XCABlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta)
for _ in range(depth)])
self.cls_attn_blocks = nn.ModuleList([
ClassAttentionBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
for _ in range(cls_attn_layers)])
# Classifier head
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Init weights
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=r'^blocks\.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
# x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos_embed:
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = self.pos_drop(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, Hp, Wp)
else:
x = blk(x, Hp, Wp)
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for blk in self.cls_attn_blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
state_dict = state_dict['model']
# For consistency with timm's transformer models while being compatible with official weights source we rename
# pos_embeder to pos_embed. Also account for use_pos_embed == False
use_pos_embed = getattr(model, 'pos_embed', None) is not None
pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')]
for k in pos_embed_keys:
if use_pos_embed:
state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k)
else:
del state_dict[k]
# timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors
# for all tokens, just the class token. To use official weights source we must split qkv into q, k, v
if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict():
num_ca_blocks = len(model.cls_attn_blocks)
for i in range(num_ca_blocks):
qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight')
qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1])
for j, subscript in enumerate('qkv'):
state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j]
qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None)
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(3, -1)
for j, subscript in enumerate('qkv'):
state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j]
return state_dict
def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
model = build_model_with_cfg(
XCiT, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
return model
@register_model
def xcit_nano_12_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs)
model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_12_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_12_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_12_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_12_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_medium_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_large_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_large_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_large_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
return model
# Patch size 8x8 models
@register_model
def xcit_nano_12_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_12_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_12_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_12_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_12_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_small_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_medium_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_large_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_large_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
return model
@register_model
def xcit_large_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
return model
|