File size: 2,359 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from collections.abc import Collection
from typing import Any
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.prints import _deprecated_root_import_class
class _ModifiedPanopticQuality(ModifiedPanopticQuality):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
>>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
>>> pq_modified = _ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> pq_modified(preds, target)
tensor(0.7667, dtype=torch.float64)
"""
def __init__(
self,
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ModifiedPanopticQuality", "detection")
super().__init__(
things=things, stuffs=stuffs, allow_unknown_preds_category=allow_unknown_preds_category, **kwargs
)
class _PanopticQuality(PanopticQuality):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = _PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)
"""
def __init__(
self,
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("PanopticQuality", "detection")
super().__init__(
things=things, stuffs=stuffs, allow_unknown_preds_category=allow_unknown_preds_category, **kwargs
)
|