diff --git a/src/aws_durable_execution_sdk_python/__init__.py b/src/aws_durable_execution_sdk_python/__init__.py index 1a24d31..0514767 100644 --- a/src/aws_durable_execution_sdk_python/__init__.py +++ b/src/aws_durable_execution_sdk_python/__init__.py @@ -2,6 +2,8 @@ # Main context - used in every durable function # Helper decorators - commonly used for step functions +# Concurrency +from aws_durable_execution_sdk_python.concurrency.models import BatchResult from aws_durable_execution_sdk_python.context import ( DurableContext, durable_step, @@ -20,7 +22,7 @@ from aws_durable_execution_sdk_python.execution import durable_execution # Essential context types - passed to user functions -from aws_durable_execution_sdk_python.types import BatchResult, StepContext +from aws_durable_execution_sdk_python.types import StepContext __all__ = [ "BatchResult", diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 8efaed0..46d48a1 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -2,6 +2,7 @@ import hashlib import logging +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar from aws_durable_execution_sdk_python.config import ( @@ -74,6 +75,20 @@ PASS_THROUGH_SERDES: SerDes[Any] = PassThroughSerDes() +@dataclass(frozen=True) +class ExecutionContext: + """Readonly metadata about the current durable execution context. + + This class provides immutable access to execution-level metadata. + + Attributes: + durable_execution_arn: The Amazon Resource Name (ARN) of the current + durable execution. + """ + + durable_execution_arn: str + + def durable_step( func: Callable[Concatenate[StepContext, Params], T], ) -> Callable[Params, Callable[[StepContext], T]]: @@ -218,11 +233,13 @@ class DurableContext(DurableContextProtocol): def __init__( self, state: ExecutionState, + execution_context: ExecutionContext, lambda_context: LambdaContext | None = None, parent_id: str | None = None, logger: Logger | None = None, ) -> None: self.state: ExecutionState = state + self.execution_context: ExecutionContext = execution_context self.lambda_context = lambda_context self._parent_id: str | None = parent_id self._step_counter: OrderedCounter = OrderedCounter() @@ -245,6 +262,9 @@ def from_lambda_context( ): return DurableContext( state=state, + execution_context=ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ), lambda_context=lambda_context, parent_id=None, ) @@ -254,6 +274,7 @@ def create_child_context(self, parent_id: str) -> DurableContext: logger.debug("Creating child context for parent %s", parent_id) return DurableContext( state=self.state, + execution_context=self.execution_context, lambda_context=self.lambda_context, parent_id=parent_id, logger=self.logger.with_log_info( diff --git a/tests/context_test.py b/tests/context_test.py index 4e43347..507cfc5 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -16,7 +16,11 @@ ParallelConfig, StepConfig, ) -from aws_durable_execution_sdk_python.context import Callback, DurableContext +from aws_durable_execution_sdk_python.context import ( + Callback, + DurableContext, + ExecutionContext, +) from aws_durable_execution_sdk_python.exceptions import ( CallbackError, SuspendExecution, @@ -39,6 +43,24 @@ from tests.test_helpers import operation_id_sequence +def create_test_context( + state: ExecutionState | None = None, parent_id: str | None = None +) -> DurableContext: + """Helper to create DurableContext for tests with required execution_context.""" + if state is None: + state = Mock(spec=ExecutionState) + state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + execution_context = ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ) + return DurableContext( + state=state, execution_context=execution_context, parent_id=parent_id + ) + + def test_durable_context(): """Test the context module.""" assert DurableContext is not None @@ -250,7 +272,7 @@ def test_create_callback_basic(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -282,7 +304,7 @@ def test_create_callback_with_name_and_config(mock_executor_class): ) config = CallbackConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() [next(operation_ids) for _ in range(5)] # Skip 5 IDs expected_operation_id = next(operation_ids) # Get the 6th ID @@ -315,7 +337,7 @@ def test_create_callback_with_parent_id(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") operation_ids = operation_id_sequence("parent123") [next(operation_ids) for _ in range(2)] # Skip 2 IDs expected_operation_id = next(operation_ids) # Get the 3rd ID @@ -345,7 +367,7 @@ def test_create_callback_increments_counter(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 callback1 = context.create_callback() @@ -383,7 +405,7 @@ def test_step_basic(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -418,7 +440,7 @@ def test_step_with_name_and_config(mock_executor_class): ) # Ensure Mock doesn't have _original_name config = StepConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 result = context.step(mock_callable, config=config) @@ -456,7 +478,7 @@ def test_step_with_parent_id(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 context.step(mock_callable) @@ -493,7 +515,7 @@ def test_step_increments_counter(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 context.step(mock_callable) @@ -529,7 +551,7 @@ def test_step_with_original_name(mock_executor_class): mock_callable = Mock() mock_callable._original_name = "original_function" # noqa: SLF001 - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.step(mock_callable, name="override_name") @@ -564,7 +586,7 @@ def test_invoke_basic(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -596,7 +618,7 @@ def test_invoke_with_name_and_config(mock_executor_class): ) config = InvokeConfig[str, str](timeout=Duration.from_seconds(30)) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 result = context.invoke( @@ -632,7 +654,7 @@ def test_invoke_with_parent_id(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 context.invoke("test_function", None) @@ -664,7 +686,7 @@ def test_invoke_increments_counter(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 context.invoke("function1", "payload1") @@ -697,7 +719,7 @@ def test_invoke_with_none_payload(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke("test_function", None) @@ -737,7 +759,7 @@ def test_invoke_with_custom_serdes(mock_executor_class): timeout=Duration.from_minutes(1), ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke( "test_function", @@ -778,7 +800,7 @@ def test_wait_basic(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -804,7 +826,7 @@ def test_wait_with_name(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 context.wait(Duration.from_minutes(1), name="test_wait") @@ -833,7 +855,7 @@ def test_wait_with_parent_id(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 context.wait(Duration.from_seconds(45)) @@ -862,7 +884,7 @@ def test_wait_increments_counter(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 context.wait(Duration.from_seconds(15)) @@ -894,7 +916,7 @@ def test_wait_returns_none(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait(Duration.from_seconds(10)) @@ -913,7 +935,7 @@ def test_wait_with_time_less_than_one(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) with pytest.raises(ValidationError): context.wait(Duration.from_seconds(0)) @@ -936,7 +958,7 @@ def test_run_in_child_context_basic(mock_handler): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -967,7 +989,7 @@ def test_run_in_child_context_with_name_and_config(mock_handler): config = ChildConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(3)] # Set counter to 3 # noqa: SLF001 result = context.run_in_child_context(mock_callable, config=config) @@ -1001,7 +1023,7 @@ def test_run_in_child_context_with_parent_id(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure Mock doesn't have _original_name - context = DurableContext(state=mock_state, parent_id="parent456") + context = create_test_context(state=mock_state, parent_id="parent456") [context._create_step_id() for _ in range(1)] # Set counter to 1 # noqa: SLF001 context.run_in_child_context(mock_callable) @@ -1037,7 +1059,7 @@ def capture_child_context(child_context): mock_callable = Mock(side_effect=capture_child_context) mock_executor_class.side_effect = lambda func, **kwargs: func() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.run_in_child_context(mock_callable) @@ -1062,7 +1084,7 @@ def test_run_in_child_context_increments_counter(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 context.run_in_child_context(mock_callable) @@ -1097,7 +1119,7 @@ def test_run_in_child_context_resolves_name_from_callable(mock_executor_class): mock_callable = Mock() mock_callable._original_name = "original_function_name" # noqa: SLF001 - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.run_in_child_context(mock_callable) @@ -1128,7 +1150,7 @@ def test_wait_for_callback_basic(mock_executor_class): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "callback_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait_for_callback(mock_submitter) @@ -1158,7 +1180,7 @@ def test_wait_for_callback_with_name_and_config(mock_executor_class): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "configured_callback_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait_for_callback(mock_submitter, config=config) @@ -1186,7 +1208,7 @@ def test_wait_for_callback_resolves_name_from_submitter(mock_executor_class): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "named_callback_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.wait_for_callback(mock_submitter) @@ -1214,11 +1236,11 @@ def capture_handler_call(context, submitter, name, config): def run_child_context(callable_func, name): # Execute the child context callable - child_context = DurableContext(state=mock_state, parent_id="test") + child_context = create_test_context(state=mock_state, parent_id="test") return callable_func(child_context) mock_run_in_child.side_effect = run_child_context - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait_for_callback(mock_submitter) @@ -1244,7 +1266,7 @@ def test_function(context, item, index, items): inputs = [1, 2, 3] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1273,7 +1295,7 @@ def test_function(context, item, index, items): inputs = ["a", "b", "c"] config = MapConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function, name="custom_map", config=config) @@ -1298,7 +1320,7 @@ def test_function(context, item, index, items): inputs = ["hello", "world"] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1322,7 +1344,7 @@ def test_function(context, item, index, items): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "empty_map_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1345,7 +1367,7 @@ def test_function(context, item, index, items): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "mixed_map_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1373,7 +1395,7 @@ def task2(context): callables = [task1, task2] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1403,7 +1425,7 @@ def task2(context): callables = [task1, task2] config = ParallelConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables, name="custom_parallel", config=config) @@ -1435,7 +1457,7 @@ def task2(context): callables = [task1, task2] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) # Use _resolve_step_name to test name resolution resolved_name = context._resolve_step_name(None, mock_callable) # noqa: SLF001 @@ -1466,7 +1488,7 @@ def task2(context): callables = [task1, task2] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1487,7 +1509,7 @@ def test_parallel_with_empty_callables(mock_handler): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "empty_parallel_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1510,7 +1532,7 @@ def single_task(context): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "single_parallel_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1536,7 +1558,7 @@ def task(context): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "many_parallel_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1562,7 +1584,7 @@ def test_function(context, item, index, items): inputs = ["a", "b", "c"] config = MapConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function, config=config) @@ -1588,7 +1610,7 @@ def task2(context): callables = [task1, task2] config = ParallelConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables, config=config) @@ -1603,7 +1625,7 @@ def test_wait_for_condition_validation_errors(): mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) def dummy_wait_strategy(state, attempt): return None @@ -1640,7 +1662,7 @@ def test_function(context, item, index, items): state = Mock() state.durable_execution_arn = "test_arn" - context = DurableContext(state=state) + context = create_test_context(state=state) # Mock the handlers to track calls with patch( @@ -1680,7 +1702,7 @@ def test_callable_2(context): state = Mock() state.durable_execution_arn = "test_arn" - context = DurableContext(state=state) + context = create_test_context(state=state) # Mock the handlers to track calls with patch( @@ -1719,7 +1741,7 @@ def test_wait_strategy(state, attempt): state = Mock() state.durable_execution_arn = "test_arn" - context = DurableContext(state=state) + context = create_test_context(state=state) # Create config config = WaitForConditionConfig( @@ -1820,7 +1842,7 @@ def test_invoke_with_explicit_tenant_id(mock_executor_class): ) config = InvokeConfig(tenant_id="explicit-tenant") - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke("test_function", "payload", config=config) @@ -1842,7 +1864,7 @@ def test_invoke_without_tenant_id_defaults_to_none(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke("test_function", "payload") @@ -1851,3 +1873,89 @@ def test_invoke_without_tenant_id_defaults_to_none(mock_executor_class): call_args = mock_executor_class.call_args[1] assert isinstance(call_args["config"], InvokeConfig) assert call_args["config"].tenant_id is None + + +# region ExecutionContext tests + + +def test_execution_context_exists_on_durable_context(): + """Test that DurableContext has execution_context attribute.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test-execution" + ) + + context = create_test_context(state=mock_state) + + assert hasattr(context, "execution_context") + assert context.execution_context is not None + + +def test_execution_context_has_correct_arn(): + """Test that ExecutionContext contains the correct durable_execution_arn.""" + expected_arn = "arn:aws:durable:us-west-2:987654321098:execution/my-execution" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = expected_arn + + context = create_test_context(state=mock_state) + + assert context.execution_context.durable_execution_arn == expected_arn + + +def test_execution_context_is_immutable(): + """Test that ExecutionContext is frozen and immutable.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = create_test_context(state=mock_state) + + # Attempt to modify should raise FrozenInstanceError for frozen dataclass + with pytest.raises(AttributeError, match="cannot assign to field"): + context.execution_context.durable_execution_arn = "new-arn" + + +def test_execution_context_propagates_to_child_context(): + """Test that child contexts inherit the same execution_context.""" + parent_arn = "arn:aws:durable:eu-west-1:111222333444:execution/parent-exec" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = parent_arn + + parent_context = create_test_context(state=mock_state) + child_context = parent_context.create_child_context(parent_id="parent-op-123") + + assert child_context.execution_context is not None + assert child_context.execution_context.durable_execution_arn == parent_arn + # Should be the same instance (not a copy) + assert child_context.execution_context is parent_context.execution_context + + +def test_from_lambda_context_creates_execution_context(): + """Test that from_lambda_context factory creates ExecutionContext.""" + expected_arn = "arn:aws:durable:ap-south-1:555666777888:execution/lambda-exec" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = expected_arn + mock_lambda_context = Mock() + + context = DurableContext.from_lambda_context( + state=mock_state, lambda_context=mock_lambda_context + ) + + assert context.execution_context is not None + assert context.execution_context.durable_execution_arn == expected_arn + + +def test_execution_context_type(): + """Test that execution_context is of type ExecutionContext.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = create_test_context(state=mock_state) + + assert isinstance(context.execution_context, ExecutionContext) + + +# endregion ExecutionContext tests diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index 5c5a5a1..69d2f31 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -19,15 +19,34 @@ ItemBatcher, MapConfig, ) -from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation import child # PLC0415 from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler from aws_durable_execution_sdk_python.serdes import serialize +from aws_durable_execution_sdk_python.state import ExecutionState from tests.serdes_test import CustomStrSerDes +def create_test_context( + state: ExecutionState | None = None, parent_id: str | None = None +) -> DurableContext: + """Helper to create DurableContext for tests with required execution_context.""" + if state is None: + state = Mock(spec=ExecutionState) + state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + execution_context = ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ) + return DurableContext( + state=state, execution_context=execution_context, parent_id=parent_id + ) + + def test_map_executor_init(): """Test MapExecutor initialization.""" executables = [Executable(index=0, func=lambda: None)] @@ -808,7 +827,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.map( ["a", "b"], lambda ctx, item, idx, items: item, @@ -870,7 +889,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.map( ["a", "b"], lambda ctx, item, idx, items: item, @@ -970,7 +989,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(["a", "b"], lambda ctx, item, idx, items: item) assert len(mock_serdes_serialize.call_args_list) == 3 @@ -1022,7 +1041,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(["a", "b"], lambda ctx, item, idx, items: item) assert isinstance(result, BatchResult) @@ -1078,7 +1097,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map( ["a", "b"], lambda ctx, item, idx, items: item, diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index c43be7e..5021788 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -17,7 +17,7 @@ Executable, ) from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig -from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation import child @@ -26,9 +26,28 @@ parallel_handler, ) from aws_durable_execution_sdk_python.serdes import serialize +from aws_durable_execution_sdk_python.state import ExecutionState from tests.serdes_test import CustomStrSerDes +def create_test_context( + state: ExecutionState | None = None, parent_id: str | None = None +) -> DurableContext: + """Helper to create DurableContext for tests with required execution_context.""" + if state is None: + state = Mock(spec=ExecutionState) + state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + execution_context = ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ) + return DurableContext( + state=state, execution_context=execution_context, parent_id=parent_id + ) + + def test_parallel_executor_init(): """Test ParallelExecutor initialization.""" executables = [Executable(index=0, func=lambda x: x)] @@ -791,7 +810,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.parallel( [lambda ctx: "a", lambda ctx: "b"], config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), @@ -852,7 +871,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.parallel( [lambda ctx: "a", lambda ctx: "b"], config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), @@ -964,7 +983,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel([lambda ctx: "a", lambda ctx: "b"]) assert len(mock_serdes_serialize.call_args_list) == 3 @@ -1015,7 +1034,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel([lambda ctx: "a", lambda ctx: "b"]) assert isinstance(result, BatchResult) @@ -1071,7 +1090,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel( [lambda ctx: "a", lambda ctx: "b"], config=ParallelConfig(serdes=custom_serdes), diff --git a/tests/test_helpers.py b/tests/test_helpers.py index dca15a0..77611a3 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,7 +2,7 @@ from unittest.mock import Mock -from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.execution import ExecutionState @@ -11,7 +11,10 @@ def operation_id_sequence(parent_id: str | None = None): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test-arn" - context = DurableContext(state=mock_state, parent_id=parent_id) + execution_context = ExecutionContext(durable_execution_arn="test-arn") + context = DurableContext( + state=mock_state, execution_context=execution_context, parent_id=parent_id + ) while True: yield context._create_step_id() # noqa: SLF001