alessandro trinca tornidor
refactor: sam-quantized submodule now referrenced without symlinks, update dockerfile
606a7d7
import json | |
import os | |
from pathlib import Path | |
import structlog.stdlib | |
import uvicorn | |
from asgi_correlation_id import CorrelationIdMiddleware | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.responses import FileResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from pydantic import ValidationError | |
from samgis_core.utilities import create_folders_if_not_exists | |
from samgis_web.utilities import frontend_builder | |
from samgis_core.utilities.session_logger import setup_logging | |
from samgis_web.prediction_api.predictors import samexporter_predict | |
from samgis_web.utilities.type_hints import ApiRequestBody | |
from starlette.responses import JSONResponse | |
load_dotenv() | |
project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent | |
workdir = os.getenv("WORKDIR", project_root_folder) | |
model_folder = Path(project_root_folder / "sam-quantized" / "machine_learning_models") | |
log_level = os.getenv("LOG_LEVEL", "INFO") | |
setup_logging(log_level=log_level) | |
app_logger = structlog.stdlib.get_logger() | |
app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.") | |
folders_map = os.getenv("FOLDERS_MAP", "{}") | |
markdown_text = os.getenv("MARKDOWN_TEXT", "") | |
examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n") | |
example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}")) | |
mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", "")) | |
static_dist_folder = Path(project_root_folder) / "static" / "dist" | |
input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css") | |
vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio") | |
vite_index_url = os.getenv("VITE_INDEX_URL", "/") | |
vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis") | |
fastapi_title = "samgis" | |
app = FastAPI(title=fastapi_title, version="1.0") | |
async def request_middleware(request, call_next): | |
from samgis_web.web.middlewares import logging_middleware | |
return await logging_middleware(request, call_next) | |
async def health() -> JSONResponse: | |
from onnxruntime import __version__ as ort_version | |
from samgis_web.__version__ import __version__ as version_web | |
from samgis_core.__version__ import __version__ as version_core | |
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME | |
msg_model_folder_error = f"health_check: model_folder:'{model_folder}' is not a directory." | |
msg_model_file_error = f"health_check: model_file:'{model_folder}' not found." | |
try: | |
assert model_folder.is_dir(), msg_model_folder_error | |
encoder_model_path = Path(model_folder) / MODEL_ENCODER_NAME | |
decoder_model_path = Path(model_folder) / MODEL_DECODER_NAME | |
assert encoder_model_path.is_file(), msg_model_file_error | |
assert decoder_model_path.is_file(), msg_model_file_error | |
app_logger.info(f"still alive, version_onnxruntime:{ort_version}, version_web:{version_web}, version_core:{version_core}.") | |
app_logger.info(f"still alive, encoder_model:{encoder_model_path}, decoder_model:{decoder_model_path}.") | |
return JSONResponse(status_code=200, content={"msg": "still alive..."}) | |
except AssertionError as ae: | |
app_logger.error(f"health_check: AssertionError:{ae}.") | |
raise HTTPException(500, detail=msg_model_folder_error) | |
def infer_samgis_fn(request_input: ApiRequestBody | str) -> str | JSONResponse: | |
from samgis_web.web.web_helpers import get_parsed_bbox_points_with_dictlist_prompt | |
app_logger.info("starting inference request...") | |
try: | |
import time | |
time_start_run = time.time() | |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input) | |
app_logger.info(f"body_request:{body_request}.") | |
try: | |
app_logger.info(f"source_name = {body_request['source_name']}.") | |
output = samexporter_predict( | |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], | |
source=body_request["source"], source_name=body_request['source_name'], model_folder=model_folder | |
) | |
duration_run = time.time() - time_start_run | |
app_logger.info(f"duration_run:{duration_run}.") | |
body = { | |
"duration_run": duration_run, | |
"output": output | |
} | |
dumped = json.dumps(body) | |
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") | |
app_logger.debug(f"complete json.dumps(body):{dumped}.") | |
return dumped | |
except Exception as inference_exception: | |
app_logger.error(f"inference_exception:{inference_exception}.") | |
app_logger.error(f"inference_exception, request_input:{request_input}.") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
except ValidationError as va1: | |
app_logger.error(f"validation error: {str(va1)}.") | |
app_logger.error(f"ValidationError, request_input:{request_input}.") | |
raise RequestValidationError("Unprocessable Entity") | |
def infer_samgis(request_input: ApiRequestBody) -> JSONResponse: | |
dumped = infer_samgis_fn(request_input=request_input) | |
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") | |
app_logger.debug(f"complete json.dumps(body):{dumped}.") | |
return JSONResponse(status_code=200, content={"body": dumped}) | |
def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
from samgis_web.web import exception_handlers | |
return exception_handlers.request_validation_exception_handler(request, exc) | |
def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: | |
from samgis_web.web import exception_handlers | |
return exception_handlers.http_exception_handler(request, exc) | |
create_folders_if_not_exists.folders_creation(folders_map) | |
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "") | |
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.") | |
if bool(write_tmp_on_disk): | |
try: | |
assert Path(write_tmp_on_disk).is_dir() | |
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output") | |
templates = Jinja2Templates(directory=str(project_root_folder / "static")) | |
def list_files(request: Request): | |
files = os.listdir(write_tmp_on_disk) | |
files_paths = sorted([f"{request.url._url}/{f}" for f in files]) | |
print(files_paths) | |
return templates.TemplateResponse( | |
"list_files.html", {"request": request, "files": files_paths} | |
) | |
except (AssertionError, RuntimeError) as rerr: | |
app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...") | |
raise rerr | |
frontend_builder.build_frontend( | |
project_root_folder=workdir, | |
input_css_path=input_css_path, | |
output_dist_folder=static_dist_folder | |
) | |
app_logger.info("build_frontend ok!") | |
# eventually needed for tailwindcss output.css | |
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static") | |
app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index") | |
app.mount(vite_gradio_url, StaticFiles(directory=static_dist_folder, html=True), name="gradio") | |
async def index() -> FileResponse: | |
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html") | |
app_logger.info(f"Mounted index on url path {vite_index_url} .") | |
app_logger.info(f"There is need to create and mount gradio app interface? {mount_gradio_app}...") | |
if mount_gradio_app: | |
try: | |
import gradio as gr | |
from samgis_web.web.gradio_helpers import get_gradio_interface_geojson | |
app_logger.info(f"creating gradio interface...") | |
gr_interface = get_gradio_interface_geojson( | |
infer_samgis_fn, | |
markdown_text, | |
examples_text_list, | |
example_body | |
) | |
app_logger.info(f"gradio interface created, mounting gradio app on url path {vite_gradio_url} within FastAPI.") | |
app_logger.debug(f"gr_interface vars:{vars(gr_interface)}.") | |
app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url) | |
app = gr.mount_gradio_app(app, gr_interface, path="/gradio") | |
app_logger.info(f"mounted gradio app within fastapi, url path {vite_gradio_url} .") | |
except (ModuleNotFoundError, ImportError) as mnfe: | |
app_logger.error("cannot import gradio, have you installed it if you want to mount a gradio app?") | |
app_logger.error(mnfe) | |
raise mnfe | |
# add the CorrelationIdMiddleware AFTER the @app.middleware("http") decorated function to avoid missing request id | |
app.add_middleware(CorrelationIdMiddleware) | |
if __name__ == '__main__': | |
try: | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) | |
except Exception as ex: | |
app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") | |
print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") | |
raise ex | |