from typing import Any, Optional from torchmetrics.retrieval.average_precision import RetrievalMAP from torchmetrics.retrieval.fall_out import RetrievalFallOut from torchmetrics.retrieval.hit_rate import RetrievalHitRate from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG from torchmetrics.retrieval.precision import RetrievalPrecision from torchmetrics.retrieval.precision_recall_curve import RetrievalPrecisionRecallCurve, RetrievalRecallAtFixedPrecision from torchmetrics.retrieval.r_precision import RetrievalRPrecision from torchmetrics.retrieval.recall import RetrievalRecall from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR from torchmetrics.utilities.prints import _deprecated_root_import_class class _RetrievalFallOut(RetrievalFallOut): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> rfo = _RetrievalFallOut(top_k=2) >>> rfo(preds, target, indexes=indexes) tensor(0.5000) """ def __init__( self, empty_target_action: str = "pos", ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalFallOut", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) class _RetrievalHitRate(RetrievalHitRate): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([True, False, False, False, True, False, True]) >>> hr2 = _RetrievalHitRate(top_k=2) >>> hr2(preds, target, indexes=indexes) tensor(0.5000) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalHitRate", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) class _RetrievalMAP(RetrievalMAP): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> rmap = _RetrievalMAP() >>> rmap(preds, target, indexes=indexes) tensor(0.7917) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalMAP", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) class _RetrievalRecall(RetrievalRecall): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> r2 = _RetrievalRecall(top_k=2) >>> r2(preds, target, indexes=indexes) tensor(0.7500) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalRecall", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) class _RetrievalRPrecision(RetrievalRPrecision): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> p2 = _RetrievalRPrecision() >>> p2(preds, target, indexes=indexes) tensor(0.7500) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalRPrecision", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs) class _RetrievalNormalizedDCG(RetrievalNormalizedDCG): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> ndcg = _RetrievalNormalizedDCG() >>> ndcg(preds, target, indexes=indexes) tensor(0.8467) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalNormalizedDCG", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs) class _RetrievalPrecision(RetrievalPrecision): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> p2 = _RetrievalPrecision(top_k=2) >>> p2(preds, target, indexes=indexes) tensor(0.5000) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, adaptive_k: bool = False, **kwargs: Any, ) -> None: _deprecated_root_import_class("", "retrieval") super().__init__( empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, adaptive_k=adaptive_k, **kwargs, ) class _RetrievalPrecisionRecallCurve(RetrievalPrecisionRecallCurve): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 0, 1, 1, 1]) >>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5]) >>> target = tensor([True, False, False, True, True, False, True]) >>> r = _RetrievalPrecisionRecallCurve(max_k=4) >>> precisions, recalls, top_k = r(preds, target, indexes=indexes) >>> precisions tensor([1.0000, 0.5000, 0.6667, 0.5000]) >>> recalls tensor([0.5000, 0.5000, 1.0000, 1.0000]) >>> top_k tensor([1, 2, 3, 4]) """ def __init__( self, max_k: Optional[int] = None, adaptive_k: bool = False, empty_target_action: str = "neg", ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("", "retrieval") super().__init__( max_k=max_k, adaptive_k=adaptive_k, empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs, ) class _RetrievalRecallAtFixedPrecision(RetrievalRecallAtFixedPrecision): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 0, 1, 1, 1]) >>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5]) >>> target = tensor([True, False, False, True, True, False, True]) >>> r = _RetrievalRecallAtFixedPrecision(min_precision=0.8) >>> r(preds, target, indexes=indexes) (tensor(0.5000), tensor(1)) """ def __init__( self, min_precision: float = 0.0, max_k: Optional[int] = None, adaptive_k: bool = False, empty_target_action: str = "neg", ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("RetrievalRecallAtFixedPrecision", "retrieval") super().__init__( min_precision=min_precision, max_k=max_k, adaptive_k=adaptive_k, empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs, ) class _RetrievalMRR(RetrievalMRR): """Wrapper for deprecated import. >>> from torch import tensor >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> mrr = _RetrievalMRR() >>> mrr(preds, target, indexes=indexes) tensor(0.7500) """ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("", "retrieval") super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs)