|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
from typing import Any, List, Optional, Tuple, Union |
|
|
import re |
|
|
import json |
|
|
import math |
|
|
import librosa |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from decord import VideoReader, cpu |
|
|
from torch import nn |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers import (GenerationConfig, Qwen3ForCausalLM, WhisperFeatureExtractor) |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
import onnxruntime |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
import torchaudio |
|
|
from transformers.utils.hub import cached_file |
|
|
|
|
|
from .configuration_interactiveomni import InteractiveOmniConfig |
|
|
from .modeling_intern_vit import InternVisionModel |
|
|
from .modeling_whisper import AudioWhisperModel |
|
|
from .modeling_voicelm import VoiceLM |
|
|
from .conversation import get_conv_template |
|
|
|
|
|
from .modeling_flow import CausalMaskedDiffWithXvec |
|
|
from .modeling_hifigan import HiFTGenerator |
|
|
|
|
|
import logging |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
IMG_START_TOKEN = '<img>' |
|
|
IMG_END_TOKEN = '</img>' |
|
|
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' |
|
|
AUDIO_START_TOKEN = '<audio>' |
|
|
AUDIO_END_TOKEN = '</audio>' |
|
|
AUDIO_CONTEXT_TOKEN = '<AUDIO_CONTEXT>' |
|
|
|
|
|
|
|
|
class InteractiveOmniModel(PreTrainedModel): |
|
|
config_class = InteractiveOmniConfig |
|
|
main_input_name = 'pixel_values' |
|
|
base_model_prefix = 'language_model' |
|
|
_no_split_modules = ['InternVisionModel', 'AudioWhisperModel', 'Qwen3DecoderLayer', 'Qwen2DecoderLayer'] |
|
|
|
|
|
def __init__(self, config: InteractiveOmniConfig, vision_model=None, language_model=None, audio_model=None): |
|
|
super().__init__(config) |
|
|
|
|
|
image_size = config.force_image_size or config.vision_config.image_size |
|
|
patch_size = config.vision_config.patch_size |
|
|
self.patch_size = patch_size |
|
|
self.select_layer = config.select_layer |
|
|
self.template = config.template |
|
|
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) |
|
|
self.downsample_ratio = config.downsample_ratio |
|
|
self.ps_version = config.ps_version |
|
|
self.audio_feature_extractor = WhisperFeatureExtractor(**config.audio_preprocessor_config) |
|
|
self.transform = self.build_transform(input_size=image_size) |
|
|
|
|
|
self.campplus_session = None |
|
|
self.default_speaker_embedding = None |
|
|
self.default_wav_path = None |
|
|
|
|
|
logger.info(f'num_image_token: {self.num_image_token}') |
|
|
logger.info(f'ps_version: {self.ps_version}') |
|
|
if vision_model is not None: |
|
|
self.vision_model = vision_model |
|
|
else: |
|
|
self.vision_model = InternVisionModel(config.vision_config) |
|
|
if audio_model is not None: |
|
|
self.audio_model = audio_model |
|
|
else: |
|
|
self.audio_model = AudioWhisperModel(config.audio_config) |
|
|
if language_model is not None: |
|
|
self.language_model = language_model |
|
|
else: |
|
|
self.language_model = Qwen3ForCausalLM(config.llm_config) |
|
|
|
|
|
self.voicelm_model = VoiceLM(config.voicelm_config) |
|
|
self.flow_model = CausalMaskedDiffWithXvec(config.flow_config).float() |
|
|
self.hifigan_model = HiFTGenerator(config.hifigan_config).float() |
|
|
|
|
|
vit_hidden_size = config.vision_config.hidden_size |
|
|
audio_hidden_size = config.audio_config.d_model |
|
|
llm_hidden_size = config.llm_config.hidden_size |
|
|
|
|
|
self.mlp1 = nn.Sequential( |
|
|
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), |
|
|
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(llm_hidden_size, llm_hidden_size) |
|
|
) |
|
|
self.mlp2 = nn.Sequential( |
|
|
nn.LayerNorm(audio_hidden_size), |
|
|
nn.Linear(audio_hidden_size, llm_hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(llm_hidden_size, llm_hidden_size) |
|
|
) |
|
|
|
|
|
self.mlp_llm2voicelm = nn.Sequential( |
|
|
nn.LayerNorm(llm_hidden_size), |
|
|
nn.Linear(llm_hidden_size, config.voicelm_config.llm_input_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(config.voicelm_config.llm_input_size, config.voicelm_config.llm_input_size) |
|
|
) |
|
|
self.gate = nn.Sequential( |
|
|
nn.Linear(2 * llm_hidden_size, llm_hidden_size), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
self.img_context_token_id = None |
|
|
self.audio_context_token_id = None |
|
|
self.neftune_alpha = None |
|
|
|
|
|
self.post_init() |
|
|
pass |
|
|
|
|
|
def fusion(self, rep, emb): |
|
|
gate = self.gate(torch.cat([rep, emb], dim=-1)) |
|
|
return rep * gate + emb * (1 - gate) |
|
|
|
|
|
def __load_campplus_session(self, campplus_path:str): |
|
|
'''''' |
|
|
logger.info(f"load campplus session: {campplus_path}") |
|
|
option = onnxruntime.SessionOptions() |
|
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
option.intra_op_num_threads = 1 |
|
|
campplus_session = onnxruntime.InferenceSession( |
|
|
campplus_path, |
|
|
sess_options=option, |
|
|
providers=["CPUExecutionProvider"], |
|
|
) |
|
|
self.campplus_session = campplus_session |
|
|
return campplus_session |
|
|
|
|
|
def extract_speaker_embedding(self, prompt_wav:str): |
|
|
'''extract speaker embedding tensor''' |
|
|
logger.info(f"extract speaker embedding: {prompt_wav}") |
|
|
target_sr = 16000 |
|
|
prompt_speech_16k, sample_rate = torchaudio.load(prompt_wav) |
|
|
prompt_speech_16k = prompt_speech_16k.mean(dim=0, keepdim=True) |
|
|
if sample_rate != target_sr: |
|
|
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) |
|
|
prompt_speech_16k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(prompt_speech_16k) |
|
|
|
|
|
feat = kaldi.fbank( |
|
|
prompt_speech_16k, |
|
|
num_mel_bins=80, |
|
|
dither=0, |
|
|
sample_frequency=target_sr, |
|
|
) |
|
|
feat = feat - feat.mean(dim=0, keepdim=True) |
|
|
speaker_embedding = self.campplus_session.run( |
|
|
None, |
|
|
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}, |
|
|
)[0].flatten().tolist() |
|
|
speaker_embedding = torch.tensor([speaker_embedding]) |
|
|
return speaker_embedding |
|
|
|
|
|
def build_transform(self, input_size): |
|
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
|
transform = T.Compose([ |
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=MEAN, std=STD) |
|
|
]) |
|
|
|
|
|
return transform |
|
|
|
|
|
def find_closest_aspect_ratio(self, image, min_num=1, max_num=6, image_size=448): |
|
|
assert min_num == 1 |
|
|
original_width, original_height = image.size |
|
|
log_ratio = math.log(original_width / original_height) |
|
|
ratio = original_width * original_height / (image_size * image_size) |
|
|
multiple = min(math.ceil(ratio), max_num) |
|
|
if multiple <= 1: |
|
|
return [1, 1] |
|
|
candidate_split_grids_nums = [] |
|
|
for i in [multiple - 1, multiple, multiple + 1]: |
|
|
if i > max_num: |
|
|
continue |
|
|
candidate_split_grids_nums.append(i) |
|
|
|
|
|
candidate_grids = [] |
|
|
for split_grids_nums in candidate_split_grids_nums: |
|
|
m = 1 |
|
|
while m <= split_grids_nums: |
|
|
if split_grids_nums % m == 0: |
|
|
candidate_grids.append([m, split_grids_nums // m]) |
|
|
m += 1 |
|
|
best_grid = [1, 1] |
|
|
min_error = float("inf") |
|
|
for grid in candidate_grids: |
|
|
error = abs(log_ratio - math.log(grid[0] / grid[1])) |
|
|
if error < min_error: |
|
|
best_grid = grid |
|
|
min_error = error |
|
|
|
|
|
return best_grid |
|
|
|
|
|
def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): |
|
|
target_aspect_ratio = self.find_closest_aspect_ratio(image, min_num, max_num, image_size) |
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // image_size)) * image_size, |
|
|
(i // (target_width // image_size)) * image_size, |
|
|
((i % (target_width // image_size)) + 1) * image_size, |
|
|
((i // (target_width // image_size)) + 1) * image_size |
|
|
) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images |
|
|
|
|
|
def load_image(self, image, input_size=448, max_num=12): |
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.open(image).convert('RGB') |
|
|
images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
|
|
return images |
|
|
|
|
|
def pixel_shuffle(self, x, scale_factor=0.5): |
|
|
n, w, h, c = x.size() |
|
|
|
|
|
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) |
|
|
|
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
|
|
|
x = x.view(n, int(h * scale_factor), int(w * scale_factor), |
|
|
int(c / (scale_factor * scale_factor))) |
|
|
if self.ps_version == 'v1': |
|
|
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " |
|
|
'which results in a transposed image.') |
|
|
else: |
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
return x |
|
|
|
|
|
def extract_feature(self, pixel_values): |
|
|
if self.select_layer == -1: |
|
|
vit_embeds = self.vision_model( |
|
|
pixel_values=pixel_values, |
|
|
output_hidden_states=False, |
|
|
return_dict=True).last_hidden_state |
|
|
else: |
|
|
vit_embeds = self.vision_model( |
|
|
pixel_values=pixel_values, |
|
|
output_hidden_states=True, |
|
|
return_dict=True).hidden_states[self.select_layer] |
|
|
vit_embeds = vit_embeds[:, 1:, :] |
|
|
|
|
|
if self.training and self.neftune_alpha is not None: |
|
|
vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha) |
|
|
|
|
|
h = w = int(vit_embeds.shape[1] ** 0.5) |
|
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
|
|
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) |
|
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
|
|
vit_embeds = self.mlp1(vit_embeds) |
|
|
return vit_embeds |
|
|
|
|
|
def get_T_after_cnn(self, L_in, dilation=1): |
|
|
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): |
|
|
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 |
|
|
L_out = 1 + L_out // stride |
|
|
L_in = L_out |
|
|
return L_out |
|
|
|
|
|
def process_audio(self, audio, return_tensors, sampling_rate=16000): |
|
|
L = (audio.shape[0] if audio.shape[0] <= 480000 else 480000) |
|
|
mel_len = L // 160 |
|
|
audio_len_after_cnn = self.get_T_after_cnn(mel_len) |
|
|
audio_token_num = (audio_len_after_cnn - 2) // 2 + 1 |
|
|
inputs = self.audio_feature_extractor(audio, return_tensors=return_tensors, sampling_rate=sampling_rate) |
|
|
inputs['audio_len_after_cnn'] = torch.tensor(audio_len_after_cnn, dtype=torch.long) |
|
|
inputs['audio_token_num'] = torch.tensor(audio_token_num, dtype=torch.long) |
|
|
return inputs |
|
|
|
|
|
def load_audio(self, audio_file, sampling_rate=16000): |
|
|
audio_values, _ = librosa.load(audio_file, sr=sampling_rate) |
|
|
|
|
|
audio_process_values = self.process_audio(audio_values, sampling_rate=sampling_rate, return_tensors="pt") |
|
|
input_features = audio_process_values['input_features'] |
|
|
audio_len_after_cnn = audio_process_values['audio_len_after_cnn'] |
|
|
audio_token_num = audio_process_values['audio_token_num'] |
|
|
|
|
|
audio_input_dict = {'audio_values': input_features, |
|
|
'audio_len_after_cnn': audio_len_after_cnn, |
|
|
'audio_token_num': audio_token_num, |
|
|
} |
|
|
return audio_input_dict |
|
|
|
|
|
def extract_audio_feature(self, audio_values, audio_len_after_cnn): |
|
|
|
|
|
audio_values = audio_values.squeeze(1) |
|
|
max_len_in_batch = int(torch.max(audio_len_after_cnn).item()) |
|
|
padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype, device=audio_values.device) |
|
|
for index in range(len(audio_values)): |
|
|
padding_mask[index, :int(audio_len_after_cnn[index].item())] = 0 |
|
|
|
|
|
last_hidden_state = self.audio_model(audio_values, padding_mask, audio_len_after_cnn) |
|
|
|
|
|
audio_embeds = self.mlp2(last_hidden_state) |
|
|
|
|
|
return audio_embeds |
|
|
|
|
|
def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=32): |
|
|
if bound: |
|
|
start, end = bound[0], bound[1] |
|
|
else: |
|
|
start, end = -100000, 100000 |
|
|
start_idx = max(first_idx, round(start * fps)) |
|
|
end_idx = min(round(end * fps), max_frame) |
|
|
seg_size = float(end_idx - start_idx) / num_segments |
|
|
frame_indices = np.array([ |
|
|
int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) |
|
|
for idx in range(num_segments) |
|
|
]) |
|
|
return frame_indices |
|
|
|
|
|
def load_video(self, video_path, bound=None, num_segments=32): |
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) |
|
|
max_frame = len(vr) - 1 |
|
|
fps = float(vr.get_avg_fps()) |
|
|
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) |
|
|
frames = list() |
|
|
for frame_index in frame_indices: |
|
|
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') |
|
|
frames.append(img) |
|
|
return frames |
|
|
|
|
|
def find_second_last_occurrence(self, input_ids_list, target_id): |
|
|
'''find taget_id index''' |
|
|
reversed_list = list(reversed(input_ids_list)) |
|
|
first_occurrence = -1 |
|
|
second_occurrence = -1 |
|
|
for idx, val in enumerate(reversed_list): |
|
|
if val == target_id: |
|
|
if first_occurrence == -1: |
|
|
first_occurrence = idx |
|
|
elif second_occurrence == -1: |
|
|
second_occurrence = idx |
|
|
break |
|
|
|
|
|
if second_occurrence == -1: |
|
|
return -1 |
|
|
return len(input_ids_list) - second_occurrence - 1 |
|
|
|
|
|
def decode_speech_tokens( |
|
|
self, |
|
|
speech_tokens, |
|
|
speaker_embedding=None, |
|
|
flow_prompt_speech_token=None, |
|
|
prompt_speech_feat=None, |
|
|
finalize=True, |
|
|
token_offset=0, |
|
|
): |
|
|
if speaker_embedding is None: |
|
|
speaker_embedding = torch.zeros(1, 192) |
|
|
pass |
|
|
if flow_prompt_speech_token is None: |
|
|
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) |
|
|
pass |
|
|
if prompt_speech_feat is None: |
|
|
prompt_speech_feat = torch.zeros(1, 0, 80) |
|
|
pass |
|
|
|
|
|
self.flow_model.encoder.static_chunk_size = 2 * self.flow_model.input_frame_rate |
|
|
self.flow_model.decoder.estimator.static_chunk_size = 2 * self.flow_model.input_frame_rate * self.flow_model.token_mel_ratio |
|
|
device = speech_tokens.device |
|
|
|
|
|
tts_mel, _ = self.flow_model.inference( |
|
|
token=speech_tokens.to(device), |
|
|
token_len=torch.tensor([speech_tokens.shape[1]], dtype=torch.int32).to(device), |
|
|
prompt_token=flow_prompt_speech_token.to(device), |
|
|
prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(device), |
|
|
prompt_feat=prompt_speech_feat.to(device), |
|
|
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(device), |
|
|
embedding=speaker_embedding.to(device), |
|
|
finalize=finalize, |
|
|
) |
|
|
tts_mel = tts_mel[:, :, token_offset * self.config.flow_config.token_mel_ratio:] |
|
|
|
|
|
hift_cache_source = torch.zeros(1, 1, 0) |
|
|
tts_speech, tts_source = self.hifigan_model.inference(speech_feat=tts_mel, cache_source=hift_cache_source) |
|
|
|
|
|
return tts_speech |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
pixel_values: torch.FloatTensor, |
|
|
input_ids: torch.FloatTensor, |
|
|
attention_mask: torch.LongTensor, |
|
|
visual_features: Optional[torch.FloatTensor] = None, |
|
|
audio_values: Optional[torch.FloatTensor] = None, |
|
|
audio_len_after_cnn: Optional[bool] = None, |
|
|
audio_token_num: Optional[bool] = None, |
|
|
generation_config: Optional[GenerationConfig] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
start_token_id:int = 151644, |
|
|
generate_audio:bool = False, |
|
|
speaker_embedding:torch.Tensor = torch.zeros(1, 192), |
|
|
mix_ratio:list=[5,25], |
|
|
**generate_kwargs, |
|
|
) -> torch.LongTensor: |
|
|
assert self.img_context_token_id is not None |
|
|
assert self.audio_context_token_id is not None |
|
|
|
|
|
vit_embeds = None |
|
|
if visual_features is not None: |
|
|
vit_embeds = visual_features |
|
|
elif pixel_values is not None: |
|
|
vit_embeds = self.extract_feature(pixel_values) |
|
|
cur_conv_start_id = self.find_second_last_occurrence(input_ids.tolist()[0], start_token_id) |
|
|
|
|
|
input_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
B, N, C = input_embeds.shape |
|
|
input_embeds = input_embeds.reshape(B * N, C) |
|
|
|
|
|
input_ids = input_ids.reshape(B * N) |
|
|
|
|
|
if vit_embeds is not None: |
|
|
selected = (input_ids == self.img_context_token_id) |
|
|
input_embeds[selected] = vit_embeds.reshape(-1, C) |
|
|
|
|
|
if audio_values is not None and audio_len_after_cnn is not None and audio_token_num is not None: |
|
|
audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn) |
|
|
output_audios = [] |
|
|
for i in range(len(audio_token_num)): |
|
|
token_num = int(audio_token_num[i].item()) |
|
|
audio = audio_embeds[i][:token_num] |
|
|
output_audios.append(audio) |
|
|
output_audios = torch.cat(output_audios, dim=0) |
|
|
selected = (input_ids == self.audio_context_token_id) |
|
|
input_embeds[selected] = output_audios.reshape(-1, C) |
|
|
|
|
|
input_embeds = input_embeds.reshape(B, N, C) |
|
|
|
|
|
outputs = self.language_model.generate( |
|
|
inputs_embeds=input_embeds, |
|
|
attention_mask=attention_mask, |
|
|
generation_config=generation_config, |
|
|
output_hidden_states=output_hidden_states or generate_audio, |
|
|
return_dict_in_generate=generate_audio, |
|
|
use_cache=True, |
|
|
**generate_kwargs, |
|
|
) |
|
|
if not generate_audio: |
|
|
return outputs, None, None |
|
|
|
|
|
hidden_states = torch.cat( |
|
|
[outputs.hidden_states[0][-1][:, -1:, :]] + [outputs.hidden_states[i][-1] for i in range(1, len(outputs.hidden_states))], |
|
|
dim=1, |
|
|
) |
|
|
sampled_token = outputs.sequences |
|
|
if sampled_token.shape[1] == hidden_states.shape[1] + 1: |
|
|
sampled_token = sampled_token[:, 1:] |
|
|
sampled_token_embeddings = self.language_model.get_input_embeddings()(sampled_token) |
|
|
target_text_token_hidden_states = self.fusion(hidden_states, sampled_token_embeddings) |
|
|
|
|
|
input_token_hidden_states = outputs.hidden_states[0][-1][:, cur_conv_start_id:-1, :] |
|
|
question_input_embeddings = input_embeds[:, cur_conv_start_id+1:, :] |
|
|
input_token_hidden_states = self.fusion(input_token_hidden_states, question_input_embeddings) |
|
|
|
|
|
input_feature = self.mlp_llm2voicelm(input_token_hidden_states) |
|
|
target_text_feature = self.mlp_llm2voicelm(target_text_token_hidden_states) |
|
|
|
|
|
try: |
|
|
speech_tokens = self.voicelm_model.inference_bistream(input_feature, target_text_feature, mix_ratio=mix_ratio) |
|
|
speech_tokens = torch.LongTensor([speech_tokens]).to(input_feature.device) |
|
|
tts_speech = self.decode_speech_tokens( |
|
|
speech_tokens, |
|
|
speaker_embedding=speaker_embedding, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"=========voice lm except:{e}") |
|
|
return outputs.sequences,None, None |
|
|
return outputs.sequences, speech_tokens, tts_speech |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
tokenizer, |
|
|
generation_config, |
|
|
messages, |
|
|
max_patch_num=12, |
|
|
frame=8, |
|
|
generate_audio=False, |
|
|
speaker_embedding=torch.zeros(1, 192), |
|
|
print_flag=True, |
|
|
): |
|
|
if self.flow_model.dtype != torch.float32 or self.hifigan_model.dtype != torch.float32: |
|
|
logger.info(f"reset flow model and higigan model dtype to float32") |
|
|
self.reset_vocoder() |
|
|
pass |
|
|
if messages is None or len(messages) == 0: |
|
|
raise RuntimeError('no messages') |
|
|
role_transfer_dict = { |
|
|
'system': ['user'], |
|
|
'user': ['assistant'], |
|
|
'assistant': ['user'], |
|
|
} |
|
|
|
|
|
first_role = ['system', 'user'] |
|
|
last_role = ['user'] |
|
|
if messages[-1]['role'] not in last_role: |
|
|
raise RuntimeError(f"last role error, expect {last_role}, but got {messages[-1]}") |
|
|
|
|
|
current_role = None |
|
|
dynamic_images = list() |
|
|
dynamic_nums = list() |
|
|
audio_values = list() |
|
|
audio_len_after_cnn = list() |
|
|
audio_token_num = list() |
|
|
template = get_conv_template(self.template) |
|
|
for index in range(len(messages)): |
|
|
text = '' |
|
|
audios = list() |
|
|
images = list() |
|
|
message = messages[index] |
|
|
if index == 0: |
|
|
if message['role'] not in first_role: |
|
|
raise RuntimeError(f'first role error expect {first_role}, but got {message}') |
|
|
else: |
|
|
if message['role'] not in current_role: |
|
|
raise RuntimeError(f'role error expect {current_role}, but got {message}') |
|
|
current_role = message['role'] |
|
|
if isinstance(message["content"], list): |
|
|
for item in message["content"]: |
|
|
if item['type'] == 'text': |
|
|
if item.get('text', None) is None: |
|
|
continue |
|
|
text += item['text'] |
|
|
elif item['type'] == 'audio': |
|
|
if item.get('audio', None) is None: |
|
|
continue |
|
|
if type(item['audio']) is list: |
|
|
assert len(item['audio']) == 1, f'only support 1 audio file in round, but got {item["audio"]}' |
|
|
audio = item['audio'][0] |
|
|
else: |
|
|
audio = item['audio'] |
|
|
audios.append(audio) |
|
|
elif item['type'] == 'image': |
|
|
if item.get('image', None) is None: |
|
|
continue |
|
|
if type(item['image']) is not list: |
|
|
images.append(item['image']) |
|
|
else: |
|
|
images.extend(item['image']) |
|
|
elif item['type'] == 'video': |
|
|
if item.get('video', None) is None: |
|
|
continue |
|
|
if type(item['video']) is list: |
|
|
assert len(item['video']) == 1, f'only support 1 video file in round, but got {item["video"]}' |
|
|
video = item['video'][0] |
|
|
else: |
|
|
video = item['video'] |
|
|
frames = self.load_video(video, num_segments=frame) |
|
|
images.extend(frames) |
|
|
else: |
|
|
assert isinstance(message["content"], str), message["content"] |
|
|
text = message["content"] |
|
|
|
|
|
if len(audios) != 0: |
|
|
assert len(audios) == 1, f'only support 1 audio file in round, but got {audios}' |
|
|
if '<audio>' in text: |
|
|
matches = re.findall(r"<audio>", text) |
|
|
assert len(matches) == len(audios), f'<audio> error {text} {len(audios)}' + text |
|
|
text = re.sub(r'(<audio>)(?!\n)', r'\1\n', text) |
|
|
else: |
|
|
text = '<audio>\n'*len(audios) + text |
|
|
|
|
|
audio_path = audios[0] |
|
|
audio_input_dict = self.load_audio(audio_path) |
|
|
assert audio_input_dict['audio_token_num'].item() != 0, f'audio_token_num of {audio_path} is 0.' |
|
|
audio_values.append(audio_input_dict['audio_values']) |
|
|
audio_len_after_cnn.append(audio_input_dict['audio_len_after_cnn']) |
|
|
audio_token_num.append(audio_input_dict['audio_token_num']) |
|
|
|
|
|
if images is not None: |
|
|
if '<image>' in text: |
|
|
matches = re.findall(r"<image>", text) |
|
|
assert len(matches) == len(images), f'<image> error {text} {len(images)}' + text |
|
|
text = re.sub(r'(<image>)(?!\n)', r'\1\n', text) |
|
|
else: |
|
|
text = '<image>\n'*len(images) + text |
|
|
|
|
|
for image in images: |
|
|
dynamic_image = self.load_image(image, max_num=max_patch_num) |
|
|
dynamic_images += dynamic_image |
|
|
dynamic_nums.append(len(dynamic_image)) |
|
|
|
|
|
if message['role'] == 'system': |
|
|
template.set_system_message(text) |
|
|
elif message['role'] == 'user': |
|
|
template.append_message(template.roles[0], text) |
|
|
elif message['role'] == 'assistant': |
|
|
template.append_message(template.roles[1], text) |
|
|
else: |
|
|
raise ValueError('unexpected role') |
|
|
|
|
|
current_role = role_transfer_dict[current_role] |
|
|
|
|
|
template.append_message(template.roles[1], None) |
|
|
|
|
|
if len(audio_values) != 0: |
|
|
audio_values = torch.cat(audio_values, dim=0).to(dtype=self.dtype).cuda() |
|
|
audio_len_after_cnn = torch.stack(audio_len_after_cnn, dim=0) |
|
|
audio_token_num = torch.stack(audio_token_num, dim=0) |
|
|
else: |
|
|
audio_values = None |
|
|
audio_len_after_cnn = None |
|
|
audio_token_num = None |
|
|
|
|
|
if len(dynamic_images) != 0: |
|
|
pixel_values = [self.transform(image) for image in dynamic_images] |
|
|
pixel_values = torch.stack(pixel_values) |
|
|
pixel_values = pixel_values.to(torch.bfloat16).cuda() |
|
|
else: |
|
|
pixel_values = None |
|
|
dynamic_nums = None |
|
|
|
|
|
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
|
|
self.img_context_token_id = img_context_token_id |
|
|
audio_context_token_id = tokenizer.convert_tokens_to_ids(AUDIO_CONTEXT_TOKEN) |
|
|
self.audio_context_token_id = audio_context_token_id |
|
|
|
|
|
|
|
|
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]] |
|
|
start_token_id = tokenizer.convert_tokens_to_ids(["<|im_start|>"])[0] |
|
|
|
|
|
query = template.get_prompt() |
|
|
|
|
|
if audio_values is not None: |
|
|
if print_flag: |
|
|
logger.info(f'audio num: {len(audio_token_num)}') |
|
|
audio_tokens_list = list() |
|
|
for index in range(len(audio_token_num)): |
|
|
audio_token_num_i = audio_token_num[index] |
|
|
if print_flag: |
|
|
logger.info(f'audio_token_num: {audio_token_num_i}') |
|
|
audio_tokens = AUDIO_START_TOKEN + AUDIO_CONTEXT_TOKEN * audio_token_num_i + AUDIO_END_TOKEN |
|
|
audio_tokens_list.append(audio_tokens) |
|
|
|
|
|
audio_tokens_iter = iter(audio_tokens_list) |
|
|
|
|
|
query = re.sub(r"<audio>", lambda match:next(audio_tokens_iter), query) |
|
|
|
|
|
if pixel_values is not None: |
|
|
if print_flag: |
|
|
logger.info(f'image num: {len(dynamic_nums)}') |
|
|
image_tokens_list = list() |
|
|
total_dynamic_num = 0 |
|
|
for index in range(len(dynamic_nums)): |
|
|
dynamic_num = dynamic_nums[index] |
|
|
total_dynamic_num += dynamic_num |
|
|
if print_flag: |
|
|
logger.info(f'dynamic ViT batch size: {dynamic_num}') |
|
|
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * dynamic_num + IMG_END_TOKEN |
|
|
image_tokens_list.append(image_tokens) |
|
|
assert total_dynamic_num == pixel_values.shape[0], f'dynamic num not equal, {total_dynamic_num}, {pixel_values.shape[0]}' |
|
|
|
|
|
image_tokens_iter = iter(image_tokens_list) |
|
|
|
|
|
query = re.sub(r"<image>", lambda match:next(image_tokens_iter), query) |
|
|
|
|
|
model_inputs = tokenizer(query, return_tensors='pt', add_special_tokens=False) |
|
|
input_ids = model_inputs['input_ids'].cuda() |
|
|
attention_mask = model_inputs['attention_mask'].cuda() |
|
|
generation_config['eos_token_id'] = eos_token_id |
|
|
generation_output, speech_token, audio_bytes = self.generate( |
|
|
pixel_values=pixel_values, |
|
|
audio_values=audio_values, |
|
|
audio_len_after_cnn=audio_len_after_cnn, |
|
|
audio_token_num=audio_token_num, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
generate_audio=generate_audio, |
|
|
start_token_id=start_token_id, |
|
|
speaker_embedding=speaker_embedding, |
|
|
**generation_config |
|
|
) |
|
|
response = tokenizer.batch_decode(generation_output, skip_special_tokens=False)[0] |
|
|
response = response.split("<|im_end|>")[0].replace('<|endoftext|>', '').strip() |
|
|
query_to_print = query |
|
|
if pixel_values is not None: |
|
|
query_to_print = query_to_print.replace(IMG_CONTEXT_TOKEN, '') |
|
|
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
|
|
if audio_values is not None: |
|
|
query_to_print = query_to_print.replace(AUDIO_CONTEXT_TOKEN, '') |
|
|
query_to_print = query_to_print.replace(f'{AUDIO_START_TOKEN}{AUDIO_END_TOKEN}', '<audio>') |
|
|
if print_flag: |
|
|
logger.info('query: ' + json.dumps(query_to_print, ensure_ascii=False)) |
|
|
logger.info('response: ' + response) |
|
|
|
|
|
if generate_audio: |
|
|
return response, audio_bytes |
|
|
return response |
|
|
|
|
|
def __cache_file(self, pretrained_model_name_or_path:str, filename:str, **kw): |
|
|
'''cache some file''' |
|
|
full_path = cached_file( |
|
|
pretrained_model_name_or_path, |
|
|
filename, |
|
|
subfolder=kw.pop("subfolder", None), |
|
|
cache_dir=kw.pop("cache_dir", None), |
|
|
force_download=kw.pop("force_download", False), |
|
|
proxies=kw.pop("proxies", None), |
|
|
resume_download=kw.pop("resume_download", None), |
|
|
local_files_only=kw.pop("local_files_only", False), |
|
|
token=kw.pop("use_auth_token", None), |
|
|
revision=kw.pop("revision", None), |
|
|
) |
|
|
if full_path is None: |
|
|
raise ValueError(f"""{pretrained_model_name_or_path}/{filename} not exists""") |
|
|
return full_path |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
config=None, |
|
|
cache_dir=None, |
|
|
ignore_mismatched_sizes=False, |
|
|
force_download=False, |
|
|
local_files_only=False, |
|
|
token=None, |
|
|
revision="main", |
|
|
use_safetensors=None, |
|
|
weights_only=True, |
|
|
**kwargs, |
|
|
): |
|
|
model = super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
config=config, |
|
|
cache_dir=cache_dir, |
|
|
ignore_mismatched_sizes=ignore_mismatched_sizes, |
|
|
force_download=force_download, |
|
|
local_files_only=local_files_only, |
|
|
token=token, |
|
|
revision=revision, |
|
|
use_safetensors=use_safetensors, |
|
|
weights_only=weights_only, |
|
|
**kwargs, |
|
|
) |
|
|
campplus_path = model.__cache_file(pretrained_model_name_or_path, "campplus.onnx", **kwargs) |
|
|
model.__load_campplus_session(campplus_path) |
|
|
default_wav_path = model.__cache_file(pretrained_model_name_or_path, "taozi.wav", **kwargs) |
|
|
model.default_wav_path = default_wav_path |
|
|
model.default_speaker_embedding = model.extract_speaker_embedding(default_wav_path) |
|
|
|
|
|
return model |