EarthLoc2 / models /mixvpr.py
Pawel Piwowarski
init commit
0a82b18
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
class FeatureMixerLayer(nn.Module):
def __init__(self, in_dim, mlp_ratio=1):
super().__init__()
self.mix = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, int(in_dim * mlp_ratio)),
nn.ReLU(),
nn.Linear(int(in_dim * mlp_ratio), in_dim),
)
for m in self.modules():
if isinstance(m, (nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
return x + self.mix(x)
class MixVPR(nn.Module):
def __init__(self,
in_channels=1024,
in_h=20,
in_w=20,
out_channels=512,
mix_depth=1,
mlp_ratio=1,
out_rows=4,
) -> None:
super().__init__()
self.in_h = in_h # height of input feature maps
self.in_w = in_w # width of input feature maps
self.in_channels = in_channels # depth of input feature maps
self.out_channels = out_channels # depth wise projection dimension
self.out_rows = out_rows # row wise projection dimesion
self.mix_depth = mix_depth # L the number of stacked FeatureMixers
self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block
hw = in_h*in_w
self.mix = nn.Sequential(*[
FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio)
for _ in range(self.mix_depth)
])
self.channel_proj = nn.Linear(in_channels, out_channels)
self.row_proj = nn.Linear(hw, out_rows)
def forward(self, x):
x = x.flatten(2)
x = self.mix(x)
x = x.permute(0, 2, 1)
x = self.channel_proj(x)
x = x.permute(0, 2, 1)
x = self.row_proj(x)
x = F.normalize(x.flatten(1), p=2, dim=1)
return x
# -------------------------------------------------------------------------------
def print_nb_params(m):
model_parameters = filter(lambda p: p.requires_grad, m.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f'Trainable parameters: {params/1e6:.3}M')
def main():
x = torch.randn(1, 1024, 20, 20)
agg = MixVPR(
in_channels=1024,
in_h=20,
in_w=20,
out_channels=1024,
mix_depth=4,
mlp_ratio=1,
out_rows=4)
print_nb_params(agg)
output = agg(x)
print(output.shape)
if __name__ == '__main__':
main()