|
import os |
|
import json |
|
import asyncio |
|
import aiofiles |
|
from time import time |
|
import json |
|
from pprint import pprint |
|
from smallestai.waves import WavesClient, AsyncWavesClient |
|
|
|
|
|
class SmallestAITTS: |
|
def __init__( |
|
self, |
|
model_name: str, |
|
api_key: str, |
|
provider: str, |
|
endpoint_url: str, |
|
voice_id: str = None, |
|
sample_rate: int = 24000, |
|
speed: float = 1.0, |
|
is_async: bool = False, |
|
): |
|
if is_async: |
|
self.client = AsyncWavesClient(api_key=api_key) |
|
else: |
|
self.client = WavesClient(api_key=api_key) |
|
|
|
self.model_name = model_name |
|
self.api_key = api_key |
|
self.provider = provider |
|
self.endpoint_url = endpoint_url |
|
self.voice_id = voice_id |
|
self.sample_rate = sample_rate |
|
self.speed = speed |
|
self.tts = self._async_tts if is_async else self._tts |
|
self.is_async = is_async |
|
|
|
def load_voice(self, voice_id: str): |
|
""" |
|
Used for loading voices (Optional) |
|
""" |
|
self.voice_id = voice_id |
|
|
|
|
|
def synthesize(self, text: str, output_filepath: str): |
|
""" |
|
Unified interface for text-to-speech synthesis. |
|
Will automatically use async or sync implementation based on initialization. |
|
|
|
Args: |
|
text: The text to synthesize |
|
output_filepath: Path to save the audio file |
|
""" |
|
if self.is_async: |
|
|
|
try: |
|
return asyncio.get_event_loop().run_until_complete( |
|
self._async_tts(text, output_filepath) |
|
) |
|
except RuntimeError: |
|
|
|
return asyncio.run(self._async_tts(text, output_filepath)) |
|
else: |
|
return self._tts(text, output_filepath) |
|
|
|
def _tts(self, text: str, output_filepath: str): |
|
|
|
assert self.voice_id is not None, "Please set a voice style." |
|
self.client.synthesize( |
|
text, |
|
save_as=output_filepath, |
|
model=self.model_name, |
|
voice_id=self.voice_id, |
|
speed=self.speed, |
|
sample_rate=self.sample_rate, |
|
) |
|
|
|
async def _async_tts(self, text: str, output_filepath: str): |
|
|
|
assert self.voice_id is not None, "Please set a voice style." |
|
async with self.client: |
|
audio_bytes = await self.client.synthesize( |
|
text, |
|
model=self.model_name, |
|
voice_id=self.voice_id, |
|
speed=self.speed, |
|
sample_rate=self.sample_rate, |
|
) |
|
async with aiofiles.open(output_filepath, "wb") as f: |
|
await f.write(audio_bytes) |
|
|
|
|
|
def get_languages(self): |
|
return self.client.get_languages() |
|
|
|
def get_voices(self, model="lightning", voiceId=None, **kwargs) -> list: |
|
voices = json.loads(self.client.get_voices(model))["voices"] |
|
|
|
if voiceId is not None: |
|
voices = [voice for voice in voices if voice["voiceId"] == voiceId] |
|
else: |
|
for key in kwargs: |
|
voices = [ |
|
voice for voice in voices if voice["tags"][key] == kwargs[key] |
|
] |
|
|
|
return voices |
|
|
|
def get_models(self): |
|
return self.client.get_models() |