|
import datetime |
|
import io |
|
import logging |
|
from dataclasses import asdict, dataclass |
|
from typing import Any, Dict, List, Optional, Sequence |
|
|
|
import wandb |
|
from wandb.sdk.data_types import trace_tree |
|
from wandb.sdk.integration_utils.auto_logging import Response |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class UsageMetrics: |
|
elapsed_time: float = None |
|
prompt_tokens: int = None |
|
completion_tokens: int = None |
|
total_tokens: int = None |
|
|
|
|
|
@dataclass |
|
class Metrics: |
|
usage: UsageMetrics = None |
|
stats: wandb.Table = None |
|
trace: trace_tree.WBTraceTree = None |
|
|
|
|
|
usage_metric_keys = {f"usage/{k}" for k in asdict(UsageMetrics())} |
|
|
|
|
|
class OpenAIRequestResponseResolver: |
|
def __init__(self): |
|
self.define_metrics_called = False |
|
|
|
def __call__( |
|
self, |
|
args: Sequence[Any], |
|
kwargs: Dict[str, Any], |
|
response: Response, |
|
start_time: float, |
|
time_elapsed: float, |
|
) -> Optional[Dict[str, Any]]: |
|
request = kwargs |
|
|
|
if not self.define_metrics_called: |
|
|
|
for key in usage_metric_keys: |
|
wandb.define_metric(key, step_metric="_timestamp") |
|
self.define_metrics_called = True |
|
|
|
try: |
|
if response.get("object") == "edit": |
|
return self._resolve_edit(request, response, time_elapsed) |
|
elif response.get("object") == "text_completion": |
|
return self._resolve_completion(request, response, time_elapsed) |
|
elif response.get("object") == "chat.completion": |
|
return self._resolve_chat_completion(request, response, time_elapsed) |
|
else: |
|
|
|
logger.info( |
|
f"Unsupported OpenAI response object: {response.get('object')}" |
|
) |
|
except Exception as e: |
|
logger.warning(f"Failed to resolve request/response: {e}") |
|
return None |
|
|
|
@staticmethod |
|
def results_to_trace_tree( |
|
request: Dict[str, Any], |
|
response: Response, |
|
results: List[trace_tree.Result], |
|
time_elapsed: float, |
|
) -> trace_tree.WBTraceTree: |
|
"""Converts the request, response, and results into a trace tree. |
|
|
|
params: |
|
request: The request dictionary |
|
response: The response object |
|
results: A list of results object |
|
time_elapsed: The time elapsed in seconds |
|
returns: |
|
A wandb trace tree object. |
|
""" |
|
start_time_ms = int(round(response["created"] * 1000)) |
|
end_time_ms = start_time_ms + int(round(time_elapsed * 1000)) |
|
span = trace_tree.Span( |
|
name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}", |
|
attributes=dict(response), |
|
start_time_ms=start_time_ms, |
|
end_time_ms=end_time_ms, |
|
span_kind=trace_tree.SpanKind.LLM, |
|
results=results, |
|
) |
|
model_obj = {"request": request, "response": response, "_kind": "openai"} |
|
return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj) |
|
|
|
def _resolve_edit( |
|
self, |
|
request: Dict[str, Any], |
|
response: Response, |
|
time_elapsed: float, |
|
) -> Dict[str, Any]: |
|
"""Resolves the request and response objects for `openai.Edit`.""" |
|
request_str = ( |
|
f"\n\n**Instruction**: {request['instruction']}\n\n" |
|
f"**Input**: {request['input']}\n" |
|
) |
|
choices = [ |
|
f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"] |
|
] |
|
|
|
return self._resolve_metrics( |
|
request=request, |
|
response=response, |
|
request_str=request_str, |
|
choices=choices, |
|
time_elapsed=time_elapsed, |
|
) |
|
|
|
def _resolve_completion( |
|
self, |
|
request: Dict[str, Any], |
|
response: Response, |
|
time_elapsed: float, |
|
) -> Dict[str, Any]: |
|
"""Resolves the request and response objects for `openai.Completion`.""" |
|
request_str = f"\n\n**Prompt**: {request['prompt']}\n" |
|
choices = [ |
|
f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"] |
|
] |
|
|
|
return self._resolve_metrics( |
|
request=request, |
|
response=response, |
|
request_str=request_str, |
|
choices=choices, |
|
time_elapsed=time_elapsed, |
|
) |
|
|
|
def _resolve_chat_completion( |
|
self, |
|
request: Dict[str, Any], |
|
response: Response, |
|
time_elapsed: float, |
|
) -> Dict[str, Any]: |
|
"""Resolves the request and response objects for `openai.Completion`.""" |
|
prompt = io.StringIO() |
|
for message in request["messages"]: |
|
prompt.write(f"\n\n**{message['role']}**: {message['content']}\n") |
|
request_str = prompt.getvalue() |
|
|
|
choices = [ |
|
f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n" |
|
for choice in response["choices"] |
|
] |
|
|
|
return self._resolve_metrics( |
|
request=request, |
|
response=response, |
|
request_str=request_str, |
|
choices=choices, |
|
time_elapsed=time_elapsed, |
|
) |
|
|
|
def _resolve_metrics( |
|
self, |
|
request: Dict[str, Any], |
|
response: Response, |
|
request_str: str, |
|
choices: List[str], |
|
time_elapsed: float, |
|
) -> Dict[str, Any]: |
|
"""Resolves the request and response objects for `openai.Completion`.""" |
|
results = [ |
|
trace_tree.Result( |
|
inputs={"request": request_str}, |
|
outputs={"response": choice}, |
|
) |
|
for choice in choices |
|
] |
|
metrics = self._get_metrics_to_log(request, response, results, time_elapsed) |
|
return self._convert_metrics_to_dict(metrics) |
|
|
|
@staticmethod |
|
def _get_usage_metrics(response: Response, time_elapsed: float) -> UsageMetrics: |
|
"""Gets the usage stats from the response object.""" |
|
if response.get("usage"): |
|
usage_stats = UsageMetrics(**response["usage"]) |
|
else: |
|
usage_stats = UsageMetrics() |
|
usage_stats.elapsed_time = time_elapsed |
|
return usage_stats |
|
|
|
def _get_metrics_to_log( |
|
self, |
|
request: Dict[str, Any], |
|
response: Response, |
|
results: List[Any], |
|
time_elapsed: float, |
|
) -> Metrics: |
|
model = response.get("model") or request.get("model") |
|
usage_metrics = self._get_usage_metrics(response, time_elapsed) |
|
|
|
usage = [] |
|
for result in results: |
|
row = { |
|
"request": result.inputs["request"], |
|
"response": result.outputs["response"], |
|
"model": model, |
|
"start_time": datetime.datetime.fromtimestamp(response["created"]), |
|
"end_time": datetime.datetime.fromtimestamp( |
|
response["created"] + time_elapsed |
|
), |
|
"request_id": response.get("id", None), |
|
"api_type": response.get("api_type", "openai"), |
|
"session_id": wandb.run.id, |
|
} |
|
row.update(asdict(usage_metrics)) |
|
usage.append(row) |
|
usage_table = wandb.Table( |
|
columns=list(usage[0].keys()), |
|
data=[(item.values()) for item in usage], |
|
) |
|
|
|
trace = self.results_to_trace_tree(request, response, results, time_elapsed) |
|
|
|
metrics = Metrics(stats=usage_table, trace=trace, usage=usage_metrics) |
|
return metrics |
|
|
|
@staticmethod |
|
def _convert_metrics_to_dict(metrics: Metrics) -> Dict[str, Any]: |
|
"""Converts metrics to a dict.""" |
|
metrics_dict = { |
|
"stats": metrics.stats, |
|
"trace": metrics.trace, |
|
} |
|
usage_stats = {f"usage/{k}": v for k, v in asdict(metrics.usage).items()} |
|
metrics_dict.update(usage_stats) |
|
return metrics_dict |
|
|