from typing import Any, Dict, Sequence from wandb.sdk.integration_utils.auto_logging import Response from .resolvers import ( SUPPORTED_MULTIMODAL_PIPELINES, DiffusersMultiModalPipelineResolver, ) class DiffusersPipelineResolver: """Resolver for `DiffusionPipeline` request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index), providing necessary data transformations, formatting, and logging. This is based off `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`. """ def __init__(self) -> None: self.wandb_table = None self.pipeline_call_count = 1 def __call__( self, args: Sequence[Any], kwargs: Dict[str, Any], response: Response, start_time: float, time_elapsed: float, ) -> Any: """Main call method for the `DiffusersPipelineResolver` class. Args: args: (Sequence[Any]) List of arguments. kwargs: (Dict[str, Any]) Dictionary of keyword arguments. response: (wandb.sdk.integration_utils.auto_logging.Response) The response from the request. start_time: (float) Time when request started. time_elapsed: (float) Time elapsed for the request. Returns: Packed data as a dictionary for logging to wandb, None if an exception occurred. """ pipeline_name = args[0].__class__.__name__ resolver = None if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES: resolver = DiffusersMultiModalPipelineResolver( pipeline_name, self.pipeline_call_count ) self.pipeline_call_count += 1 loggable_dict = resolver(args, kwargs, response, start_time, time_elapsed) return loggable_dict