|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Any, Collection, Final, Optional, Protocol, TypedDict |
|
|
|
from pydantic import Field |
|
from typing_extensions import Annotated, Self, Unpack |
|
|
|
from wandb._pydantic import GQLBase, GQLId, computed_field, model_validator, to_json |
|
|
|
from ._filters import MongoLikeFilter |
|
from ._generated import ( |
|
CreateFilterTriggerInput, |
|
QueueJobActionInput, |
|
TriggeredActionConfig, |
|
UpdateFilterTriggerInput, |
|
) |
|
from ._validators import to_input_action |
|
from .actions import ( |
|
ActionType, |
|
DoNothing, |
|
InputAction, |
|
SavedAction, |
|
SendNotification, |
|
SendWebhook, |
|
) |
|
from .automations import Automation, NewAutomation |
|
from .events import EventType, InputEvent, RunMetricFilter, _WrappedSavedEventFilter |
|
from .scopes import AutomationScope, ScopeType |
|
|
|
EXCLUDED_INPUT_EVENTS: Final[Collection[EventType]] = frozenset( |
|
{ |
|
EventType.UPDATE_ARTIFACT_ALIAS, |
|
} |
|
) |
|
"""Event types that should not be assigned when creating/updating automations.""" |
|
|
|
EXCLUDED_INPUT_ACTIONS: Final[Collection[ActionType]] = frozenset( |
|
{ |
|
ActionType.QUEUE_JOB, |
|
} |
|
) |
|
"""Action types that should not be assigned when creating/updating automations.""" |
|
|
|
ALWAYS_SUPPORTED_EVENTS: Final[Collection[EventType]] = frozenset( |
|
{ |
|
EventType.CREATE_ARTIFACT, |
|
EventType.LINK_ARTIFACT, |
|
EventType.ADD_ARTIFACT_ALIAS, |
|
} |
|
) |
|
"""Event types that we can safely assume all contemporary server versions support.""" |
|
|
|
ALWAYS_SUPPORTED_ACTIONS: Final[Collection[ActionType]] = frozenset( |
|
{ |
|
ActionType.NOTIFICATION, |
|
ActionType.GENERIC_WEBHOOK, |
|
} |
|
) |
|
"""Action types that we can safely assume all contemporary server versions support.""" |
|
|
|
|
|
class HasId(Protocol): |
|
id: str |
|
|
|
|
|
def extract_id(obj: HasId | str) -> str: |
|
return obj.id if hasattr(obj, "id") else obj |
|
|
|
|
|
|
|
ACTION_CONFIG_KEYS: dict[ActionType, str] = { |
|
ActionType.NOTIFICATION: "notification_action_input", |
|
ActionType.GENERIC_WEBHOOK: "generic_webhook_action_input", |
|
ActionType.NO_OP: "no_op_action_input", |
|
ActionType.QUEUE_JOB: "queue_job_action_input", |
|
} |
|
|
|
|
|
class InputActionConfig(TriggeredActionConfig): |
|
"""A `TriggeredActionConfig` that prepares the action config for saving an automation.""" |
|
|
|
|
|
|
|
|
|
queue_job_action_input: Optional[QueueJobActionInput] = None |
|
|
|
notification_action_input: Optional[SendNotification] = None |
|
generic_webhook_action_input: Optional[SendWebhook] = None |
|
no_op_action_input: Optional[DoNothing] = None |
|
|
|
|
|
def prepare_action_config_input(obj: SavedAction | InputAction) -> dict[str, Any]: |
|
"""Prepare the `TriggeredActionConfig` input, nesting the action input inside the appropriate key. |
|
|
|
This is necessary to conform to the schemas for: |
|
- CreateFilterTriggerInput |
|
- UpdateFilterTriggerInput |
|
""" |
|
|
|
obj = to_input_action(obj) |
|
return InputActionConfig(**{ACTION_CONFIG_KEYS[obj.action_type]: obj}).model_dump() |
|
|
|
|
|
def prepare_event_filter_input( |
|
obj: _WrappedSavedEventFilter | MongoLikeFilter | RunMetricFilter, |
|
) -> str: |
|
"""Prepare the `EventFilter` input, unnesting the filter if needed and serializing to JSON. |
|
|
|
This is necessary to conform to the schemas for: |
|
- CreateFilterTriggerInput |
|
- UpdateFilterTriggerInput |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
filter_to_serialize = ( |
|
obj.filter if isinstance(obj, _WrappedSavedEventFilter) else obj |
|
) |
|
return to_json(filter_to_serialize) |
|
|
|
|
|
class WriteAutomationsKwargs(TypedDict, total=False): |
|
"""Keyword arguments that can be passed to create or update an automation.""" |
|
|
|
name: str |
|
description: str |
|
enabled: bool |
|
scope: AutomationScope |
|
event: InputEvent |
|
action: InputAction |
|
|
|
|
|
class ValidatedCreateInput(GQLBase, extra="forbid", frozen=True): |
|
"""Validated automation parameters, prepared for creating a new automation. |
|
|
|
Note: Users should never need to instantiate this class directly. |
|
""" |
|
|
|
name: str |
|
description: Optional[str] = None |
|
enabled: bool = True |
|
|
|
|
|
|
|
|
|
event: Annotated[InputEvent, Field(exclude=True)] |
|
action: Annotated[InputAction, Field(exclude=True)] |
|
|
|
|
|
|
|
@computed_field |
|
def scope_type(self) -> ScopeType: |
|
return self.event.scope.scope_type |
|
|
|
@computed_field |
|
def scope_id(self) -> GQLId: |
|
return self.event.scope.id |
|
|
|
@computed_field |
|
def triggering_event_type(self) -> EventType: |
|
return self.event.event_type |
|
|
|
@computed_field |
|
def event_filter(self) -> str: |
|
return prepare_event_filter_input(self.event.filter) |
|
|
|
@computed_field |
|
def triggered_action_type(self) -> ActionType: |
|
return self.action.action_type |
|
|
|
@computed_field |
|
def triggered_action_config(self) -> dict[str, Any]: |
|
return prepare_action_config_input(self.action) |
|
|
|
|
|
|
|
@model_validator(mode="after") |
|
def _forbid_legacy_event_types(self) -> Self: |
|
if (type_ := self.event.event_type) in EXCLUDED_INPUT_EVENTS: |
|
raise ValueError(f"{type_!r} events cannot be assigned to automations.") |
|
return self |
|
|
|
@model_validator(mode="after") |
|
def _forbid_legacy_action_types(self) -> Self: |
|
if (type_ := self.action.action_type) in EXCLUDED_INPUT_ACTIONS: |
|
raise ValueError(f"{type_!r} actions cannot be assigned to automations.") |
|
return self |
|
|
|
|
|
def prepare_to_create( |
|
obj: NewAutomation | None = None, |
|
/, |
|
**kwargs: Unpack[WriteAutomationsKwargs], |
|
) -> CreateFilterTriggerInput: |
|
"""Prepares the payload to create an automation in a GraphQL request.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obj_dict = {**obj.model_dump(exclude_none=True), **kwargs} if obj else kwargs |
|
validated = ValidatedCreateInput(**obj_dict) |
|
return CreateFilterTriggerInput.model_validate(validated) |
|
|
|
|
|
def prepare_to_update( |
|
obj: Automation | None = None, |
|
/, |
|
**kwargs: Unpack[WriteAutomationsKwargs], |
|
) -> UpdateFilterTriggerInput: |
|
"""Prepares the payload to update an automation in a GraphQL request.""" |
|
|
|
|
|
|
|
v_obj = Automation(**{**dict(obj or {}), **kwargs}) |
|
|
|
return UpdateFilterTriggerInput( |
|
id=v_obj.id, |
|
name=v_obj.name, |
|
description=v_obj.description, |
|
enabled=v_obj.enabled, |
|
scope_type=v_obj.scope.scope_type, |
|
scope_id=v_obj.scope.id, |
|
triggering_event_type=v_obj.event.event_type, |
|
event_filter=prepare_event_filter_input(v_obj.event.filter), |
|
triggered_action_type=v_obj.action.action_type, |
|
triggered_action_config=prepare_action_config_input(v_obj.action), |
|
) |
|
|