|
|
|
|
|
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
|
|
|
from .model import FantasyTalkingAudioConditionModel
|
|
from .utils import get_audio_features
|
|
import gc, torch
|
|
|
|
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
|
|
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
|
|
from mmgp import offload
|
|
from accelerate import init_empty_weights
|
|
from fantasytalking.model import AudioProjModel
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
with init_empty_weights():
|
|
proj_model = AudioProjModel( 768, 2048)
|
|
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
|
|
proj_model.to("cpu").eval().requires_grad_(False)
|
|
|
|
wav2vec_model_dir = "ckpts/wav2vec"
|
|
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
|
|
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False)
|
|
wav2vec.to(device)
|
|
proj_model.to(device)
|
|
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
|
|
|
|
audio_proj_fea = proj_model(audio_wav2vec_fea)
|
|
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
|
|
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 )
|
|
wav2vec, proj_model= None, None
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
return audio_proj_split, audio_context_lens |