File size: 4,376 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.mse import _mean_squared_error_update


def _normalized_root_mean_squared_error_update(
    preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean"
) -> tuple[Tensor, int, Tensor]:
    """Updates and returns the sum of squared errors and the number of observations for NRMSE computation.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        num_outputs: Number of outputs in multioutput setting
        normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2"

    """
    sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs)

    target = target.view(-1) if num_outputs == 1 else target
    if normalization == "mean":
        denom = torch.mean(target, dim=0)
    elif normalization == "range":
        denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values
    elif normalization == "std":
        denom = torch.std(target, correction=0, dim=0)
    elif normalization == "l2":
        denom = torch.norm(target, p=2, dim=0)
    else:
        raise ValueError(
            f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2' but got {normalization}"
        )
    return sum_squared_error, num_obs, denom


def _normalized_root_mean_squared_error_compute(
    sum_squared_error: Tensor, num_obs: Union[int, Tensor], denom: Tensor
) -> Tensor:
    """Calculates RMSE and normalizes it."""
    rmse = torch.sqrt(sum_squared_error / num_obs)
    return rmse / denom


def normalized_root_mean_squared_error(
    preds: Tensor,
    target: Tensor,
    normalization: Literal["mean", "range", "std", "l2"] = "mean",
    num_outputs: int = 1,
) -> Tensor:
    """Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index.

    Args:
        preds: estimated labels
        target: ground truth labels
        normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds
          to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the
          target or the L2 norm of the target.
        num_outputs: Number of outputs in multioutput setting

    Return:
        Tensor with the NRMSE score

    Example:
        >>> import torch
        >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error
        >>> preds = torch.tensor([0., 1, 2, 3])
        >>> target = torch.tensor([0., 1, 2, 2])
        >>> normalized_root_mean_squared_error(preds, target, normalization="mean")
        tensor(0.4000)
        >>> normalized_root_mean_squared_error(preds, target, normalization="range")
        tensor(0.2500)
        >>> normalized_root_mean_squared_error(preds, target, normalization="std")
        tensor(0.6030)
        >>> normalized_root_mean_squared_error(preds, target, normalization="l2")
        tensor(0.1667)

    Example (multioutput):
        >>> import torch
        >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error
        >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]])
        >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]])
        >>> normalized_root_mean_squared_error(preds, target, normalization="mean", num_outputs=2)
        tensor([0.2981, 0.2222])

    """
    sum_squared_error, num_obs, denom = _normalized_root_mean_squared_error_update(
        preds, target, num_outputs=num_outputs, normalization=normalization
    )
    return _normalized_root_mean_squared_error_compute(sum_squared_error, num_obs, denom)