File size: 11,924 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union

import torch
from torch import Tensor

from torchmetrics.functional.multimodal.clip_iqa import (
    _clip_iqa_compute,
    _clip_iqa_format_prompts,
    _clip_iqa_get_anchor_vectors,
    _clip_iqa_update,
    _get_clip_iqa_model_and_processor,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import (
    _MATPLOTLIB_AVAILABLE,
    _PIQ_GREATER_EQUAL_0_8,
    _TRANSFORMERS_GREATER_EQUAL_4_10,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _PIQ_GREATER_EQUAL_0_8:
    __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]

if not _MATPLOTLIB_AVAILABLE:
    __doctest_skip__ = ["CLIPImageQualityAssessment.plot"]

if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
    from transformers import CLIPModel as _CLIPModel
    from transformers import CLIPProcessor as _CLIPProcessor

    def _download_clip_iqa_metric() -> None:
        _CLIPModel.from_pretrained("openai/clip-vit-large-patch14", resume_download=True)
        _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", resume_download=True)

    if not _try_proceed_with_timeout(_download_clip_iqa_metric):
        __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]
else:
    __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]


class CLIPImageQualityAssessment(Metric):
    """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images.

    The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to
    be able to generate a vector representation of the image and the text that is similar if the image and text are
    semantically similar.

    The metric works by calculating the cosine similarity between user provided images and pre-defined prompts. The
    prompts always comes in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating
    the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine
    which prompt the image is more similar to. The metric then returns the probability that the image is more similar
    to the first prompt than the second prompt.

    Build in prompts are:
        * quality: "Good photo." vs "Bad photo."
        * brightness: "Bright photo." vs "Dark photo."
        * noisiness: "Clean photo." vs "Noisy photo."
        * colorfullness: "Colorful photo." vs "Dull photo."
        * sharpness: "Sharp photo." vs "Blurry photo."
        * contrast: "High contrast photo." vs "Low contrast photo."
        * complexity: "Complex photo." vs "Simple photo."
        * natural: "Natural photo." vs "Synthetic photo."
        * happy: "Happy photo." vs "Sad photo."
        * scary: "Scary photo." vs "Peaceful photo."
        * new: "New photo." vs "Old photo."
        * warm: "Warm photo." vs "Cold photo."
        * real: "Real photo." vs "Abstract photo."
        * beautiful: "Beautiful photo." vs "Ugly photo."
        * lonely: "Lonely photo." vs "Sociable photo."
        * relaxing: "Relaxing photo." vs "Stressful photo."

    As input to ``forward`` and ``update`` the metric accepts the following input

    - ``images`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with shape ``(N,C,H,W)``

    As output of `forward` and `compute` the metric returns the following output

    - ``clip_iqa`` (:class:`~torch.Tensor` or dict of tensors): tensor with the CLIP-IQA score. If a single prompt is
      provided, a single tensor with shape ``(N,)`` is returned. If a list of prompts is provided, a dict of tensors
      is returned with the prompt as key and the tensor with shape ``(N,)`` as value.

    Args:
        model_name_or_path: string indicating the version of the CLIP model to use. Available models are:

            - `"clip_iqa"`, model corresponding to the CLIP-IQA paper.
            - `"openai/clip-vit-base-patch16"`
            - `"openai/clip-vit-base-patch32"`
            - `"openai/clip-vit-large-patch14-336"`
            - `"openai/clip-vit-large-patch14"`

        data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255],
            data_range should be 255. The images are normalized by this value.
        prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one
            of the available prompts (see above). Else the input is expected to be a tuple, where each element can
            be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the
            available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a
            positive prompt and the second string must be a negative prompt.
        kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

    .. hint::
        If using the default `clip_iqa` model, the package `piq` must be installed. Either install with
        `pip install piq` or `pip install torchmetrics[image]`.

    Raises:
        ModuleNotFoundError:
            If transformers package is not installed or version is lower than 4.10.0
        ValueError:
            If `prompts` is a tuple and it is not of length 2
        ValueError:
            If `prompts` is a string and it is not one of the available prompts
        ValueError:
            If `prompts` is a list of strings and not all strings are one of the available prompts

    Example::
        Single prompt:

        >>> from torch import randint
        >>> from torchmetrics.multimodal import CLIPImageQualityAssessment
        >>> imgs = randint(255, (2, 3, 224, 224)).float()
        >>> metric = CLIPImageQualityAssessment()
        >>> metric(imgs)
        tensor([0.8894, 0.8902])

    Example::
        Multiple prompts:

        >>> from torch import randint
        >>> from torchmetrics.multimodal import CLIPImageQualityAssessment
        >>> imgs = randint(255, (2, 3, 224, 224)).float()
        >>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness"))
        >>> metric(imgs)
        {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])}

    Example::
        Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.

        >>> from torch import randint
        >>> from torchmetrics.multimodal import CLIPImageQualityAssessment
        >>> imgs = randint(255, (2, 3, 224, 224)).float()
        >>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness"))
        >>> metric(imgs)
        {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])}

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = True
    plot_lower_bound = 0.0
    plot_upper_bound = 100.0

    anchors: Tensor
    probs_list: List[Tensor]
    feature_network: str = "model"

    def __init__(
        self,
        model_name_or_path: Literal[
            "clip_iqa",
            "openai/clip-vit-base-patch16",
            "openai/clip-vit-base-patch32",
            "openai/clip-vit-large-patch14-336",
            "openai/clip-vit-large-patch14",
        ] = "clip_iqa",
        data_range: float = 1.0,
        prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if not (isinstance(data_range, (int, float)) and data_range > 0):
            raise ValueError("Argument `data_range` should be a positive number.")
        self.data_range = data_range

        prompts_list, prompts_name = _clip_iqa_format_prompts(prompts)
        self.prompts_list = prompts_list
        self.prompts_name = prompts_name

        self.model, self.processor = _get_clip_iqa_model_and_processor(model_name_or_path)
        self.model_name_or_path = model_name_or_path

        with torch.inference_mode():
            anchors = _clip_iqa_get_anchor_vectors(
                model_name_or_path, self.model, self.processor, self.prompts_list, self.device
            )
        self.register_buffer("anchors", anchors)

        self.add_state("probs_list", [], dist_reduce_fx="cat")

    def update(self, images: Tensor) -> None:
        """Update metric state with new data."""
        with torch.inference_mode():
            img_features = _clip_iqa_update(
                self.model_name_or_path, images, self.model, self.processor, self.data_range, self.device
            )
            probs = _clip_iqa_compute(img_features, self.anchors, self.prompts_name, format_as_dict=False)
            if not isinstance(probs, Tensor):
                raise ValueError("Output probs should be a tensor")
            self.probs_list.append(probs)

    def compute(self) -> Union[Tensor, dict[str, Tensor]]:
        """Compute metric."""
        probs = dim_zero_cat(self.probs_list)
        if len(self.prompts_name) == 1:
            return probs.squeeze()
        return {p: probs[:, i] for i, p in enumerate(self.prompts_name)}

    def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
        """Plot a single or multiple values from the metric.

        Args:
            val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
                If no value is provided, will automatically call `metric.compute` and plot that result.
            ax: An matplotlib axis object. If provided will add plot to that axis

        Returns:
            Figure and Axes object

        Raises:
            ModuleNotFoundError:
                If `matplotlib` is not installed

        .. plot::
            :scale: 75

            >>> # Example plotting a single value
            >>> import torch
            >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment
            >>> metric = CLIPImageQualityAssessment()
            >>> metric.update(torch.rand(1, 3, 224, 224))
            >>> fig_, ax_ = metric.plot()

        .. plot::
            :scale: 75

            >>> # Example plotting multiple values
            >>> import torch
            >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment
            >>> metric = CLIPImageQualityAssessment()
            >>> values = [ ]
            >>> for _ in range(10):
            ...     values.append(metric(torch.rand(1, 3, 224, 224)))
            >>> fig_, ax_ = metric.plot(values)

        """
        return self._plot(val, ax)


if TYPE_CHECKING:
    f = CLIPImageQualityAssessment
    f(prompts=("colorfullness",))
    f(
        prompts=("quality", "brightness", "noisiness"),
    )
    f(
        prompts=("quality", "brightness", "noisiness", "colorfullness"),
    )
    f(prompts=(("Photo of a cat", "Photo of a dog"), "quality", ("Colorful photo", "Black and white photo")))