use huggingface models
Browse files- app.py +56 -74
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -7,20 +7,20 @@ import torch
|
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
from einops import rearrange
|
| 10 |
-
from importlib import import_module
|
| 11 |
from pytorch_grad_cam import GradCAM
|
| 12 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 13 |
from skimage.exposure import match_histograms
|
| 14 |
-
from
|
| 15 |
|
| 16 |
|
| 17 |
class ModelForGradCAM(nn.Module):
|
| 18 |
-
def __init__(self, model):
|
| 19 |
super().__init__()
|
| 20 |
self.model = model
|
|
|
|
| 21 |
|
| 22 |
def forward(self, x):
|
| 23 |
-
return self.model(
|
| 24 |
|
| 25 |
|
| 26 |
def convert_bone_age_to_string(bone_age: float):
|
|
@@ -47,67 +47,29 @@ def convert_bone_age_to_string(bone_age: float):
|
|
| 47 |
return str_output
|
| 48 |
|
| 49 |
|
| 50 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 51 |
-
|
| 52 |
-
cfg_crop = import_module("skp.configs.boneage.cfg_crop_simple_resize").cfg
|
| 53 |
-
crop_model = load_model_from_config(
|
| 54 |
-
cfg_crop, weights_path="crop.pt", device=device, eval_mode=True
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
cfg = import_module("skp.configs.boneage.cfg_female_channel_reg_cls_match_hist").cfg
|
| 58 |
-
cfg.backbone = "convnextv2_tiny"
|
| 59 |
-
|
| 60 |
-
model_list = load_kfold_ensemble_as_list(
|
| 61 |
-
cfg, [f"net{i}.pt" for i in range(3)], device=device, eval_mode=True
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
ref_img = rearrange(cv2.imread("ref_img.png", 0), "h w -> h w 1 ")
|
| 65 |
-
|
| 66 |
-
with open("greulich_and_pyle_ages.json", "r") as f:
|
| 67 |
-
greulich_and_pyle_ages = json.load(f)["bone_ages"]
|
| 68 |
-
|
| 69 |
-
greulich_and_pyle_ages = {k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()}
|
| 70 |
-
|
| 71 |
-
model_grad_cam = ModelForGradCAM(model_list[0])
|
| 72 |
-
target_layers = [model_grad_cam.model.backbone.stages[-1]]
|
| 73 |
-
|
| 74 |
-
|
| 75 |
@spaces.GPU
|
| 76 |
def predict_bone_age(Radiograph, Sex, Heatmap):
|
| 77 |
-
|
| 78 |
-
x =
|
| 79 |
-
x =
|
| 80 |
-
x = rearrange(x, "h w c -> 1 c h w")
|
| 81 |
# crop
|
|
|
|
| 82 |
with torch.inference_mode():
|
| 83 |
-
box = crop_model(
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
box[[0, 2]] = box[[0, 2]] * x0.shape[1]
|
| 87 |
-
box[[1, 3]] = box[[1, 3]] * x0.shape[0]
|
| 88 |
-
box = box.numpy().astype("int")
|
| 89 |
-
x, y, w, h = box
|
| 90 |
-
x0 = x0[y : y + h, x : x + w]
|
| 91 |
# histogram matching
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
x = np.concatenate([x, ch], axis=-1)
|
| 99 |
-
x = torch.from_numpy(x)
|
| 100 |
-
x = rearrange(x, "h w c -> 1 c h w")
|
| 101 |
with torch.inference_mode():
|
| 102 |
-
bone_age = []
|
| 103 |
-
for each_model in model_list:
|
| 104 |
-
pred = each_model({"x": x.to(device).float()}, return_loss=False)[
|
| 105 |
-
"logits1"
|
| 106 |
-
][0].cpu()
|
| 107 |
-
pred = (pred.softmax(0) * torch.arange(240)).sum().numpy()
|
| 108 |
-
bone_age.append(pred)
|
| 109 |
-
bone_age = np.mean(bone_age)
|
| 110 |
|
|
|
|
|
|
|
| 111 |
gp_ages = greulich_and_pyle_ages["female" if Sex else "male"]
|
| 112 |
diffs_gp = np.abs(bone_age - gp_ages)
|
| 113 |
diffs_gp = np.argsort(diffs_gp)
|
|
@@ -119,29 +81,33 @@ def predict_bone_age(Radiograph, Sex, Heatmap):
|
|
| 119 |
closest2 = convert_bone_age_to_string(closest2)
|
| 120 |
|
| 121 |
if Heatmap:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
targets = [ClassifierOutputTarget(round(bone_age))]
|
| 123 |
with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam:
|
| 124 |
-
grayscale_cam = cam(
|
| 125 |
-
input_tensor=x.to(device).float(), targets=targets, eigen_smooth=True
|
| 126 |
-
)
|
| 127 |
|
| 128 |
heatmap = cv2.applyColorMap(
|
| 129 |
(grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET
|
| 130 |
)
|
| 131 |
-
image = cv2.cvtColor(
|
|
|
|
|
|
|
| 132 |
image_weight = 0.6
|
| 133 |
grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
|
| 134 |
-
grad_cam_image = grad_cam_image
|
| 135 |
else:
|
| 136 |
# if no heatmap desired, just show image
|
| 137 |
-
grad_cam_image = cv2.cvtColor(
|
| 138 |
-
x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB
|
| 139 |
-
)
|
| 140 |
|
| 141 |
return (
|
| 142 |
bone_age_str,
|
| 143 |
f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}",
|
| 144 |
-
grad_cam_image,
|
| 145 |
)
|
| 146 |
|
| 147 |
|
|
@@ -157,11 +123,8 @@ with gr.Blocks() as demo:
|
|
| 157 |
"""
|
| 158 |
# Deep Learning Model for Pediatric Bone Age
|
| 159 |
|
| 160 |
-
This model predicts the bone age from a single frontal view hand radiograph.
|
| 161 |
-
|
| 162 |
-
[RSNA Pediatric Bone Age Challenge](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017) dataset.
|
| 163 |
-
The model achieves a mean absolute error of 4.26 months on the original test set comprising 200 multi-annotated hand radiographs,
|
| 164 |
-
which is competitive with [top solutions](https://pubs.rsna.org/doi/10.1148/radiol.2018180736) from the original challenge.
|
| 165 |
|
| 166 |
There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on
|
| 167 |
to make its prediction. However, this takes extra computation and will increase the runtime.
|
|
@@ -172,7 +135,7 @@ with gr.Blocks() as demo:
|
|
| 172 |
|
| 173 |
Created by: Ian Pan, <https://ianpan.me>
|
| 174 |
|
| 175 |
-
Last updated: December
|
| 176 |
"""
|
| 177 |
)
|
| 178 |
gr.Interface(
|
|
@@ -184,8 +147,27 @@ with gr.Blocks() as demo:
|
|
| 184 |
["examples/10043.png", "Female", "No"],
|
| 185 |
["examples/8888.png", "Female", "Yes"],
|
| 186 |
],
|
| 187 |
-
cache_examples=
|
| 188 |
)
|
| 189 |
|
| 190 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
demo.launch(share=True)
|
|
|
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
from einops import rearrange
|
|
|
|
| 10 |
from pytorch_grad_cam import GradCAM
|
| 11 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 12 |
from skimage.exposure import match_histograms
|
| 13 |
+
from transformers import AutoModel
|
| 14 |
|
| 15 |
|
| 16 |
class ModelForGradCAM(nn.Module):
|
| 17 |
+
def __init__(self, model, female):
|
| 18 |
super().__init__()
|
| 19 |
self.model = model
|
| 20 |
+
self.female = female
|
| 21 |
|
| 22 |
def forward(self, x):
|
| 23 |
+
return self.model(x, self.female, return_logits=True)
|
| 24 |
|
| 25 |
|
| 26 |
def convert_bone_age_to_string(bone_age: float):
|
|
|
|
| 47 |
return str_output
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
@spaces.GPU
|
| 51 |
def predict_bone_age(Radiograph, Sex, Heatmap):
|
| 52 |
+
x = crop_model.preprocess(Radiograph)
|
| 53 |
+
x = torch.from_numpy(x).float().to(device)
|
| 54 |
+
x = rearrange(x, "h w -> 1 1 h w")
|
|
|
|
| 55 |
# crop
|
| 56 |
+
img_shape = torch.tensor([Radiograph.shape[:2]]).to(device)
|
| 57 |
with torch.inference_mode():
|
| 58 |
+
box = crop_model(x, img_shape=img_shape).to("cpu").numpy()
|
| 59 |
+
x, y, w, h = box[0]
|
| 60 |
+
cropped = Radiograph[y : y + h, x : x + w]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# histogram matching
|
| 62 |
+
x = match_histograms(cropped, ref_img)
|
| 63 |
+
|
| 64 |
+
x = model.preprocess(x)
|
| 65 |
+
x = torch.from_numpy(x).float().to(device)
|
| 66 |
+
x = rearrange(x, "h w -> 1 1 h w")
|
| 67 |
+
female = torch.tensor([Sex]).to(device)
|
|
|
|
|
|
|
|
|
|
| 68 |
with torch.inference_mode():
|
| 69 |
+
bone_age = model(x, female)[0].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# get closest G&P ages
|
| 72 |
+
# from: https://rad.esmil.com/Reference/G_P_BoneAge/
|
| 73 |
gp_ages = greulich_and_pyle_ages["female" if Sex else "male"]
|
| 74 |
diffs_gp = np.abs(bone_age - gp_ages)
|
| 75 |
diffs_gp = np.argsort(diffs_gp)
|
|
|
|
| 81 |
closest2 = convert_bone_age_to_string(closest2)
|
| 82 |
|
| 83 |
if Heatmap:
|
| 84 |
+
# net1 and net2 to give good GradCAMs
|
| 85 |
+
# net0 is bad for some reason
|
| 86 |
+
# because GradCAM expects 1 input tensor, need to
|
| 87 |
+
# pass female during class instantiation
|
| 88 |
+
model_grad_cam = ModelForGradCAM(model.net1, female)
|
| 89 |
+
target_layers = [model_grad_cam.model.backbone.stages[-1]]
|
| 90 |
targets = [ClassifierOutputTarget(round(bone_age))]
|
| 91 |
with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam:
|
| 92 |
+
grayscale_cam = cam(input_tensor=x, targets=targets, eigen_smooth=True)
|
|
|
|
|
|
|
| 93 |
|
| 94 |
heatmap = cv2.applyColorMap(
|
| 95 |
(grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET
|
| 96 |
)
|
| 97 |
+
image = cv2.cvtColor(
|
| 98 |
+
x[0, 0].to("cpu").numpy().astype("uint8"), cv2.COLOR_GRAY2RGB
|
| 99 |
+
)
|
| 100 |
image_weight = 0.6
|
| 101 |
grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
|
| 102 |
+
grad_cam_image = grad_cam_image
|
| 103 |
else:
|
| 104 |
# if no heatmap desired, just show image
|
| 105 |
+
grad_cam_image = cv2.cvtColor(x[0, 0].to("cpu").numpy(), cv2.COLOR_GRAY2RGB)
|
|
|
|
|
|
|
| 106 |
|
| 107 |
return (
|
| 108 |
bone_age_str,
|
| 109 |
f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}",
|
| 110 |
+
grad_cam_image.astype("uint8"),
|
| 111 |
)
|
| 112 |
|
| 113 |
|
|
|
|
| 123 |
"""
|
| 124 |
# Deep Learning Model for Pediatric Bone Age
|
| 125 |
|
| 126 |
+
This model predicts the bone age from a single frontal view hand radiograph. Read more about the model here:
|
| 127 |
+
<https://huggingface.co/ianpan/bone-age>
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on
|
| 130 |
to make its prediction. However, this takes extra computation and will increase the runtime.
|
|
|
|
| 135 |
|
| 136 |
Created by: Ian Pan, <https://ianpan.me>
|
| 137 |
|
| 138 |
+
Last updated: December 16, 2024
|
| 139 |
"""
|
| 140 |
)
|
| 141 |
gr.Interface(
|
|
|
|
| 147 |
["examples/10043.png", "Female", "No"],
|
| 148 |
["examples/8888.png", "Female", "Yes"],
|
| 149 |
],
|
| 150 |
+
cache_examples="lazy",
|
| 151 |
)
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|
| 154 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 155 |
+
print(f"Using device `{device}` ...")
|
| 156 |
+
|
| 157 |
+
crop_model = AutoModel.from_pretrained(
|
| 158 |
+
"ianpan/bone-age-crop", trust_remote_code=True
|
| 159 |
+
)
|
| 160 |
+
model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
|
| 161 |
+
|
| 162 |
+
crop_model, model = crop_model.eval().to(device), model.eval().to(device)
|
| 163 |
+
|
| 164 |
+
ref_img = cv2.imread("ref_img.png", 0)
|
| 165 |
+
|
| 166 |
+
with open("greulich_and_pyle_ages.json", "r") as f:
|
| 167 |
+
greulich_and_pyle_ages = json.load(f)["bone_ages"]
|
| 168 |
+
|
| 169 |
+
greulich_and_pyle_ages = {
|
| 170 |
+
k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
demo.launch(share=True)
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ gradio
|
|
| 5 |
scikit-image
|
| 6 |
spaces
|
| 7 |
timm
|
| 8 |
-
torch
|
|
|
|
|
|
| 5 |
scikit-image
|
| 6 |
spaces
|
| 7 |
timm
|
| 8 |
+
torch
|
| 9 |
+
transformers
|