tky823 commited on
Commit
0c81dc9
·
verified ·
1 Parent(s): dda975e

Upload modeling_mulan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_mulan.py +229 -0
modeling_mulan.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 LY Corporation
2
+ # ported from https://huggingface.co/line-corporation/clip-japanese-base/blob/main/modeling_clyp.py
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from sentence_transformers import SentenceTransformer
9
+ from torch.nn.modules.utils import _pair
10
+ from transformers import PreTrainedModel
11
+ from transformers.tokenization_utils_base import BatchEncoding
12
+
13
+ from .configuration_mulan import (
14
+ JapaneseMuLanConfig,
15
+ JapaneseMuLanMusicEncoderConfig,
16
+ JapaneseMuLanTextEncoderConfig,
17
+ )
18
+ from .modeling_ast import (
19
+ AudioSpectrogramTransformer,
20
+ HeadTokenAggregator,
21
+ PositionalPatchEmbedding,
22
+ )
23
+
24
+
25
+ class MuLanPreTrainedModel(PreTrainedModel):
26
+ config_class = JapaneseMuLanConfig
27
+
28
+ def __init__(self, *args, **kwargs) -> None:
29
+ super().__init__(*args, **kwargs)
30
+
31
+ def _init_weights(self, module: Any) -> None:
32
+ pass
33
+
34
+
35
+ class MuLanModel(MuLanPreTrainedModel):
36
+ def __init__(self, config: JapaneseMuLanConfig) -> None:
37
+ super().__init__(config)
38
+
39
+ self.music_encoder = create_music_encoder(config.music_encoder_config)
40
+ self.text_encoder = create_text_encoder(config.text_encoder_config)
41
+
42
+ def get_music_features(
43
+ self, spectrogram: torch.Tensor, batch_mean: bool = True
44
+ ) -> torch.Tensor:
45
+ if batch_mean is None:
46
+ if self.training:
47
+ batch_mean = False
48
+ else:
49
+ batch_mean = True
50
+
51
+ music_embedding = self.music_encoder(spectrogram, batch_mean=batch_mean)
52
+
53
+ return music_embedding
54
+
55
+ def get_text_features(
56
+ self,
57
+ input_ids: Optional[torch.Tensor] = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ ) -> torch.Tensor:
60
+ text_embedding = self.text_encoder(
61
+ {
62
+ "input_ids": input_ids,
63
+ "attention_mask": attention_mask,
64
+ },
65
+ batch_mean=False,
66
+ )
67
+
68
+ return text_embedding
69
+
70
+
71
+ class ModalEncoderWrapper(nn.Module):
72
+ """Wrapper class of modal tower."""
73
+
74
+ def __init__(
75
+ self,
76
+ backbone: nn.Module,
77
+ out_channels: int,
78
+ hidden_channels: Optional[int] = None,
79
+ freeze_backbone: bool = False,
80
+ ) -> None:
81
+ super().__init__()
82
+
83
+ self.backbone = backbone
84
+
85
+ if hidden_channels is None:
86
+ if isinstance(backbone, AudioSpectrogramTransformer):
87
+ backbone: AudioSpectrogramTransformer
88
+ hidden_channels = backbone.embedding.embedding_dim
89
+ elif isinstance(backbone, SentenceTransformer):
90
+ backbone: SentenceTransformer
91
+ hidden_channels = backbone[-1].word_embedding_dimension
92
+ else:
93
+ raise NotImplementedError(
94
+ f"{type(backbone)} is not supported as backbone network."
95
+ )
96
+
97
+ self.linear = nn.Linear(hidden_channels, out_channels)
98
+
99
+ self.freeze_backbone = freeze_backbone
100
+
101
+ if self.freeze_backbone:
102
+ for p in self.backbone.parameters():
103
+ p.requires_grad = False
104
+
105
+ self.out_channels = out_channels
106
+
107
+ def forward(self, *args, batch_mean: bool = None, **kwargs) -> torch.Tensor:
108
+ """Forward pass of tower wrapper.
109
+
110
+ Args:
111
+ args (tuple): Positional arguments given to backbone.
112
+ kwargs (dict): Keyword arguments given to backbone.
113
+
114
+ Returns:
115
+ torch.Tensor: Embedding of shape (*, out_channels).
116
+
117
+ """
118
+ embed = self.backbone(*args, **kwargs)
119
+
120
+ if isinstance(self.backbone, SentenceTransformer):
121
+ if isinstance(embed, (dict, BatchEncoding)):
122
+ embed = embed["sentence_embedding"]
123
+ else:
124
+ raise ValueError(
125
+ f"Invalid type {type(embed)} is detected as sentence transformer output."
126
+ )
127
+ else:
128
+ assert isinstance(embed, torch.Tensor), (
129
+ f"Invalid type {type(embed)} is detected."
130
+ )
131
+
132
+ x = self.linear(embed)
133
+ output = F.normalize(x, p=2, dim=-1)
134
+
135
+ if self.training:
136
+ assert not batch_mean
137
+ else:
138
+ if batch_mean is None:
139
+ batch_mean = False
140
+
141
+ if batch_mean:
142
+ output = output.mean(dim=0, keepdim=True)
143
+
144
+ return output
145
+
146
+
147
+ class MusicEncoder(ModalEncoderWrapper):
148
+ """Alias of ModalEncoderWrapper for music modal."""
149
+
150
+
151
+ class TextEncoder(ModalEncoderWrapper):
152
+ """Alias of ModalEncoderWrapper for text modal."""
153
+
154
+
155
+ def create_music_encoder(config: JapaneseMuLanMusicEncoderConfig) -> MusicEncoder:
156
+ stride = _pair(config.stride)
157
+ n_bins = config.n_bins
158
+ n_frames = config.n_pretrained_frames
159
+ model_name = config.model_name
160
+ out_channels = config.out_channels
161
+
162
+ ast_prefix = "ast-"
163
+
164
+ if model_name.startswith(ast_prefix):
165
+ model_size = model_name[len(ast_prefix) :]
166
+
167
+ assert model_size == "base384", "Only base384 is supported as model_size."
168
+
169
+ kernel_size = (16, 16)
170
+ embedding_dim = 768
171
+ nhead = 12
172
+ dim_feedforward = 3072
173
+ activation = "gelu"
174
+ num_layers = 12
175
+ layer_norm_eps = 1e-6
176
+
177
+ embedding = PositionalPatchEmbedding(
178
+ embedding_dim=embedding_dim,
179
+ kernel_size=kernel_size,
180
+ stride=stride,
181
+ insert_cls_token=True,
182
+ insert_dist_token=True,
183
+ n_bins=n_bins,
184
+ n_frames=n_frames,
185
+ )
186
+ encoder_layer = nn.TransformerEncoderLayer(
187
+ d_model=embedding_dim,
188
+ nhead=nhead,
189
+ dim_feedforward=dim_feedforward,
190
+ activation=activation,
191
+ batch_first=True,
192
+ norm_first=True,
193
+ layer_norm_eps=layer_norm_eps,
194
+ )
195
+ norm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps)
196
+ backbone = nn.TransformerEncoder(
197
+ encoder_layer, num_layers=num_layers, norm=norm
198
+ )
199
+ aggregator = HeadTokenAggregator(position=0)
200
+ backbone = AudioSpectrogramTransformer(
201
+ embedding,
202
+ backbone,
203
+ aggregator=aggregator,
204
+ )
205
+ else:
206
+ raise NotImplementedError(
207
+ f"{model_name} is not supported as model_name of MusicEncoder."
208
+ )
209
+
210
+ return MusicEncoder(backbone, out_channels)
211
+
212
+
213
+ def create_text_encoder(config: JapaneseMuLanTextEncoderConfig) -> TextEncoder:
214
+ model_name = config.model_name
215
+ out_channels = config.out_channels
216
+
217
+ if model_name == "pkshatech/GLuCoSE-base-ja":
218
+ # NOTE: hack to avoid meta tensor error
219
+ backbone = SentenceTransformer(
220
+ model_name_or_path=model_name,
221
+ device="meta",
222
+ )
223
+ backbone.to_empty(device="cpu")
224
+ else:
225
+ raise NotImplementedError(
226
+ f"{model_name} is not supported as model_name of TextEncoder."
227
+ )
228
+
229
+ return TextEncoder(backbone, out_channels)