Skip to content

Commit

Permalink
feature: support of the intelligent stopping in the tuner (aws#3652)
Browse files Browse the repository at this point in the history
Co-authored-by: Anton Repushko <[email protected]>
  • Loading branch information
repushko and Anton Repushko committed Feb 13, 2023
1 parent 4046ae4 commit 5cf3e44
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 4 deletions.
20 changes: 20 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,9 @@ def tune( # noqa: C901
stop_condition,
tags,
warm_start_config,
max_runtime_in_seconds=None,
strategy_config=None,
completion_criteria_config=None,
enable_network_isolation=False,
image_uri=None,
algorithm_arn=None,
Expand Down Expand Up @@ -2256,6 +2258,10 @@ def tune( # noqa: C901
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
warm_start_config (dict): Configuration defining the type of warm start and
other required configurations.
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
that a training job launched by a hyperparameter tuning job can run.
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A
configuration for the completion criteria.
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
attempted. If set to 'Auto', early stopping of some training jobs may happen, but
Expand Down Expand Up @@ -2311,12 +2317,14 @@ def tune( # noqa: C901
strategy=strategy,
max_jobs=max_jobs,
max_parallel_jobs=max_parallel_jobs,
max_runtime_in_seconds=max_runtime_in_seconds,
objective_type=objective_type,
objective_metric_name=objective_metric_name,
parameter_ranges=parameter_ranges,
early_stopping_type=early_stopping_type,
random_seed=random_seed,
strategy_config=strategy_config,
completion_criteria_config=completion_criteria_config,
),
"TrainingJobDefinition": self._map_training_config(
static_hyperparameters=static_hyperparameters,
Expand Down Expand Up @@ -2470,12 +2478,14 @@ def _map_tuning_config(
strategy,
max_jobs,
max_parallel_jobs,
max_runtime_in_seconds=None,
early_stopping_type="Off",
objective_type=None,
objective_metric_name=None,
parameter_ranges=None,
random_seed=None,
strategy_config=None,
completion_criteria_config=None,
):
"""Construct tuning job configuration dictionary.
Expand All @@ -2484,6 +2494,8 @@ def _map_tuning_config(
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
tuning job.
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
that a training job launched by a hyperparameter tuning job can run.
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
attempted. If set to 'Auto', early stopping of some training jobs may happen,
Expand All @@ -2498,6 +2510,8 @@ def _map_tuning_config(
produce more consistent configurations for the same tuning job.
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
strategy.
completion_criteria_config (dict): A configuration
for the completion criteria.
Returns:
A dictionary of tuning job configuration. For format details, please refer to
Expand All @@ -2514,6 +2528,9 @@ def _map_tuning_config(
"TrainingJobEarlyStoppingType": early_stopping_type,
}

if max_runtime_in_seconds is not None:
tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds

if random_seed is not None:
tuning_config["RandomSeed"] = random_seed

Expand All @@ -2526,6 +2543,9 @@ def _map_tuning_config(

if strategy_config is not None:
tuning_config["StrategyConfig"] = strategy_config

if completion_criteria_config is not None:
tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config
return tuning_config

@classmethod
Expand Down
Loading

0 comments on commit 5cf3e44

Please sign in to comment.