File size: 4,289 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Sequence

from wandb.sdk.artifacts._validators import (
    REGISTRY_PREFIX,
    validate_artifact_types_list,
)

if TYPE_CHECKING:
    from wandb_gql import Client

from wandb_gql import gql


class _Visibility(str, Enum):
    # names are what users see/pass into Python methods
    # values are what's expected by backend API
    organization = "PRIVATE"
    restricted = "RESTRICTED"

    @classmethod
    def _missing_(cls, value: object) -> Any:
        return next(
            (e for e in cls if e.name == value),
            None,
        )


def _format_gql_artifact_types_input(
    artifact_types: Optional[List[str]] = None,
):
    """Format the artifact types for the GQL input.

    Args:
        artifact_types: The artifact types to add to the registry.

    Returns:
        The artifact types for the GQL input.
    """
    if artifact_types is None:
        return []
    new_types = validate_artifact_types_list(artifact_types)
    return [{"name": type} for type in new_types]


def _gql_to_registry_visibility(
    visibility: str,
) -> Literal["organization", "restricted"]:
    """Convert the GQL visibility to the registry visibility.

    Args:
        visibility: The GQL visibility.

    Returns:
        The registry visibility.
    """
    try:
        return _Visibility(visibility).name
    except ValueError:
        raise ValueError(f"Invalid visibility: {visibility!r} from backend")


def _registry_visibility_to_gql(
    visibility: Literal["organization", "restricted"],
) -> str:
    """Convert the registry visibility to the GQL visibility."""
    try:
        return _Visibility[visibility].value
    except KeyError:
        raise ValueError(
            f"Invalid visibility: {visibility!r}. "
            f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}"
        )


def _ensure_registry_prefix_on_names(query, in_name=False):
    """Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.

    - in_name: True if we are under a "name" key (or propagating from one).

    EX: {"name": "model"} -> {"name": "wandb-registry-model"}
    """
    if isinstance((txt := query), str):
        if in_name:
            return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
        return txt
    if isinstance((dct := query), Mapping):
        new_dict = {}
        for key, obj in dct.items():
            if key == "name":
                new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
            elif key == "$regex":
                # For regex operator, we skip transformation of its value.
                new_dict[key] = obj
            else:
                # For any other key, propagate the in_name and skip_transform flags as-is.
                new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
        return new_dict
    if isinstance((objs := query), Sequence):
        return list(
            map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
        )
    return query


@lru_cache(maxsize=10)
def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
    """Fetch the org entity from the organization.

    Args:
        client (Client): Graphql client.
        organization (str): The organization to fetch the org entity for.
    """
    query = gql(
        """
        query FetchOrgEntityFromOrganization($organization: String!) {
            organization(name: $organization) {
                    orgEntity {
                        name
                    }
                }
            }
        """
    )
    try:
        response = client.execute(query, variable_values={"organization": organization})
    except Exception as e:
        raise ValueError(
            f"Error fetching org entity for organization: {organization}"
        ) from e

    if (
        not (org := response["organization"])
        or not (org_entity := org["orgEntity"])
        or not (org_name := org_entity["name"])
    ):
        raise ValueError(f"Organization entity for {organization} not found.")

    return org_name