""" Scalable, Detailed and Mask-free Universal Photometric Stereo Network (CVPR2023) # Copyright (c) 2023 Satoshi Ikehata # All rights reserved. """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange import math import torch.nn.init as init class MultiHeadAttentionBlock(nn.Module): def __init__(self, dim_in, dim_out, num_heads, ln=False, attention_dropout = 0.1, dim_feedforward = 512, q_bucket_size = 1024, k_bucket_size = 2048, attn_mode = 'Normal'): super(MultiHeadAttentionBlock, self).__init__() if attn_mode == 'Efficient': self.q_bucket_size = q_bucket_size self.k_bucket_size = k_bucket_size self.attn_mode = attn_mode self.dim_V = dim_out self.dim_Q = dim_in self.dim_K = dim_in self.num_heads = num_heads self.fc_q = nn.Linear(self.dim_Q, self.dim_V, bias=False) # dimin -> dimhidden self.fc_k = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimin -> dimhidden self.fc_v = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimhidden -> dim if ln: self.ln0 = nn.LayerNorm(self.dim_Q) self.ln1 = nn.LayerNorm(self.dim_V) self.dropout_attn = nn.Dropout(attention_dropout) self.fc_o1 = nn.Linear(self.dim_V, dim_feedforward, bias=False) self.fc_o2 = nn.Linear(dim_feedforward, self.dim_V, bias=False) self.dropout1 = nn.Dropout(attention_dropout) self.dropout2 = nn.Dropout(attention_dropout) # memory efficient attention related parameters # can be overriden on forward self.q_bucket_size = q_bucket_size self.k_bucket_size = k_bucket_size # memory efficient attention def summarize_qkv_chunk(self, q, k, v): weight = torch.einsum('b h i d, b h j d -> b h i j', q, k) weight_max = weight.amax(dim = -1, keepdim = True).detach() weight = weight - weight_max exp_weight = self.dropout_attn(weight.exp()) # attention_dropout weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v) return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') def memory_efficient_attention( self, q, k, v, q_bucket_size = 512, k_bucket_size = 1024, eps = 1e-8, ): scale = q.shape[-1] ** -0.5 q = q * scale summarize_qkv_fn = self.summarize_qkv_chunk # chunk all the inputs q_chunks = q.split(q_bucket_size, dim = -2) k_chunks = k.split(k_bucket_size, dim = -2) v_chunks = v.split(k_bucket_size, dim = -2) # loop through all chunks and accumulate values = [] weights = [] for q_chunk in q_chunks: exp_weights = [] weighted_values = [] weight_maxes = [] for (k_chunk, v_chunk) in zip(k_chunks, v_chunks): exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( q_chunk, k_chunk, v_chunk ) exp_weights.append(exp_weight_chunk) weighted_values.append(weighted_value_chunk) weight_maxes.append(weight_max_chunk) weight_maxes = torch.stack(weight_maxes, dim = -1) weighted_values = torch.stack(weighted_values, dim = -1) exp_weights = torch.stack(exp_weights, dim = -1) global_max = weight_maxes.amax(dim = -1, keepdim = True) renorm_factor = (weight_maxes - global_max).exp().detach() exp_weights = exp_weights * renorm_factor weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') all_values = weighted_values.sum(dim = -1) all_weights = exp_weights.sum(dim = -1) values.append(all_values) weights.append(all_weights) values = torch.cat(values, dim=2) weights = torch.cat(weights, dim=2) # (rearrange(weights, '... -> ... 1') normalized_values = values / (rearrange(weights, '... -> ... 1') + eps) return normalized_values def forward( self, x,y, ): x = x if getattr(self, 'ln0', None) is None else self.ln0(x) # pre-normalization Q = self.fc_q(x) # input_dim -> embed dim K, V = self.fc_k(y), self.fc_v(y) # input_dim -> embed dim dim_split = self.dim_V // self.num_heads # multi-head attention if self.attn_mode == 'Efficient': q_bucket_size = self.q_bucket_size k_bucket_size = self.k_bucket_size Q_ = torch.stack(Q.split(int(dim_split), 2), 1) K_ = torch.stack(K.split(int(dim_split), 2), 1) V_ = torch.stack(V.split(int(dim_split), 2), 1) A = self.memory_efficient_attention(Q_, K_, V_, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) A = A.reshape(-1, A.shape[2], A.shape[3]) Q_ = Q_.reshape(-1, Q_.shape[2], Q_.shape[3]) O = torch.cat((Q_ + A).split(Q.size(0), 0), 2) else: # Basic Q_ = torch.cat(Q.split(int(dim_split), 2), 0) K_ = torch.cat(K.split(int(dim_split), 2), 0) V_ = torch.cat(V.split(int(dim_split), 2), 0) A = self.dropout_attn(torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)) # this may not be correct due to mult-head attention A = A.bmm(V_) # A(Q, K, V) attention_output O = torch.cat((Q_ + A).split(Q.size(0), 0), 2) O_ = O if getattr(self, 'ln1', None) is None else self.ln1(O) O = O + self.dropout2(self.fc_o2(self.dropout1(F.gelu(self.fc_o1(O_))))) return O class SAB(nn.Module): # self attention block def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'): super(SAB, self).__init__() self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode) def forward(self, X): return self.mab(X, X) class CAB(nn.Module): # cross attention block def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'): super(CAB, self).__init__() self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode) def forward(self, q, kv): return self.mab(q, kv) class PMA(nn.Module): def __init__(self, dim, num_heads, num_seeds, ln=False, attn_mode='Normal'): super(PMA, self).__init__() self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) init.xavier_uniform_(self.S) self.mab = MultiHeadAttentionBlock(dim, dim, num_heads, ln=ln, attn_mode=attn_mode) def forward(self, X): return self.mab(self.S.repeat(X.size(0), 1, 1), X) class CommunicationBlock(nn.Module): def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False): super(CommunicationBlock, self).__init__() if use_efficient_attention: attn_mode = 'Efficient' else: attn_mode = 'Normal' self.dim_hidden = dim_hidden modules_enc = [] modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) for k in range(num_enc_sab): modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) self.enc = nn.Sequential(*modules_enc) def forward(self, x): x = self.enc(x) return x class CrossAttentionBlock(nn.Module): def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False): super(CrossAttentionBlock, self).__init__() if use_efficient_attention: attn_mode = 'Efficient' else: attn_mode = 'Normal' self.dim_hidden = dim_hidden modules_enc = [] modules_enc.append(CAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) for k in range(num_enc_sab): modules_enc.append(CAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) self.layers = nn.ModuleList(modules_enc) def forward(self, q, kv): for k in range(len(self.layers)): q = self.layers[k](q, kv) # x[0] query, x[1] key value return q # the output token length is len(x) class AggregationBlock(nn.Module): def __init__(self, dim_input, num_enc_sab = 3, num_outputs = 1, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False): super(AggregationBlock, self).__init__() self.num_outputs = num_outputs self.dim_hidden = dim_hidden if use_efficient_attention: attn_mode = 'Efficient' else: attn_mode = 'Normal' modules_enc = [] modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) for k in range(num_enc_sab): modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) self.enc = nn.Sequential(*modules_enc) modules_dec = [] modules_dec.append(PMA(dim_hidden, num_heads, num_outputs, attn_mode=attn_mode)) # after the PMA we should not put drop out self.dec = nn.Sequential(*modules_dec) def forward(self, x): x = self.enc(x) x = self.dec(x) x = x.view(-1, self.num_outputs * self.dim_hidden) return x