chemistrymath commited on
Commit
da054be
·
verified ·
1 Parent(s): 810d904

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +71 -67
main.py CHANGED
@@ -1,67 +1,71 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import json
4
- import tempfile
5
- import shutil
6
- import os
7
- from single_person_processor import SinglePersonPredictor
8
-
9
- app = FastAPI()
10
-
11
- app.add_middleware(
12
- CORSMiddleware,
13
- allow_origins=["http://localhost:3000"], # Change this to match your frontend URL
14
- allow_credentials=True,
15
- allow_methods=["*"], # Allow all methods (POST, GET, etc.)
16
- allow_headers=["*"], # Allow all headers
17
- )
18
-
19
-
20
- # Initialize the predictor
21
- predictor = SinglePersonPredictor(model_path="best_model.keras")
22
-
23
- @app.post("/predict/")
24
- async def predict(
25
- front_image: UploadFile = File(...),
26
- side_image: UploadFile = File(...),
27
- input_data: str = Form(...)
28
- ):
29
- try:
30
- try:
31
- input_dict = json.loads(input_data)
32
- except json.JSONDecodeError:
33
- return {"error": "Invalid JSON format in input_data"}
34
-
35
- # Validate file types
36
- if front_image.content_type not in ["image/jpeg", "image/png"]:
37
- return {"error": "Front image must be a JPEG or PNG file"}
38
- if side_image.content_type not in ["image/jpeg", "image/png"]:
39
- return {"error": "Side image must be a JPEG or PNG file"}
40
-
41
- # Create temporary files for images
42
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as front_temp:
43
- shutil.copyfileobj(front_image.file, front_temp)
44
- front_temp_path = front_temp.name
45
-
46
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as side_temp:
47
- shutil.copyfileobj(side_image.file, side_temp)
48
- side_temp_path = side_temp.name
49
-
50
- # Perform predictions
51
- results = predictor.predict_measurements(
52
- front_img_path=front_temp_path,
53
- side_img_path=side_temp_path,
54
- gender=input_dict.get("gender"),
55
- height_cm=input_dict.get("height_cm"),
56
- weight_kg=input_dict.get("weight_kg"),
57
- apparel_type=input_dict.get("apparel_type")
58
- )
59
-
60
- # Clean up temporary files
61
- os.remove(front_temp_path)
62
- os.remove(side_temp_path)
63
-
64
- return {"results": results}
65
-
66
- except Exception as e:
67
- return {"error": str(e)}
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import json
4
+ import tempfile
5
+ import shutil
6
+ import os
7
+ from single_person_processor import SinglePersonPredictor
8
+
9
+ @app.get("/")
10
+ async def root():
11
+ return {"message": "FastAPI is running!"}
12
+
13
+ app = FastAPI()
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["http://localhost:3000"], # Change this to match your frontend URL
18
+ allow_credentials=True,
19
+ allow_methods=["*"], # Allow all methods (POST, GET, etc.)
20
+ allow_headers=["*"], # Allow all headers
21
+ )
22
+
23
+
24
+ # Initialize the predictor
25
+ predictor = SinglePersonPredictor(model_path="best_model.keras")
26
+
27
+ @app.post("/predict/")
28
+ async def predict(
29
+ front_image: UploadFile = File(...),
30
+ side_image: UploadFile = File(...),
31
+ input_data: str = Form(...)
32
+ ):
33
+ try:
34
+ try:
35
+ input_dict = json.loads(input_data)
36
+ except json.JSONDecodeError:
37
+ return {"error": "Invalid JSON format in input_data"}
38
+
39
+ # Validate file types
40
+ if front_image.content_type not in ["image/jpeg", "image/png"]:
41
+ return {"error": "Front image must be a JPEG or PNG file"}
42
+ if side_image.content_type not in ["image/jpeg", "image/png"]:
43
+ return {"error": "Side image must be a JPEG or PNG file"}
44
+
45
+ # Create temporary files for images
46
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as front_temp:
47
+ shutil.copyfileobj(front_image.file, front_temp)
48
+ front_temp_path = front_temp.name
49
+
50
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as side_temp:
51
+ shutil.copyfileobj(side_image.file, side_temp)
52
+ side_temp_path = side_temp.name
53
+
54
+ # Perform predictions
55
+ results = predictor.predict_measurements(
56
+ front_img_path=front_temp_path,
57
+ side_img_path=side_temp_path,
58
+ gender=input_dict.get("gender"),
59
+ height_cm=input_dict.get("height_cm"),
60
+ weight_kg=input_dict.get("weight_kg"),
61
+ apparel_type=input_dict.get("apparel_type")
62
+ )
63
+
64
+ # Clean up temporary files
65
+ os.remove(front_temp_path)
66
+ os.remove(side_temp_path)
67
+
68
+ return {"results": results}
69
+
70
+ except Exception as e:
71
+ return {"error": str(e)}