Spaces:
Running
Running
import os | |
import torch | |
import tempfile | |
import uuid | |
import logging | |
from typing import Optional | |
from huggingface_hub import snapshot_download | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from TTS.api import TTS | |
# Set environment variables for Coqui TTS | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="Coqui TTS C-3PO API", | |
description="Text-to-Speech API using Coqui TTS with C-3PO fine-tuned voice model", | |
version="1.0.0" | |
) | |
class TTSRequest(BaseModel): | |
text: str | |
language: str = "en" | |
class CoquiTTSService: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Download and initialize the C-3PO fine-tuned model | |
try: | |
logger.info("Downloading C-3PO fine-tuned XTTS model from Hugging Face...") | |
# Download the model files from Hugging Face | |
model_path = snapshot_download( | |
repo_id="Borcherding/XTTS-v2_C3PO", | |
local_dir="./models/XTTS-v2_C3PO", | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"Model downloaded to: {model_path}") | |
# Initialize TTS with the downloaded C-3PO model | |
config_path = os.path.join(model_path, "config.json") | |
if os.path.exists(config_path): | |
logger.info("Loading C-3PO fine-tuned model...") | |
self.tts = TTS( | |
model_path=model_path, | |
config_path=config_path, | |
progress_bar=False, | |
gpu=torch.cuda.is_available() | |
).to(self.device) | |
logger.info("C-3PO fine-tuned model loaded successfully!") | |
else: | |
# Fallback to using the model by name if config not found | |
logger.info("Config not found, trying to load by repo ID...") | |
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
logger.info("Fallback XTTS v2 model loaded!") | |
# Store model path for reference audio | |
self.model_path = model_path | |
# Check for speakers | |
if hasattr(self.tts, 'speakers') and self.tts.speakers: | |
logger.info(f"Available speakers: {len(self.tts.speakers)}") | |
self.default_speaker = self.tts.speakers[0] if self.tts.speakers else None | |
else: | |
logger.info("No preset speakers available - voice cloning mode") | |
self.default_speaker = None | |
except Exception as e: | |
logger.error(f"Failed to load C-3PO model: {e}") | |
logger.info("Falling back to standard XTTS v2 model...") | |
try: | |
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
self.model_path = None | |
self.default_speaker = None | |
logger.info("Fallback XTTS v2 model loaded!") | |
except Exception as fallback_error: | |
logger.error(f"Fallback model also failed: {fallback_error}") | |
raise fallback_error | |
def get_c3po_reference_audio(self): | |
"""Get reference audio file for C-3PO voice if available""" | |
if self.model_path: | |
# Look for reference audio files in the model directory | |
possible_ref_files = [ | |
"reference.wav", "speaker.wav", "c3po.wav", | |
"sample.wav", "reference_audio.wav" | |
] | |
for ref_file in possible_ref_files: | |
ref_path = os.path.join(self.model_path, ref_file) | |
if os.path.exists(ref_path): | |
logger.info(f"Found C-3PO reference audio: {ref_path}") | |
return ref_path | |
return None | |
def generate_speech(self, text: str, speaker_wav_path: Optional[str] = None, | |
language: str = "en", use_c3po_voice: bool = True) -> str: | |
"""Generate speech using Coqui TTS with optional C-3PO voice""" | |
try: | |
# Validate text length | |
if len(text) < 2: | |
raise HTTPException(status_code=400, detail="Text too short") | |
if len(text) > 500: | |
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
# Generate unique output filename | |
output_filename = f"c3po_tts_output_{uuid.uuid4().hex}.wav" | |
output_path = os.path.join(tempfile.gettempdir(), output_filename) | |
# Determine which speaker to use | |
final_speaker_wav = speaker_wav_path | |
# If no speaker provided and C-3PO voice requested, try to use reference audio | |
if not final_speaker_wav and use_c3po_voice: | |
c3po_ref = self.get_c3po_reference_audio() | |
if c3po_ref: | |
final_speaker_wav = c3po_ref | |
logger.info("Using C-3PO reference audio for voice synthesis") | |
if final_speaker_wav: | |
# Voice cloning mode | |
logger.info("Generating speech with voice cloning...") | |
wav = self.tts.tts( | |
text=text, | |
speaker_wav=final_speaker_wav, | |
language=language | |
) | |
# Save the audio | |
import torchaudio | |
if isinstance(wav, list): | |
wav = torch.tensor(wav) | |
if wav.dim() == 1: | |
wav = wav.unsqueeze(0) | |
torchaudio.save(output_path, wav, 22050) | |
elif self.default_speaker: | |
# Use preset speaker | |
logger.info(f"Generating speech with preset speaker: {self.default_speaker}") | |
self.tts.tts_to_file( | |
text=text, | |
speaker=self.default_speaker, | |
language=language, | |
file_path=output_path | |
) | |
else: | |
# Try without speaker (some models support this) | |
logger.info("Generating speech without specific speaker...") | |
self.tts.tts_to_file( | |
text=text, | |
language=language, | |
file_path=output_path | |
) | |
if not os.path.exists(output_path): | |
raise HTTPException(status_code=500, detail="Failed to generate audio file") | |
logger.info(f"Speech generated successfully: {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Error generating speech: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=f"Speech generation failed: {str(e)}") | |
# Initialize TTS service | |
logger.info("Initializing Coqui TTS service...") | |
try: | |
tts_service = CoquiTTSService() | |
logger.info("TTS service initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize TTS service: {e}") | |
tts_service = None | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"message": "Coqui TTS C-3PO API", | |
"status": "healthy" if tts_service else "error", | |
"model": "XTTS v2", | |
"voice_cloning": True | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
if not tts_service: | |
raise HTTPException(status_code=503, detail="TTS service not available") | |
c3po_ref_available = tts_service.get_c3po_reference_audio() is not None | |
return { | |
"status": "healthy", | |
"device": tts_service.device, | |
"model": "C-3PO Fine-tuned XTTS v2 (Coqui TTS)", | |
"default_speaker": tts_service.default_speaker, | |
"voice_cloning_available": True, | |
"c3po_voice_available": c3po_ref_available, | |
"model_path": getattr(tts_service, 'model_path', None) | |
} | |
async def text_to_speech( | |
text: str = Form(...), | |
language: str = Form("en"), | |
speaker_file: UploadFile = File(None), | |
use_c3po_voice: bool = Form(True) | |
): | |
""" | |
Convert text to speech using Coqui TTS | |
- **text**: Text to convert to speech (2-500 characters) | |
- **language**: Language code (default: "en") | |
- **speaker_file**: Reference audio file for voice cloning (optional) | |
- **use_c3po_voice**: Use C-3PO voice if no speaker file provided (default: True) | |
""" | |
if not tts_service: | |
raise HTTPException(status_code=503, detail="TTS service not available") | |
if not text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
speaker_temp_path = None | |
try: | |
# Handle speaker file if provided | |
if speaker_file is not None: | |
if not speaker_file.content_type or not speaker_file.content_type.startswith('audio/'): | |
raise HTTPException(status_code=400, detail="Speaker file must be an audio file") | |
# Save uploaded file temporarily | |
speaker_temp_path = os.path.join( | |
tempfile.gettempdir(), | |
f"speaker_{uuid.uuid4().hex}.wav" | |
) | |
with open(speaker_temp_path, "wb") as buffer: | |
content = await speaker_file.read() | |
buffer.write(content) | |
logger.info(f"Speaker file saved: {speaker_temp_path}") | |
# Generate speech | |
output_path = tts_service.generate_speech(text, speaker_temp_path, language, use_c3po_voice) | |
# Clean up temporary speaker file | |
if speaker_temp_path and os.path.exists(speaker_temp_path): | |
try: | |
os.remove(speaker_temp_path) | |
except: | |
pass | |
# Return the generated audio | |
voice_type = "custom" if speaker_file else ("c3po" if use_c3po_voice else "default") | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"c3po_tts_{voice_type}_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
# Clean up on error | |
if speaker_temp_path and os.path.exists(speaker_temp_path): | |
try: | |
os.remove(speaker_temp_path) | |
except: | |
pass | |
logger.error(f"Error in TTS endpoint: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def text_to_speech_c3po( | |
text: str = Form(...), | |
language: str = Form("en") | |
): | |
""" | |
Convert text to speech using C-3PO voice specifically | |
- **text**: Text to convert to speech (2-500 characters) | |
- **language**: Language code (default: "en") | |
""" | |
if not tts_service: | |
raise HTTPException(status_code=503, detail="TTS service not available") | |
if not text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
# Check if C-3PO voice is available | |
c3po_ref = tts_service.get_c3po_reference_audio() | |
if not c3po_ref: | |
raise HTTPException(status_code=503, detail="C-3PO reference audio not available") | |
try: | |
# Generate speech with C-3PO voice | |
output_path = tts_service.generate_speech(text, None, language, use_c3po_voice=True) | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"c3po_voice_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
logger.error(f"Error in C-3PO TTS endpoint: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def text_to_speech_json(request: TTSRequest): | |
""" | |
Convert text to speech using JSON request with C-3PO voice | |
- **request**: TTSRequest containing text and language | |
""" | |
if not tts_service: | |
raise HTTPException(status_code=503, detail="TTS service not available") | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
try: | |
# Generate speech with C-3PO voice by default | |
output_path = tts_service.generate_speech(request.text, None, request.language, use_c3po_voice=True) | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"c3po_tts_{request.language}_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
logger.error(f"Error in TTS JSON endpoint: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_models(): | |
"""List available TTS models""" | |
try: | |
# Create a temporary TTS instance to list models | |
temp_tts = TTS() | |
models = temp_tts.list_models() | |
return {"models": models[:20]} # Return first 20 models | |
except Exception as e: | |
logger.error(f"Error listing models: {e}") | |
raise HTTPException(status_code=500, detail="Failed to list models") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |