import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import patch

import ffmpy
import numpy as np
import pytest
from gradio_client import media_data
from PIL import Image, ImageCms

from gradio import components, data_classes, processing_utils, utils
from gradio.route_utils import API_PREFIX


class TestTempFileManagement:
    def test_hash_file(self):
        h1 = processing_utils.hash_file("gradio/test_data/cheetah1.jpg")
        h2 = processing_utils.hash_file("gradio/test_data/cheetah1-copy.jpg")
        h3 = processing_utils.hash_file("gradio/test_data/cheetah2.jpg")
        assert h1 == h2
        assert h1 != h3

    def test_make_temp_copy_if_needed(self, gradio_temp_dir):
        f = processing_utils.save_file_to_cache(
            "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir
        )
        try:  # Delete if already exists from before this test
            os.remove(f)
        except OSError:
            pass

        f = processing_utils.save_file_to_cache(
            "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir
        )
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

        assert Path(f).name == "cheetah1.jpg"

        f = processing_utils.save_file_to_cache(
            "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir
        )
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

        f = processing_utils.save_file_to_cache(
            "gradio/test_data/cheetah1-copy.jpg", cache_dir=gradio_temp_dir
        )
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2
        assert Path(f).name == "cheetah1-copy.jpg"

    def test_save_b64_to_cache(self, gradio_temp_dir):
        base64_file_1 = media_data.BASE64_IMAGE
        base64_file_2 = media_data.BASE64_AUDIO["data"]

        f = processing_utils.save_base64_to_cache(
            base64_file_1, cache_dir=gradio_temp_dir
        )
        try:  # Delete if already exists from before this test
            os.remove(f)
        except OSError:
            pass

        f = processing_utils.save_base64_to_cache(
            base64_file_1, cache_dir=gradio_temp_dir
        )
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

        f = processing_utils.save_base64_to_cache(
            base64_file_1, cache_dir=gradio_temp_dir
        )
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

        f = processing_utils.save_base64_to_cache(
            base64_file_2, cache_dir=gradio_temp_dir
        )
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

    @pytest.mark.flaky
    def test_ssrf_protected_download(self, gradio_temp_dir):
        url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
        url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"

        f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir)
        try:  # Delete if already exists from before this test
            os.remove(f)
        except OSError:
            pass

        f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir)
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

        f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir)
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

        f = processing_utils.save_url_to_cache(url2, cache_dir=gradio_temp_dir)
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

    @pytest.mark.flaky
    def test_ssrf_protected_download_with_redirect(self, gradio_temp_dir):
        url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
        processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1


class TestImagePreprocessing:
    def test_encode_plot_to_base64(self):
        with utils.MatplotlibBackendMananger():
            import matplotlib.pyplot as plt

            plt.plot([1, 2, 3, 4])
            output_base64 = processing_utils.encode_plot_to_base64(plt)
        assert output_base64.startswith(
            ""
        )

    def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir):
        input_img = Image.open("gradio/test_data/test_image.png")
        input_img = input_img.convert("RGB")
        input_img.info = {"key1": "value1", "key2": "value2"}
        input_img.save(gradio_temp_dir / "test_test_image.png")

        file_obj = processing_utils.save_pil_to_cache(
            input_img, cache_dir=gradio_temp_dir, format="png"
        )
        output_img = Image.open(file_obj)

        assert output_img.info == input_img.info

    def test_save_pil_to_file_keeps_all_gif_frames(self, gradio_temp_dir):
        input_img = Image.open("gradio/test_data/rectangles.gif")
        file_obj = processing_utils.save_pil_to_cache(
            input_img, cache_dir=gradio_temp_dir, format="gif"
        )
        output_img = Image.open(file_obj)
        assert output_img.n_frames == input_img.n_frames == 3  # type: ignore

    def test_np_pil_encode_to_the_same(self, gradio_temp_dir):
        arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8)
        pil = Image.fromarray(arr)
        assert processing_utils.save_pil_to_cache(
            pil, cache_dir=gradio_temp_dir
        ) == processing_utils.save_img_array_to_cache(arr, cache_dir=gradio_temp_dir)

    def test_encode_pil_to_temp_file_metadata_color_profile(self, gradio_temp_dir):
        # Read image
        img = Image.open("gradio/test_data/test_image.png")
        img_metadata = Image.open("gradio/test_data/test_image.png")
        img_metadata.info = {"key1": "value1", "key2": "value2"}

        # Creating sRGB profile
        profile = ImageCms.createProfile("sRGB")
        profile2 = ImageCms.ImageCmsProfile(profile)
        img.save(
            gradio_temp_dir / "img_color_profile.png", icc_profile=profile2.tobytes()
        )
        img_cp1 = Image.open(str(gradio_temp_dir / "img_color_profile.png"))

        # Creating XYZ profile
        profile = ImageCms.createProfile("XYZ")
        profile2 = ImageCms.ImageCmsProfile(profile)
        img.save(
            gradio_temp_dir / "img_color_profile_2.png", icc_profile=profile2.tobytes()
        )
        img_cp2 = Image.open(str(gradio_temp_dir / "img_color_profile_2.png"))

        img_path = processing_utils.save_pil_to_cache(
            img, cache_dir=gradio_temp_dir, format="png"
        )
        img_metadata_path = processing_utils.save_pil_to_cache(
            img_metadata, cache_dir=gradio_temp_dir, format="png"
        )
        img_cp1_path = processing_utils.save_pil_to_cache(
            img_cp1, cache_dir=gradio_temp_dir, format="png"
        )
        img_cp2_path = processing_utils.save_pil_to_cache(
            img_cp2, cache_dir=gradio_temp_dir, format="png"
        )

        assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4

    def test_resize_and_crop(self):
        img = Image.open("gradio/test_data/test_image.png")
        new_img = processing_utils.resize_and_crop(img, (20, 20))
        assert new_img.size == (20, 20)
        with pytest.raises(ValueError):
            processing_utils.resize_and_crop(
                **{"img": img, "size": (20, 20), "crop_type": "test"}
            )


class TestAudioPreprocessing:
    def test_audio_from_file(self):
        audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
        assert audio[0] == 22050
        assert isinstance(audio[1], np.ndarray)

    def test_audio_to_file(self):
        audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
        processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
        assert os.path.exists("test_audio_to_file")
        os.remove("test_audio_to_file")

    def test_convert_to_16_bit_wav(self):
        # Generate a random audio sample and set the amplitude
        audio = np.random.randint(-100, 100, size=(100), dtype="int16")
        audio[0] = -32767
        audio[1] = 32766

        audio_ = audio.astype("float64")
        audio_ = processing_utils.convert_to_16_bit_wav(audio_)
        assert np.allclose(audio, audio_)
        assert audio_.dtype == "int16"

        audio_ = audio.astype("float32")
        audio_ = processing_utils.convert_to_16_bit_wav(audio_)
        assert np.allclose(audio, audio_)
        assert audio_.dtype == "int16"

        audio_ = processing_utils.convert_to_16_bit_wav(audio)
        assert np.allclose(audio, audio_)
        assert audio_.dtype == "int16"


class TestOutputPreprocessing:
    float_dtype_list = [
        float,
        float,
        np.double,
        np.single,
        np.float32,
        np.float64,
        "float32",
        "float64",
    ]

    def test_float_conversion_dtype(self):
        """Test any conversion from a float dtype to an other."""

        x = np.array([-1, 1])
        # Test all combinations of dtypes conversions
        dtype_combin = np.array(
            np.meshgrid(
                TestOutputPreprocessing.float_dtype_list,
                TestOutputPreprocessing.float_dtype_list,
            )
        ).T.reshape(-1, 2)

        for dtype_in, dtype_out in dtype_combin:
            x = x.astype(dtype_in)
            y = processing_utils._convert(x, dtype_out)
            assert y.dtype == np.dtype(dtype_out)

    def test_subclass_conversion(self):
        """Check subclass conversion behavior"""
        x = np.array([-1, 1])
        for dtype in TestOutputPreprocessing.float_dtype_list:
            x = x.astype(dtype)
            y = processing_utils._convert(x, np.floating)
            assert y.dtype == x.dtype


class TestVideoProcessing:
    def test_video_has_playable_codecs(self, test_file_dir):
        assert processing_utils.video_is_playable(
            str(test_file_dir / "video_sample.mp4")
        )
        assert processing_utils.video_is_playable(
            str(test_file_dir / "video_sample.ogg")
        )
        assert processing_utils.video_is_playable(
            str(test_file_dir / "video_sample.webm")
        )
        assert not processing_utils.video_is_playable(
            str(test_file_dir / "bad_video_sample.mp4")
        )

    def raise_ffmpy_runtime_exception(*args, **kwargs):
        raise ffmpy.FFRuntimeError("", "", "", "")  # type: ignore

    @pytest.mark.parametrize(
        "exception_to_raise", [raise_ffmpy_runtime_exception, KeyError(), IndexError()]
    )
    def test_video_has_playable_codecs_catches_exceptions(
        self, exception_to_raise, test_file_dir
    ):
        with (
            patch("ffmpy.FFprobe.run", side_effect=exception_to_raise),
            tempfile.NamedTemporaryFile(
                suffix="out.avi", delete=False
            ) as tmp_not_playable_vid,
        ):
            shutil.copy(
                str(test_file_dir / "bad_video_sample.mp4"),
                tmp_not_playable_vid.name,
            )
            assert processing_utils.video_is_playable(tmp_not_playable_vid.name)

    def test_convert_video_to_playable_mp4(self, test_file_dir):
        with tempfile.NamedTemporaryFile(
            suffix="out.avi", delete=False
        ) as tmp_not_playable_vid:
            shutil.copy(
                str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
            )
            with patch("os.remove", wraps=os.remove) as mock_remove:
                playable_vid = processing_utils.convert_video_to_playable_mp4(
                    tmp_not_playable_vid.name
                )
            # check tempfile got deleted
            assert not Path(mock_remove.call_args[0][0]).exists()
            assert processing_utils.video_is_playable(playable_vid)

    @patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
    def test_video_conversion_returns_original_video_if_fails(
        self, mock_run, test_file_dir
    ):
        with tempfile.NamedTemporaryFile(
            suffix="out.avi", delete=False
        ) as tmp_not_playable_vid:
            shutil.copy(
                str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
            )
            playable_vid = processing_utils.convert_video_to_playable_mp4(
                tmp_not_playable_vid.name
            )
            # If the conversion succeeded it'd be .mp4
            assert Path(playable_vid).suffix == ".avi"


def test_add_root_url():
    data = {
        "file": {
            "path": "path",
            "url": f"{API_PREFIX}/file=path",
            "meta": {"_type": "gradio.FileData"},
        },
        "file2": {
            "path": "path2",
            "url": "https://www.gradio.app",
            "meta": {"_type": "gradio.FileData"},
        },
    }
    root_url = "http://localhost:7860"
    expected = {
        "file": {
            "path": "path",
            "url": f"{root_url}{API_PREFIX}/file=path",
            "meta": {"_type": "gradio.FileData"},
        },
        "file2": {
            "path": "path2",
            "url": "https://www.gradio.app",
            "meta": {"_type": "gradio.FileData"},
        },
    }
    assert processing_utils.add_root_url(data, root_url, None) == expected
    new_root_url = "https://1234.gradio.live"
    new_expected = {
        "file": {
            "path": "path",
            "url": f"{new_root_url}{API_PREFIX}/file=path",
            "meta": {"_type": "gradio.FileData"},
        },
        "file2": {
            "path": "path2",
            "url": "https://www.gradio.app",
            "meta": {"_type": "gradio.FileData"},
        },
    }
    assert (
        processing_utils.add_root_url(expected, new_root_url, root_url) == new_expected
    )


def test_hash_url_encodes_url():
    assert processing_utils.hash_url(
        "https://www.gradio.app/image 1.jpg"
    ) == processing_utils.hash_bytes(b"https://www.gradio.app/image 1.jpg")


@pytest.mark.asyncio
async def test_json_data_not_moved_to_cache():
    data = data_classes.JsonData(
        root={
            "file": {
                "path": "path",
                "url": f"{API_PREFIX}/file=path",
                "meta": {"_type": "gradio.FileData"},
            }
        }
    )
    assert (
        processing_utils.move_files_to_cache(data, components.Number(), False) == data
    )
    assert processing_utils.move_files_to_cache(data, components.Number(), True) == data
    assert (
        await processing_utils.async_move_files_to_cache(
            data, components.Number(), False
        )
        == data
    )
    assert (
        await processing_utils.async_move_files_to_cache(
            data, components.Number(), True
        )
        == data
    )


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "url",
    [
        "https://localhost",
        "http://127.0.0.1/file/a/b/c",
        "http://[::1]",
        "https://192.168.0.1",
        "http://10.0.0.1?q=a",
        "http://192.168.1.250.nip.io",
    ],
)
async def test_local_urls_fail(url):
    with pytest.raises(ValueError, match="failed validation"):
        await processing_utils.async_validate_url(url)


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "url",
    [
        "https://google.com",
        "https://8.8.8.8/",
        "http://93.184.215.14.nip.io/",
        "https://huggingface.co/datasets/dylanebert/3dgs/resolve/main/luigi/luigi.ply",
    ],
)
async def test_public_urls_pass(url):
    await processing_utils.async_validate_url(url)


def test_public_request_pass():
    tempdir = tempfile.TemporaryDirectory()
    file = processing_utils.ssrf_protected_download(
        "https://en.wikipedia.org/static/images/icons/wikipedia.png", tempdir.name
    )
    assert os.path.exists(file)
    assert os.path.getsize(file) == 13444


@pytest.mark.asyncio
async def test_async_public_request_pass():
    tempdir = tempfile.TemporaryDirectory()
    file = await processing_utils.async_ssrf_protected_download(
        "https://en.wikipedia.org/static/images/icons/wikipedia.png", tempdir.name
    )
    assert os.path.exists(file)
    assert os.path.getsize(file) == 13444


def test_private_request_fail():
    with pytest.raises(ValueError, match="failed validation"):
        tempdir = tempfile.TemporaryDirectory()
        processing_utils.ssrf_protected_download(
            "http://192.168.1.250.nip.io/image.png", tempdir.name
        )


@pytest.mark.asyncio
async def test_async_private_request_fail():
    with pytest.raises(ValueError, match="failed validation"):
        tempdir = tempfile.TemporaryDirectory()
        await processing_utils.async_ssrf_protected_download(
            "http://192.168.1.250.nip.io/image.png", tempdir.name
        )