Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import pathlib | |
import shutil | |
from typing import Union | |
from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile | |
from fastapi.params import File | |
from fastapi.responses import FileResponse, StreamingResponse | |
from loguru import logger | |
from app.config import config | |
from app.controllers import base | |
from app.controllers.manager.memory_manager import InMemoryTaskManager | |
from app.controllers.manager.redis_manager import RedisTaskManager | |
from app.controllers.v1.base import new_router | |
from app.models.exception import HttpException | |
from app.models.schema import ( | |
AudioRequest, | |
BgmRetrieveResponse, | |
BgmUploadResponse, | |
SubtitleRequest, | |
TaskDeletionResponse, | |
TaskQueryRequest, | |
TaskQueryResponse, | |
TaskResponse, | |
TaskVideoRequest, | |
) | |
from app.services import state as sm | |
from app.services import task as tm | |
from app.utils import utils | |
# 认证依赖项 | |
router = new_router(dependencies=[Depends(base.verify_token)]) | |
# router = new_router() | |
_enable_redis = config.app.get("enable_redis", False) | |
_redis_host = config.app.get("redis_host", "localhost") | |
_redis_port = config.app.get("redis_port", 6379) | |
_redis_db = config.app.get("redis_db", 0) | |
_redis_password = config.app.get("redis_password", None) | |
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5) | |
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" | |
# 根据配置选择合适的任务管理器 | |
if _enable_redis: | |
task_manager = RedisTaskManager( | |
max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url | |
) | |
else: | |
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) | |
def create_video( | |
background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest | |
): | |
return create_task(request, body, stop_at="video") | |
def create_subtitle( | |
background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest | |
): | |
return create_task(request, body, stop_at="subtitle") | |
def create_audio( | |
background_tasks: BackgroundTasks, request: Request, body: AudioRequest | |
): | |
return create_task(request, body, stop_at="audio") | |
def create_task( | |
request: Request, | |
body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest], | |
stop_at: str, | |
): | |
task_id = utils.get_uuid() | |
request_id = base.get_task_id(request) | |
try: | |
task = { | |
"task_id": task_id, | |
"request_id": request_id, | |
"params": "[REDACTED]", # 不在日志中暴露完整参数 | |
} | |
sm.state.update_task(task_id) | |
task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at) | |
logger.success(f"Task created: {utils.to_json(task)}") | |
return utils.get_response(200, {"task_id": task_id, "request_id": request_id}) | |
except ValueError as e: | |
raise HttpException( | |
task_id=task_id, status_code=400, message=f"Invalid request parameters" | |
) | |
from fastapi import Query | |
def get_all_tasks(request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)): | |
request_id = base.get_task_id(request) | |
tasks, total = sm.state.get_all_tasks(page, page_size) | |
response = { | |
"tasks": tasks, | |
"total": total, | |
"page": page, | |
"page_size": page_size, | |
} | |
return utils.get_response(200, response) | |
def get_task( | |
request: Request, | |
task_id: str = Path(..., description="Task ID"), | |
query: TaskQueryRequest = Depends(), | |
): | |
endpoint = config.app.get("endpoint", "") | |
if not endpoint: | |
endpoint = str(request.base_url) | |
endpoint = endpoint.rstrip("/") | |
request_id = base.get_task_id(request) | |
task = sm.state.get_task(task_id) | |
if task: | |
task_dir = utils.task_dir() | |
def file_to_uri(file): | |
if not file.startswith(endpoint): | |
_uri_path = v.replace(task_dir, "tasks").replace("\\", "/") | |
_uri_path = f"{endpoint}/{_uri_path}" | |
else: | |
_uri_path = file | |
return _uri_path | |
if "videos" in task: | |
videos = task["videos"] | |
urls = [] | |
for v in videos: | |
urls.append(file_to_uri(v)) | |
task["videos"] = urls | |
if "combined_videos" in task: | |
combined_videos = task["combined_videos"] | |
urls = [] | |
for v in combined_videos: | |
urls.append(file_to_uri(v)) | |
task["combined_videos"] = urls | |
return utils.get_response(200, task) | |
raise HttpException( | |
task_id=task_id, status_code=404, message=f"{request_id}: task not found" | |
) | |
def delete_video(request: Request, task_id: str = Path(..., description="Task ID")): | |
request_id = base.get_task_id(request) | |
task = sm.state.get_task(task_id) | |
if task: | |
tasks_dir = utils.task_dir() | |
current_task_dir = os.path.join(tasks_dir, task_id) | |
if os.path.exists(current_task_dir): | |
shutil.rmtree(current_task_dir) | |
sm.state.delete_task(task_id) | |
logger.success(f"Task deleted: {task_id}") # 只记录任务ID,不暴露完整任务信息 | |
return utils.get_response(200) | |
raise HttpException( | |
task_id=task_id, status_code=404, message="Task not found" | |
) | |
def get_bgm_list(request: Request): | |
suffix = "*.mp3" | |
song_dir = utils.song_dir() | |
files = glob.glob(os.path.join(song_dir, suffix)) | |
bgm_list = [] | |
for file in files: | |
bgm_list.append( | |
{ | |
"name": os.path.basename(file), | |
"size": os.path.getsize(file), | |
"file": file, | |
} | |
) | |
response = {"files": bgm_list} | |
return utils.get_response(200, response) | |
def upload_bgm_file(request: Request, file: UploadFile = File(...)): | |
request_id = base.get_task_id(request) | |
# check file ext | |
if file.filename.endswith("mp3"): | |
song_dir = utils.song_dir() | |
save_path = os.path.join(song_dir, file.filename) | |
# save file | |
with open(save_path, "wb+") as buffer: | |
# If the file already exists, it will be overwritten | |
file.file.seek(0) | |
buffer.write(file.file.read()) | |
response = {"file": save_path} | |
return utils.get_response(200, response) | |
raise HttpException( | |
"", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" | |
) | |
async def stream_video(request: Request, file_path: str): | |
tasks_dir = utils.task_dir() | |
video_path = os.path.join(tasks_dir, file_path) | |
range_header = request.headers.get("Range") | |
video_size = os.path.getsize(video_path) | |
start, end = 0, video_size - 1 | |
length = video_size | |
if range_header: | |
range_ = range_header.split("bytes=")[1] | |
start, end = [int(part) if part else None for part in range_.split("-")] | |
if start is None: | |
start = video_size - end | |
end = video_size - 1 | |
if end is None: | |
end = video_size - 1 | |
length = end - start + 1 | |
def file_iterator(file_path, offset=0, bytes_to_read=None): | |
with open(file_path, "rb") as f: | |
f.seek(offset, os.SEEK_SET) | |
remaining = bytes_to_read or video_size | |
while remaining > 0: | |
bytes_to_read = min(4096, remaining) | |
data = f.read(bytes_to_read) | |
if not data: | |
break | |
remaining -= len(data) | |
yield data | |
response = StreamingResponse( | |
file_iterator(video_path, start, length), media_type="video/mp4" | |
) | |
response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" | |
response.headers["Accept-Ranges"] = "bytes" | |
response.headers["Content-Length"] = str(length) | |
response.status_code = 206 # Partial Content | |
return response | |
async def download_video(_: Request, file_path: str): | |
""" | |
download video | |
:param _: Request request | |
:param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4 | |
:return: video file | |
""" | |
tasks_dir = utils.task_dir() | |
video_path = os.path.join(tasks_dir, file_path) | |
file_path = pathlib.Path(video_path) | |
filename = file_path.stem | |
extension = file_path.suffix | |
headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} | |
return FileResponse( | |
path=video_path, | |
headers=headers, | |
filename=f"{filename}{extension}", | |
media_type=f"video/{extension[1:]}", | |
) | |