diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 4d33c9c064..ecdd4b95eb 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1985,7 +1985,75 @@ def expand_role(self, role): if "/" in role: return role return self.boto_session.resource("iam").Role(role).arn - + + # ======================================== + # Hub Operations + # ======================================== + + def describe_hub_content( + self, hub_name, hub_content_name, hub_content_version, hub_content_type, **kwargs + ): + """Describe hub content in a SageMaker Hub. + + Args: + hub_name (str): The name or ARN of the hub. + hub_content_name (str): The name of the hub content. + hub_content_version (str): The version of the hub content. + hub_content_type (str): The type of hub content (Model, ModelReference, Notebook). + + Returns: + dict: Response from the DescribeHubContent API. + """ + return self.sagemaker_client.describe_hub_content( + HubName=hub_name, + HubContentName=hub_content_name, + HubContentVersion=hub_content_version, + HubContentType=hub_content_type, + **kwargs, + ) + + def list_hub_content_versions(self, hub_name, hub_content_name, hub_content_type, **kwargs): + """List versions of hub content in a SageMaker Hub. + + Args: + hub_name (str): The name or ARN of the hub. + hub_content_name (str): The name of the hub content. + hub_content_type (str): The type of hub content. + **kwargs: Additional arguments (e.g., next_token for pagination). + + Returns: + dict: Response from the ListHubContentVersions API. + """ + request = { + "HubName": hub_name, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + } + next_token = kwargs.get("next_token") + if next_token: + request["NextToken"] = next_token + return self.sagemaker_client.list_hub_content_versions(**request) + + def list_hub_contents(self, hub_name, hub_content_type, **kwargs): + """List hub contents in a SageMaker Hub. + + Args: + hub_name (str): The name or ARN of the hub. + hub_content_type (str): The type of hub content to list. + **kwargs: Additional arguments (e.g., next_token for pagination). + + Returns: + dict: Response from the ListHubContents API. + """ + request = { + "HubName": hub_name, + "HubContentType": hub_content_type, + } + next_token = kwargs.get("next_token") + if next_token: + request["NextToken"] = next_token + return self.sagemaker_client.list_hub_contents(**request) + def _expand_container_def(c_def): """Placeholder docstring""" diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index b7dc98b768..b7be863d9b 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -3895,6 +3895,16 @@ def from_jumpstart_config( mb_instance.resource_requirements = resource_requirements mb_instance.model_kms_key = model_kms_key mb_instance.hub_name = jumpstart_config.hub_name + if mb_instance.hub_name and not getattr(mb_instance, "hub_arn", None): + from sagemaker.core.jumpstart.hub.utils import ( + generate_hub_arn_for_init_kwargs, + ) + + mb_instance.hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=mb_instance.hub_name, + region=mb_instance.region, + session=mb_instance.sagemaker_session, + ) mb_instance.config_name = jumpstart_config.inference_config_name mb_instance.accept_eula = jumpstart_config.accept_eula mb_instance.tolerate_vulnerable_model = tolerate_vulnerable_model diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index ecbc270540..3ca6b40f6d 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -975,15 +975,20 @@ def _build_for_jumpstart(self) -> Model: self.secret_key = "" # Get JumpStart model configuration - init_kwargs = get_init_kwargs( + init_kwargs_params = dict( model_id=self.model, model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, + sagemaker_session=self.sagemaker_session, tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), config_name=getattr(self, "config_name", None), ) + hub_arn = getattr(self, "hub_arn", None) + if hub_arn: + init_kwargs_params["hub_arn"] = hub_arn + init_kwargs = get_init_kwargs(**init_kwargs_params) # Configure image URI and environment variables self.image_uri = self.image_uri or init_kwargs.image_uri diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 2ec9ef6475..e58ea4d7ad 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -629,14 +629,19 @@ def _detect_jumpstart_image(self) -> None: ValueError: If image URI cannot be determined or JumpStart lookup fails. """ try: - init_kwargs = get_init_kwargs( + detect_kwargs = dict( model_id=self.model, model_version=getattr(self, "model_version", None) or "*", region=self.region, instance_type=getattr(self, "instance_type", None), + sagemaker_session=getattr(self, "sagemaker_session", None), tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) + hub_arn = getattr(self, "hub_arn", None) + if hub_arn: + detect_kwargs["hub_arn"] = hub_arn + init_kwargs = get_init_kwargs(**detect_kwargs) self.image_uri = init_kwargs.get("image_uri") if not self.image_uri: diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index 053b14f416..4fc1a558ba 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -948,6 +948,7 @@ def test_build_passes_config_name_to_get_init_kwargs( model_version="*", region=self.builder.region, instance_type=self.builder.instance_type, + sagemaker_session=self.builder.sagemaker_session, tolerate_vulnerable_model=None, tolerate_deprecated_model=None, config_name="lmi-optimized", @@ -979,6 +980,7 @@ def test_build_passes_none_config_name_when_not_set( model_version="*", region=self.builder.region, instance_type=self.builder.instance_type, + sagemaker_session=self.builder.sagemaker_session, tolerate_vulnerable_model=None, tolerate_deprecated_model=None, config_name=None, diff --git a/sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py b/sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py new file mode 100644 index 0000000000..4fdee7e6ef --- /dev/null +++ b/sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py @@ -0,0 +1,371 @@ +""" +Unit tests for private hub artifact resolution fix. + +Tests two defects: +1. from_jumpstart_config sets hub_name after __init__ already ran + _initialize_jumpstart_config(), leaving hub_arn as None. +2. _build_for_jumpstart does not forward hub_arn to get_init_kwargs, + so model data resolves from the public catalog instead of the private hub. +""" + +import unittest +from unittest.mock import Mock, patch + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.core.training.configs import Compute +from sagemaker.core.jumpstart.configs import JumpStartConfig + + +MOCK_ROLE_ARN = "arn:aws:iam::123456789012:role/SageMakerRole" +MOCK_HUB_NAME = "my-private-hub" +MOCK_HUB_ARN = "arn:aws:sagemaker:us-east-1:123456789012:hub/my-private-hub" +MOCK_MODEL_ID = "huggingface-llm-phi-4-mini-instruct" +MOCK_MODEL_VERSION = "1.1.0" + + +def _mock_session(): + """Create a mock session that won't trigger real AWS calls.""" + session = Mock() + session.boto_region_name = "us-east-1" + session.sagemaker_config = None + session.boto_session = Mock() + session.boto_session.region_name = "us-east-1" + return session + + +# Common patch to prevent __init__ from making real S3/API calls during +# instance type auto-detection and model ID validation. +_PATCH_IS_JS = patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + + +class TestFromJumpStartConfigHubArnDerivation(unittest.TestCase): + """Test that from_jumpstart_config correctly derives hub_arn from hub_name.""" + + @_PATCH_IS_JS + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch( + "sagemaker.core.jumpstart.hub.utils.generate_hub_arn_for_init_kwargs", + return_value=MOCK_HUB_ARN, + ) + def test_hub_arn_derived_when_hub_name_set( + self, mock_generate_arn, mock_validate, mock_deploy_kwargs, mock_is_js + ): + """hub_arn must be derived after hub_name is assigned in from_jumpstart_config.""" + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + hub_name=MOCK_HUB_NAME, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + ) + + # The key assertion: hub_arn is populated, proving _initialize_jumpstart_config + # ran after hub_name was set in from_jumpstart_config + self.assertEqual(mb.hub_name, MOCK_HUB_NAME) + self.assertEqual(mb.hub_arn, MOCK_HUB_ARN) + # generate_hub_arn_for_init_kwargs must have been called with the hub_name + mock_generate_arn.assert_called() + + @_PATCH_IS_JS + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch( + "sagemaker.core.jumpstart.hub.utils.generate_hub_arn_for_init_kwargs", + return_value=MOCK_HUB_ARN, + ) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + def test_hub_arn_populated_end_to_end( + self, mock_validate, mock_generate_arn, mock_deploy_kwargs, mock_is_js + ): + """End-to-end: hub_arn is correctly populated when hub_name is specified.""" + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + hub_name=MOCK_HUB_NAME, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + ) + + self.assertEqual(mb.hub_name, MOCK_HUB_NAME) + self.assertEqual(mb.hub_arn, MOCK_HUB_ARN) + mock_generate_arn.assert_called() + + @_PATCH_IS_JS + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + def test_hub_arn_is_none_when_no_hub_name(self, mock_validate, mock_deploy_kwargs, mock_is_js): + """hub_arn should remain None when hub_name is not provided.""" + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + ) + + self.assertIsNone(mb.hub_name) + self.assertIsNone(mb.hub_arn) + + +class TestBuildForJumpStartForwardsHubArn(unittest.TestCase): + """Test that _build_for_jumpstart forwards hub_arn to get_init_kwargs.""" + + def setUp(self): + self.mock_session = _mock_session() + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_hub_arn_forwarded_to_get_init_kwargs( + self, mock_prepare, mock_create, mock_get_kwargs, mock_validate, mock_is_js + ): + """get_init_kwargs must receive hub_arn so model data resolves via private hub.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = { + "S3DataSource": { + "S3Uri": "s3://my-private-hub-bucket/artifacts/model.tar.gz", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + builder.hub_name = MOCK_HUB_NAME + builder.hub_arn = MOCK_HUB_ARN + builder.model_version = MOCK_MODEL_VERSION + + builder._build_for_jumpstart() + + # Verify hub_arn was passed to get_init_kwargs + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn") + self.assertEqual(actual_hub_arn, MOCK_HUB_ARN) + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_hub_arn_none_when_no_private_hub( + self, mock_prepare, mock_create, mock_get_kwargs, mock_validate, mock_is_js + ): + """When no private hub is configured, hub_arn should be None (public catalog).""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache-prod-us-east-1/models/model.tar.gz" + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + builder.model_version = MOCK_MODEL_VERSION + + builder._build_for_jumpstart() + + # Verify hub_arn is NOT passed when no private hub (public catalog) + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") + self.assertIsNone(actual_hub_arn) + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_private_hub_resolves_non_public_model_data( + self, mock_prepare, mock_create, mock_get_kwargs, mock_validate, mock_is_js + ): + """With hub_arn set, model_data should resolve to private hub bucket, not public cache.""" + private_s3_uri = "s3://my-private-hub-bucket/hub-content/artifacts/model.tar.gz" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = { + "S3DataSource": { + "S3Uri": private_s3_uri, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + builder.hub_name = MOCK_HUB_NAME + builder.hub_arn = MOCK_HUB_ARN + builder.model_version = MOCK_MODEL_VERSION + + builder._build_for_jumpstart() + + # Confirm model data does NOT point to public JumpStart cache + self.assertNotIn("jumpstart-cache-prod", builder.s3_model_data_url) + self.assertEqual(builder.s3_model_data_url, private_s3_uri) + + +class TestDetectJumpStartImageForwardsHubArn(unittest.TestCase): + """Test that _detect_jumpstart_image forwards hub_arn to get_init_kwargs.""" + + def setUp(self): + self.mock_session = _mock_session() + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.serve.model_builder_utils.get_init_kwargs") + def test_hub_arn_forwarded_in_detect_jumpstart_image( + self, mock_get_kwargs, mock_validate, mock_is_js + ): + """_detect_jumpstart_image must pass hub_arn so private hub images resolve correctly.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.get = lambda k: mock_init_kwargs.image_uri if k == "image_uri" else None + mock_get_kwargs.return_value = mock_init_kwargs + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + ) + builder.hub_arn = MOCK_HUB_ARN + builder.model_version = MOCK_MODEL_VERSION + + builder._detect_jumpstart_image() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn") + self.assertEqual(actual_hub_arn, MOCK_HUB_ARN) + + +class TestEndToEndPrivateHubFlow(unittest.TestCase): + """Integration-style test: from_jumpstart_config with hub_name -> _build_for_jumpstart.""" + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch( + "sagemaker.core.jumpstart.hub.utils.generate_hub_arn_for_init_kwargs", + return_value=MOCK_HUB_ARN, + ) + def test_from_jumpstart_config_then_build_uses_private_hub( + self, + mock_generate_arn, + mock_validate, + mock_deploy_kwargs, + mock_prepare, + mock_create, + mock_get_kwargs, + mock_is_js, + ): + """Full flow: from_jumpstart_config with hub_name -> build -> hub_arn passed through.""" + private_s3_uri = "s3://private-hub-bucket/content/model.tar.gz" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = { + "S3DataSource": { + "S3Uri": private_s3_uri, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + hub_name=MOCK_HUB_NAME, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + compute=Compute(instance_type="ml.g5.xlarge"), + ) + + # Verify hub_arn was derived + self.assertEqual(mb.hub_arn, MOCK_HUB_ARN) + self.assertEqual(mb.hub_name, MOCK_HUB_NAME) + + # Now trigger build + mb.mode = Mode.SAGEMAKER_ENDPOINT + mb._optimizing = False + mb._build_for_jumpstart() + + # Verify hub_arn was forwarded to get_init_kwargs + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn") + self.assertEqual(actual_hub_arn, MOCK_HUB_ARN) + + # Verify model data points to private hub, not public cache + self.assertEqual(mb.s3_model_data_url, private_s3_uri) + self.assertNotIn("jumpstart-cache-prod", mb.s3_model_data_url) + + +if __name__ == "__main__": + unittest.main()