|
from typing import Any, Optional |
|
|
|
from torchmetrics.retrieval.average_precision import RetrievalMAP |
|
from torchmetrics.retrieval.fall_out import RetrievalFallOut |
|
from torchmetrics.retrieval.hit_rate import RetrievalHitRate |
|
from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG |
|
from torchmetrics.retrieval.precision import RetrievalPrecision |
|
from torchmetrics.retrieval.precision_recall_curve import RetrievalPrecisionRecallCurve, RetrievalRecallAtFixedPrecision |
|
from torchmetrics.retrieval.r_precision import RetrievalRPrecision |
|
from torchmetrics.retrieval.recall import RetrievalRecall |
|
from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR |
|
from torchmetrics.utilities.prints import _deprecated_root_import_class |
|
|
|
|
|
class _RetrievalFallOut(RetrievalFallOut): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> rfo = _RetrievalFallOut(top_k=2) |
|
>>> rfo(preds, target, indexes=indexes) |
|
tensor(0.5000) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "pos", |
|
ignore_index: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalFallOut", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) |
|
|
|
|
|
class _RetrievalHitRate(RetrievalHitRate): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([True, False, False, False, True, False, True]) |
|
>>> hr2 = _RetrievalHitRate(top_k=2) |
|
>>> hr2(preds, target, indexes=indexes) |
|
tensor(0.5000) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalHitRate", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) |
|
|
|
|
|
class _RetrievalMAP(RetrievalMAP): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> rmap = _RetrievalMAP() |
|
>>> rmap(preds, target, indexes=indexes) |
|
tensor(0.7917) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalMAP", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) |
|
|
|
|
|
class _RetrievalRecall(RetrievalRecall): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> r2 = _RetrievalRecall(top_k=2) |
|
>>> r2(preds, target, indexes=indexes) |
|
tensor(0.7500) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalRecall", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) |
|
|
|
|
|
class _RetrievalRPrecision(RetrievalRPrecision): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> p2 = _RetrievalRPrecision() |
|
>>> p2(preds, target, indexes=indexes) |
|
tensor(0.7500) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalRPrecision", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs) |
|
|
|
|
|
class _RetrievalNormalizedDCG(RetrievalNormalizedDCG): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> ndcg = _RetrievalNormalizedDCG() |
|
>>> ndcg(preds, target, indexes=indexes) |
|
tensor(0.8467) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalNormalizedDCG", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) |
|
|
|
|
|
class _RetrievalPrecision(RetrievalPrecision): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> p2 = _RetrievalPrecision(top_k=2) |
|
>>> p2(preds, target, indexes=indexes) |
|
tensor(0.5000) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
adaptive_k: bool = False, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("", "retrieval") |
|
super().__init__( |
|
empty_target_action=empty_target_action, |
|
ignore_index=ignore_index, |
|
top_k=top_k, |
|
adaptive_k=adaptive_k, |
|
**kwargs, |
|
) |
|
|
|
|
|
class _RetrievalPrecisionRecallCurve(RetrievalPrecisionRecallCurve): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 0, 1, 1, 1]) |
|
>>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5]) |
|
>>> target = tensor([True, False, False, True, True, False, True]) |
|
>>> r = _RetrievalPrecisionRecallCurve(max_k=4) |
|
>>> precisions, recalls, top_k = r(preds, target, indexes=indexes) |
|
>>> precisions |
|
tensor([1.0000, 0.5000, 0.6667, 0.5000]) |
|
>>> recalls |
|
tensor([0.5000, 0.5000, 1.0000, 1.0000]) |
|
>>> top_k |
|
tensor([1, 2, 3, 4]) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_k: Optional[int] = None, |
|
adaptive_k: bool = False, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("", "retrieval") |
|
super().__init__( |
|
max_k=max_k, |
|
adaptive_k=adaptive_k, |
|
empty_target_action=empty_target_action, |
|
ignore_index=ignore_index, |
|
**kwargs, |
|
) |
|
|
|
|
|
class _RetrievalRecallAtFixedPrecision(RetrievalRecallAtFixedPrecision): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 0, 1, 1, 1]) |
|
>>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5]) |
|
>>> target = tensor([True, False, False, True, True, False, True]) |
|
>>> r = _RetrievalRecallAtFixedPrecision(min_precision=0.8) |
|
>>> r(preds, target, indexes=indexes) |
|
(tensor(0.5000), tensor(1)) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
min_precision: float = 0.0, |
|
max_k: Optional[int] = None, |
|
adaptive_k: bool = False, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("RetrievalRecallAtFixedPrecision", "retrieval") |
|
super().__init__( |
|
min_precision=min_precision, |
|
max_k=max_k, |
|
adaptive_k=adaptive_k, |
|
empty_target_action=empty_target_action, |
|
ignore_index=ignore_index, |
|
**kwargs, |
|
) |
|
|
|
|
|
class _RetrievalMRR(RetrievalMRR): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) |
|
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) |
|
>>> target = tensor([False, False, True, False, True, False, True]) |
|
>>> mrr = _RetrievalMRR() |
|
>>> mrr(preds, target, indexes=indexes) |
|
tensor(0.7500) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
empty_target_action: str = "neg", |
|
ignore_index: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("", "retrieval") |
|
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs) |
|
|