File size: 8,256 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import itertools
import logging
import re
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Tuple
import mlflow
from packaging.version import Version # type: ignore
import wandb
from wandb import Artifact
from .internals import internal
from .internals.util import Namespace, for_each
mlflow_version = Version(mlflow.__version__)
logger = logging.getLogger("import_logger")
class MlflowRun:
def __init__(self, run, mlflow_client):
self.run = run
self.mlflow_client: mlflow.MlflowClient = mlflow_client
def run_id(self) -> str:
return self.run.info.run_id
def entity(self) -> str:
return self.run.info.user_id
def project(self) -> str:
return "imported-from-mlflow"
def config(self) -> Dict[str, Any]:
conf = self.run.data.params
# Add tags here since mlflow supports very long tag names but we only support up to 64 chars
tags = {
k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
}
return {**conf, "imported_mlflow_tags": tags}
def summary(self) -> Dict[str, float]:
return self.run.data.metrics
def metrics(self) -> Iterable[Dict[str, float]]:
d: Dict[int, Dict[str, float]] = defaultdict(dict)
for k in self.run.data.metrics.keys():
metric = self.mlflow_client.get_metric_history(self.run.info.run_id, k)
for item in metric:
d[item.step][item.key] = item.value
for k, v in d.items():
yield {"_step": k, **v}
def run_group(self) -> Optional[str]:
# this is nesting? Parent at `run.info.tags.get("mlflow.parentRunId")`
return f"Experiment {self.run.info.experiment_id}"
def job_type(self) -> Optional[str]:
# Is this the right approach?
return f"User {self.run.info.user_id}"
def display_name(self) -> str:
if mlflow_version < Version("1.30.0"):
return self.run.data.tags["mlflow.runName"]
return self.run.info.run_name
def notes(self) -> Optional[str]:
return self.run.data.tags.get("mlflow.note.content")
def tags(self) -> Optional[List[str]]:
...
# W&B tags are different than mlflow tags.
# The full mlflow tags are added to config under key `imported_mlflow_tags` instead
def artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
if mlflow_version < Version("2.0.0"):
dir_path = self.mlflow_client.download_artifacts(
run_id=self.run.info.run_id,
path="",
)
else:
dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id)
# Since mlflow doesn't have extra metadata about the artifacts,
# we just lump them all together into a single wandb.Artifact
artifact_name = self._handle_incompatible_strings(self.display_name())
art = wandb.Artifact(artifact_name, "imported-artifacts")
art.add_dir(dir_path)
return [art]
def used_artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
... # pragma: no cover
def os_version(self) -> Optional[str]: ... # pragma: no cover
def python_version(self) -> Optional[str]: ... # pragma: no cover
def cuda_version(self) -> Optional[str]: ... # pragma: no cover
def program(self) -> Optional[str]: ... # pragma: no cover
def host(self) -> Optional[str]: ... # pragma: no cover
def username(self) -> Optional[str]: ... # pragma: no cover
def executable(self) -> Optional[str]: ... # pragma: no cover
def gpus_used(self) -> Optional[str]: ... # pragma: no cover
def cpus_used(self) -> Optional[int]: # can we get the model?
... # pragma: no cover
def memory_used(self) -> Optional[int]: ... # pragma: no cover
def runtime(self) -> Optional[int]:
end_time = (
self.run.info.end_time // 1000
if self.run.info.end_time is not None
else self.start_time()
)
return end_time - self.start_time()
def start_time(self) -> Optional[int]:
return self.run.info.start_time // 1000
def code_path(self) -> Optional[str]: ... # pragma: no cover
def cli_version(self) -> Optional[str]: ... # pragma: no cover
def files(self) -> Optional[Iterable[Tuple[str, str]]]: ... # pragma: no cover
def logs(self) -> Optional[Iterable[str]]: ... # pragma: no cover
@staticmethod
def _handle_incompatible_strings(s: str) -> str:
valid_chars = r"[^a-zA-Z0-9_\-\.]"
replacement = "__"
return re.sub(valid_chars, replacement, s)
class MlflowImporter:
def __init__(
self,
dst_base_url: str,
dst_api_key: str,
mlflow_tracking_uri: str,
mlflow_registry_uri: Optional[str] = None,
*,
custom_api_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.dst_base_url = dst_base_url
self.dst_api_key = dst_api_key
if custom_api_kwargs is None:
custom_api_kwargs = {"timeout": 600}
self.dst_api = wandb.Api(
api_key=dst_api_key,
overrides={"base_url": dst_base_url},
**custom_api_kwargs,
)
self.mlflow_tracking_uri = mlflow_tracking_uri
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
if mlflow_registry_uri:
mlflow.set_registry_uri(mlflow_registry_uri)
self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri)
def __repr__(self):
return f"<MlflowImporter src={self.mlflow_tracking_uri}>"
def collect_runs(self, *, limit: Optional[int] = None) -> Iterable[MlflowRun]:
if mlflow_version < Version("1.28.0"):
experiments = self.mlflow_client.list_experiments()
else:
experiments = self.mlflow_client.search_experiments()
def _runs():
for exp in experiments:
for run in self.mlflow_client.search_runs(exp.experiment_id):
yield MlflowRun(run, self.mlflow_client)
runs = itertools.islice(_runs(), limit)
yield from runs
def _import_run(
self,
run: MlflowRun,
*,
artifacts: bool = True,
namespace: Optional[Namespace] = None,
config: Optional[internal.SendManagerConfig] = None,
) -> None:
if namespace is None:
namespace = Namespace(run.entity(), run.project())
if config is None:
config = internal.SendManagerConfig(
metadata=True,
files=True,
media=True,
code=True,
history=True,
summary=True,
terminal_output=True,
)
settings_override = {
"api_key": self.dst_api_key,
"base_url": self.dst_base_url,
"resume": "true",
"resumed": True,
}
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
internal.send_run(
run,
overrides=namespace.send_manager_overrides,
settings_override=settings_override,
config=config,
)
# in mlflow, the artifacts come with the runs, so import them together
if artifacts:
arts = list(run.artifacts())
logger.debug(f"Importing history artifacts, {run=}")
internal.send_run(
run,
extra_arts=arts,
overrides=namespace.send_manager_overrides,
settings_override=settings_override,
config=internal.SendManagerConfig(log_artifacts=True),
)
def import_runs(
self,
runs: Iterable[MlflowRun],
*,
artifacts: bool = True,
namespace: Optional[Namespace] = None,
parallel: bool = True,
max_workers: Optional[int] = None,
) -> None:
def _import_run_wrapped(run):
self._import_run(run, namespace=namespace, artifacts=artifacts)
for_each(_import_run_wrapped, runs, parallel=parallel, max_workers=max_workers)
|