tryon / app.py
jocoandonob
test
856f007
import os
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import cv2
import gradio as gr
import mediapipe as mp
import numpy as np
from PIL import Image
from gradio_client import Client, handle_file
import io
import base64
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
example_path = os.path.join(os.path.dirname(__file__), 'example')
garm_list = os.listdir(os.path.join(example_path, "cloth"))
garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
human_list = os.listdir(os.path.join(example_path, "human"))
human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=True)
mp_drawing = mp.solutions.drawing_utils
mp_pose_landmark = mp_pose.PoseLandmark
def detect_pose(image):
# Convert to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run pose detection
result = pose.process(image_rgb)
keypoints = {}
if result.pose_landmarks:
# Draw landmarks on image
mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Get image dimensions
height, width, _ = image.shape
# Extract specific landmarks
landmark_indices = {
'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
'left_hip': mp_pose_landmark.LEFT_HIP,
'right_hip': mp_pose_landmark.RIGHT_HIP
}
for name, index in landmark_indices.items():
lm = result.pose_landmarks.landmark[index]
x, y = int(lm.x * width), int(lm.y * height)
keypoints[name] = (x, y)
# Draw a circle + label for debug
cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
cv2.putText(image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return image
def align_clothing(body_img, clothing_img):
image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
result = pose.process(image_rgb)
output = body_img.copy()
if result.pose_landmarks:
h, w, _ = output.shape
# Extract key points
def get_point(landmark_id):
lm = result.pose_landmarks.landmark[landmark_id]
return int(lm.x * w), int(lm.y * h)
left_shoulder = get_point(mp_pose_landmark.LEFT_SHOULDER)
right_shoulder = get_point(mp_pose_landmark.RIGHT_SHOULDER)
left_hip = get_point(mp_pose_landmark.LEFT_HIP)
right_hip = get_point(mp_pose_landmark.RIGHT_HIP)
# Destination box (torso region)
dst_pts = np.array([
left_shoulder,
right_shoulder,
right_hip,
left_hip
], dtype=np.float32)
# Source box (clothing image corners)
src_h, src_w = clothing_img.shape[:2]
src_pts = np.array([
[0, 0],
[src_w, 0],
[src_w, src_h],
[0, src_h]
], dtype=np.float32)
# Compute perspective transform and warp
matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
warped_clothing = cv2.warpPerspective(clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT)
# Handle transparency
if clothing_img.shape[2] == 4:
alpha = warped_clothing[:, :, 3] / 255.0
for c in range(3):
output[:, :, c] = (1 - alpha) * output[:, :, c] + alpha * warped_clothing[:, :, c]
else:
output = cv2.addWeighted(output, 0.8, warped_clothing, 0.5, 0)
return output
def process_image(human_img_path, garm_img_path):
client = Client("franciszzj/Leffa")
result = client.predict(
src_image_path=handle_file(human_img_path),
ref_image_path=handle_file(garm_img_path),
ref_acceleration=False,
step=30,
scale=2.5,
seed=42,
vt_model_type="viton_hd",
vt_garment_type="upper_body",
vt_repaint=False,
api_name="/leffa_predict_vt"
)
print(result)
generated_image_path = result[0]
print("generated_image_path" + generated_image_path)
generated_image = Image.open(generated_image_path)
return generated_image
@app.post("/")
async def try_on_api(human_image: UploadFile = File(...), garment_image: UploadFile = File(...)):
try:
# Read the uploaded files
human_content = await human_image.read()
garment_content = await garment_image.read()
# Convert to PIL Image
human_img = Image.open(io.BytesIO(human_content))
garment_img = Image.open(io.BytesIO(garment_content))
# Save temporarily to process
human_path = "temp_human.jpg"
garment_path = "temp_garment.jpg"
human_img.save(human_path)
garment_img.save(garment_path)
# Process the images
result = process_image(human_path, garment_path)
# Convert result to base64
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode('utf-8')
# Clean up temporary files
os.remove(human_path)
os.remove(garment_path)
return {
"status": "success",
"image": base64_image,
"format": "base64"
}
except Exception as e:
return {"status": "error", "message": str(e)}
# Create the Gradio interface
image_blocks = gr.Blocks().queue()
with image_blocks as demo:
gr.HTML("<center><h1>Virtual Try-On</h1></center>")
gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
with gr.Row():
with gr.Column():
human_img = gr.Image(type="filepath", label='Human', interactive=True)
example = gr.Examples(
inputs=human_img,
examples_per_page=10,
examples=human_list_path
)
with gr.Column():
garm_img = gr.Image(label="Garment", type="filepath", interactive=True)
example = gr.Examples(
inputs=garm_img,
examples_per_page=8,
examples=garm_list_path)
with gr.Column():
image_out = gr.Image(label="Processed image", type="pil")
with gr.Row():
try_button = gr.Button(value="Try-on", variant='primary')
# Linking the button to the processing function
try_button.click(fn=process_image, inputs=[human_img, garm_img], outputs=image_out)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/gradio")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)