Skip to content

Commit

Permalink
fix: Create workflow module scoped sagemaker_session to resolve test …
Browse files Browse the repository at this point in the history
…race condition (aws#4518)
  • Loading branch information
qidewenwhen committed Mar 21, 2024
1 parent 345381e commit b82fb74
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 240 deletions.
4 changes: 2 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(
sagemaker_runtime_client=None,
sagemaker_featurestore_runtime_client=None,
default_bucket=None,
settings=SessionSettings(),
settings=None,
sagemaker_metrics_client=None,
sagemaker_config: dict = None,
default_bucket_prefix: str = None,
Expand Down Expand Up @@ -260,7 +260,7 @@ def __init__(
self.resource_group_tagging_client = None
self._config = None
self.lambda_client = None
self.settings = settings
self.settings = settings if settings else SessionSettings()

self._initialize(
boto_session=boto_session,
Expand Down
65 changes: 65 additions & 0 deletions tests/integ/sagemaker/workflow/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

import pytest
from botocore.config import Config

from tests.integ import DATA_DIR
from sagemaker import Session, get_execution_role

CUSTOM_S3_OBJECT_KEY_PREFIX = "session-default-prefix"


# Create a sagemaker_session in workflow scope to prevent race condition
# with other tests. Some other tests may change the session `settings`.
@pytest.fixture(scope="module")
def sagemaker_session_for_pipeline(
sagemaker_client_config,
boto_session,
):
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
if sagemaker_client_config
else None
)

return Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_config={},
default_bucket_prefix=CUSTOM_S3_OBJECT_KEY_PREFIX,
)


@pytest.fixture(scope="module")
def smclient(sagemaker_session):
return sagemaker_session.boto_session.client("sagemaker")


@pytest.fixture(scope="module")
def role(sagemaker_session_for_pipeline):
return get_execution_role(sagemaker_session_for_pipeline)


@pytest.fixture(scope="module")
def region_name(sagemaker_session_for_pipeline):
return sagemaker_session_for_pipeline.boto_session.region_name


@pytest.fixture(scope="module")
def script_dir():
return os.path.join(DATA_DIR, "sklearn_processing")
37 changes: 8 additions & 29 deletions tests/integ/sagemaker/workflow/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution
from sagemaker.processing import ProcessingInput
from sagemaker.session import get_execution_role
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
from sagemaker.workflow.execution_variables import ExecutionVariables
Expand All @@ -33,33 +32,13 @@
from tests.integ import DATA_DIR


@pytest.fixture(scope="module")
def region_name(sagemaker_session):
return sagemaker_session.boto_session.region_name


@pytest.fixture(scope="module")
def role(sagemaker_session):
return get_execution_role(sagemaker_session)


@pytest.fixture(scope="module")
def script_dir():
return os.path.join(DATA_DIR, "sklearn_processing")


@pytest.fixture
def pipeline_name():
return f"my-pipeline-{int(time.time() * 10**7)}"


@pytest.fixture
def smclient(sagemaker_session):
return sagemaker_session.boto_session.client("sagemaker")


@pytest.fixture
def athena_dataset_definition(sagemaker_session):
def athena_dataset_definition(sagemaker_session_for_pipeline):
return DatasetDefinition(
local_path="/opt/ml/processing/input/add",
data_distribution_type="FullyReplicated",
Expand All @@ -69,15 +48,15 @@ def athena_dataset_definition(sagemaker_session):
database="default",
work_group="workgroup",
query_string='SELECT * FROM "default"."s3_test_table_$STAGE_$REGIONUNDERSCORED";',
output_s3_uri=f"s3://{sagemaker_session.default_bucket()}/add",
output_s3_uri=f"s3://{sagemaker_session_for_pipeline.default_bucket()}/add",
output_format="JSON",
output_compression="GZIP",
),
)


def test_pipeline_execution_with_default_experiment_config(
sagemaker_session,
sagemaker_session_for_pipeline,
smclient,
role,
sklearn_latest_version,
Expand All @@ -99,7 +78,7 @@ def test_pipeline_execution_with_default_experiment_config(
instance_type=cpu_instance_type,
instance_count=instance_count,
command=["python3"],
sagemaker_session=sagemaker_session,
sagemaker_session=sagemaker_session_for_pipeline,
base_job_name="test-sklearn",
)

Expand All @@ -113,7 +92,7 @@ def test_pipeline_execution_with_default_experiment_config(
name=pipeline_name,
parameters=[instance_count],
steps=[step_sklearn],
sagemaker_session=sagemaker_session,
sagemaker_session=sagemaker_session_for_pipeline,
)

try:
Expand Down Expand Up @@ -142,7 +121,7 @@ def test_pipeline_execution_with_default_experiment_config(


def test_pipeline_execution_with_custom_experiment_config(
sagemaker_session,
sagemaker_session_for_pipeline,
smclient,
role,
sklearn_latest_version,
Expand All @@ -164,7 +143,7 @@ def test_pipeline_execution_with_custom_experiment_config(
instance_type=cpu_instance_type,
instance_count=instance_count,
command=["python3"],
sagemaker_session=sagemaker_session,
sagemaker_session=sagemaker_session_for_pipeline,
base_job_name="test-sklearn",
)

Expand All @@ -185,7 +164,7 @@ def test_pipeline_execution_with_custom_experiment_config(
trial_name=Join(on="-", values=["my-trial", ExecutionVariables.PIPELINE_EXECUTION_ID]),
),
steps=[step_sklearn],
sagemaker_session=sagemaker_session,
sagemaker_session=sagemaker_session_for_pipeline,
)

try:
Expand Down
Loading

0 comments on commit b82fb74

Please sign in to comment.