File size: 5,098 Bytes
6f3dcb0
d0ac7e9
 
6f3dcb0
d0ac7e9
 
 
 
 
 
 
 
 
 
 
 
6f3dcb0
 
d0ac7e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f3dcb0
 
 
 
 
f3d7c7b
6f3dcb0
d0ac7e9
902b033
 
d0ac7e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import torch
import torchvision
import huggingface_hub
from torchvision.transforms import InterpolationMode
from network.models.facexformer import FaceXFormer
from dataclasses import dataclass
import numpy as np

# import mediapipe as mp
# import cv2


# device = "cuda:0"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
weights_path = "ckpts/model.pt"
# weights_path = "ckpts/pytorch_model.bin"
# face_model_path = "ckpts/blaze_face_short_range.tflite"

# import mediapipe as mp

# BaseOptions = mp.tasks.BaseOptions
# FaceDetector = mp.tasks.vision.FaceDetector
# FaceDetectorOptions = mp.tasks.vision.FaceDetectorOptions
# FaceDetectorResult = mp.tasks.vision.FaceDetectorResult
# VisionRunningMode = mp.tasks.vision.RunningMode

# options = FaceDetectorOptions(
#     base_options=BaseOptions(model_asset_path=face_model_path),
#     running_mode=VisionRunningMode.LIVE_STREAM,
# )
# face_detector = FaceDetector.create_from_options(options)

transforms_image = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToPILImage(),
        torchvision.transforms.Resize(
            size=(224, 224), interpolation=InterpolationMode.BICUBIC
        ),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)


def load_model(weights_path):
    model = FaceXFormer().to(device)
    if not os.path.exists(weights_path):
        huggingface_hub.hf_hub_download(
            "kartiknarayan/facexformer",
            "ckpts/model.pt",
            repo_type="model",
            local_dir=".",
        )
    checkpoint = torch.load(weights_path, map_location=device)
    # model.load_state_dict(checkpoint)
    model.load_state_dict(checkpoint["state_dict_backbone"])
    model = model.eval()
    model = model.to(dtype=dtype)
    # model = torch.compile(model, mode="reduce-overhead")
    return model


model = load_model(weights_path)


def adjust_bbox(
    x_min, y_min, x_max, y_max, image_width, image_height, margin_percentage=50
):
    width = x_max - x_min
    height = y_max - y_min

    increase_width = width * (margin_percentage / 100.0) / 2
    increase_height = height * (margin_percentage / 100.0) / 2

    x_min_adjusted = int(max(0, x_min - increase_width))
    y_min_adjusted = int(max(0, y_min - increase_height))
    x_max_adjusted = int(min(image_width, x_max + increase_width))
    y_max_adjusted = int(min(image_height, y_max + increase_height))

    return x_min_adjusted, y_min_adjusted, x_max_adjusted, y_max_adjusted


def denorm_points(points, h, w, align_corners=False):
    if align_corners:
        denorm_points = (
            (points + 1) / 2 * torch.tensor([w - 1, h - 1]).to(points).view(1, 1, 2)
        )
    else:
        denorm_points = (
            (points + 1) * torch.tensor([w, h]).to(points).view(1, 1, 2) - 1
        ) / 2

    return denorm_points


@dataclass
class BoundingBox:
    x_min: int
    y_min: int
    x_max: int
    y_max: int


@dataclass
class FaceImg:
    image: np.ndarray
    x_min: int
    y_min: int


def get_faces_img(img: np.ndarray, boxes: list[BoundingBox]):
    if boxes is None or len(boxes) == 0:
        return []
    results: list[FaceImg] = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box.x_min, box.y_min, box.x_max, box.y_max

        # Padding
        x_min, y_min, x_max, y_max = adjust_bbox(
            x_min, y_min, x_max, y_max, img.shape[1], img.shape[0]
        )
        image = img[y_min:y_max, x_min:x_max]
        results.append(FaceImg(image, int(x_min), int(y_min)))

    return results


@dataclass
class Face:
    image: torch.Tensor
    x_min: int
    y_min: int
    original_w: int
    original_h: int


def get_faces(img: np.ndarray, boxes: list[BoundingBox]):
    images = get_faces_img(img, boxes)
    images = [
        Face(
            transforms_image(face_image.image),
            face_image.x_min,
            face_image.y_min,
            face_image.image.shape[1],
            face_image.image.shape[0],
        )
        for face_image in images
    ]
    return images


def get_landmarks(faces: list[Face]):
    if len(faces) == 0:
        return []

    images = torch.stack([face.image for face in faces]).to(device=device, dtype=dtype)

    tasks = torch.tensor([1] * len(faces), device=device, dtype=dtype)
    with torch.inference_mode():
        # with torch.amp.autocast("cuda"):
        (
            batch_landmarks,
            headposes,
            attributes,
            visibilities,
            ages,
            geders,
            races,
            segs,
        ) = model.predict(images, None, tasks)
    batch_denormed = [
        denorm_points(landmarks, face.original_h, face.original_w)[0]
        for landmarks, face in zip(batch_landmarks.view(-1, 68, 2), faces)
    ]

    results = []
    for landmarks, face in zip(batch_denormed, faces):
        results.append(
            [(int(x + face.x_min), int(y + face.y_min)) for x, y in landmarks]
        )

    return results