|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Optional, Union |
|
|
|
import torch |
|
|
|
from torchmetrics.collections import MetricCollection |
|
from torchmetrics.metric import Metric |
|
from torchmetrics.wrappers.abstract import WrapperMetric |
|
|
|
|
|
class MetricInputTransformer(WrapperMetric): |
|
"""Abstract base class for metric input transformations. |
|
|
|
Input transformations are characterized by them applying a transformation to the input data of a metric, and then |
|
forwarding all calls to the wrapped metric with modifications applied. |
|
|
|
""" |
|
|
|
def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: dict[str, Any]) -> None: |
|
super().__init__(**kwargs) |
|
if not isinstance(wrapped_metric, (Metric, MetricCollection)): |
|
raise TypeError( |
|
f"Expected wrapped metric to be an instance of `torchmetrics.Metric` or " |
|
f"`torchmetrics.MetricsCollection`but received {wrapped_metric}" |
|
) |
|
self.wrapped_metric = wrapped_metric |
|
|
|
def transform_pred(self, pred: torch.Tensor) -> torch.Tensor: |
|
"""Define transform operations on the prediction data. |
|
|
|
Overridden by subclasses. Identity by default. |
|
|
|
""" |
|
return pred |
|
|
|
def transform_target(self, target: torch.Tensor) -> torch.Tensor: |
|
"""Define transform operations on the target data. |
|
|
|
Overridden by subclasses. Identity by default. |
|
|
|
""" |
|
return target |
|
|
|
def _wrap_transform(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: |
|
"""Wrap transformation functions to dispatch args to their individual transform functions.""" |
|
if len(args) == 1: |
|
return (self.transform_pred(args[0]),) |
|
if len(args) == 2: |
|
return self.transform_pred(args[0]), self.transform_target(args[1]) |
|
return self.transform_pred(args[0]), self.transform_target(args[1]), *args[2:] |
|
|
|
def update(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> None: |
|
"""Wrap the update call of the underlying metric.""" |
|
args = self._wrap_transform(*args) |
|
self.wrapped_metric.update(*args, **kwargs) |
|
|
|
def compute(self) -> Any: |
|
"""Wrap the compute call of the underlying metric.""" |
|
return self.wrapped_metric.compute() |
|
|
|
def forward(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> Any: |
|
"""Wrap the forward call of the underlying metric.""" |
|
args = self._wrap_transform(*args) |
|
return self.wrapped_metric.forward(*args, **kwargs) |
|
|
|
|
|
class LambdaInputTransformer(MetricInputTransformer): |
|
"""Wrapper class for transforming a metrics' inputs given a user-defined lambda function. |
|
|
|
Args: |
|
wrapped_metric: |
|
The underlying `Metric` or `MetricCollection`. |
|
transform_pred: |
|
The function to apply to the predictions before computing the metric. |
|
transform_target: |
|
The function to apply to the target before computing the metric. |
|
|
|
Raises: |
|
TypeError: |
|
If `transform_pred` is not a Callable. |
|
TypeError: |
|
If `transform_target` is not a Callable. |
|
|
|
Example: |
|
>>> import torch |
|
>>> from torchmetrics.classification import BinaryAccuracy |
|
>>> from torchmetrics.wrappers import LambdaInputTransformer |
|
>>> |
|
>>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4]) |
|
>>> targets = torch.tensor([1,0,0,0,0,1,1,0,0,0]) |
|
>>> |
|
>>> metric = LambdaInputTransformer(BinaryAccuracy(), lambda preds: 1 - preds) |
|
>>> metric.update(preds, targets) |
|
>>> metric.compute() |
|
tensor(0.6000) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
wrapped_metric: Metric, |
|
transform_pred: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, |
|
transform_target: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(wrapped_metric, **kwargs) |
|
if transform_pred is not None: |
|
if not callable(transform_pred): |
|
raise TypeError(f"Expected `transform_pred` to be of type `Callable` but received `{transform_pred}`") |
|
self.transform_pred = transform_pred |
|
|
|
if transform_target is not None: |
|
if not callable(transform_target): |
|
raise TypeError( |
|
f"Expected `transform_target` to be of type `Callable` but received `{transform_target}`" |
|
) |
|
self.transform_target = transform_target |
|
|
|
|
|
class BinaryTargetTransformer(MetricInputTransformer): |
|
"""Wrapper class for computing a metric on binarized targets. |
|
|
|
Useful when the given ground-truth targets are continuous, but the metric requires binary targets. |
|
|
|
Args: |
|
wrapped_metric: |
|
The underlying `Metric` or `MetricCollection`. |
|
threshold: |
|
The binarization threshold for the targets. Targets values `t` are cast to binary with `t > threshold`. |
|
|
|
Raises: |
|
TypeError: |
|
If `threshold` is not an `int` or `float`. |
|
|
|
Example: |
|
>>> import torch |
|
>>> from torchmetrics.retrieval import RetrievalMRR |
|
>>> from torchmetrics.wrappers import BinaryTargetTransformer |
|
>>> |
|
>>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4]) |
|
>>> targets = torch.tensor([1,0,0,0,0,2,1,0,0,0]) |
|
>>> topics = torch.tensor([0,0,0,0,0,1,1,1,1,1]) |
|
>>> |
|
>>> metric = BinaryTargetTransformer(RetrievalMRR()) |
|
>>> metric.update(preds, targets, indexes=topics) |
|
>>> metric.compute() |
|
tensor(0.7500) |
|
|
|
""" |
|
|
|
def __init__(self, wrapped_metric: Union[Metric, MetricCollection], threshold: float = 0, **kwargs: Any) -> None: |
|
super().__init__(wrapped_metric, **kwargs) |
|
if not isinstance(threshold, (int, float)): |
|
raise TypeError(f"Expected `threshold` to be of type `int` or `float` but received `{threshold}`") |
|
self.threshold = threshold |
|
|
|
def transform_target(self, target: torch.Tensor) -> torch.Tensor: |
|
"""Cast the target tensor to binary values according to the threshold. |
|
|
|
Output assumes same type as input. |
|
|
|
""" |
|
return target.gt(self.threshold).to(target.dtype) |
|
|