from .transcription import SpeechEncoder from .sentiment import TextEncoder import torch import torch.nn as nn class MultimodalSentimentClassifier(nn.Module): def __init__( self, wav2vec_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-french", #wav2vec_name: str = "alec228/audio-sentiment/tree/main/wav2vec2", bert_name: str = "nlptown/bert-base-multilingual-uncased-sentiment", #bert_name: str = "alec228/audio-sentiment/tree/main/bert-sentiment", #cache_dir: str = "./models", hidden_dim: int = 256, n_classes: int = 3 ): super().__init__() self.speech_encoder = SpeechEncoder( model_name=wav2vec_name, # cache_dir=cache_dir ) self.text_encoder = TextEncoder( model_name=bert_name, # cache_dir=cache_dir ) dim_a = self.speech_encoder.model.config.hidden_size dim_t = self.text_encoder.model.config.hidden_size self.classifier = nn.Sequential( nn.Linear(dim_a + dim_t, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, n_classes) ) def forward(self, audio_path: str, text: str) -> torch.Tensor: a_feat = self.speech_encoder.extract_features(audio_path) t_feat = self.text_encoder.extract_features([text]) fused = torch.cat([a_feat, t_feat], dim=1) return self.classifier(fused)