jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from typing import Any, Dict, Sequence
from wandb.sdk.integration_utils.auto_logging import Response
from .resolvers import (
SUPPORTED_MULTIMODAL_PIPELINES,
DiffusersMultiModalPipelineResolver,
)
class DiffusersPipelineResolver:
"""Resolver for `DiffusionPipeline` request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index), providing necessary data transformations, formatting, and logging.
This is based off `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`.
"""
def __init__(self) -> None:
self.wandb_table = None
self.pipeline_call_count = 1
def __call__(
self,
args: Sequence[Any],
kwargs: Dict[str, Any],
response: Response,
start_time: float,
time_elapsed: float,
) -> Any:
"""Main call method for the `DiffusersPipelineResolver` class.
Args:
args: (Sequence[Any]) List of arguments.
kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
the request.
start_time: (float) Time when request started.
time_elapsed: (float) Time elapsed for the request.
Returns:
Packed data as a dictionary for logging to wandb, None if an exception occurred.
"""
pipeline_name = args[0].__class__.__name__
resolver = None
if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES:
resolver = DiffusersMultiModalPipelineResolver(
pipeline_name, self.pipeline_call_count
)
self.pipeline_call_count += 1
loggable_dict = resolver(args, kwargs, response, start_time, time_elapsed)
return loggable_dict