Skip to content

Commit

Permalink
feature: Support for environment variables in the HPO (#3614)
Browse files Browse the repository at this point in the history
Co-authored-by: Anton Repushko <[email protected]>
  • Loading branch information
repushko and Anton Repushko committed Jan 27, 2023
1 parent 75d1f2c commit cee70dc
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,7 @@ def tune( # noqa: C901
checkpoint_s3_uri=None,
checkpoint_local_path=None,
random_seed=None,
environment=None,
):
"""Create an Amazon SageMaker hyperparameter tuning job.
Expand Down Expand Up @@ -2283,6 +2284,8 @@ def tune( # noqa: C901
random_seed (int): An initial value used to initialize a pseudo-random number generator.
Setting a random seed will make the hyperparameter tuning search strategies to
produce more consistent configurations for the same tuning job. (default: ``None``).
environment (dict[str, str]) : Environment variables to be set for
use during training jobs (default: ``None``)
"""

tune_request = {
Expand Down Expand Up @@ -2315,6 +2318,7 @@ def tune( # noqa: C901
use_spot_instances=use_spot_instances,
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
environment=environment,
),
}

Expand Down Expand Up @@ -2558,6 +2562,7 @@ def _map_training_config(
checkpoint_s3_uri=None,
checkpoint_local_path=None,
max_retry_attempts=None,
environment=None,
):
"""Construct a dictionary of training job configuration from the arguments.
Expand Down Expand Up @@ -2612,6 +2617,8 @@ def _map_training_config(
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
be one of three types: Continuous, Integer, or Categorical.
max_retry_attempts (int): The number of times to retry the job.
environment (dict[str, str]) : Environment variables to be set for
use during training jobs (default: ``None``)
Returns:
A dictionary of training job configuration. For format details, please refer to
Expand Down Expand Up @@ -2674,6 +2681,9 @@ def _map_training_config(

if max_retry_attempts is not None:
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}

if environment is not None:
training_job_definition["Environment"] = environment
return training_job_definition

def stop_tuning_job(self, name):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,9 @@ def _prepare_training_config(
if estimator.max_retry_attempts is not None:
training_config["max_retry_attempts"] = estimator.max_retry_attempts

if estimator.environment is not None:
training_config["environment"] = estimator.environment

return training_config

def stop(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def test_train_pack_to_request(sagemaker_session):
"OutputDataConfig": SAMPLE_OUTPUT,
"ResourceConfig": RESOURCE_CONFIG,
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
"Environment": ENV_INPUT,
},
}

Expand Down Expand Up @@ -957,6 +958,7 @@ def test_train_pack_to_request(sagemaker_session):
"OutputDataConfig": SAMPLE_OUTPUT,
"ResourceConfig": RESOURCE_CONFIG,
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
"Environment": ENV_INPUT,
},
{
"DefinitionName": "estimator_2",
Expand All @@ -973,6 +975,7 @@ def test_train_pack_to_request(sagemaker_session):
"OutputDataConfig": SAMPLE_OUTPUT,
"ResourceConfig": RESOURCE_CONFIG,
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
"Environment": ENV_INPUT,
},
],
}
Expand Down Expand Up @@ -1032,6 +1035,7 @@ def assert_create_tuning_job_request(**kwrags):
warm_start_config=WarmStartConfig(
warm_start_type=WarmStartTypes(warm_start_type), parents=parents
).to_input_req(),
environment=ENV_INPUT,
)


Expand Down Expand Up @@ -1122,6 +1126,7 @@ def assert_create_tuning_job_request(**kwrags):
"output_config": SAMPLE_OUTPUT,
"resource_config": RESOURCE_CONFIG,
"stop_condition": SAMPLE_STOPPING_CONDITION,
"environment": ENV_INPUT,
},
tags=None,
warm_start_config=None,
Expand Down Expand Up @@ -1163,6 +1168,7 @@ def assert_create_tuning_job_request(**kwrags):
"objective_type": "Maximize",
"objective_metric_name": "val-score",
"parameter_ranges": SAMPLE_PARAM_RANGES,
"environment": ENV_INPUT,
},
{
"static_hyperparameters": STATIC_HPs_2,
Expand All @@ -1178,6 +1184,7 @@ def assert_create_tuning_job_request(**kwrags):
"objective_type": "Maximize",
"objective_metric_name": "value-score",
"parameter_ranges": SAMPLE_PARAM_RANGES_2,
"environment": ENV_INPUT,
},
],
tags=None,
Expand Down Expand Up @@ -1218,6 +1225,7 @@ def assert_create_tuning_job_request(**kwrags):
stop_condition=SAMPLE_STOPPING_CONDITION,
tags=None,
warm_start_config=None,
environment=ENV_INPUT,
)


Expand Down Expand Up @@ -1259,6 +1267,7 @@ def assert_create_tuning_job_request(**kwrags):
tags=None,
warm_start_config=None,
strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
environment=ENV_INPUT,
)


Expand Down
8 changes: 8 additions & 0 deletions tests/unit/tuner_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
ESTIMATOR_NAME = "estimator_name"
ESTIMATOR_NAME_TWO = "estimator_name_two"

ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

SAGEMAKER_SESSION = Mock()

ESTIMATOR = Estimator(
Expand All @@ -78,13 +80,15 @@
INSTANCE_TYPE,
output_path="s3://bucket/prefix",
sagemaker_session=SAGEMAKER_SESSION,
environment=ENV_INPUT,
)
ESTIMATOR_TWO = PCA(
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
NUM_COMPONENTS,
sagemaker_session=SAGEMAKER_SESSION,
environment=ENV_INPUT,
)

WARM_START_CONFIG = WarmStartConfig(
Expand Down Expand Up @@ -148,6 +152,7 @@
],
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
"Environment": ENV_INPUT,
},
"TrainingJobCounters": {
"ClientError": 0,
Expand Down Expand Up @@ -212,6 +217,7 @@
],
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
"Environment": ENV_INPUT,
},
{
"DefinitionName": ESTIMATOR_NAME_TWO,
Expand Down Expand Up @@ -252,6 +258,7 @@
],
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
"Environment": ENV_INPUT,
},
],
"TrainingJobCounters": {
Expand Down Expand Up @@ -291,6 +298,7 @@
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA},
"Environment": ENV_INPUT,
}

ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"}
Expand Down

0 comments on commit cee70dc

Please sign in to comment.