jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
from torchmetrics.collections import MetricCollection
from torchmetrics.metric import Metric
from torchmetrics.wrappers.abstract import WrapperMetric
class MetricInputTransformer(WrapperMetric):
"""Abstract base class for metric input transformations.
Input transformations are characterized by them applying a transformation to the input data of a metric, and then
forwarding all calls to the wrapped metric with modifications applied.
"""
def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: dict[str, Any]) -> None:
super().__init__(**kwargs)
if not isinstance(wrapped_metric, (Metric, MetricCollection)):
raise TypeError(
f"Expected wrapped metric to be an instance of `torchmetrics.Metric` or "
f"`torchmetrics.MetricsCollection`but received {wrapped_metric}"
)
self.wrapped_metric = wrapped_metric
def transform_pred(self, pred: torch.Tensor) -> torch.Tensor:
"""Define transform operations on the prediction data.
Overridden by subclasses. Identity by default.
"""
return pred
def transform_target(self, target: torch.Tensor) -> torch.Tensor:
"""Define transform operations on the target data.
Overridden by subclasses. Identity by default.
"""
return target
def _wrap_transform(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Wrap transformation functions to dispatch args to their individual transform functions."""
if len(args) == 1:
return (self.transform_pred(args[0]),)
if len(args) == 2:
return self.transform_pred(args[0]), self.transform_target(args[1])
return self.transform_pred(args[0]), self.transform_target(args[1]), *args[2:]
def update(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> None:
"""Wrap the update call of the underlying metric."""
args = self._wrap_transform(*args)
self.wrapped_metric.update(*args, **kwargs)
def compute(self) -> Any:
"""Wrap the compute call of the underlying metric."""
return self.wrapped_metric.compute()
def forward(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> Any:
"""Wrap the forward call of the underlying metric."""
args = self._wrap_transform(*args)
return self.wrapped_metric.forward(*args, **kwargs)
class LambdaInputTransformer(MetricInputTransformer):
"""Wrapper class for transforming a metrics' inputs given a user-defined lambda function.
Args:
wrapped_metric:
The underlying `Metric` or `MetricCollection`.
transform_pred:
The function to apply to the predictions before computing the metric.
transform_target:
The function to apply to the target before computing the metric.
Raises:
TypeError:
If `transform_pred` is not a Callable.
TypeError:
If `transform_target` is not a Callable.
Example:
>>> import torch
>>> from torchmetrics.classification import BinaryAccuracy
>>> from torchmetrics.wrappers import LambdaInputTransformer
>>>
>>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
>>> targets = torch.tensor([1,0,0,0,0,1,1,0,0,0])
>>>
>>> metric = LambdaInputTransformer(BinaryAccuracy(), lambda preds: 1 - preds)
>>> metric.update(preds, targets)
>>> metric.compute()
tensor(0.6000)
"""
def __init__(
self,
wrapped_metric: Metric,
transform_pred: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
transform_target: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
**kwargs: Any,
) -> None:
super().__init__(wrapped_metric, **kwargs)
if transform_pred is not None:
if not callable(transform_pred):
raise TypeError(f"Expected `transform_pred` to be of type `Callable` but received `{transform_pred}`")
self.transform_pred = transform_pred # type: ignore[assignment,method-assign]
if transform_target is not None:
if not callable(transform_target):
raise TypeError(
f"Expected `transform_target` to be of type `Callable` but received `{transform_target}`"
)
self.transform_target = transform_target # type: ignore[assignment,method-assign]
class BinaryTargetTransformer(MetricInputTransformer):
"""Wrapper class for computing a metric on binarized targets.
Useful when the given ground-truth targets are continuous, but the metric requires binary targets.
Args:
wrapped_metric:
The underlying `Metric` or `MetricCollection`.
threshold:
The binarization threshold for the targets. Targets values `t` are cast to binary with `t > threshold`.
Raises:
TypeError:
If `threshold` is not an `int` or `float`.
Example:
>>> import torch
>>> from torchmetrics.retrieval import RetrievalMRR
>>> from torchmetrics.wrappers import BinaryTargetTransformer
>>>
>>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
>>> targets = torch.tensor([1,0,0,0,0,2,1,0,0,0])
>>> topics = torch.tensor([0,0,0,0,0,1,1,1,1,1])
>>>
>>> metric = BinaryTargetTransformer(RetrievalMRR())
>>> metric.update(preds, targets, indexes=topics)
>>> metric.compute()
tensor(0.7500)
"""
def __init__(self, wrapped_metric: Union[Metric, MetricCollection], threshold: float = 0, **kwargs: Any) -> None:
super().__init__(wrapped_metric, **kwargs)
if not isinstance(threshold, (int, float)):
raise TypeError(f"Expected `threshold` to be of type `int` or `float` but received `{threshold}`")
self.threshold = threshold
def transform_target(self, target: torch.Tensor) -> torch.Tensor:
"""Cast the target tensor to binary values according to the threshold.
Output assumes same type as input.
"""
return target.gt(self.threshold).to(target.dtype)