From 648bdc82af2f35f020479e3d6337d8e4ee4862dd Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Wed, 1 Jul 2026 13:54:02 -0700 Subject: [PATCH] feat(feature-store): add BatchWriteRecord and ListRecords to FeatureGroup Add two new FeatureStore Runtime operations to the FeatureGroup resource: - batch_write_record: writes a batch of records to one or more FeatureGroups - list_records: lists RecordIdentifier values stored in a FeatureGroup's OnlineStore Changes: - Update service-2.json from botocore 1.43.38 - Add shape classes (BatchWriteRecordEntry, BatchWriteRecordResponse, etc.) - Add shape_dag entries and additional_operations mapping - Add exclude_resource_attributes support to codegen to prevent FeatureGroup.next_token from being incorrectly mapped to ListRecords.NextToken (different pagination contexts) - Regenerate resources.py and shapes.py via codegen --- X-AI-Prompt: add batch_write_record and list_records to sagemaker-core FeatureGroup, fix codegen next_token bug X-AI-Tool: kiro-cli --- .../2020-07-01/service-2.json | 190 +++++++++- .../src/sagemaker/core/resources.py | 116 +++++++ .../src/sagemaker/core/shapes/shapes.py | 78 ++++- .../core/tools/additional_operations.json | 17 + .../sagemaker/core/tools/resources_codegen.py | 16 +- .../core/utils/code_injection/shape_dag.py | 71 ++++ .../test_feature_store_operations.py | 324 ++++++++++++++++++ 7 files changed, 800 insertions(+), 12 deletions(-) create mode 100644 sagemaker-core/tests/unit/generated/test_feature_store_operations.py diff --git a/sagemaker-core/sample/sagemaker-featurestore-runtime/2020-07-01/service-2.json b/sagemaker-core/sample/sagemaker-featurestore-runtime/2020-07-01/service-2.json index 305a3abf79..9819e70f98 100644 --- a/sagemaker-core/sample/sagemaker-featurestore-runtime/2020-07-01/service-2.json +++ b/sagemaker-core/sample/sagemaker-featurestore-runtime/2020-07-01/service-2.json @@ -5,11 +5,13 @@ "endpointPrefix":"featurestore-runtime.sagemaker", "jsonVersion":"1.1", "protocol":"rest-json", + "protocols":["rest-json"], "serviceFullName":"Amazon SageMaker Feature Store Runtime", "serviceId":"SageMaker FeatureStore Runtime", "signatureVersion":"v4", "signingName":"sagemaker", - "uid":"sagemaker-featurestore-runtime-2020-07-01" + "uid":"sagemaker-featurestore-runtime-2020-07-01", + "auth":["aws.auth#sigv4"] }, "operations":{ "BatchGetRecord":{ @@ -28,6 +30,23 @@ ], "documentation":"

Retrieves a batch of Records from a FeatureGroup.

" }, + "BatchWriteRecord":{ + "name":"BatchWriteRecord", + "http":{ + "method":"POST", + "requestUri":"/BatchWriteRecord" + }, + "input":{"shape":"BatchWriteRecordRequest"}, + "output":{"shape":"BatchWriteRecordResponse"}, + "errors":[ + {"shape":"ValidationError"}, + {"shape":"ResourceNotFound"}, + {"shape":"InternalFailure"}, + {"shape":"ServiceUnavailable"}, + {"shape":"AccessForbidden"} + ], + "documentation":"

Writes a batch of Records to one or more FeatureGroups. Use this API for bulk ingestion of records into the OnlineStore and OfflineStore.

You can set the ingested records to expire at a given time to live (TTL) duration after the record's event time by specifying the TtlDuration parameter. A request level TtlDuration applies to all entries that do not specify their own TtlDuration.

" + }, "DeleteRecord":{ "name":"DeleteRecord", "http":{ @@ -41,7 +60,7 @@ {"shape":"ServiceUnavailable"}, {"shape":"AccessForbidden"} ], - "documentation":"

Deletes a Record from a FeatureGroup in the OnlineStore. Feature Store supports both SoftDelete and HardDelete. For SoftDelete (default), feature columns are set to null and the record is no longer retrievable by GetRecord or BatchGetRecord. For HardDelete, the complete Record is removed from the OnlineStore. In both cases, Feature Store appends the deleted record marker to the OfflineStore. The deleted record marker is a record with the same RecordIdentifer as the original, but with is_deleted value set to True, EventTime set to the delete input EventTime, and other feature values set to null.

