|
import logging |
|
import os |
|
from datetime import datetime |
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
import pytz |
|
|
|
import wandb |
|
from wandb.sdk.integration_utils.auto_logging import Response |
|
from wandb.sdk.lib.runid import generate_id |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SUPPORTED_PIPELINE_TASKS = [ |
|
"text-classification", |
|
"sentiment-analysis", |
|
"question-answering", |
|
"summarization", |
|
"translation", |
|
"text2text-generation", |
|
"text-generation", |
|
|
|
] |
|
|
|
PIPELINES_WITH_TOP_K = [ |
|
"text-classification", |
|
"sentiment-analysis", |
|
"question-answering", |
|
] |
|
|
|
|
|
class HuggingFacePipelineRequestResponseResolver: |
|
"""Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting. |
|
|
|
This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver) |
|
""" |
|
|
|
autolog_id = None |
|
|
|
def __call__( |
|
self, |
|
args: Sequence[Any], |
|
kwargs: Dict[str, Any], |
|
response: Response, |
|
start_time: float, |
|
time_elapsed: float, |
|
) -> Optional[Dict[str, Any]]: |
|
"""Main call method for this class. |
|
|
|
:param args: list of arguments |
|
:param kwargs: dictionary of keyword arguments |
|
:param response: the response from the request |
|
:param start_time: time when request started |
|
:param time_elapsed: time elapsed for the request |
|
:returns: packed data as a dictionary for logging to wandb, None if an exception occurred |
|
""" |
|
try: |
|
pipe, input_data = args[:2] |
|
task = pipe.task |
|
|
|
|
|
if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"): |
|
model = self._get_model(pipe) |
|
if model is None: |
|
return None |
|
model_alias = model.name_or_path |
|
timestamp = datetime.now(pytz.utc) |
|
|
|
input_data, response = self._transform_task_specific_data( |
|
task, input_data, response |
|
) |
|
formatted_data = self._format_data(task, input_data, response, kwargs) |
|
packed_data = self._create_table( |
|
formatted_data, model_alias, timestamp, time_elapsed |
|
) |
|
table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}") |
|
|
|
|
|
return { |
|
table_name: wandb.Table( |
|
columns=packed_data[0], data=packed_data[1:] |
|
) |
|
} |
|
|
|
logger.warning( |
|
f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task" |
|
) |
|
except Exception as e: |
|
logger.warning(e) |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def _get_model(self, pipe) -> Optional[Any]: |
|
"""Extracts model from the pipeline. |
|
|
|
:param pipe: the HuggingFace pipeline |
|
:returns: Model if available, None otherwise |
|
""" |
|
model = pipe.model |
|
try: |
|
return model.model |
|
except AttributeError: |
|
logger.info( |
|
"Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model." |
|
) |
|
return model |
|
|
|
@staticmethod |
|
def _transform_task_specific_data( |
|
task: str, input_data: Union[List[Any], Any], response: Union[List[Any], Any] |
|
) -> Tuple[Union[List[Any], Any], Union[List[Any], Any]]: |
|
"""Transform input and response data based on specific tasks. |
|
|
|
:param task: the task name |
|
:param input_data: the input data |
|
:param response: the response data |
|
:returns: tuple of transformed input_data and response |
|
""" |
|
if task == "question-answering": |
|
input_data = input_data if isinstance(input_data, list) else [input_data] |
|
input_data = [data.__dict__ for data in input_data] |
|
elif task == "conversational": |
|
|
|
|
|
input_data = input_data if isinstance(input_data, list) else [input_data] |
|
input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data] |
|
|
|
response = response if isinstance(response, list) else [response] |
|
response = [data.__dict__["generated_responses"][-1] for data in response] |
|
return input_data, response |
|
|
|
def _format_data( |
|
self, |
|
task: str, |
|
input_data: Union[List[Any], Any], |
|
response: Union[List[Any], Any], |
|
kwargs: Dict[str, Any], |
|
) -> List[Dict[str, Any]]: |
|
"""Formats input data, response, and kwargs into a list of dictionaries. |
|
|
|
:param task: the task name |
|
:param input_data: the input data |
|
:param response: the response data |
|
:param kwargs: dictionary of keyword arguments |
|
:returns: list of dictionaries containing formatted data |
|
""" |
|
input_data = input_data if isinstance(input_data, list) else [input_data] |
|
response = response if isinstance(response, list) else [response] |
|
|
|
formatted_data = [] |
|
for i_text, r_text in zip(input_data, response): |
|
|
|
|
|
|
|
if ( |
|
(isinstance(r_text, list)) |
|
and (len(r_text) == 1) |
|
and task not in PIPELINES_WITH_TOP_K |
|
): |
|
r_text = r_text[0] |
|
formatted_data.append( |
|
{"input": i_text, "response": r_text, "kwargs": kwargs} |
|
) |
|
return formatted_data |
|
|
|
def _create_table( |
|
self, |
|
formatted_data: List[Dict[str, Any]], |
|
model_alias: str, |
|
timestamp: float, |
|
time_elapsed: float, |
|
) -> List[List[Any]]: |
|
"""Creates a table from formatted data, model alias, timestamp, and elapsed time. |
|
|
|
:param formatted_data: list of dictionaries containing formatted data |
|
:param model_alias: alias of the model |
|
:param timestamp: timestamp of the data |
|
:param time_elapsed: time elapsed from the beginning |
|
:returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data |
|
""" |
|
header = [ |
|
"ID", |
|
"Model Alias", |
|
"Timestamp", |
|
"Elapsed Time", |
|
"Input", |
|
"Response", |
|
"Kwargs", |
|
] |
|
table = [header] |
|
autolog_id = generate_id(length=16) |
|
|
|
for data in formatted_data: |
|
row = [ |
|
autolog_id, |
|
model_alias, |
|
timestamp, |
|
time_elapsed, |
|
data["input"], |
|
data["response"], |
|
data["kwargs"], |
|
] |
|
table.append(row) |
|
|
|
self.autolog_id = autolog_id |
|
|
|
return table |
|
|
|
def get_latest_id(self): |
|
return self.autolog_id |
|
|