jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from __future__ import annotations
import json
import os
import re
import warnings
from typing import Any
from . import files as sm_files
def is_using_sagemaker() -> bool:
"""Returns whether we're in a SageMaker environment."""
return (
os.path.exists(sm_files.SM_PARAM_CONFIG) #
or "SM_TRAINING_ENV" in os.environ
)
def parse_sm_config() -> dict[str, Any]:
"""Parses SageMaker configuration.
Returns:
A dictionary of SageMaker config keys/values
or an empty dict if not found.
SM_TRAINING_ENV is a json string of the
training environment variables set by SageMaker
and is only available when running in SageMaker,
but not in local mode.
SM_TRAINING_ENV is set by the SageMaker container and
contains arguments such as hyperparameters
and arguments passed to the training job.
"""
conf = {}
if os.path.exists(sm_files.SM_PARAM_CONFIG):
conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")
# Hyperparameter searches quote configs...
with open(sm_files.SM_PARAM_CONFIG) as fid:
for key, val in json.load(fid).items():
cast = val.strip('"')
if re.match(r"^-?[\d]+$", cast):
cast = int(cast)
elif re.match(r"^-?[.\d]+$", cast):
cast = float(cast)
conf[key] = cast
if env := os.environ.get("SM_TRAINING_ENV"):
try:
conf.update(json.loads(env))
except json.JSONDecodeError:
warnings.warn(
"Failed to parse SM_TRAINING_ENV not valid JSON string",
stacklevel=2,
)
return conf