Note that the EventTime specified in DeleteRecord should be set later than the EventTime of the existing record in the OnlineStore for that RecordIdentifer. If it is not, the deletion does not occur:

When a record is deleted from the OnlineStore, the deleted record marker is appended to the OfflineStore. If you have the Iceberg table format enabled for your OfflineStore, you can remove all history of a record from the OfflineStore using Amazon Athena or Apache Spark. For information on how to hard delete a record from the OfflineStore with the Iceberg table format enabled, see Delete records from the offline store.

" + "documentation":"

Deletes a Record from a FeatureGroup in the OnlineStore. Feature Store supports both SoftDelete and HardDelete. For SoftDelete (default), feature columns are set to null and the record is no longer retrievable by GetRecord or BatchGetRecord. For HardDelete, the complete Record is removed from the OnlineStore. In both cases, Feature Store appends the deleted record marker to the OfflineStore. The deleted record marker is a record with the same RecordIdentifer as the original, but with is_deleted value set to True, EventTime set to the delete input EventTime, and other feature values set to null.

Note that the EventTime specified in DeleteRecord should be set later than the EventTime of the existing record in the OnlineStore for that RecordIdentifer. If it is not, the deletion does not occur:

When a record is deleted from the OnlineStore, the deleted record marker is appended to the OfflineStore. If you have the Iceberg table format enabled for your OfflineStore, you can remove all history of a record from the OfflineStore using Amazon Athena or Apache Spark. For information on how to hard delete a record from the OfflineStore with the Iceberg table format enabled, see Delete records from the offline store.

" }, "GetRecord":{ "name":"GetRecord", @@ -60,6 +79,24 @@ ], "documentation":"

Use for OnlineStore serving from a FeatureStore. Only the latest records stored in the OnlineStore can be retrieved. If no Record with RecordIdentifierValue is found, then an empty result is returned.

" }, + "ListRecords":{ + "name":"ListRecords", + "http":{ + "method":"POST", + "requestUri":"/FeatureGroup/{FeatureGroupName}/ListRecords" + }, + "input":{"shape":"ListRecordsRequest"}, + "output":{"shape":"ListRecordsResponse"}, + "errors":[ + {"shape":"ValidationError"}, + {"shape":"ResourceNotFound"}, + {"shape":"InternalFailure"}, + {"shape":"ServiceUnavailable"}, + {"shape":"AccessForbidden"} + ], + "documentation":"

Lists the RecordIdentifier values of all records stored in a FeatureGroup's OnlineStore. This enables you to discover which records exist without retrieving the full record data.

", + "readonly":true + }, "PutRecord":{ "name":"PutRecord", "http":{ @@ -216,6 +253,98 @@ "member":{"shape":"BatchGetRecordResultDetail"}, "min":0 }, + "BatchWriteRecordEntries":{ + "type":"list", + "member":{"shape":"BatchWriteRecordEntry"}, + "max":25, + "min":1 + }, + "BatchWriteRecordEntry":{ + "type":"structure", + "required":[ + "FeatureGroupName", + "Record" + ], + "members":{ + "FeatureGroupName":{ + "shape":"FeatureGroupNameOrArn", + "documentation":"

The name or Amazon Resource Name (ARN) of the FeatureGroup to write the record to.

" + }, + "Record":{ + "shape":"Record", + "documentation":"

List of FeatureValues to be inserted. This will be a full over-write.

" + }, + "TargetStores":{ + "shape":"TargetStores", + "documentation":"

A list of stores to which you're adding the record. By default, Feature Store adds the record to all of the stores that you're using for the FeatureGroup.

" + }, + "TtlDuration":{ + "shape":"TtlDuration", + "documentation":"

Time to live duration for this entry, where the record is hard deleted after the expiration time is reached; ExpiresAt = EventTime + TtlDuration. This overrides the request level TtlDuration.

" + } + }, + "documentation":"

An entry to write as part of a BatchWriteRecord request.

