diff --git a/doc/overview.rst b/doc/overview.rst index 39f5f6ecae..df320e3b47 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -746,6 +746,7 @@ see `Model None: class JumpStartECRSpecs(JumpStartDataHolderType): """Data class for JumpStart ECR specs.""" - __slots__ = { + __slots__ = [ "framework", "framework_version", "py_version", "huggingface_transformers_version", - } + ] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartECRSpecs object from its json representation. @@ -173,7 +173,7 @@ def to_json(self) -> Dict[str, Any]: class JumpStartHyperparameter(JumpStartDataHolderType): """Data class for JumpStart hyperparameter definition in the training container.""" - __slots__ = { + __slots__ = [ "name", "type", "options", @@ -183,7 +183,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "max", "exclusive_min", "exclusive_max", - } + ] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartHyperparameter object from its json representation. @@ -234,12 +234,12 @@ def to_json(self) -> Dict[str, Any]: class JumpStartEnvironmentVariable(JumpStartDataHolderType): """Data class for JumpStart environment variable definitions in the hosting container.""" - __slots__ = { + __slots__ = [ "name", "type", "default", "scope", - } + ] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartEnvironmentVariable object from its json representation. diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 16bdd9fc4f..c59966d1b5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -13,6 +13,7 @@ """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import import logging +import os from typing import Dict, List, Optional from urllib.parse import urlparse from packaging.version import Version @@ -60,6 +61,14 @@ def get_jumpstart_content_bucket(region: str) -> str: Raises: RuntimeError: If JumpStart is not launched in ``region``. """ + + if ( + constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] + LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override) + return bucket_override try: return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket except KeyError: diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 10c5b38a81..65268388c3 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -49,7 +49,7 @@ def _validate_hyperparameter( if len(hyperparameter_spec) > 1: raise JumpStartHyperparametersError( - f"Unable to perform validation -- found multiple hyperparameter " + "Unable to perform validation -- found multiple hyperparameter " f"'{hyperparameter_name}' in model specs." ) @@ -76,35 +76,35 @@ def _validate_hyperparameter( if hyperparameter_value not in hyperparameter_spec.options: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have one of the following " - f"values: {', '.join(hyperparameter_spec.options)}" + f"values: {', '.join(hyperparameter_spec.options)}." ) if hasattr(hyperparameter_spec, "min"): if len(hyperparameter_value) < hyperparameter_spec.min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length no less than " - f"{hyperparameter_spec.min}" + f"{hyperparameter_spec.min}." ) if hasattr(hyperparameter_spec, "exclusive_min"): if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length greater than " - f"{hyperparameter_spec.exclusive_min}" + f"{hyperparameter_spec.exclusive_min}." ) if hasattr(hyperparameter_spec, "max"): if len(hyperparameter_value) > hyperparameter_spec.max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length no greater than " - f"{hyperparameter_spec.max}" + f"{hyperparameter_spec.max}." ) if hasattr(hyperparameter_spec, "exclusive_max"): if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length less than " - f"{hyperparameter_spec.exclusive_max}" + f"{hyperparameter_spec.exclusive_max}." ) # validate numeric types @@ -125,35 +125,35 @@ def _validate_hyperparameter( if not hyperparameter_value_str[start_index:].isdigit(): raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must be integer type " - "('{hyperparameter_value}')." + f"('{hyperparameter_value}')." ) if hasattr(hyperparameter_spec, "min"): if numeric_hyperparam_value < hyperparameter_spec.min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' can be no less than " - "{hyperparameter_spec.min}." + f"{hyperparameter_spec.min}." ) if hasattr(hyperparameter_spec, "max"): if numeric_hyperparam_value > hyperparameter_spec.max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' can be no greater than " - "{hyperparameter_spec.max}." + f"{hyperparameter_spec.max}." ) if hasattr(hyperparameter_spec, "exclusive_min"): if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must be greater than " - "{hyperparameter_spec.exclusive_min}." + f"{hyperparameter_spec.exclusive_min}." ) if hasattr(hyperparameter_spec, "exclusive_max"): if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must be less than " - "{hyperparameter_spec.exclusive_max}." + f"{hyperparameter_spec.exclusive_max}." ) diff --git a/src/sagemaker/lineage/lineage_trial_component.py b/src/sagemaker/lineage/lineage_trial_component.py index f8bc0e53b4..1e02e83657 100644 --- a/src/sagemaker/lineage/lineage_trial_component.py +++ b/src/sagemaker/lineage/lineage_trial_component.py @@ -130,8 +130,15 @@ def pipeline_execution_arn(self) -> str: Returns: str: A pipeline execution ARN. """ + trial_component = self.load( + trial_component_name=self.trial_component_name, sagemaker_session=self.sagemaker_session + ) + + if trial_component.source is None or trial_component.source["SourceArn"] is None: + return None + tags = self.sagemaker_session.sagemaker_client.list_tags( - ResourceArn=self.trial_component_arn + ResourceArn=trial_component.source["SourceArn"] )["Tags"] for tag in tags: if tag["Key"] == "sagemaker:pipeline-execution-arn": diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 00a04a3199..2d01bb4c0f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -303,6 +303,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -328,6 +329,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -355,6 +358,7 @@ def register( description=description, container_def_list=[container_def], drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index df0dd31a28..0a10cbf3c1 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -158,6 +158,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +184,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -211,6 +214,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 3a0c3a283c..0f51788626 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -157,6 +157,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -182,6 +183,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -210,6 +213,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91b89ea4c9..c50a22d3f8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2778,6 +2778,7 @@ def create_model_package_from_containers( approval_status="PendingManualApproval", description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -2803,6 +2804,9 @@ def create_model_package_from_containers( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + """ request = get_create_model_package_request( @@ -2819,7 +2823,17 @@ def create_model_package_from_containers( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) + if model_package_group_name is not None: + try: + self.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) + except ClientError: + self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) return self.sagemaker_client.create_model_package(**request) def wait_for_model_package(self, model_package_name, poll=5): @@ -4120,6 +4134,7 @@ def get_model_package_args( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get arguments for create_model_package method. @@ -4148,6 +4163,8 @@ def get_model_package_args( (default: None). container_def_list (list): A list of container defintiions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: dict: A dictionary of method argument names and values. """ @@ -4185,6 +4202,8 @@ def get_model_package_args( model_package_args["description"] = description if tags is not None: model_package_args["tags"] = tags + if customer_metadata_properties is not None: + model_package_args["customer_metadata_properties"] = customer_metadata_properties return model_package_args @@ -4203,6 +4222,7 @@ def get_create_model_package_request( description=None, tags=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -4229,6 +4249,8 @@ def get_create_model_package_request( tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4250,6 +4272,8 @@ def get_create_model_package_request( request_dict["DriftCheckBaselines"] = drift_check_baselines if metadata_properties: request_dict["MetadataProperties"] = metadata_properties + if customer_metadata_properties is not None: + request_dict["CustomerMetadataProperties"] = customer_metadata_properties if containers is not None: if not all([content_types, response_types, inference_instances, transform_instances]): raise ValueError( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 0b8d2f7235..9f6a7841d5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -201,6 +201,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +227,9 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + Returns: A `sagemaker.model.ModelPackage` instance. @@ -254,6 +258,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def deploy( diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 6ce7e41831..f98f170f39 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -34,7 +34,7 @@ from distutils.dir_util import copy_tree -def repack(inference_script, model_archive, dependencies=None, source_dir=None): +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover """Repack custom dependencies and code into an existing model TAR archive Args: @@ -95,7 +95,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): copy_tree(src_dir, "/opt/ml/model") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover parser = argparse.ArgumentParser() parser.add_argument("--inference_script", type=str, default="inference.py") parser.add_argument("--dependencies", type=str, default=None) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index ca078fe7ea..fbbb6acba9 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -80,7 +80,7 @@ def __init__( artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. model_data (str): The S3 location of a SageMaker model data - ``.tar.gz`` file (default: None). + ``.tar.gz`` file. entry_point (str): Path (absolute or relative) to the local Python source file which should be executed as the entry point to inference. If ``source_dir`` is specified, then ``entry_point`` @@ -310,6 +310,7 @@ def __init__( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Constructor of a register model step. @@ -347,6 +348,8 @@ def __init__( this step depends on retry_policies (List[RetryPolicy]): The list of retry policies for the current step drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -362,6 +365,7 @@ def __init__( self.tags = tags self.model_metrics = model_metrics self.drift_check_baselines = drift_check_baselines + self.customer_metadata_properties = customer_metadata_properties self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -435,6 +439,7 @@ def arguments(self) -> RequestType: description=self.description, tags=self.tags, container_def_list=self.container_def_list, + customer_metadata_properties=self.customer_metadata_properties, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index a34330d94d..a2597c07f9 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -95,7 +95,7 @@ def properties(self): @attr.s -class JsonGet(Expression): +class JsonGet(Expression): # pragma: no cover """Get JSON properties from PropertyFiles. Attributes: diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 2e2849cc80..065cf01315 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -79,7 +79,7 @@ def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return { "Type": self.condition_type.value, - "LeftValue": self.left.expr, + "LeftValue": primitive_or_expr(self.left), "RightValue": primitive_or_expr(self.right), } diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 03ac099d18..e0076322de 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -75,8 +75,8 @@ class JsonGet(Expression): @property def expr(self): """The expression dict for a `JsonGet` function.""" - if not isinstance(self.step_name, str): - raise ValueError("Please give step name as a string") + if not isinstance(self.step_name, str) or not self.step_name: + raise ValueError("Please give a valid step name as a string") if isinstance(self.property_file, PropertyFile): name = self.property_file.name diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 0446a0b46c..5240ae60b9 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -161,8 +161,8 @@ def _get_function_arn(self): partition = "aws" if self.lambda_func.function_arn is None: + account_id = self.lambda_func.session.account_id() try: - account_id = self.lambda_func.session.account_id() response = self.lambda_func.create() return response["FunctionArn"] except ValueError as error: diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index f4606488b2..1280637006 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -75,6 +75,7 @@ def __init__( tags=None, model: Union[Model, PipelineModel] = None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -95,7 +96,7 @@ def __init__( for the repack model step register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for register model step - model_package_group_name (str): The Model Package Group name, exclusive to + model_package_group_name (str): The Model Package Group name or Arn, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). @@ -113,6 +114,9 @@ def __init__( model (object or Model): A PipelineModel object that comprises a list of models which gets executed as a serial inference pipeline or a Model object. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] @@ -229,6 +233,7 @@ def __init__( tags=tags, container_def_list=self.container_def_list, retry_policies=register_model_step_retry_policies, + customer_metadata_properties=customer_metadata_properties, **kwargs, ) if not repack_model: @@ -318,15 +323,15 @@ def __init__( """ steps = [] if "entry_point" in kwargs: - entry_point = kwargs["entry_point"] - source_dir = kwargs.get("source_dir") - dependencies = kwargs.get("dependencies") + entry_point = kwargs.get("entry_point", None) + source_dir = kwargs.get("source_dir", None) + dependencies = kwargs.get("dependencies", None) repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, - role=estimator.sagemaker_session, + role=estimator.role, model_data=model_data, entry_point=entry_point, source_dir=source_dir, @@ -352,7 +357,11 @@ def predict_wrapper(endpoint, session): vpc_config=None, sagemaker_session=estimator.sagemaker_session, role=estimator.role, - **kwargs, + env=kwargs.get("env", None), + name=kwargs.get("name", None), + enable_network_isolation=kwargs.get("enable_network_isolation", None), + model_kms_key=kwargs.get("model_kms_key", None), + image_config=kwargs.get("image_config", None), ) model_step = CreateModelStep( name=f"{name}CreateModelStep", diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 0139a5b658..4ede5c193d 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -233,7 +233,7 @@ def upstream_trial_associated_artifact( sagemaker_session=sagemaker_session, ) trial_obj.add_trial_component(trial_component_obj) - time.sleep(3) + time.sleep(4) yield artifact_obj trial_obj.remove_trial_component(trial_component_obj) assntn.delete() @@ -561,14 +561,14 @@ def static_approval_action( @pytest.fixture -def static_model_deployment_action(sagemaker_session, static_endpoint_context): +def static_model_deployment_action(sagemaker_session, static_processing_job_trial_component): query_filter = LineageFilter( entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_endpoint_context.context_arn], + start_arns=[static_processing_job_trial_component.trial_component_arn], query_filter=query_filter, - direction=LineageQueryDirectionEnum.ASCENDANTS, + direction=LineageQueryDirectionEnum.DESCENDANTS, include_edges=False, ) model_approval_actions = [] @@ -579,14 +579,14 @@ def static_model_deployment_action(sagemaker_session, static_endpoint_context): @pytest.fixture def static_processing_job_trial_component( - sagemaker_session, static_endpoint_context + sagemaker_session, static_dataset_artifact ) -> LineageTrialComponent: query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_endpoint_context.context_arn], + start_arns=[static_dataset_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, @@ -600,14 +600,14 @@ def static_processing_job_trial_component( @pytest.fixture def static_training_job_trial_component( - sagemaker_session, static_endpoint_context + sagemaker_session, static_model_artifact ) -> LineageTrialComponent: query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_endpoint_context.context_arn], + start_arns=[static_model_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, @@ -738,12 +738,12 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session): @pytest.fixture -def static_image_artifact(static_model_artifact, sagemaker_session): +def static_image_artifact(static_dataset_artifact, sagemaker_session): query_filter = LineageFilter( entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.IMAGE] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_model_artifact.artifact_arn], + start_arns=[static_dataset_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 160f9f934b..dd24149ca4 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -67,7 +67,6 @@ ConditionLessThanOrEqualTo, ) from sagemaker.workflow.condition_step import ConditionStep -from sagemaker.workflow.condition_step import JsonGet as ConditionStepJsonGet from sagemaker.workflow.callback_step import ( CallbackStep, CallbackOutput, @@ -1952,6 +1951,7 @@ def test_model_registration_with_drift_check_baselines( content_type="application/json", ), ) + customer_metadata_properties = {"key1": "value1"} estimator = XGBoost( entry_point="training.py", source_dir=os.path.join(DATA_DIR, "sip"), @@ -1973,6 +1973,7 @@ def test_model_registration_with_drift_check_baselines( model_package_group_name="testModelPackageGroup", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) pipeline = Pipeline( @@ -2043,6 +2044,7 @@ def test_model_registration_with_drift_check_baselines( response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"] == "application/json" ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties break finally: try: @@ -2832,8 +2834,8 @@ def test_end_to_end_pipeline_successful_execution( # define condition step cond_lte = ConditionLessThanOrEqualTo( - left=ConditionStepJsonGet( - step=step_eval, + left=JsonGet( + step_name=step_eval.name, property_file=evaluation_report, json_path="regression_metrics.mse.value", ), diff --git a/tests/integ/test_workflow_with_clarify.py b/tests/integ/test_workflow_with_clarify.py index 0c41b2212a..486abab89b 100644 --- a/tests/integ/test_workflow_with_clarify.py +++ b/tests/integ/test_workflow_with_clarify.py @@ -33,7 +33,8 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.session import get_execution_role from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo -from sagemaker.workflow.condition_step import ConditionStep, JsonGet +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.functions import JsonGet from sagemaker.workflow.parameters import ( ParameterInteger, ParameterString, @@ -237,7 +238,7 @@ def test_workflow_with_clarify( ) cond_left = JsonGet( - step=step_process, + step_name=step_process.name, property_file="BiasOutput", json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value", ) diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index ddeeccba1d..83092f74e5 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -147,49 +147,54 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) hyperparameter_to_test["batch-size"] = "0" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' " "can be no less than 1.") hyperparameter_to_test["batch-size"] = "-1" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' can be no " "less than 1.") hyperparameter_to_test["batch-size"] = "-1.5" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' must be " "integer type ('-1.5').") hyperparameter_to_test["batch-size"] = "1.5" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' must be integer " "type ('1.5').") hyperparameter_to_test["batch-size"] = "99999" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' can be no greater " "than 1024.") hyperparameter_to_test["batch-size"] = 5 hyperparameters.validate( @@ -210,13 +215,17 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [None, "", 5, "Truesday", "Falsehood"]: hyperparameter_to_test["test_bool_param"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Expecting boolean valued hyperparameter, " f"but got '{str(val)}'." + ) + hyperparameter_to_test["test_bool_param"] = original_bool_val original_exclusive_min_val = hyperparameter_to_test["test_exclusive_min_param"] @@ -230,13 +239,16 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [1, 1 - 1e-99, -99]: hyperparameter_to_test["test_exclusive_min_param"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'test_exclusive_min_param' must " "be greater than 1." + ) hyperparameter_to_test["test_exclusive_min_param"] = original_exclusive_min_val original_exclusive_max_val = hyperparameter_to_test["test_exclusive_max_param"] @@ -250,13 +262,15 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [4, 5, 99]: hyperparameter_to_test["test_exclusive_max_param"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == "Hyperparameter 'test_exclusive_max_param' must be less than 4." + hyperparameter_to_test["test_exclusive_max_param"] = original_exclusive_max_val original_exclusive_max_text_val = hyperparameter_to_test["test_exclusive_max_param_text"] @@ -270,13 +284,17 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in ["123456", "123456789"]: hyperparameter_to_test["test_exclusive_max_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert ( + str(e.value) + == "Hyperparameter 'test_exclusive_max_param_text' must have length less than 6." + ) hyperparameter_to_test["test_exclusive_max_param_text"] = original_exclusive_max_text_val original_max_text_val = hyperparameter_to_test["test_max_param_text"] @@ -290,13 +308,17 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in ["1234567", "123456789"]: hyperparameter_to_test["test_max_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert ( + str(e.value) + == "Hyperparameter 'test_max_param_text' must have length no greater than 6." + ) hyperparameter_to_test["test_max_param_text"] = original_max_text_val original_exclusive_min_text_val = hyperparameter_to_test["test_exclusive_min_param_text"] @@ -310,13 +332,16 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in ["1", "d", ""]: hyperparameter_to_test["test_exclusive_min_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'test_exclusive_min_param_text' must have length greater " "than 1." + ) hyperparameter_to_test["test_exclusive_min_param_text"] = original_exclusive_min_text_val original_min_text_val = hyperparameter_to_test["test_min_param_text"] @@ -330,24 +355,31 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [""]: hyperparameter_to_test["test_min_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'test_min_param_text' " "must have length no less than 1." + ) hyperparameter_to_test["test_min_param_text"] = original_min_text_val del hyperparameter_to_test["batch-size"] hyperparameter_to_test["penalty"] = "blah" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'penalty' must have one of the following values: l1, l2, elasticnet," + " none." + ) hyperparameter_to_test["penalty"] = "elasticnet" hyperparameters.validate( @@ -411,7 +443,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) del hyperparameter_to_test["adam-learning-rate"] - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, @@ -419,6 +451,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, ) + assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'." @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -454,7 +487,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): del hyperparameter_to_test["sagemaker_submit_directory"] - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, @@ -462,13 +495,14 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALL, ) + assert str(e.value) == "Cannot find hyperparameter for 'sagemaker_submit_directory'." hyperparameter_to_test[ "sagemaker_submit_directory" ] = "/opt/ml/input/data/code/sourcedir.tar.gz" del hyperparameter_to_test["epochs"] - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, @@ -476,6 +510,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALL, ) + assert str(e.value) == "Cannot find hyperparameter for 'epochs'." hyperparameter_to_test["epochs"] = "3" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 761b53d469..93e8114185 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -151,17 +151,41 @@ def test_jumpstart_cache_get_header(): semantic_version_str="3.*", ) assert ( - "Unable to find model manifest for tensorflow-ic-imagenet-inception-v3-classification-4 " - "with version 3.* compatible with your SageMaker version (2.68.3). Consider upgrading " - "your SageMaker library to at least version 4.49.0 so you can use version 3.0.0 of " - "tensorflow-ic-imagenet-inception-v3-classification-4." in str(e.value) + "Unable to find model manifest for 'tensorflow-ic-imagenet-inception-v3-classification-4' " + "with version '3.*' compatible with your SageMaker version ('2.68.3'). Consider upgrading " + "your SageMaker library to at least version '4.49.0' so you can use version '3.0.0' of " + "'tensorflow-ic-imagenet-inception-v3-classification-4'." in str(e.value) ) with pytest.raises(KeyError) as e: cache.get_header( model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="3.*" ) - assert "Consider upgrading" not in str(e.value) + assert ( + "Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with " + "version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-" + "classification-4' with version '2.0.0'." + ) in str(e.value) + + with pytest.raises(KeyError) as e: + cache.get_header(model_id="pytorch-ic-", semantic_version_str="*") + assert ( + "Unable to find model manifest for 'pytorch-ic-' with version '*'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "for updated list of models. " + "Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?" + ) in str(e.value) + + with pytest.raises(KeyError) as e: + cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*") + assert ( + "Unable to find model manifest for 'tensorflow-ic-' with version '*'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "for updated list of models. " + "Did you mean to use model ID 'tensorflow-ic-imagenet-inception-" + "v3-classification-4'?" + ) in str(e.value) with pytest.raises(KeyError): cache.get_header( diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fe494eb459..04eddced08 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -11,11 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import os from mock.mock import Mock, patch import pytest import random from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, JUMPSTART_BUCKET_NAME_SET, JUMPSTART_REGION_NAME_SET, JumpStartScriptScope, @@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket(): utils.get_jumpstart_content_bucket(bad_region) +def test_get_jumpstart_content_bucket_override(): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart bucket override: '%s'", + "some-val", + ) + + def test_get_jumpstart_launched_regions_message(): with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): diff --git a/tests/unit/sagemaker/lineage/test_lineage_trial_component.py b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py index 9b466832a1..5755f512f9 100644 --- a/tests/unit/sagemaker/lineage/test_lineage_trial_component.py +++ b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py @@ -114,9 +114,28 @@ def test_pipeline_execution_arn(sagemaker_session): trial_component_arn = ( "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" ) - obj = lineage_trial_component.LineageTrialComponent( - sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain" ) + context = lineage_trial_component.LineageTrialComponent( + sagemaker_session, + trial_component_name="foo", + trial_component_arn=trial_component_arn, + source={ + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + ) + obj = { + "TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "TrialComponentArn": trial_component_arn, + "DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "Source": { + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj sagemaker_session.sagemaker_client.list_tags.return_value = { "Tags": [ @@ -124,9 +143,10 @@ def test_pipeline_execution_arn(sagemaker_session): ], } expected_calls = [ - unittest.mock.call(ResourceArn=trial_component_arn), + unittest.mock.call(ResourceArn=training_job_arn), ] - pipeline_execution_arn_result = obj.pipeline_execution_arn() + pipeline_execution_arn_result = context.pipeline_execution_arn() + assert pipeline_execution_arn_result == "tag1" assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls @@ -135,9 +155,28 @@ def test_no_pipeline_execution_arn(sagemaker_session): trial_component_arn = ( "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" ) - obj = lineage_trial_component.LineageTrialComponent( - sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain" ) + context = lineage_trial_component.LineageTrialComponent( + sagemaker_session, + trial_component_name="foo", + trial_component_arn=trial_component_arn, + source={ + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + ) + obj = { + "TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "TrialComponentArn": trial_component_arn, + "DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "Source": { + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj sagemaker_session.sagemaker_client.list_tags.return_value = { "Tags": [ @@ -145,9 +184,48 @@ def test_no_pipeline_execution_arn(sagemaker_session): ], } expected_calls = [ - unittest.mock.call(ResourceArn=trial_component_arn), + unittest.mock.call(ResourceArn=training_job_arn), ] - pipeline_execution_arn_result = obj.pipeline_execution_arn() + pipeline_execution_arn_result = context.pipeline_execution_arn() + expected_result = None + assert pipeline_execution_arn_result == expected_result + assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls + + +def test_no_source_arn_pipeline_execution_arn(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain" + ) + context = lineage_trial_component.LineageTrialComponent( + sagemaker_session, + trial_component_name="foo", + trial_component_arn=trial_component_arn, + source={ + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + ) + obj = { + "TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "TrialComponentArn": trial_component_arn, + "DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "Source": { + "SourceArn": None, + "SourceType": None, + }, + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj + + sagemaker_session.sagemaker_client.list_tags.return_value = { + "Tags": [ + {"Key": "abcd", "Value": "efg"}, + ], + } + expected_calls = [] + pipeline_execution_arn_result = context.pipeline_execution_arn() expected_result = None assert pipeline_execution_arn_result == expected_result assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls diff --git a/tests/unit/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py similarity index 100% rename from tests/unit/test_airflow.py rename to tests/unit/sagemaker/workflow/test_airflow.py diff --git a/tests/unit/sagemaker/workflow/test_conditions.py b/tests/unit/sagemaker/workflow/test_conditions.py index d473b36121..f4bea55b6e 100644 --- a/tests/unit/sagemaker/workflow/test_conditions.py +++ b/tests/unit/sagemaker/workflow/test_conditions.py @@ -165,3 +165,12 @@ def test_condition_or(): }, ], } + + +def test_left_and_right_primitives(): + cond = ConditionEquals(left=2, right=1) + assert cond.to_request() == { + "Type": "Equals", + "LeftValue": 2, + "RightValue": 1, + } diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 8e5d6b6d31..9b07a41d09 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -13,6 +13,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join, JsonGet from sagemaker.workflow.parameters import ( @@ -97,3 +99,23 @@ def test_json_get_expressions(): "Path": "my-json-path", }, } + + +def test_json_get_expressions_with_invalid_step_name(): + with pytest.raises(ValueError) as err: + JsonGet( + step_name="", + property_file="my-property-file", + json_path="my-json-path", + ).expr + + assert "Please give a valid step name as a string" in str(err.value) + + with pytest.raises(ValueError) as err: + JsonGet( + step_name=ParameterString(name="MyString"), + property_file="my-property-file", + json_path="my-json-path", + ).expr + + assert "Please give a valid step name as a string" in str(err.value) diff --git a/tests/unit/sagemaker/workflow/test_lambda_step.py b/tests/unit/sagemaker/workflow/test_lambda_step.py index 0566e39318..bdaa781b1c 100644 --- a/tests/unit/sagemaker/workflow/test_lambda_step.py +++ b/tests/unit/sagemaker/workflow/test_lambda_step.py @@ -22,6 +22,7 @@ from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from sagemaker.lambda_helper import Lambda +from sagemaker.workflow.steps import CacheConfig @pytest.fixture() @@ -38,10 +39,25 @@ def sagemaker_session(): return session_mock +@pytest.fixture() +def sagemaker_session_cn(): + boto_mock = Mock(name="boto_session", region_name="cn-north-1") + session_mock = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name="cn-north-1", + config=None, + local_mode=False, + ) + session_mock.account_id.return_value = "234567890123" + return session_mock + + def test_lambda_step(sagemaker_session): param = ParameterInteger(name="MyInt") - outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) - outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) + output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) + output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") lambda_step = LambdaStep( name="MyLambdaStep", depends_on=["TestStep"], @@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session): display_name="MyLambdaStep", description="MyLambdaStepDescription", inputs={"arg1": "foo", "arg2": 5, "arg3": param}, - outputs=[outputParam1, outputParam2], + outputs=[output_param1, output_param2], + cache_config=cache_config, ) lambda_step.add_depends_on(["SecondTestStep"]) - assert lambda_step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[param], + steps=[lambda_step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyLambdaStep", "Type": "Lambda", "DependsOn": ["TestStep", "SecondTestStep"], @@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session): {"OutputName": "output1", "OutputType": "String"}, {"OutputName": "output2", "OutputType": "Boolean"}, ], - "Arguments": {"arg1": "foo", "arg2": 5, "arg3": param}, + "Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}}, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } @@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session): def test_pipeline_interpolates_lambda_outputs(sagemaker_session): parameter = ParameterString("MyStr") - outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) - outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) + output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) + output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) lambda_step1 = LambdaStep( name="MyLambdaStep1", depends_on=["TestStep"], @@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): session=sagemaker_session, ), inputs={"arg1": "foo"}, - outputs=[outputParam1], + outputs=[output_param1], ) lambda_step2 = LambdaStep( name="MyLambdaStep2", @@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", session=sagemaker_session, ), - inputs={"arg1": outputParam1}, - outputs=[outputParam2], + inputs={"arg1": output_param1}, + outputs=[output_param2], ) pipeline = Pipeline( @@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session): ) lambda_step._get_function_arn() sagemaker_session.account_id.assert_called_once() + + +def test_lambda_step_without_function_arn_and_with_error(sagemaker_session_cn): + lambda_func = MagicMock( + function_arn=None, + function_name="name", + execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role", + zipped_code_dir="", + handler="", + session=sagemaker_session_cn, + ) + # The raised ValueError contains ResourceConflictException + lambda_func.create.side_effect = ValueError("ResourceConflictException") + lambda_step1 = LambdaStep( + name="MyLambdaStep1", + depends_on=["TestStep"], + lambda_func=lambda_func, + inputs={}, + outputs=[], + ) + function_arn = lambda_step1._get_function_arn() + assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name" + + # The raised ValueError does not contain ResourceConflictException + lambda_func.create.side_effect = ValueError() + lambda_step2 = LambdaStep( + name="MyLambdaStep2", + depends_on=["TestStep"], + lambda_func=lambda_func, + inputs={}, + outputs=[], + ) + with pytest.raises(ValueError): + lambda_step2._get_function_arn() diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 6c78412b22..d2f1f07059 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -19,6 +19,7 @@ import pytest from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.workflow.utilities import list_to_request from tests.unit import DATA_DIR import sagemaker @@ -206,6 +207,16 @@ def test_step_collection(): ] +def test_step_collection_with_list_to_request(): + step_collection = StepCollection(steps=[CustomStep("MyStep1"), CustomStep("MyStep2")]) + custom_step = CustomStep("MyStep3") + assert list_to_request([step_collection, custom_step]) == [ + {"Name": "MyStep1", "Type": "Training", "Arguments": dict()}, + {"Name": "MyStep2", "Type": "Training", "Arguments": dict()}, + {"Name": "MyStep3", "Type": "Training", "Arguments": dict()}, + ] + + def test_register_model(estimator, model_metrics, drift_check_baselines): model_data = f"s3://{BUCKET}/model.tar.gz" register_model = RegisterModel( @@ -216,6 +227,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): response_types=["response_type"], inference_instances=["inference_instance"], transform_instances=["transform_instance"], + image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", model_package_group_name="mpg", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, @@ -236,7 +248,10 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "Arguments": { "InferenceSpecification": { "Containers": [ - {"Image": "fakeimage", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz"} + { + "Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", + "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", + } ], "SupportedContentTypes": ["content_type"], "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"], @@ -865,3 +880,117 @@ def test_estimator_transformer(estimator): } else: raise Exception("A step exists in the collection of an invalid type.") + + +def test_estimator_transformer_with_model_repack_with_estimator(estimator): + model_data = f"s3://{BUCKET}/model.tar.gz" + model_inputs = CreateModelInput( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + service_fault_retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 + ) + transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") + estimator_transformer = EstimatorTransformer( + name="EstimatorTransformerStep", + estimator=estimator, + model_data=model_data, + model_inputs=model_inputs, + instance_count=1, + instance_type="ml.c4.4xlarge", + transform_inputs=transform_inputs, + depends_on=["TestStep"], + model_step_retry_policies=[service_fault_retry_policy], + transform_step_retry_policies=[service_fault_retry_policy], + repack_model_step_retry_policies=[service_fault_retry_policy], + entry_point=f"{DATA_DIR}/dummy_script.py", + ) + request_dicts = estimator_transformer.request_dicts() + assert len(request_dicts) == 3 + + for request_dict in request_dicts: + if request_dict["Type"] == "Training": + assert request_dict["Name"] == "EstimatorTransformerStepRepackModel" + assert request_dict["DependsOn"] == ["TestStep"] + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + # pop out the dynamic generated fields + arguments["HyperParameters"].pop("sagemaker_submit_directory") + arguments["HyperParameters"].pop("sagemaker_job_name") + assert arguments == { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + + "sagemaker-scikit-learn:0.23-1-cpu-py3", + }, + "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 30, + }, + "RoleArn": "DummyRole", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://my-bucket", + "S3DataDistributionType": "FullyReplicated", + } + }, + "ChannelName": "training", + } + ], + "HyperParameters": { + "inference_script": '"dummy_script.py"', + "model_archive": '"model.tar.gz"', + "dependencies": "null", + "source_dir": "null", + "sagemaker_program": '"_repack_model.py"', + "sagemaker_container_log_level": "20", + "sagemaker_region": '"us-west-2"', + }, + "VpcConfig": {"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]}, + "DebugHookConfig": { + "S3OutputPath": "s3://my-bucket/", + "CollectionConfigurations": [], + }, + } + elif request_dict["Type"] == "Model": + assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties) + arguments["PrimaryContainer"].pop("ModelDataUrl") + assert "DependsOn" not in request_dict + assert arguments == { + "ExecutionRoleArn": "DummyRole", + "PrimaryContainer": { + "Environment": {}, + "Image": "fakeimage", + }, + } + elif request_dict["Type"] == "Transform": + assert request_dict["Name"] == "EstimatorTransformerStepTransformStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["ModelName"], Properties) + arguments.pop("ModelName") + assert "DependsOn" not in request_dict + assert arguments == { + "TransformInput": { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": f"s3://{BUCKET}/transform_manifest", + } + } + }, + "TransformOutput": {"S3OutputPath": None}, + "TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"}, + } + else: + raise Exception("A step exists in the collection of an invalid type.") diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index e3dc10e23e..674c715617 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -13,6 +13,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json + import pytest import sagemaker import os @@ -43,7 +45,8 @@ ) from sagemaker.network import NetworkConfig from sagemaker.transformer import Transformer -from sagemaker.workflow.properties import Properties +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.properties import Properties, PropertyFile from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.retry import ( StepRetryPolicy, @@ -535,6 +538,9 @@ def test_processing_step(sagemaker_session): ) ] cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + evaluation_report = PropertyFile( + name="EvaluationReport", output_name="evaluation", path="evaluation.json" + ) step = ProcessingStep( name="MyProcessingStep", description="ProcessingStep description", @@ -544,9 +550,20 @@ def test_processing_step(sagemaker_session): inputs=inputs, outputs=[], cache_config=cache_config, + property_files=[evaluation_report], ) step.add_depends_on(["ThirdTestStep"]) - assert step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[ + processing_input_data_uri_parameter, + instance_type_parameter, + instance_count_parameter, + ], + steps=[step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyProcessingStep", "Description": "ProcessingStep description", "DisplayName": "MyProcessingStep", @@ -564,20 +581,27 @@ def test_processing_step(sagemaker_session): "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3InputMode": "File", - "S3Uri": processing_input_data_uri_parameter, + "S3Uri": {"Get": "Parameters.ProcessingInputDataUri"}, }, } ], "ProcessingResources": { "ClusterConfig": { - "InstanceCount": instance_count_parameter, - "InstanceType": instance_type_parameter, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, + "InstanceType": {"Get": "Parameters.InstanceType"}, "VolumeSizeInGB": 30, } }, "RoleArn": "DummyRole", }, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + "PropertyFiles": [ + { + "FilePath": "evaluation.json", + "OutputName": "evaluation", + "PropertyFileName": "EvaluationReport", + } + ], } assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 5a2a9497f8..e534aa531e 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -26,6 +26,7 @@ ) from sagemaker.estimator import Estimator +from sagemaker.workflow import Properties from sagemaker.workflow._utils import _RepackModelStep from tests.unit import DATA_DIR @@ -156,7 +157,7 @@ def test_repack_model_step(estimator): def test_repack_model_step_with_source_dir(estimator, source_dir): - model_data = f"s3://{BUCKET}/model.tar.gz" + model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput") entry_point = "inference.py" step = _RepackModelStep( name="MyRepackModelStep", @@ -189,7 +190,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir): "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", - "S3Uri": f"s3://{BUCKET}", + "S3Uri": model_data, } }, } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8604835890..4523253a7f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): marketplace_cert = (True,) approval_status = ("Approved",) description = "description" + customer_metadata_properties = {"key1": "value1"} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): approval_status=approval_status, description=description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) expected_args = { "ModelPackageName": model_package_name, @@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "CertifyForMarketplace": marketplace_cert, "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)