tts-api / coqui_api.py
Divax
test
71905d8
import os
import torch
import tempfile
import uuid
import logging
from typing import Optional
from huggingface_hub import snapshot_download
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from TTS.api import TTS
# Set environment variables for Coqui TTS
os.environ["COQUI_TOS_AGREED"] = "1"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Coqui TTS C-3PO API",
description="Text-to-Speech API using Coqui TTS with C-3PO fine-tuned voice model",
version="1.0.0"
)
class TTSRequest(BaseModel):
text: str
language: str = "en"
class CoquiTTSService:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Download and initialize the C-3PO fine-tuned model
try:
logger.info("Downloading C-3PO fine-tuned XTTS model from Hugging Face...")
# Download the model files from Hugging Face
model_path = snapshot_download(
repo_id="Borcherding/XTTS-v2_C3PO",
local_dir="./models/XTTS-v2_C3PO",
local_dir_use_symlinks=False
)
logger.info(f"Model downloaded to: {model_path}")
# Initialize TTS with the downloaded C-3PO model
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
logger.info("Loading C-3PO fine-tuned model...")
self.tts = TTS(
model_path=model_path,
config_path=config_path,
progress_bar=False,
gpu=torch.cuda.is_available()
).to(self.device)
logger.info("C-3PO fine-tuned model loaded successfully!")
else:
# Fallback to using the model by name if config not found
logger.info("Config not found, trying to load by repo ID...")
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
logger.info("Fallback XTTS v2 model loaded!")
# Store model path for reference audio
self.model_path = model_path
# Check for speakers
if hasattr(self.tts, 'speakers') and self.tts.speakers:
logger.info(f"Available speakers: {len(self.tts.speakers)}")
self.default_speaker = self.tts.speakers[0] if self.tts.speakers else None
else:
logger.info("No preset speakers available - voice cloning mode")
self.default_speaker = None
except Exception as e:
logger.error(f"Failed to load C-3PO model: {e}")
logger.info("Falling back to standard XTTS v2 model...")
try:
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
self.model_path = None
self.default_speaker = None
logger.info("Fallback XTTS v2 model loaded!")
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
raise fallback_error
def get_c3po_reference_audio(self):
"""Get reference audio file for C-3PO voice if available"""
if self.model_path:
# Look for reference audio files in the model directory
possible_ref_files = [
"reference.wav", "speaker.wav", "c3po.wav",
"sample.wav", "reference_audio.wav"
]
for ref_file in possible_ref_files:
ref_path = os.path.join(self.model_path, ref_file)
if os.path.exists(ref_path):
logger.info(f"Found C-3PO reference audio: {ref_path}")
return ref_path
return None
def generate_speech(self, text: str, speaker_wav_path: Optional[str] = None,
language: str = "en", use_c3po_voice: bool = True) -> str:
"""Generate speech using Coqui TTS with optional C-3PO voice"""
try:
# Validate text length
if len(text) < 2:
raise HTTPException(status_code=400, detail="Text too short")
if len(text) > 500:
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
# Generate unique output filename
output_filename = f"c3po_tts_output_{uuid.uuid4().hex}.wav"
output_path = os.path.join(tempfile.gettempdir(), output_filename)
# Determine which speaker to use
final_speaker_wav = speaker_wav_path
# If no speaker provided and C-3PO voice requested, try to use reference audio
if not final_speaker_wav and use_c3po_voice:
c3po_ref = self.get_c3po_reference_audio()
if c3po_ref:
final_speaker_wav = c3po_ref
logger.info("Using C-3PO reference audio for voice synthesis")
if final_speaker_wav:
# Voice cloning mode
logger.info("Generating speech with voice cloning...")
wav = self.tts.tts(
text=text,
speaker_wav=final_speaker_wav,
language=language
)
# Save the audio
import torchaudio
if isinstance(wav, list):
wav = torch.tensor(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
torchaudio.save(output_path, wav, 22050)
elif self.default_speaker:
# Use preset speaker
logger.info(f"Generating speech with preset speaker: {self.default_speaker}")
self.tts.tts_to_file(
text=text,
speaker=self.default_speaker,
language=language,
file_path=output_path
)
else:
# Try without speaker (some models support this)
logger.info("Generating speech without specific speaker...")
self.tts.tts_to_file(
text=text,
language=language,
file_path=output_path
)
if not os.path.exists(output_path):
raise HTTPException(status_code=500, detail="Failed to generate audio file")
logger.info(f"Speech generated successfully: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error generating speech: {e}")
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"Speech generation failed: {str(e)}")
# Initialize TTS service
logger.info("Initializing Coqui TTS service...")
try:
tts_service = CoquiTTSService()
logger.info("TTS service initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize TTS service: {e}")
tts_service = None
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Coqui TTS C-3PO API",
"status": "healthy" if tts_service else "error",
"model": "XTTS v2",
"voice_cloning": True
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if not tts_service:
raise HTTPException(status_code=503, detail="TTS service not available")
c3po_ref_available = tts_service.get_c3po_reference_audio() is not None
return {
"status": "healthy",
"device": tts_service.device,
"model": "C-3PO Fine-tuned XTTS v2 (Coqui TTS)",
"default_speaker": tts_service.default_speaker,
"voice_cloning_available": True,
"c3po_voice_available": c3po_ref_available,
"model_path": getattr(tts_service, 'model_path', None)
}
@app.post("/tts")
async def text_to_speech(
text: str = Form(...),
language: str = Form("en"),
speaker_file: UploadFile = File(None),
use_c3po_voice: bool = Form(True)
):
"""
Convert text to speech using Coqui TTS
- **text**: Text to convert to speech (2-500 characters)
- **language**: Language code (default: "en")
- **speaker_file**: Reference audio file for voice cloning (optional)
- **use_c3po_voice**: Use C-3PO voice if no speaker file provided (default: True)
"""
if not tts_service:
raise HTTPException(status_code=503, detail="TTS service not available")
if not text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
speaker_temp_path = None
try:
# Handle speaker file if provided
if speaker_file is not None:
if not speaker_file.content_type or not speaker_file.content_type.startswith('audio/'):
raise HTTPException(status_code=400, detail="Speaker file must be an audio file")
# Save uploaded file temporarily
speaker_temp_path = os.path.join(
tempfile.gettempdir(),
f"speaker_{uuid.uuid4().hex}.wav"
)
with open(speaker_temp_path, "wb") as buffer:
content = await speaker_file.read()
buffer.write(content)
logger.info(f"Speaker file saved: {speaker_temp_path}")
# Generate speech
output_path = tts_service.generate_speech(text, speaker_temp_path, language, use_c3po_voice)
# Clean up temporary speaker file
if speaker_temp_path and os.path.exists(speaker_temp_path):
try:
os.remove(speaker_temp_path)
except:
pass
# Return the generated audio
voice_type = "custom" if speaker_file else ("c3po" if use_c3po_voice else "default")
return FileResponse(
output_path,
media_type="audio/wav",
filename=f"c3po_tts_{voice_type}_{uuid.uuid4().hex}.wav",
headers={"Content-Disposition": "attachment"}
)
except Exception as e:
# Clean up on error
if speaker_temp_path and os.path.exists(speaker_temp_path):
try:
os.remove(speaker_temp_path)
except:
pass
logger.error(f"Error in TTS endpoint: {e}")
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts-c3po")
async def text_to_speech_c3po(
text: str = Form(...),
language: str = Form("en")
):
"""
Convert text to speech using C-3PO voice specifically
- **text**: Text to convert to speech (2-500 characters)
- **language**: Language code (default: "en")
"""
if not tts_service:
raise HTTPException(status_code=503, detail="TTS service not available")
if not text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Check if C-3PO voice is available
c3po_ref = tts_service.get_c3po_reference_audio()
if not c3po_ref:
raise HTTPException(status_code=503, detail="C-3PO reference audio not available")
try:
# Generate speech with C-3PO voice
output_path = tts_service.generate_speech(text, None, language, use_c3po_voice=True)
return FileResponse(
output_path,
media_type="audio/wav",
filename=f"c3po_voice_{uuid.uuid4().hex}.wav",
headers={"Content-Disposition": "attachment"}
)
except Exception as e:
logger.error(f"Error in C-3PO TTS endpoint: {e}")
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts-json")
async def text_to_speech_json(request: TTSRequest):
"""
Convert text to speech using JSON request with C-3PO voice
- **request**: TTSRequest containing text and language
"""
if not tts_service:
raise HTTPException(status_code=503, detail="TTS service not available")
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
try:
# Generate speech with C-3PO voice by default
output_path = tts_service.generate_speech(request.text, None, request.language, use_c3po_voice=True)
return FileResponse(
output_path,
media_type="audio/wav",
filename=f"c3po_tts_{request.language}_{uuid.uuid4().hex}.wav",
headers={"Content-Disposition": "attachment"}
)
except Exception as e:
logger.error(f"Error in TTS JSON endpoint: {e}")
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models():
"""List available TTS models"""
try:
# Create a temporary TTS instance to list models
temp_tts = TTS()
models = temp_tts.list_models()
return {"models": models[:20]} # Return first 20 models
except Exception as e:
logger.error(f"Error listing models: {e}")
raise HTTPException(status_code=500, detail="Failed to list models")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)