" + }, + "BatchWriteRecordError":{ + "type":"structure", + "required":[ + "Entry", + "ErrorCode", + "ErrorMessage" + ], + "members":{ + "Entry":{ + "shape":"BatchWriteRecordEntry", + "documentation":"

The entry that failed to be written.

" + }, + "ErrorCode":{ + "shape":"ValueAsString", + "documentation":"

The error code for the failed record write.

" + }, + "ErrorMessage":{ + "shape":"Message", + "documentation":"

The error message for the failed record write.

" + } + }, + "documentation":"

The error that has occurred when attempting to write a record in a batch.

" + }, + "BatchWriteRecordErrors":{ + "type":"list", + "member":{"shape":"BatchWriteRecordError"}, + "min":0 + }, + "BatchWriteRecordRequest":{ + "type":"structure", + "required":["Entries"], + "members":{ + "Entries":{ + "shape":"BatchWriteRecordEntries", + "documentation":"

A list of records to write. Each entry specifies the FeatureGroup, the record data, and optionally target stores and a TTL duration.

" + }, + "TtlDuration":{ + "shape":"TtlDuration", + "documentation":"

Time to live duration applied to all entries in the batch that do not specify their own TtlDuration; ExpiresAt = EventTime + TtlDuration. For information on HardDelete, see the DeleteRecord API in the Amazon SageMaker API Reference guide.

" + } + } + }, + "BatchWriteRecordResponse":{ + "type":"structure", + "required":[ + "Errors", + "UnprocessedEntries" + ], + "members":{ + "Errors":{ + "shape":"BatchWriteRecordErrors", + "documentation":"

A list of errors that occurred when writing records in the batch.

" + }, + "UnprocessedEntries":{ + "shape":"UnprocessedBatchWriteRecordEntries", + "documentation":"

A list of entries that were not processed. These entries can be retried.

" + } + } + }, + "Boolean":{"type":"boolean"}, "DeleteRecordRequest":{ "type":"structure", "required":[ @@ -364,6 +493,54 @@ "fault":true, "synthetic":true }, + "ListRecordsMaxResults":{ + "type":"integer", + "max":100, + "min":1 + }, + "ListRecordsNextToken":{ + "type":"string", + "max":8192, + "min":1 + }, + "ListRecordsRequest":{ + "type":"structure", + "required":["FeatureGroupName"], + "members":{ + "FeatureGroupName":{ + "shape":"FeatureGroupNameOrArn", + "documentation":"

The name or Amazon Resource Name (ARN) of the feature group to list records from.

", + "location":"uri", + "locationName":"FeatureGroupName" + }, + "MaxResults":{ + "shape":"ListRecordsMaxResults", + "documentation":"

The maximum number of record identifiers to return in a single page of results. For the InMemory tier, this value is a hint and not a strict requirement. The response may contain more or fewer results than the specified MaxResults.

" + }, + "NextToken":{ + "shape":"ListRecordsNextToken", + "documentation":"

A token to resume pagination of ListRecords results.

" + }, + "IncludeSoftDeletedRecords":{ + "shape":"Boolean", + "documentation":"

If set to true, the result includes records that have been soft deleted.

" + } + } + }, + "ListRecordsResponse":{ + "type":"structure", + "required":["RecordIdentifiers"], + "members":{ + "RecordIdentifiers":{ + "shape":"RecordIdentifierList", + "documentation":"

A list of record identifier values for the records stored in the OnlineStore.

" + }, + "NextToken":{ + "shape":"ListRecordsNextToken", + "documentation":"

A token to resume pagination if the response includes more record identifiers than MaxResults.

