diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index 34289beb30..90a272c4d7 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -26,94 +26,147 @@ class RedshiftDatasetDefinition(ApiObject): """DatasetDefinition for Redshift. With this input, SQL queries will be executed using Redshift to generate datasets to S3. - - Parameters: - cluster_id (str): The Redshift cluster Identifier. - database (str): The name of the Redshift database used in Redshift query execution. - db_user (str): The database user name used in Redshift query execution. - query_string (str): The SQL query statements to be executed. - cluster_role_arn (str): The IAM role attached to your Redshift cluster that - Amazon SageMaker uses to generate datasets. - output_s3_uri (str): The location in Amazon S3 where the Redshift query - results are stored. - kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon - SageMaker uses to encrypt data from a Redshift execution. - output_format (str): The data storage format for Redshift query results. - Valid options are "PARQUET", "CSV" - output_compression (str): The compression used for Redshift query results. - Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2" """ - cluster_id = None - database = None - db_user = None - query_string = None - cluster_role_arn = None - output_s3_uri = None - kms_key_id = None - output_format = None - output_compression = None + def __init__( + self, + cluster_id=None, + database=None, + db_user=None, + query_string=None, + cluster_role_arn=None, + output_s3_uri=None, + kms_key_id=None, + output_format=None, + output_compression=None, + ): + """Initialize RedshiftDatasetDefinition. + + Args: + cluster_id (str, default=None): The Redshift cluster Identifier. + database (str, default=None): + The name of the Redshift database used in Redshift query execution. + db_user (str, default=None): The database user name used in Redshift query execution. + query_string (str, default=None): The SQL query statements to be executed. + cluster_role_arn (str, default=None): The IAM role attached to your Redshift cluster + that Amazon SageMaker uses to generate datasets. + output_s3_uri (str, default=None): The location in Amazon S3 where the Redshift query + results are stored. + kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon + SageMaker uses to encrypt data from a Redshift execution. + output_format (str, default=None): The data storage format for Redshift query results. + Valid options are "PARQUET", "CSV" + output_compression (str, default=None): The compression used for Redshift query results. + Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2" + """ + super(RedshiftDatasetDefinition, self).__init__( + cluster_id=cluster_id, + database=database, + db_user=db_user, + query_string=query_string, + cluster_role_arn=cluster_role_arn, + output_s3_uri=output_s3_uri, + kms_key_id=kms_key_id, + output_format=output_format, + output_compression=output_compression, + ) class AthenaDatasetDefinition(ApiObject): """DatasetDefinition for Athena. With this input, SQL queries will be executed using Athena to generate datasets to S3. - - Parameters: - catalog (str): The name of the data catalog used in Athena query execution. - database (str): The name of the database used in the Athena query execution. - query_string (str): The SQL query statements, to be executed. - output_s3_uri (str): The location in Amazon S3 where Athena query results are stored. - work_group (str): The name of the workgroup in which the Athena query is being started. - kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon - SageMaker uses to encrypt data generated from an Athena query execution. - output_format (str): The data storage format for Athena query results. - Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE" - output_compression (str): The compression used for Athena query results. - Valid options are "GZIP", "SNAPPY", "ZLIB" """ - catalog = None - database = None - query_string = None - output_s3_uri = None - work_group = None - kms_key_id = None - output_format = None - output_compression = None + def __init__( + self, + catalog=None, + database=None, + query_string=None, + output_s3_uri=None, + work_group=None, + kms_key_id=None, + output_format=None, + output_compression=None, + ): + """Initialize AthenaDatasetDefinition. + + Args: + catalog (str, default=None): The name of the data catalog used in Athena query + execution. + database (str, default=None): The name of the database used in the Athena query + execution. + query_string (str, default=None): The SQL query statements, to be executed. + output_s3_uri (str, default=None): + The location in Amazon S3 where Athena query results are stored. + work_group (str, default=None): + The name of the workgroup in which the Athena query is being started. + kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon + SageMaker uses to encrypt data generated from an Athena query execution. + output_format (str, default=None): The data storage format for Athena query results. + Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE" + output_compression (str, default=None): The compression used for Athena query results. + Valid options are "GZIP", "SNAPPY", "ZLIB" + """ + super(AthenaDatasetDefinition, self).__init__( + catalog=catalog, + database=database, + query_string=query_string, + output_s3_uri=output_s3_uri, + work_group=work_group, + kms_key_id=kms_key_id, + output_format=output_format, + output_compression=output_compression, + ) class DatasetDefinition(ApiObject): - """DatasetDefinition input. - - Parameters: - data_distribution_type (str): Whether the generated dataset is FullyReplicated or - ShardedByS3Key (default). - input_mode (str): Whether to use File or Pipe input mode. In File (default) mode, Amazon - SageMaker copies the data from the input source onto the local Amazon Elastic Block - Store (Amazon EBS) volumes before starting your training algorithm. This is the most - commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the - source directly to your algorithm without using the EBS volume. - local_path (str): The local path where you want Amazon SageMaker to download the Dataset - Definition inputs to run a processing job. LocalPath is an absolute path to the input - data. This is a required parameter when `AppManaged` is False (default). - redshift_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`): - Configuration for Redshift Dataset Definition input. - athena_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`): - Configuration for Athena Dataset Definition input. - """ + """DatasetDefinition input.""" _custom_boto_types = { "redshift_dataset_definition": (RedshiftDatasetDefinition, True), "athena_dataset_definition": (AthenaDatasetDefinition, True), } - data_distribution_type = "ShardedByS3Key" - input_mode = "File" - local_path = None - redshift_dataset_definition = None - athena_dataset_definition = None + def __init__( + self, + data_distribution_type="ShardedByS3Key", + input_mode="File", + local_path=None, + redshift_dataset_definition=None, + athena_dataset_definition=None, + ): + """Initialize DatasetDefinition. + + Parameters: + data_distribution_type (str, default="ShardedByS3Key"): + Whether the generated dataset is FullyReplicated or ShardedByS3Key (default). + input_mode (str, default="File"): + Whether to use File or Pipe input mode. In File (default) mode, Amazon + SageMaker copies the data from the input source onto the local Amazon Elastic Block + Store (Amazon EBS) volumes before starting your training algorithm. This is the most + commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the + source directly to your algorithm without using the EBS volume. + local_path (str, default=None): + The local path where you want Amazon SageMaker to download the Dataset + Definition inputs to run a processing job. LocalPath is an absolute path to the + input data. This is a required parameter when `AppManaged` is False (default). + redshift_dataset_definition + (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`, + default=None): + Configuration for Redshift Dataset Definition input. + athena_dataset_definition + (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`, + default=None): + Configuration for Athena Dataset Definition input. + """ + super(DatasetDefinition, self).__init__( + data_distribution_type=data_distribution_type, + input_mode=input_mode, + local_path=local_path, + redshift_dataset_definition=redshift_dataset_definition, + athena_dataset_definition=athena_dataset_definition, + ) class S3Input(ApiObject): @@ -124,20 +177,35 @@ class S3Input(ApiObject): Note: Strong consistency is not guaranteed if S3Prefix is provided here. S3 list operations are not strongly consistent. Use ManifestFile if strong consistency is required. - - Parameters: - s3_uri (str): the path to a specific S3 object or a S3 prefix - local_path (str): the path to a local directory. If not provided, skips data download - by SageMaker platform. - s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix". - s3_input_mode (str): Valid options are "Pipe" or "File". - s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key". - s3_compression_type (str): Valid options are "None" or "Gzip". """ - s3_uri = None - local_path = None - s3_data_type = "S3Prefix" - s3_input_mode = "File" - s3_data_distribution_type = "FullyReplicated" - s3_compression_type = None + def __init__( + self, + s3_uri=None, + local_path=None, + s3_data_type="S3Prefix", + s3_input_mode="File", + s3_data_distribution_type="FullyReplicated", + s3_compression_type=None, + ): + """Initialize S3Input. + + Parameters: + s3_uri (str, default=None): the path to a specific S3 object or a S3 prefix + local_path (str, default=None): + the path to a local directory. If not provided, skips data download + by SageMaker platform. + s3_data_type (str, default="S3Prefix"): Valid options are "ManifestFile" or "S3Prefix". + s3_input_mode (str, default="File"): Valid options are "Pipe" or "File". + s3_data_distribution_type (str, default="FullyReplicated"): + Valid options are "FullyReplicated" or "ShardedByS3Key". + s3_compression_type (str, default=None): Valid options are "None" or "Gzip". + """ + super(S3Input, self).__init__( + s3_uri=s3_uri, + local_path=local_path, + s3_data_type=s3_data_type, + s3_input_mode=s3_input_mode, + s3_data_distribution_type=s3_data_distribution_type, + s3_compression_type=s3_compression_type, + ) diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 60b74d66c7..6ce7e41831 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -62,15 +62,15 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): with tarfile.open(name=local_path, mode="r:gz") as tf: tf.extractall(path=src_dir) - # copy the custom inference script to code/ - entry_point = os.path.join("/opt/ml/code", inference_script) - shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script)) - - # copy source_dir to code/ if source_dir: + # copy /opt/ml/code to code/ if os.path.exists(code_dir): shutil.rmtree(code_dir) - shutil.copytree(source_dir, code_dir) + shutil.copytree("/opt/ml/code", code_dir) + else: + # copy the custom inference script to code/ + entry_point = os.path.join("/opt/ml/code", inference_script) + shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) # copy any dependencies to code/lib/ if dependencies: @@ -79,13 +79,16 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): lib_dir = os.path.join(code_dir, "lib") if not os.path.exists(lib_dir): os.mkdir(lib_dir) - if os.path.isdir(actual_dependency_path): - shutil.copytree( - actual_dependency_path, - os.path.join(lib_dir, os.path.basename(actual_dependency_path)), - ) - else: + if os.path.isfile(actual_dependency_path): shutil.copy2(actual_dependency_path, lib_dir) + else: + if os.path.exists(lib_dir): + shutil.rmtree(lib_dir) + # a directory is in the dependencies. we have to copy + # all of /opt/ml/code into the lib dir because the original directory + # was flattened by the SDK training job upload.. + shutil.copytree("/opt/ml/code", lib_dir) + break # copy the "src" dir, which includes the previous training job's model and the # custom inference script, to the output of this training job diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 6975c6ca97..dd81553a02 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -14,8 +14,10 @@ from __future__ import absolute_import import abc +import warnings from enum import Enum from typing import Dict, List, Union +from urllib.parse import urlparse import attr @@ -270,6 +272,16 @@ def __init__( ) self.cache_config = cache_config + if self.cache_config is not None and not self.estimator.disable_profiler: + msg = ( + "Profiling is enabled on the provided estimator. " + "The default profiler rule includes a timestamp " + "which will change each time the pipeline is " + "upserted, causing cache misses. If profiling " + "is not needed, set disable_profiler to True on the estimator." + ) + warnings.warn(msg) + @property def arguments(self) -> RequestType: """The arguments dict that is used to call `create_training_job`. @@ -498,6 +510,7 @@ def __init__( self.job_arguments = job_arguments self.code = code self.property_files = property_files + self.job_name = None # Examine why run method in sagemaker.processing.Processor mutates the processor instance # by setting the instance's arguments attribute. Refactor Processor.run, if possible. @@ -508,6 +521,17 @@ def __init__( ) self.cache_config = cache_config + if code: + code_url = urlparse(code) + if code_url.scheme == "" or code_url.scheme == "file": + # By default, Processor will upload the local code to an S3 path + # containing a timestamp. This causes cache misses whenever a + # pipeline is updated, even if the underlying script hasn't changed. + # To avoid this, hash the contents of the script and include it + # in the job_name passed to the Processor, which will be used + # instead of the timestamped path. + self.job_name = self._generate_code_upload_path() + @property def arguments(self) -> RequestType: """The arguments dict that is used to call `create_processing_job`. @@ -516,6 +540,7 @@ def arguments(self) -> RequestType: ProcessingJobName and ExperimentConfig cannot be included in the arguments. """ normalized_inputs, normalized_outputs = self.processor._normalize_args( + job_name=self.job_name, arguments=self.job_arguments, inputs=self.inputs, outputs=self.outputs, @@ -546,6 +571,13 @@ def to_request(self) -> RequestType: ] return request_dict + def _generate_code_upload_path(self) -> str: + """Generate an upload path for local processing scripts based on its contents""" + from sagemaker.workflow.utilities import hash_file + + code_hash = hash_file(self.code) + return f"{self.name}-{code_hash}"[:1024] + class TuningStep(ConfigurableRetryStep): """Tuning step for workflow.""" diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 069894d761..3e77465ff6 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import List, Sequence, Union +import hashlib from sagemaker.workflow.entities import ( Entity, @@ -37,3 +38,23 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R elif isinstance(entity, StepCollection): request_dicts.extend(entity.request_dicts()) return request_dicts + + +def hash_file(path: str) -> str: + """Get the MD5 hash of a file. + + Args: + path (str): The local path for the file. + Returns: + str: The MD5 hash of the file. + """ + BUF_SIZE = 65536 # read in 64KiB chunks + md5 = hashlib.md5() + with open(path, "rb") as f: + while True: + data = f.read(BUF_SIZE) + if not data: + break + md5.update(data) + + return md5.hexdigest() diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index 337d88af59..8ceb3f2195 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -747,6 +747,14 @@ def _get_processing_inputs_with_all_parameters(bucket): destination="/opt/ml/processing/input/data/", input_name="my_dataset", ), + ProcessingInput( + input_name="s3_input_wo_defaults", + s3_input=S3Input( + s3_uri=f"s3://{bucket}", + local_path="/opt/ml/processing/input/s3_input_wo_defaults", + s3_data_type="S3Prefix", + ), + ), ProcessingInput( input_name="s3_input", s3_input=S3Input( @@ -822,6 +830,17 @@ def _get_processing_job_inputs_and_outputs(bucket, output_kms_key): "S3CompressionType": "None", }, }, + { + "InputName": "s3_input_wo_defaults", + "AppManaged": False, + "S3Input": { + "S3Uri": f"s3://{bucket}", + "LocalPath": "/opt/ml/processing/input/s3_input_wo_defaults", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + }, + }, { "InputName": "s3_input", "AppManaged": False, diff --git a/tests/unit/sagemaker/workflow/test_repack_model_script.py b/tests/unit/sagemaker/workflow/test_repack_model_script.py index 67c8231dcc..69c9e7b740 100644 --- a/tests/unit/sagemaker/workflow/test_repack_model_script.py +++ b/tests/unit/sagemaker/workflow/test_repack_model_script.py @@ -94,7 +94,7 @@ def test_repack_with_dependencies(tmp): _repack_model.repack( inference_script="inference.py", model_archive=model_tar_name, - dependencies=["dependencies/a", "bb", "dependencies/some/dir"], + dependencies="dependencies/a bb dependencies/some/dir", ) # /opt/ml/model should now have the original model and the inference script @@ -145,7 +145,7 @@ def test_repack_with_source_dir_and_dependencies(tmp): _repack_model.repack( inference_script="inference.py", model_archive=model_tar_name, - dependencies=["dependencies/a", "bb", "dependencies/some/dir"], + dependencies="dependencies/a bb dependencies/some/dir", source_dir="sourcedir", ) diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 42c3bed7b6..3c2adc7bd9 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -16,6 +16,7 @@ import pytest import sagemaker import os +import warnings from mock import ( Mock, @@ -63,8 +64,7 @@ ) from tests.unit import DATA_DIR -SCRIPT_FILE = "dummy_script.py" -SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE) +DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") REGION = "us-west-2" BUCKET = "my-bucket" @@ -129,6 +129,31 @@ def sagemaker_session(boto_session, client): ) +@pytest.fixture +def script_processor(sagemaker_session): + return ScriptProcessor( + role=ROLE, + image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", + command=["python3"], + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=100, + volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + max_runtime_in_seconds=3600, + base_job_name="my_sklearn_processor", + env={"my_env_variable": "my_env_variable_value"}, + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], + network_config=NetworkConfig( + subnets=["my_subnet_id"], + security_group_ids=["my_security_group_id"], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ), + sagemaker_session=sagemaker_session, + ) + + def test_custom_step(): step = CustomStep( name="MyStep", display_name="CustomStepDisplayName", description="CustomStepDescription" @@ -326,7 +351,7 @@ def test_training_step_tensorflow(sagemaker_session): training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5) training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500) estimator = TensorFlow( - entry_point=os.path.join(DATA_DIR, SCRIPT_FILE), + entry_point=DUMMY_SCRIPT_PATH, role=ROLE, model_dir=False, image_uri=IMAGE_URI, @@ -403,6 +428,75 @@ def test_training_step_tensorflow(sagemaker_session): assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} +def test_training_step_profiler_warning(sagemaker_session): + estimator = TensorFlow( + entry_point=DUMMY_SCRIPT_PATH, + role=ROLE, + model_dir=False, + image_uri=IMAGE_URI, + source_dir="s3://mybucket/source", + framework_version="2.4.1", + py_version="py37", + disable_profiler=False, + instance_count=1, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + debugger_hook_config=False, + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ) + + inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + with warnings.catch_warnings(record=True) as w: + TrainingStep( + name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config + ) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Profiling is enabled on the provided estimator" in str(w[-1].message) + + +def test_training_step_no_profiler_warning(sagemaker_session): + estimator = TensorFlow( + entry_point=DUMMY_SCRIPT_PATH, + role=ROLE, + model_dir=False, + image_uri=IMAGE_URI, + source_dir="s3://mybucket/source", + framework_version="2.4.1", + py_version="py37", + disable_profiler=True, + instance_count=1, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + debugger_hook_config=False, + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ) + + inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + with warnings.catch_warnings(record=True) as w: + # profiler disabled, cache config not None + TrainingStep( + name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config + ) + assert len(w) == 0 + + with warnings.catch_warnings(record=True) as w: + # profiler enabled, cache config is None + estimator.disable_profiler = False + TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=None) + assert len(w) == 0 + + def test_processing_step(sagemaker_session): processing_input_data_uri_parameter = ParameterString( name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest" @@ -473,28 +567,42 @@ def test_processing_step(sagemaker_session): @patch("sagemaker.processing.ScriptProcessor._normalize_args") -def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session): - processor = ScriptProcessor( - role=ROLE, - image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", - command=["python3"], - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=100, - volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", - output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", - max_runtime_in_seconds=3600, - base_job_name="my_sklearn_processor", - env={"my_env_variable": "my_env_variable_value"}, - tags=[{"Key": "my-tag", "Value": "my-tag-value"}], - network_config=NetworkConfig( - subnets=["my_subnet_id"], - security_group_ids=["my_security_group_id"], - enable_network_isolation=True, - encrypt_inter_container_traffic=True, - ), - sagemaker_session=sagemaker_session, +def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, script_processor): + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + inputs = [ + ProcessingInput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + outputs = [ + ProcessingOutput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + step = ProcessingStep( + name="MyProcessingStep", + processor=script_processor, + code=DUMMY_SCRIPT_PATH, + inputs=inputs, + outputs=outputs, + job_arguments=["arg1", "arg2"], + cache_config=cache_config, ) + mock_normalize_args.return_value = [step.inputs, step.outputs] + step.to_request() + mock_normalize_args.assert_called_with( + job_name="MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db", + arguments=step.job_arguments, + inputs=step.inputs, + outputs=step.outputs, + code=step.code, + ) + + +@patch("sagemaker.processing.ScriptProcessor._normalize_args") +def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, script_processor): cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") inputs = [ ProcessingInput( @@ -510,8 +618,8 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session) ] step = ProcessingStep( name="MyProcessingStep", - processor=processor, - code="foo.py", + processor=script_processor, + code="s3://foo", inputs=inputs, outputs=outputs, job_arguments=["arg1", "arg2"], @@ -520,6 +628,7 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session) mock_normalize_args.return_value = [step.inputs, step.outputs] step.to_request() mock_normalize_args.assert_called_with( + job_name=None, arguments=step.job_arguments, inputs=step.inputs, outputs=step.outputs, @@ -527,6 +636,40 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session) ) +@patch("sagemaker.processing.ScriptProcessor._normalize_args") +def test_processing_step_normalizes_args_with_no_code(mock_normalize_args, script_processor): + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + inputs = [ + ProcessingInput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + outputs = [ + ProcessingOutput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + step = ProcessingStep( + name="MyProcessingStep", + processor=script_processor, + inputs=inputs, + outputs=outputs, + job_arguments=["arg1", "arg2"], + cache_config=cache_config, + ) + mock_normalize_args.return_value = [step.inputs, step.outputs] + step.to_request() + mock_normalize_args.assert_called_with( + job_name=None, + arguments=step.job_arguments, + inputs=step.inputs, + outputs=step.outputs, + code=None, + ) + + def test_create_model_step(sagemaker_session): model = Model( image_uri=IMAGE_URI,