|
from fastapi import FastAPI, Request |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import JSONResponse, FileResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from optimum.neuron import utils |
|
import logging |
|
import sys |
|
import os |
|
import httpx |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") |
|
logger.info(f"Static directory path: {static_dir}") |
|
|
|
|
|
templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") |
|
logger.info(f"Templates directory path: {templates_dir}") |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
templates = Jinja2Templates(directory=templates_dir) |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
logger.info("Health check endpoint called") |
|
return {"status": "healthy"} |
|
|
|
@app.get("/") |
|
async def home(request: Request): |
|
logger.info("Home page requested") |
|
|
|
is_spaces = os.getenv("SPACE_ID") is not None |
|
|
|
base_url = str(request.base_url) |
|
if is_spaces: |
|
base_url = base_url.replace("http://", "https://") |
|
return templates.TemplateResponse( |
|
"index.html", |
|
{ |
|
"request": request, |
|
"base_url": base_url |
|
} |
|
) |
|
|
|
@app.get("/api/models") |
|
async def get_model_list(): |
|
logger.info("Fetching model list") |
|
try: |
|
|
|
logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}") |
|
|
|
model_list = utils.get_hub_cached_models(mode="inference") |
|
logger.info(f"Found {len(model_list)} models") |
|
|
|
models = [] |
|
seen_models = set() |
|
|
|
for model_tuple in model_list: |
|
architecture, org, model_id = model_tuple |
|
full_model_id = f"{org}/{model_id}" |
|
|
|
if full_model_id not in seen_models: |
|
models.append({ |
|
"id": full_model_id, |
|
"name": full_model_id, |
|
"type": architecture |
|
}) |
|
seen_models.add(full_model_id) |
|
|
|
logger.info(f"Returning {len(models)} unique models") |
|
return JSONResponse(content=models) |
|
except Exception as e: |
|
|
|
logger.error(f"Error fetching models: {str(e)}") |
|
logger.error("Full error details:", exc_info=True) |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": str(e), "type": str(type(e).__name__)} |
|
) |
|
|
|
@app.get("/api/models/{model_id:path}") |
|
async def get_model_info_endpoint(model_id: str): |
|
logger.info(f"Fetching configurations for model: {model_id}") |
|
try: |
|
|
|
base_url = "https://huggingface.co/api/integrations/aws/v1/lookup" |
|
api_url = f"{base_url}/{model_id}" |
|
|
|
|
|
timeout = httpx.Timeout(15.0, connect=5.0) |
|
async with httpx.AsyncClient(timeout=timeout) as client: |
|
response = await client.get(api_url) |
|
response.raise_for_status() |
|
|
|
data = response.json() |
|
configs = data.get("cached_configs", []) |
|
|
|
logger.info(f"Found {len(configs)} configurations for model {model_id}") |
|
return JSONResponse(content={"configurations": configs}) |
|
except httpx.TimeoutException as e: |
|
logger.error(f"Timeout while fetching configurations for model {model_id}: {str(e)}", exc_info=True) |
|
return JSONResponse( |
|
status_code=504, |
|
content={"error": "Request timed out while fetching model configurations"} |
|
) |
|
except httpx.HTTPError as e: |
|
logger.error(f"HTTP error fetching configurations for model {model_id}: {str(e)}", exc_info=True) |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": f"Failed to fetch model configurations: {str(e)}"} |
|
) |
|
except Exception as e: |
|
logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": str(e)} |
|
) |
|
|
|
@app.get("/static/{path:path}") |
|
async def static_files(path: str, request: Request): |
|
logger.info(f"Static file requested: {path}") |
|
file_path = os.path.join(static_dir, path) |
|
if os.path.exists(file_path): |
|
response = FileResponse(file_path) |
|
|
|
if path.endswith('.css'): |
|
response.headers["content-type"] = "text/css" |
|
elif path.endswith('.js'): |
|
response.headers["content-type"] = "application/javascript" |
|
return response |
|
return JSONResponse(status_code=404, content={"error": "File not found"}) |