jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import pytz
import wandb
from wandb.sdk.integration_utils.auto_logging import Response
from wandb.sdk.lib.runid import generate_id
logger = logging.getLogger(__name__)
SUPPORTED_PIPELINE_TASKS = [
"text-classification",
"sentiment-analysis",
"question-answering",
"summarization",
"translation",
"text2text-generation",
"text-generation",
# "conversational",
]
PIPELINES_WITH_TOP_K = [
"text-classification",
"sentiment-analysis",
"question-answering",
]
class HuggingFacePipelineRequestResponseResolver:
"""Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting.
This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver)
"""
autolog_id = None
def __call__(
self,
args: Sequence[Any],
kwargs: Dict[str, Any],
response: Response,
start_time: float,
time_elapsed: float,
) -> Optional[Dict[str, Any]]:
"""Main call method for this class.
:param args: list of arguments
:param kwargs: dictionary of keyword arguments
:param response: the response from the request
:param start_time: time when request started
:param time_elapsed: time elapsed for the request
:returns: packed data as a dictionary for logging to wandb, None if an exception occurred
"""
try:
pipe, input_data = args[:2]
task = pipe.task
# Translation tasks are in the form of `translation_x_to_y`
if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"):
model = self._get_model(pipe)
if model is None:
return None
model_alias = model.name_or_path
timestamp = datetime.now(pytz.utc)
input_data, response = self._transform_task_specific_data(
task, input_data, response
)
formatted_data = self._format_data(task, input_data, response, kwargs)
packed_data = self._create_table(
formatted_data, model_alias, timestamp, time_elapsed
)
table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}")
# TODO: Let users decide the name in a way that does not use an environment variable
return {
table_name: wandb.Table(
columns=packed_data[0], data=packed_data[1:]
)
}
logger.warning(
f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task"
)
except Exception as e:
logger.warning(e)
return None
# TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
# from transformers.modeling_utils import PreTrainedModel
# We do not want this dependency explicitly in our codebase so we make a very general
# assumption about the structure of the pipeline which may have unintended consequences
def _get_model(self, pipe) -> Optional[Any]:
"""Extracts model from the pipeline.
:param pipe: the HuggingFace pipeline
:returns: Model if available, None otherwise
"""
model = pipe.model
try:
return model.model
except AttributeError:
logger.info(
"Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model."
)
return model
@staticmethod
def _transform_task_specific_data(
task: str, input_data: Union[List[Any], Any], response: Union[List[Any], Any]
) -> Tuple[Union[List[Any], Any], Union[List[Any], Any]]:
"""Transform input and response data based on specific tasks.
:param task: the task name
:param input_data: the input data
:param response: the response data
:returns: tuple of transformed input_data and response
"""
if task == "question-answering":
input_data = input_data if isinstance(input_data, list) else [input_data]
input_data = [data.__dict__ for data in input_data]
elif task == "conversational":
# We only grab the latest input/output pair from the conversation
# Logging the whole conversation renders strangely.
input_data = input_data if isinstance(input_data, list) else [input_data]
input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data]
response = response if isinstance(response, list) else [response]
response = [data.__dict__["generated_responses"][-1] for data in response]
return input_data, response
def _format_data(
self,
task: str,
input_data: Union[List[Any], Any],
response: Union[List[Any], Any],
kwargs: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Formats input data, response, and kwargs into a list of dictionaries.
:param task: the task name
:param input_data: the input data
:param response: the response data
:param kwargs: dictionary of keyword arguments
:returns: list of dictionaries containing formatted data
"""
input_data = input_data if isinstance(input_data, list) else [input_data]
response = response if isinstance(response, list) else [response]
formatted_data = []
for i_text, r_text in zip(input_data, response):
# Unpack single element responses for better rendering in wandb UI when it is a task without top_k
# top_k = 1 would unpack the response into a single element while top_k > 1 would be a list
# this would cause the UI to not properly concatenate the tables of the same task by omitting the elements past the first
if (
(isinstance(r_text, list))
and (len(r_text) == 1)
and task not in PIPELINES_WITH_TOP_K
):
r_text = r_text[0]
formatted_data.append(
{"input": i_text, "response": r_text, "kwargs": kwargs}
)
return formatted_data
def _create_table(
self,
formatted_data: List[Dict[str, Any]],
model_alias: str,
timestamp: float,
time_elapsed: float,
) -> List[List[Any]]:
"""Creates a table from formatted data, model alias, timestamp, and elapsed time.
:param formatted_data: list of dictionaries containing formatted data
:param model_alias: alias of the model
:param timestamp: timestamp of the data
:param time_elapsed: time elapsed from the beginning
:returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data
"""
header = [
"ID",
"Model Alias",
"Timestamp",
"Elapsed Time",
"Input",
"Response",
"Kwargs",
]
table = [header]
autolog_id = generate_id(length=16)
for data in formatted_data:
row = [
autolog_id,
model_alias,
timestamp,
time_elapsed,
data["input"],
data["response"],
data["kwargs"],
]
table.append(row)
self.autolog_id = autolog_id
return table
def get_latest_id(self):
return self.autolog_id