File size: 1,748 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import annotations

import json
import os
import re
import warnings
from typing import Any

from . import files as sm_files


def is_using_sagemaker() -> bool:
    """Returns whether we're in a SageMaker environment."""
    return (
        os.path.exists(sm_files.SM_PARAM_CONFIG)  #
        or "SM_TRAINING_ENV" in os.environ
    )


def parse_sm_config() -> dict[str, Any]:
    """Parses SageMaker configuration.

    Returns:
        A dictionary of SageMaker config keys/values
        or an empty dict if not found.
        SM_TRAINING_ENV is a json string of the
        training environment variables set by SageMaker
        and is only available when running in SageMaker,
        but not in local mode.
        SM_TRAINING_ENV is set by the SageMaker container and
        contains arguments such as hyperparameters
        and arguments passed to the training job.
    """
    conf = {}

    if os.path.exists(sm_files.SM_PARAM_CONFIG):
        conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")

        # Hyperparameter searches quote configs...
        with open(sm_files.SM_PARAM_CONFIG) as fid:
            for key, val in json.load(fid).items():
                cast = val.strip('"')
                if re.match(r"^-?[\d]+$", cast):
                    cast = int(cast)
                elif re.match(r"^-?[.\d]+$", cast):
                    cast = float(cast)
                conf[key] = cast

    if env := os.environ.get("SM_TRAINING_ENV"):
        try:
            conf.update(json.loads(env))
        except json.JSONDecodeError:
            warnings.warn(
                "Failed to parse SM_TRAINING_ENV not valid JSON string",
                stacklevel=2,
            )

    return conf