jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from typing import Any, Optional
from torchmetrics.retrieval.average_precision import RetrievalMAP
from torchmetrics.retrieval.fall_out import RetrievalFallOut
from torchmetrics.retrieval.hit_rate import RetrievalHitRate
from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG
from torchmetrics.retrieval.precision import RetrievalPrecision
from torchmetrics.retrieval.precision_recall_curve import RetrievalPrecisionRecallCurve, RetrievalRecallAtFixedPrecision
from torchmetrics.retrieval.r_precision import RetrievalRPrecision
from torchmetrics.retrieval.recall import RetrievalRecall
from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR
from torchmetrics.utilities.prints import _deprecated_root_import_class
class _RetrievalFallOut(RetrievalFallOut):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> rfo = _RetrievalFallOut(top_k=2)
>>> rfo(preds, target, indexes=indexes)
tensor(0.5000)
"""
def __init__(
self,
empty_target_action: str = "pos",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalFallOut", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
class _RetrievalHitRate(RetrievalHitRate):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([True, False, False, False, True, False, True])
>>> hr2 = _RetrievalHitRate(top_k=2)
>>> hr2(preds, target, indexes=indexes)
tensor(0.5000)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalHitRate", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
class _RetrievalMAP(RetrievalMAP):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> rmap = _RetrievalMAP()
>>> rmap(preds, target, indexes=indexes)
tensor(0.7917)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalMAP", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
class _RetrievalRecall(RetrievalRecall):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> r2 = _RetrievalRecall(top_k=2)
>>> r2(preds, target, indexes=indexes)
tensor(0.7500)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalRecall", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
class _RetrievalRPrecision(RetrievalRPrecision):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = _RetrievalRPrecision()
>>> p2(preds, target, indexes=indexes)
tensor(0.7500)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalRPrecision", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs)
class _RetrievalNormalizedDCG(RetrievalNormalizedDCG):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> ndcg = _RetrievalNormalizedDCG()
>>> ndcg(preds, target, indexes=indexes)
tensor(0.8467)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalNormalizedDCG", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
class _RetrievalPrecision(RetrievalPrecision):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = _RetrievalPrecision(top_k=2)
>>> p2(preds, target, indexes=indexes)
tensor(0.5000)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
adaptive_k: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("", "retrieval")
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
top_k=top_k,
adaptive_k=adaptive_k,
**kwargs,
)
class _RetrievalPrecisionRecallCurve(RetrievalPrecisionRecallCurve):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 0, 1, 1, 1])
>>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
>>> target = tensor([True, False, False, True, True, False, True])
>>> r = _RetrievalPrecisionRecallCurve(max_k=4)
>>> precisions, recalls, top_k = r(preds, target, indexes=indexes)
>>> precisions
tensor([1.0000, 0.5000, 0.6667, 0.5000])
>>> recalls
tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> top_k
tensor([1, 2, 3, 4])
"""
def __init__(
self,
max_k: Optional[int] = None,
adaptive_k: bool = False,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("", "retrieval")
super().__init__(
max_k=max_k,
adaptive_k=adaptive_k,
empty_target_action=empty_target_action,
ignore_index=ignore_index,
**kwargs,
)
class _RetrievalRecallAtFixedPrecision(RetrievalRecallAtFixedPrecision):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 0, 1, 1, 1])
>>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
>>> target = tensor([True, False, False, True, True, False, True])
>>> r = _RetrievalRecallAtFixedPrecision(min_precision=0.8)
>>> r(preds, target, indexes=indexes)
(tensor(0.5000), tensor(1))
"""
def __init__(
self,
min_precision: float = 0.0,
max_k: Optional[int] = None,
adaptive_k: bool = False,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("RetrievalRecallAtFixedPrecision", "retrieval")
super().__init__(
min_precision=min_precision,
max_k=max_k,
adaptive_k=adaptive_k,
empty_target_action=empty_target_action,
ignore_index=ignore_index,
**kwargs,
)
class _RetrievalMRR(RetrievalMRR):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> mrr = _RetrievalMRR()
>>> mrr(preds, target, indexes=indexes)
tensor(0.7500)
"""
def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("", "retrieval")
super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs)