File size: 8,376 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
# ruff: noqa: UP007 # Avoid using `X | Y` for union fields, as this can cause issues with pydantic < 2.6
from __future__ import annotations
from typing import Any, Collection, Final, Optional, Protocol, TypedDict
from pydantic import Field
from typing_extensions import Annotated, Self, Unpack
from wandb._pydantic import GQLBase, GQLId, computed_field, model_validator, to_json
from ._filters import MongoLikeFilter
from ._generated import (
CreateFilterTriggerInput,
QueueJobActionInput,
TriggeredActionConfig,
UpdateFilterTriggerInput,
)
from ._validators import to_input_action
from .actions import (
ActionType,
DoNothing,
InputAction,
SavedAction,
SendNotification,
SendWebhook,
)
from .automations import Automation, NewAutomation
from .events import EventType, InputEvent, RunMetricFilter, _WrappedSavedEventFilter
from .scopes import AutomationScope, ScopeType
EXCLUDED_INPUT_EVENTS: Final[Collection[EventType]] = frozenset(
{
EventType.UPDATE_ARTIFACT_ALIAS,
}
)
"""Event types that should not be assigned when creating/updating automations."""
EXCLUDED_INPUT_ACTIONS: Final[Collection[ActionType]] = frozenset(
{
ActionType.QUEUE_JOB,
}
)
"""Action types that should not be assigned when creating/updating automations."""
ALWAYS_SUPPORTED_EVENTS: Final[Collection[EventType]] = frozenset(
{
EventType.CREATE_ARTIFACT,
EventType.LINK_ARTIFACT,
EventType.ADD_ARTIFACT_ALIAS,
}
)
"""Event types that we can safely assume all contemporary server versions support."""
ALWAYS_SUPPORTED_ACTIONS: Final[Collection[ActionType]] = frozenset(
{
ActionType.NOTIFICATION,
ActionType.GENERIC_WEBHOOK,
}
)
"""Action types that we can safely assume all contemporary server versions support."""
class HasId(Protocol):
id: str
def extract_id(obj: HasId | str) -> str:
return obj.id if hasattr(obj, "id") else obj
# ---------------------------------------------------------------------------
ACTION_CONFIG_KEYS: dict[ActionType, str] = {
ActionType.NOTIFICATION: "notification_action_input",
ActionType.GENERIC_WEBHOOK: "generic_webhook_action_input",
ActionType.NO_OP: "no_op_action_input",
ActionType.QUEUE_JOB: "queue_job_action_input",
}
class InputActionConfig(TriggeredActionConfig):
"""A `TriggeredActionConfig` that prepares the action config for saving an automation."""
# NOTE: `QueueJobActionInput` for defining a Launch job is deprecated,
# so while it's allowed here to update EXISTING mutations, we don't
# currently expose it through the public API to create NEW automations.
queue_job_action_input: Optional[QueueJobActionInput] = None
notification_action_input: Optional[SendNotification] = None
generic_webhook_action_input: Optional[SendWebhook] = None
no_op_action_input: Optional[DoNothing] = None
def prepare_action_config_input(obj: SavedAction | InputAction) -> dict[str, Any]:
"""Prepare the `TriggeredActionConfig` input, nesting the action input inside the appropriate key.
This is necessary to conform to the schemas for:
- CreateFilterTriggerInput
- UpdateFilterTriggerInput
"""
# Delegate to inner validators to convert SavedAction -> InputAction types, if needed.
obj = to_input_action(obj)
return InputActionConfig(**{ACTION_CONFIG_KEYS[obj.action_type]: obj}).model_dump()
def prepare_event_filter_input(
obj: _WrappedSavedEventFilter | MongoLikeFilter | RunMetricFilter,
) -> str:
"""Prepare the `EventFilter` input, unnesting the filter if needed and serializing to JSON.
This is necessary to conform to the schemas for:
- CreateFilterTriggerInput
- UpdateFilterTriggerInput
"""
# Input event filters are nested one level deeper than saved event filters.
# Note that this is NOT the case for run/run metric filters.
#
# Yes, this is confusing. It's also necessary to conform to under-the-hood
# schemas and logic in the backend.
filter_to_serialize = (
obj.filter if isinstance(obj, _WrappedSavedEventFilter) else obj
)
return to_json(filter_to_serialize)
class WriteAutomationsKwargs(TypedDict, total=False):
"""Keyword arguments that can be passed to create or update an automation."""
name: str
description: str
enabled: bool
scope: AutomationScope
event: InputEvent
action: InputAction
class ValidatedCreateInput(GQLBase, extra="forbid", frozen=True):
"""Validated automation parameters, prepared for creating a new automation.
Note: Users should never need to instantiate this class directly.
"""
name: str
description: Optional[str] = None
enabled: bool = True
# ------------------------------------------------------------------------------
# Set on instantiation, but used to derive other fields and deliberately
# EXCLUDED from the final GraphQL request vars
event: Annotated[InputEvent, Field(exclude=True)]
action: Annotated[InputAction, Field(exclude=True)]
# ------------------------------------------------------------------------------
# Derived fields to match the input schemas
@computed_field
def scope_type(self) -> ScopeType:
return self.event.scope.scope_type
@computed_field
def scope_id(self) -> GQLId:
return self.event.scope.id
@computed_field
def triggering_event_type(self) -> EventType:
return self.event.event_type
@computed_field
def event_filter(self) -> str:
return prepare_event_filter_input(self.event.filter)
@computed_field
def triggered_action_type(self) -> ActionType:
return self.action.action_type
@computed_field
def triggered_action_config(self) -> dict[str, Any]:
return prepare_action_config_input(self.action)
# ------------------------------------------------------------------------------
# Custom validation
@model_validator(mode="after")
def _forbid_legacy_event_types(self) -> Self:
if (type_ := self.event.event_type) in EXCLUDED_INPUT_EVENTS:
raise ValueError(f"{type_!r} events cannot be assigned to automations.")
return self
@model_validator(mode="after")
def _forbid_legacy_action_types(self) -> Self:
if (type_ := self.action.action_type) in EXCLUDED_INPUT_ACTIONS:
raise ValueError(f"{type_!r} actions cannot be assigned to automations.")
return self
def prepare_to_create(
obj: NewAutomation | None = None,
/,
**kwargs: Unpack[WriteAutomationsKwargs],
) -> CreateFilterTriggerInput:
"""Prepares the payload to create an automation in a GraphQL request."""
# Validate all input variables, and prepare as expected by the GraphQL request.
# - if an object is provided, override its fields with any keyword args
# - otherwise, instantiate from the keyword args
# NOTE: `exclude_none=True` drops fields that are still `None`.
#
# This assumes that `None` is good enough for now as a sentinel
# "unset" value. If this proves insufficient, revisit in the future,
# as it should be reasonably easy to implement a custom sentinel
# type later on.
obj_dict = {**obj.model_dump(exclude_none=True), **kwargs} if obj else kwargs
validated = ValidatedCreateInput(**obj_dict)
return CreateFilterTriggerInput.model_validate(validated)
def prepare_to_update(
obj: Automation | None = None,
/,
**kwargs: Unpack[WriteAutomationsKwargs],
) -> UpdateFilterTriggerInput:
"""Prepares the payload to update an automation in a GraphQL request."""
# Validate all values:
# - if an object is provided, override its fields with any keyword args
# - otherwise, instantiate from the keyword args
v_obj = Automation(**{**dict(obj or {}), **kwargs})
return UpdateFilterTriggerInput(
id=v_obj.id,
name=v_obj.name,
description=v_obj.description,
enabled=v_obj.enabled,
scope_type=v_obj.scope.scope_type,
scope_id=v_obj.scope.id,
triggering_event_type=v_obj.event.event_type,
event_filter=prepare_event_filter_input(v_obj.event.filter),
triggered_action_type=v_obj.action.action_type,
triggered_action_config=prepare_action_config_input(v_obj.action),
)
|