File size: 13,813 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Tuple
import wandb
from wandb.sdk.integration_utils.auto_logging import Response
from wandb.sdk.lib.runid import generate_id
logger = logging.getLogger(__name__)
def subset_dict(
original_dict: Dict[str, Any], keys_subset: Sequence[str]
) -> Dict[str, Any]:
"""Create a subset of a dictionary using a subset of keys.
:param original_dict: The original dictionary.
:param keys_subset: The subset of keys to extract.
:return: A dictionary containing only the specified keys.
"""
return {key: original_dict[key] for key in keys_subset if key in original_dict}
def reorder_and_convert_dict_list_to_table(
data: List[Dict[str, Any]], order: List[str]
) -> Tuple[List[str], List[List[Any]]]:
"""Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
:param data: A list of dictionaries.
:param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
:return: A pair of column names and corresponding values.
"""
final_columns = []
keys_present = set()
# First, add all ordered keys to the final columns
for key in order:
if key not in keys_present:
final_columns.append(key)
keys_present.add(key)
# Then, add any keys present in the dictionaries but not in the order
for d in data:
for key in d:
if key not in keys_present:
final_columns.append(key)
keys_present.add(key)
# Then, construct the table of values
values = []
for d in data:
row = []
for key in final_columns:
row.append(d.get(key, None))
values.append(row)
return final_columns, values
def flatten_dict(
dictionary: Dict[str, Any], parent_key: str = "", sep: str = "-"
) -> Dict[str, Any]:
"""Flatten a nested dictionary, joining keys using a specified separator.
:param dictionary: The dictionary to flatten.
:param parent_key: The base key to prepend to each key.
:param sep: The separator to use when joining keys.
:return: A flattened dictionary.
"""
flattened_dict = {}
for key, value in dictionary.items():
new_key = f"{parent_key}{sep}{key}" if parent_key else key
if isinstance(value, dict):
flattened_dict.update(flatten_dict(value, new_key, sep=sep))
else:
flattened_dict[new_key] = value
return flattened_dict
def collect_common_keys(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
"""Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.
:param list_of_dicts: The list of dictionaries to inspect.
:return: A dictionary with each common key and its corresponding list of values.
"""
common_keys = set.intersection(*map(set, list_of_dicts))
common_dict = {key: [] for key in common_keys}
for d in list_of_dicts:
for key in common_keys:
common_dict[key].append(d[key])
return common_dict
class CohereRequestResponseResolver:
"""Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""
def __call__(
self,
args: Sequence[Any],
kwargs: Dict[str, Any],
response: Response,
start_time: float,
time_elapsed: float,
) -> Optional[Dict[str, Any]]:
"""Process the response from the Cohere API and convert it to a dictionary that can be logged.
:param args: The arguments of the original function.
:param kwargs: The keyword arguments of the original function.
:param response: The response from the Cohere API.
:param start_time: The start time of the request.
:param time_elapsed: The time elapsed for the request.
:return: A dictionary containing the parsed response and timing information.
"""
try:
# Each of the different endpoints map to one specific response type
# We want to 'type check' the response without directly importing the packages type
# It may make more sense to pass the invoked symbol from the AutologAPI instead
response_type = str(type(response)).split("'")[1].split(".")[-1]
# Initialize parsed_response to None to handle the case where the response type is unsupported
parsed_response = None
if response_type == "Generations":
parsed_response = self._resolve_generate_response(response)
# TODO: Remove hard-coded default model name
table_column_order = [
"start_time",
"query_id",
"model",
"prompt",
"text",
"token_likelihoods",
"likelihood",
"time_elapsed_(seconds)",
"end_time",
]
default_model = "command"
elif response_type == "Chat":
parsed_response = self._resolve_chat_response(response)
table_column_order = [
"start_time",
"query_id",
"model",
"conversation_id",
"response_id",
"query",
"text",
"prompt",
"preamble",
"chat_history",
"chatlog",
"time_elapsed_(seconds)",
"end_time",
]
default_model = "command"
elif response_type == "Classifications":
parsed_response = self._resolve_classify_response(response)
kwargs = self._resolve_classify_kwargs(kwargs)
table_column_order = [
"start_time",
"query_id",
"model",
"id",
"input",
"prediction",
"confidence",
"time_elapsed_(seconds)",
"end_time",
]
default_model = "embed-english-v2.0"
elif response_type == "SummarizeResponse":
parsed_response = self._resolve_summarize_response(response)
table_column_order = [
"start_time",
"query_id",
"model",
"response_id",
"text",
"additional_command",
"summary",
"time_elapsed_(seconds)",
"end_time",
"length",
"format",
]
default_model = "summarize-xlarge"
elif response_type == "Reranking":
parsed_response = self._resolve_rerank_response(response)
table_column_order = [
"start_time",
"query_id",
"model",
"id",
"query",
"top_n",
# This is a nested dict key that got flattened
"document-text",
"relevance_score",
"index",
"time_elapsed_(seconds)",
"end_time",
]
default_model = "rerank-english-v2.0"
else:
logger.info(f"Unsupported Cohere response object: {response}")
return self._resolve(
args,
kwargs,
parsed_response,
start_time,
time_elapsed,
response_type,
table_column_order,
default_model,
)
except Exception as e:
logger.warning(f"Failed to resolve request/response: {e}")
return None
# These helper functions process the response from different endpoints of the Cohere API.
# Since the response objects for different endpoints have different structures,
# we need different logic to process them.
def _resolve_generate_response(self, response: Response) -> List[Dict[str, Any]]:
return_list = []
for _response in response:
# Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
_response_dict = _response._visualize_helper()
try:
_response_dict["token_likelihoods"] = wandb.Html(
_response_dict["token_likelihoods"]
)
except (KeyError, ValueError):
pass
return_list.append(_response_dict)
return return_list
def _resolve_chat_response(self, response: Response) -> List[Dict[str, Any]]:
return [
subset_dict(
response.__dict__,
[
"response_id",
"generation_id",
"query",
"text",
"conversation_id",
"prompt",
"chatlog",
"preamble",
],
)
]
def _resolve_classify_response(self, response: Response) -> List[Dict[str, Any]]:
# The labels key is a dict returning the scores for the classification probability for each label provided
# We flatten this nested dict for ease of consumption in the wandb UI
return [flatten_dict(_response.__dict__) for _response in response]
def _resolve_classify_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
# Example texts look strange when rendered in Wandb UI as it is a list of text and label
# We extract each value into its own column
example_texts = []
example_labels = []
for example in kwargs["examples"]:
example_texts.append(example.text)
example_labels.append(example.label)
kwargs.pop("examples")
kwargs["example_texts"] = example_texts
kwargs["example_labels"] = example_labels
return kwargs
def _resolve_summarize_response(self, response: Response) -> List[Dict[str, Any]]:
return [{"response_id": response.id, "summary": response.summary}]
def _resolve_rerank_response(self, response: Response) -> List[Dict[str, Any]]:
# The documents key contains a dict containing the content of the document which is at least "text"
# We flatten this nested dict for ease of consumption in the wandb UI
flattened_response_dicts = [
flatten_dict(_response.__dict__) for _response in response
]
# ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
# As opposed to each row being one of the top_n responses
return_dict = collect_common_keys(flattened_response_dicts)
return_dict["id"] = response.id
return [return_dict]
def _resolve(
self,
args: Sequence[Any],
kwargs: Dict[str, Any],
parsed_response: List[Dict[str, Any]],
start_time: float,
time_elapsed: float,
response_type: str,
table_column_order: List[str],
default_model: str,
) -> Dict[str, Any]:
"""Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
:param args: The arguments passed to the API client.
:param kwargs: The keyword arguments passed to the API client.
:param parsed_response: The parsed response from the API.
:param start_time: The start time of the API request.
:param time_elapsed: The time elapsed during the API request.
:param response_type: The type of the API response.
:param table_column_order: The desired order of columns in the resulting table.
:param default_model: The default model to use if not specified in the response.
:return: A dictionary containing the formatted response.
"""
# Args[0] is the client object where we can grab specific metadata about the underlying API status
query_id = generate_id(length=16)
parsed_args = subset_dict(
args[0].__dict__,
["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
)
start_time_dt = datetime.fromtimestamp(start_time)
end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)
timings = {
"start_time": start_time_dt,
"end_time": end_time_dt,
"time_elapsed_(seconds)": time_elapsed,
}
packed_data = []
for _parsed_response in parsed_response:
_packed_dict = {
"query_id": query_id,
**kwargs,
**_parsed_response,
**timings,
**parsed_args,
}
if "model" not in _packed_dict:
_packed_dict["model"] = default_model
packed_data.append(_packed_dict)
columns, data = reorder_and_convert_dict_list_to_table(
packed_data, table_column_order
)
request_response_table = wandb.Table(data=data, columns=columns)
return {f"{response_type}": request_response_table}
|