File size: 18,809 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from torch import Tensor
from typing_extensions import Literal
from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.precision_recall_curve import (
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
MultilabelPrecisionRecallCurve,
)
from torchmetrics.functional.classification.specificity_sensitivity import (
_binary_specificity_at_sensitivity_arg_validation,
_binary_specificity_at_sensitivity_compute,
_multiclass_specificity_at_sensitivity_arg_validation,
_multiclass_specificity_at_sensitivity_compute,
_multilabel_specificity_at_sensitivity_arg_validation,
_multilabel_specificity_at_sensitivity_compute,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat as _cat
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
"BinarySpecificityAtSensitivity.plot",
"MulticlassSpecificityAtSensitivity.plot",
"MultilabelSpecificityAtSensitivity.plot",
]
class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve):
r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
find the specificity for a given sensitivity level.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain {0,1} values (except if `ignore_index` is specified).
Additional dimension ``...`` will be flattened into the batch dimension.
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
Args:
min_sensitivity: float value specifying minimum sensitivity threshold.
thresholds:
Can be one of:
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Returns:
(tuple): a tuple of 2 tensors containing:
- specificity: an scalar tensor with the maximum specificity for the given sensitivity level
- threshold: an scalar tensor with the corresponding threshold level
Example:
>>> from torchmetrics.classification import BinarySpecificityAtSensitivity
>>> from torch import tensor
>>> preds = tensor([0, 0.5, 0.4, 0.1])
>>> target = tensor([0, 1, 1, 1])
>>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=None)
>>> metric(preds, target)
(tensor(1.), tensor(0.4000))
>>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=5)
>>> metric(preds, target)
(tensor(1.), tensor(0.2500))
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
def __init__(
self,
min_sensitivity: float,
thresholds: Optional[Union[int, list[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(thresholds, ignore_index, validate_args=False, **kwargs)
if validate_args:
_binary_specificity_at_sensitivity_arg_validation(min_sensitivity, thresholds, ignore_index)
self.validate_args = validate_args
self.min_sensitivity = min_sensitivity
def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
return _binary_specificity_at_sensitivity_compute(state, self.thresholds, self.min_sensitivity)
class MulticlassSpecificityAtSensitivity(MulticlassPrecisionRecallCurve):
r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
find the specificity for a given sensitivity level.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
softmax per sample.
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
Additional dimension ``...`` will be flattened into the batch dimension.
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
Args:
num_classes: Integer specifying the number of classes
min_sensitivity: float value specifying minimum sensitivity threshold.
thresholds:
Can be one of:
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Returns:
(tuple): a tuple of either 2 tensors or 2 lists containing
- specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given
sensitivity level per class
- thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
Example:
>>> from torchmetrics.classification import MulticlassSpecificityAtSensitivity
>>> from torch import tensor
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
>>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=5)
>>> metric(preds, target)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"
def __init__(
self,
num_classes: int,
min_sensitivity: float,
thresholds: Optional[Union[int, list[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(
num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
)
if validate_args:
_multiclass_specificity_at_sensitivity_arg_validation(
num_classes, min_sensitivity, thresholds, ignore_index
)
self.validate_args = validate_args
self.min_sensitivity = min_sensitivity
def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_specificity_at_sensitivity_compute(
state, self.num_classes, self.thresholds, self.min_sensitivity
)
class MultilabelSpecificityAtSensitivity(MultilabelPrecisionRecallCurve):
r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
find the specificity for a given sensitivity level.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain {0,1} values (except if `ignore_index` is specified).
Additional dimension ``...`` will be flattened into the batch dimension.
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
Args:
num_labels: Integer specifying the number of labels
min_sensitivity: float value specifying minimum sensitivity threshold.
thresholds:
Can be one of:
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Returns:
(tuple): a tuple of either 2 tensors or 2 lists containing
- specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given
sensitivity level per class
- thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
Example:
>>> from torchmetrics.classification import MultilabelSpecificityAtSensitivity
>>> from torch import tensor
>>> preds = tensor([[0.75, 0.05, 0.35],
... [0.45, 0.75, 0.05],
... [0.05, 0.55, 0.75],
... [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
... [0, 0, 0],
... [0, 1, 1],
... [1, 1, 1]])
>>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500]))
>>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=5)
>>> metric(preds, target)
(tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"
def __init__(
self,
num_labels: int,
min_sensitivity: float,
thresholds: Optional[Union[int, list[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(
num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
)
if validate_args:
_multilabel_specificity_at_sensitivity_arg_validation(num_labels, min_sensitivity, thresholds, ignore_index)
self.validate_args = validate_args
self.min_sensitivity = min_sensitivity
def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_specificity_at_sensitivity_compute(
state, self.num_labels, self.thresholds, self.ignore_index, self.min_sensitivity
)
class SpecificityAtSensitivity(_ClassificationTaskWrapper):
r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
find the specificity for a given sensitivity level.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of
:class:`~torchmetrics.classification.BinarySpecificityAtSensitivity`,
:class:`~torchmetrics.classification.MulticlassSpecificityAtSensitivity` and
:class:`~torchmetrics.classification.MultilabelSpecificityAtSensitivity` for the specific details of each argument
influence and examples.
"""
def __new__( # type: ignore[misc]
cls: type["SpecificityAtSensitivity"],
task: Literal["binary", "multiclass", "multilabel"],
min_sensitivity: float,
thresholds: Optional[Union[int, list[float], Tensor]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
if task == ClassificationTask.BINARY:
return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs)
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
return MulticlassSpecificityAtSensitivity(
num_classes, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs
)
if task == ClassificationTask.MULTILABEL:
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelSpecificityAtSensitivity(
num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs
)
raise ValueError(f"Task {task} not supported!")
|