" + } + } + }, "Message":{ "type":"string", "max":2048 @@ -400,6 +577,10 @@ "member":{"shape":"FeatureValue"}, "min":1 }, + "RecordIdentifierList":{ + "type":"list", + "member":{"shape":"ValueAsString"} + }, "RecordIdentifiers":{ "type":"list", "member":{"shape":"ValueAsString"}, @@ -471,6 +652,11 @@ "type":"integer", "min":1 }, + "UnprocessedBatchWriteRecordEntries":{ + "type":"list", + "member":{"shape":"BatchWriteRecordEntry"}, + "min":0 + }, "UnprocessedIdentifiers":{ "type":"list", "member":{"shape":"BatchGetRecordIdentifier"}, diff --git a/sagemaker-core/src/sagemaker/core/resources.py b/sagemaker-core/src/sagemaker/core/resources.py index 9093fe7f89..ff462080bb 100644 --- a/sagemaker-core/src/sagemaker/core/resources.py +++ b/sagemaker-core/src/sagemaker/core/resources.py @@ -12394,6 +12394,122 @@ def batch_get_record( transformed_response = transform(response, "BatchGetRecordResponse") return BatchGetRecordResponse(**transformed_response) + @Base.add_validate_call + def batch_write_record( + self, + entries: List[BatchWriteRecordEntry], + ttl_duration: Optional[TtlDuration] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[BatchWriteRecordResponse]: + """ + Writes a batch of Records to one or more FeatureGroups. + + Parameters: + entries: A list of records to write. Each entry specifies the FeatureGroup, the record data, and optionally target stores and a TTL duration. + ttl_duration: Time to live duration applied to all entries in the batch that do not specify their own TtlDuration; ExpiresAt = EventTime + TtlDuration. For information on HardDelete, see the DeleteRecord API in the Amazon SageMaker API Reference guide. + session: Boto3 session. + region: Region name. + + Returns: + BatchWriteRecordResponse + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + AccessForbidden: You do not have permission to perform an action. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ResourceNotFound: Resource being access is not found. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. + """ + + operation_input_args = { + "Entries": entries, + "TtlDuration": ttl_duration, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-featurestore-runtime" + ) + + logger.debug(f"Calling batch_write_record API") + response = client.batch_write_record(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "BatchWriteRecordResponse") + return BatchWriteRecordResponse(**transformed_response) + + @Base.add_validate_call + def list_records( + self, + max_results: Optional[int] = Unassigned(), + next_token: Optional[StrPipeVar] = Unassigned(), + include_soft_deleted_records: Optional[bool] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[ListRecordsResponse]: + """ + Lists the RecordIdentifier values of all records stored in a FeatureGroup's OnlineStore. + + Parameters: + max_results: The maximum number of record identifiers to return in a single page of results. For the InMemory tier, this value is a hint and not a strict requirement. The response may contain more or fewer results than the specified MaxResults. + next_token: A token to resume pagination of ListRecords results. + include_soft_deleted_records: If set to true, the result includes records that have been soft deleted. + session: Boto3 session. + region: Region name. + + Returns: + ListRecordsResponse + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + AccessForbidden: You do not have permission to perform an action. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ResourceNotFound: Resource being access is not found. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. + """ + + operation_input_args = { + "FeatureGroupName": self.feature_group_name, + "MaxResults": max_results, + "NextToken": next_token, + "IncludeSoftDeletedRecords": include_soft_deleted_records, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-featurestore-runtime" + ) + + logger.debug(f"Calling list_records API") + response = client.list_records(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "ListRecordsResponse") + return ListRecordsResponse(**transformed_response) + class FeatureMetadata(Base): """ diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index ce25c890dd..3eb547ec1c 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -325,6 +325,71 @@ class BatchGetRecordResponse(Base): unprocessed_identifiers: List[BatchGetRecordIdentifier] +class TtlDuration(Base): + """ + TtlDuration + Time to live duration, where the record is hard deleted after the expiration time is reached; ExpiresAt = EventTime + TtlDuration. For information on HardDelete, see the DeleteRecord API in the Amazon SageMaker API Reference guide. + + Attributes + ---------------------- + unit: TtlDuration time unit. + value: TtlDuration time value. + """ + + unit: Optional[StrPipeVar] = Unassigned() + value: Optional[int] = Unassigned() + + +class BatchWriteRecordEntry(Base): + """ + BatchWriteRecordEntry + An entry to write as part of a BatchWriteRecord request. + + Attributes + ---------------------- + feature_group_name: The name or Amazon Resource Name (ARN) of the FeatureGroup to write the record to. + record: List of FeatureValues to be inserted. This will be a full over-write. + target_stores: A list of stores to which you're adding the record. By default, Feature Store adds the record to all of the stores that you're using for the FeatureGroup. + ttl_duration: Time to live duration for this entry, where the record is hard deleted after the expiration time is reached; ExpiresAt = EventTime + TtlDuration. This overrides the request level TtlDuration. + """ + + feature_group_name: Union[StrPipeVar, object] + record: List[FeatureValue] + target_stores: Optional[List[StrPipeVar]] = Unassigned() + ttl_duration: Optional[TtlDuration] = Unassigned() + + +class BatchWriteRecordError(Base): + """ + BatchWriteRecordError + The error that has occurred when attempting to write a record in a batch. + + Attributes + ---------------------- + entry: The entry that failed to be written. + error_code: The error code for the failed record write. + error_message: The error message for the failed record write. + """ + + entry: BatchWriteRecordEntry + error_code: StrPipeVar + error_message: StrPipeVar + + +class BatchWriteRecordResponse(Base): + """ + BatchWriteRecordResponse + + Attributes + ---------------------- + errors: A list of errors that occurred when writing records in the batch. + unprocessed_entries: A list of entries that were not processed. These entries can be retried. + """ + + errors: List[BatchWriteRecordError] + unprocessed_entries: List[BatchWriteRecordEntry] + + class GetRecordResponse(Base): """ GetRecordResponse @@ -339,19 +404,18 @@ class GetRecordResponse(Base): expires_at: Optional[StrPipeVar] = Unassigned() -class TtlDuration(Base): +class ListRecordsResponse(Base): """ - TtlDuration - Time to live duration, where the record is hard deleted after the expiration time is reached; ExpiresAt = EventTime + TtlDuration. For information on HardDelete, see the DeleteRecord API in the Amazon SageMaker API Reference guide. + ListRecordsResponse Attributes ---------------------- - unit: TtlDuration time unit. - value: TtlDuration time value. + record_identifiers: A list of record identifier values for the records stored in the OnlineStore. + next_token: A token to resume pagination if the response includes more record identifiers than MaxResults. """ - unit: Optional[StrPipeVar] = Unassigned() - value: Optional[int] = Unassigned() + record_identifiers: List[StrPipeVar] + next_token: Optional[StrPipeVar] = Unassigned() class ResourceNotFound(Base): diff --git a/sagemaker-core/src/sagemaker/core/tools/additional_operations.json b/sagemaker-core/src/sagemaker/core/tools/additional_operations.json index 8579612806..25ba6b0fce 100644 --- a/sagemaker-core/src/sagemaker/core/tools/additional_operations.json +++ b/sagemaker-core/src/sagemaker/core/tools/additional_operations.json @@ -474,6 +474,23 @@ "return_type": "BatchGetRecordResponse", "method_type": "object", "service_name": "sagemaker-featurestore-runtime" + }, + "BatchWriteRecord": { + "operation_name": "BatchWriteRecord", + "resource_name": "FeatureGroup", + "method_name": "batch_write_record", + "return_type": "BatchWriteRecordResponse", + "method_type": "object", + "service_name": "sagemaker-featurestore-runtime" + }, + "ListRecords": { + "operation_name": "ListRecords", + "resource_name": "FeatureGroup", + "method_name": "list_records", + "return_type": "ListRecordsResponse", + "method_type": "object", + "service_name": "sagemaker-featurestore-runtime", + "exclude_resource_attributes": ["next_token"] } } } diff --git a/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py b/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py index eb8f536d11..0d796ceb24 100644 --- a/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py +++ b/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py @@ -1493,13 +1493,23 @@ def generate_method(self, method: Method, resource_attributes: list): else: decorator = "" method_args = add_indent("self,\n", 4) + # Allow operations to exclude specific resource attributes from self-mapping + # This is needed when a resource attribute name collides with an unrelated + # operation input (e.g., FeatureGroup.next_token vs ListRecords.NextToken) + exclude_resource_attrs_override = getattr(method, "exclude_resource_attributes", []) + effective_resource_attributes = [ + attr for attr in resource_attributes if attr not in exclude_resource_attrs_override + ] method_args += ( - self._generate_method_args(operation_input_shape_name, resource_attributes) + "\n" + self._generate_method_args( + operation_input_shape_name, effective_resource_attributes + ) + + "\n" ) operation_input_args = self._generate_operation_input_args_updated( - operation_metadata, False, resource_attributes + operation_metadata, False, effective_resource_attributes ) - exclude_resource_attrs = resource_attributes + exclude_resource_attrs = effective_resource_attributes method_args += add_indent("session: Optional[Session] = None,\n", 4) method_args += add_indent("region: Optional[str] = None,", 4) diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py index 5d0de63efd..61d7a1b223 100644 --- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py +++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py @@ -1597,6 +1597,51 @@ ], "type": "structure", }, + "BatchWriteRecordEntries": { + "member_shape": "BatchWriteRecordEntry", + "member_type": "structure", + "type": "list", + }, + "BatchWriteRecordEntry": { + "members": [ + {"name": "FeatureGroupName", "shape": "FeatureGroupNameOrArn", "type": "string"}, + {"name": "Record", "shape": "Record", "type": "list"}, + {"name": "TargetStores", "shape": "TargetStores", "type": "list"}, + {"name": "TtlDuration", "shape": "TtlDuration", "type": "structure"}, + ], + "type": "structure", + }, + "BatchWriteRecordError": { + "members": [ + {"name": "Entry", "shape": "BatchWriteRecordEntry", "type": "structure"}, + {"name": "ErrorCode", "shape": "ValueAsString", "type": "string"}, + {"name": "ErrorMessage", "shape": "Message", "type": "string"}, + ], + "type": "structure", + }, + "BatchWriteRecordErrors": { + "member_shape": "BatchWriteRecordError", + "member_type": "structure", + "type": "list", + }, + "BatchWriteRecordRequest": { + "members": [ + {"name": "Entries", "shape": "BatchWriteRecordEntries", "type": "list"}, + {"name": "TtlDuration", "shape": "TtlDuration", "type": "structure"}, + ], + "type": "structure", + }, + "BatchWriteRecordResponse": { + "members": [ + {"name": "Errors", "shape": "BatchWriteRecordErrors", "type": "list"}, + { + "name": "UnprocessedEntries", + "shape": "UnprocessedBatchWriteRecordEntries", + "type": "list", + }, + ], + "type": "structure", + }, "BedrockCustomModelDeploymentMetadata": { "members": [{"name": "Arn", "shape": "String1024", "type": "string"}], "type": "structure", @@ -12221,6 +12266,22 @@ ], "type": "structure", }, + "ListRecordsRequest": { + "members": [ + {"name": "FeatureGroupName", "shape": "FeatureGroupNameOrArn", "type": "string"}, + {"name": "MaxResults", "shape": "ListRecordsMaxResults", "type": "integer"}, + {"name": "NextToken", "shape": "ListRecordsNextToken", "type": "string"}, + {"name": "IncludeSoftDeletedRecords", "shape": "Boolean", "type": "boolean"}, + ], + "type": "structure", + }, + "ListRecordsResponse": { + "members": [ + {"name": "RecordIdentifiers", "shape": "RecordIdentifierList", "type": "list"}, + {"name": "NextToken", "shape": "ListRecordsNextToken", "type": "string"}, + ], + "type": "structure", + }, "ListResourceCatalogsRequest": { "members": [ {"name": "NameContains", "shape": "ResourceCatalogName", "type": "string"}, @@ -15583,6 +15644,11 @@ "type": "structure", }, "Record": {"member_shape": "FeatureValue", "member_type": "structure", "type": "list"}, + "RecordIdentifierList": { + "member_shape": "ValueAsString", + "member_type": "string", + "type": "list", + }, "RecordIdentifiers": {"member_shape": "ValueAsString", "member_type": "string", "type": "list"}, "RedshiftDatasetDefinition": { "members": [ @@ -17743,6 +17809,11 @@ ], "type": "structure", }, + "UnprocessedBatchWriteRecordEntries": { + "member_shape": "BatchWriteRecordEntry", + "member_type": "structure", + "type": "list", + }, "UnprocessedIdentifiers": { "member_shape": "BatchGetRecordIdentifier", "member_type": "structure", diff --git a/sagemaker-core/tests/unit/generated/test_feature_store_operations.py b/sagemaker-core/tests/unit/generated/test_feature_store_operations.py new file mode 100644 index 0000000000..9b54acb34d --- /dev/null +++ b/sagemaker-core/tests/unit/generated/test_feature_store_operations.py @@ -0,0 +1,324 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for FeatureGroup batch_write_record and list_records methods.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import patch, MagicMock + +from sagemaker.core.resources import FeatureGroup +from sagemaker.core.shapes.shapes import ( + BatchWriteRecordEntry, + BatchWriteRecordResponse, + FeatureValue, + ListRecordsResponse, + TtlDuration, +) + + +@pytest.fixture +def mock_feature_group(): + """Create a FeatureGroup instance with mocked internals.""" + fg = FeatureGroup.model_construct( + feature_group_name="test-feature-group", + next_token=None, + ) + return fg + + +class TestBatchWriteRecord: + """Tests for FeatureGroup.batch_write_record method.""" + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_batch_write_record_success(self, mock_get_client, mock_transform, mock_feature_group): + """Test that batch_write_record calls the client with correct arguments.""" + mock_client = MagicMock() + mock_client.batch_write_record.return_value = { + "Errors": [], + "UnprocessedEntries": [], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "errors": [], + "unprocessed_entries": [], + } + + entries = [ + BatchWriteRecordEntry( + feature_group_name="test-feature-group", + record=[FeatureValue(feature_name="feature1", value_as_string="value1")], + ) + ] + + mock_feature_group.batch_write_record(entries=entries) + + mock_get_client.assert_called_once_with( + session=None, region_name=None, service_name="sagemaker-featurestore-runtime" + ) + mock_client.batch_write_record.assert_called_once() + call_kwargs = mock_client.batch_write_record.call_args[1] + assert "Entries" in call_kwargs + mock_transform.assert_called_once() + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_batch_write_record_with_ttl_duration( + self, mock_get_client, mock_transform, mock_feature_group + ): + """Test batch_write_record with optional ttl_duration parameter.""" + mock_client = MagicMock() + mock_client.batch_write_record.return_value = { + "Errors": [], + "UnprocessedEntries": [], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "errors": [], + "unprocessed_entries": [], + } + + entries = [ + BatchWriteRecordEntry( + feature_group_name="test-feature-group", + record=[FeatureValue(feature_name="feature1", value_as_string="value1")], + ) + ] + ttl = TtlDuration(unit="Hours", value=24) + + mock_feature_group.batch_write_record(entries=entries, ttl_duration=ttl) + + mock_get_client.assert_called_once_with( + session=None, region_name=None, service_name="sagemaker-featurestore-runtime" + ) + mock_client.batch_write_record.assert_called_once() + call_kwargs = mock_client.batch_write_record.call_args[1] + assert "Entries" in call_kwargs + assert "TtlDuration" in call_kwargs + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_batch_write_record_returns_response( + self, mock_get_client, mock_transform, mock_feature_group + ): + """Test that batch_write_record returns a BatchWriteRecordResponse.""" + mock_client = MagicMock() + mock_client.batch_write_record.return_value = { + "Errors": [], + "UnprocessedEntries": [], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "errors": [], + "unprocessed_entries": [], + } + + entries = [ + BatchWriteRecordEntry( + feature_group_name="test-feature-group", + record=[FeatureValue(feature_name="feature1", value_as_string="value1")], + ) + ] + + result = mock_feature_group.batch_write_record(entries=entries) + + assert isinstance(result, BatchWriteRecordResponse) + assert result.errors == [] + assert result.unprocessed_entries == [] + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_batch_write_record_returns_response_with_nested_data( + self, mock_get_client, mock_transform, mock_feature_group + ): + """Test that batch_write_record correctly deserializes nested response data.""" + mock_client = MagicMock() + mock_client.batch_write_record.return_value = { + "Errors": [ + { + "Entry": { + "FeatureGroupName": "test-feature-group", + "Record": [{"FeatureName": "f1", "ValueAsString": "v1"}], + }, + "ErrorCode": "ValidationError", + "ErrorMessage": "Invalid feature value", + } + ], + "UnprocessedEntries": [ + { + "FeatureGroupName": "test-feature-group", + "Record": [{"FeatureName": "f2", "ValueAsString": "v2"}], + } + ], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "errors": [ + { + "entry": { + "feature_group_name": "test-feature-group", + "record": [{"feature_name": "f1", "value_as_string": "v1"}], + }, + "error_code": "ValidationError", + "error_message": "Invalid feature value", + } + ], + "unprocessed_entries": [ + { + "feature_group_name": "test-feature-group", + "record": [{"feature_name": "f2", "value_as_string": "v2"}], + } + ], + } + + entries = [ + BatchWriteRecordEntry( + feature_group_name="test-feature-group", + record=[FeatureValue(feature_name="feature1", value_as_string="value1")], + ) + ] + + result = mock_feature_group.batch_write_record(entries=entries) + + assert isinstance(result, BatchWriteRecordResponse) + assert len(result.errors) == 1 + assert len(result.unprocessed_entries) == 1 + + +class TestListRecords: + """Tests for FeatureGroup.list_records method.""" + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_list_records_success(self, mock_get_client, mock_transform, mock_feature_group): + """Test that list_records calls the client with FeatureGroupName.""" + mock_client = MagicMock() + mock_client.list_records.return_value = { + "RecordIdentifiers": ["id1", "id2"], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "record_identifiers": ["id1", "id2"], + } + + mock_feature_group.list_records() + + mock_get_client.assert_called_once_with( + session=None, region_name=None, service_name="sagemaker-featurestore-runtime" + ) + mock_client.list_records.assert_called_once() + call_kwargs = mock_client.list_records.call_args[1] + assert call_kwargs["FeatureGroupName"] == "test-feature-group" + mock_transform.assert_called_once() + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_list_records_with_parameters( + self, mock_get_client, mock_transform, mock_feature_group + ): + """Test list_records with max_results and include_soft_deleted_records.""" + mock_client = MagicMock() + mock_client.list_records.return_value = { + "RecordIdentifiers": ["id1"], + "NextToken": "token123", + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "record_identifiers": ["id1"], + "next_token": "token123", + } + + mock_feature_group.list_records( + max_results=10, include_soft_deleted_records=True + ) + + mock_get_client.assert_called_once_with( + session=None, region_name=None, service_name="sagemaker-featurestore-runtime" + ) + mock_client.list_records.assert_called_once() + call_kwargs = mock_client.list_records.call_args[1] + assert call_kwargs["FeatureGroupName"] == "test-feature-group" + assert call_kwargs["MaxResults"] == 10 + assert call_kwargs["IncludeSoftDeletedRecords"] is True + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_list_records_returns_response( + self, mock_get_client, mock_transform, mock_feature_group + ): + """Test that list_records returns a ListRecordsResponse.""" + mock_client = MagicMock() + mock_client.list_records.return_value = { + "RecordIdentifiers": ["id1", "id2", "id3"], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "record_identifiers": ["id1", "id2", "id3"], + } + + result = mock_feature_group.list_records() + + assert isinstance(result, ListRecordsResponse) + assert result.record_identifiers == ["id1", "id2", "id3"] + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_list_records_does_not_use_self_next_token( + self, mock_get_client, mock_transform + ): + """Test that list_records does NOT pass self.next_token (from DescribeFeatureGroup) to ListRecords.""" + fg = FeatureGroup.model_construct( + feature_group_name="test-feature-group", + next_token="describe-pagination-token", + ) + mock_client = MagicMock() + mock_client.list_records.return_value = { + "RecordIdentifiers": ["id1"], + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "record_identifiers": ["id1"], + } + + fg.list_records() + + call_kwargs = mock_client.list_records.call_args[1] + # self.next_token should NOT be passed to ListRecords + assert "NextToken" not in call_kwargs + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_list_records_accepts_next_token_parameter( + self, mock_get_client, mock_transform + ): + """Test that list_records accepts next_token as a pagination parameter.""" + fg = FeatureGroup.model_construct( + feature_group_name="test-feature-group", + next_token="describe-pagination-token", + ) + mock_client = MagicMock() + mock_client.list_records.return_value = { + "RecordIdentifiers": ["id1"], + "NextToken": "next-page-token", + } + mock_get_client.return_value = mock_client + mock_transform.return_value = { + "record_identifiers": ["id1"], + "next_token": "next-page-token", + } + + fg.list_records(next_token="list-records-page-2-token") + + call_kwargs = mock_client.list_records.call_args[1] + # The explicitly passed next_token should be used, not self.next_token + assert call_kwargs["NextToken"] == "list-records-page-2-token"