diff --git a/.github/workflows/durabletask-azuremanaged-dev.yml b/.github/workflows/durabletask-azuremanaged-dev.yml new file mode 100644 index 00000000..0ba1ece0 --- /dev/null +++ b/.github/workflows/durabletask-azuremanaged-dev.yml @@ -0,0 +1,52 @@ +name: Durable Task Scheduler SDK (durabletask-azuremanaged) Dev Release + +on: + workflow_run: + workflows: ["Durable Task Scheduler SDK (durabletask-azuremanaged)"] + types: + - completed + branches: + - main + +jobs: + publish-dev: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Extract version from tag + run: echo "VERSION=${GITHUB_REF#refs/tags/azuremanaged-v}" >> $GITHUB_ENV # Extract version from the tag + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.14" # Adjust Python version as needed + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Append dev to version in pyproject.toml + working-directory: durabletask-azuremanaged + run: | + sed -i 's/^version = "\(.*\)"/version = "\1.dev${{ github.run_number }}"/' pyproject.toml + + - name: Build package from directory durabletask-azuremanaged + working-directory: durabletask-azuremanaged + run: | + python -m build + + - name: Check package + working-directory: durabletask-azuremanaged + run: | + twine check dist/* + + - name: Publish package to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN_AZUREMANAGED }} # Store your PyPI API token in GitHub Secrets + working-directory: durabletask-azuremanaged + run: | + twine upload dist/* \ No newline at end of file diff --git a/.github/workflows/durabletask-azuremanaged-experimental.yml b/.github/workflows/durabletask-azuremanaged-experimental.yml new file mode 100644 index 00000000..444b7f96 --- /dev/null +++ b/.github/workflows/durabletask-azuremanaged-experimental.yml @@ -0,0 +1,51 @@ +name: Durable Task Scheduler SDK (durabletask-azuremanaged) Experimental Release + +on: + push: + branches-ignore: + - main + - release/* + +jobs: + publish-experimental: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Extract version from tag + run: echo "VERSION=${GITHUB_REF#refs/tags/azuremanaged-v}" >> $GITHUB_ENV # Extract version from the tag + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.14" # Adjust Python version as needed + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Change the version in pyproject.toml to 0.0.0dev{github.run_number} + working-directory: durabletask-azuremanaged + run: | + sed -i 's/^version = ".*"/version = "0.0.0.dev${{ github.run_number }}"/' pyproject.toml + sed -i 's/"durabletask>=.*"/"durabletask>=0.0.0dev1"/' pyproject.toml + + - name: Build package from directory durabletask-azuremanaged + working-directory: durabletask-azuremanaged + run: | + python -m build + + - name: Check package + working-directory: durabletask-azuremanaged + run: | + twine check dist/* + + - name: Publish package to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN_AZUREMANAGED }} # Store your PyPI API token in GitHub Secrets + working-directory: durabletask-azuremanaged + run: | + twine upload dist/* \ No newline at end of file diff --git a/.github/workflows/durabletask-azuremanaged.yml b/.github/workflows/durabletask-azuremanaged.yml index c2c40aee..852b06d7 100644 --- a/.github/workflows/durabletask-azuremanaged.yml +++ b/.github/workflows/durabletask-azuremanaged.yml @@ -86,7 +86,7 @@ jobs: run: | pytest -m "dts" --verbose - publish: + publish-release: if: startsWith(github.ref, 'refs/tags/azuremanaged-v') # Only run if a matching tag is pushed needs: run-docker-tests runs-on: ubuntu-latest diff --git a/.github/workflows/durabletask-dev.yml b/.github/workflows/durabletask-dev.yml new file mode 100644 index 00000000..09ee4be4 --- /dev/null +++ b/.github/workflows/durabletask-dev.yml @@ -0,0 +1,49 @@ +name: Durable Task SDK (durabletask) Dev Release + +on: + workflow_run: + workflows: ["Durable Task SDK (durabletask)"] + types: + - completed + branches: + - main + +jobs: + publish-dev: + # needs: run-tests + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Extract version from tag + run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV # Extract version from the tag + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.14" # Adjust Python version as needed + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Append dev to version in pyproject.toml + run: | + sed -i 's/^version = "\(.*\)"/version = "\1.dev${{ github.run_number }}"/' pyproject.toml + + - name: Build package from root directory + run: | + python -m build + + - name: Check package + run: | + twine check dist/* + + - name: Publish package to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} # Store your PyPI API token in GitHub Secrets + run: | + twine upload dist/* \ No newline at end of file diff --git a/.github/workflows/durabletask-experiment.yml b/.github/workflows/durabletask-experiment.yml new file mode 100644 index 00000000..a9d440a5 --- /dev/null +++ b/.github/workflows/durabletask-experiment.yml @@ -0,0 +1,47 @@ +name: Durable Task SDK (durabletask) Experimental Release + +on: + push: + branches-ignore: + - main + - release/* + +jobs: + publish-experimental: + # needs: run-tests + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Extract version from tag + run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV # Extract version from the tag + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.14" # Adjust Python version as needed + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Change the version in pyproject.toml to 0.0.0dev{github.run_number} + run: | + sed -i 's/^version = ".*"/version = "0.0.0.dev${{ github.run_number }}"/' pyproject.toml + + - name: Build package from root directory + run: | + python -m build + + - name: Check package + run: | + twine check dist/* + + - name: Publish package to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} # Store your PyPI API token in GitHub Secrets + run: | + twine upload dist/* \ No newline at end of file diff --git a/.github/workflows/durabletask.yml b/.github/workflows/durabletask.yml index 2f417d9b..e7465ef8 100644 --- a/.github/workflows/durabletask.yml +++ b/.github/workflows/durabletask.yml @@ -2,7 +2,7 @@ name: Durable Task SDK (durabletask) on: push: - branches: + branches: - "main" tags: - "v*" # Only run for tags starting with "v" @@ -71,7 +71,7 @@ jobs: durabletask-go --port 4001 & pytest -m "e2e and not dts" --verbose - publish: + publish-release: if: startsWith(github.ref, 'refs/tags/v') # Only run if a matching tag is pushed needs: run-tests runs-on: ubuntu-latest @@ -105,4 +105,4 @@ jobs: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} # Store your PyPI API token in GitHub Secrets run: | - twine upload dist/* \ No newline at end of file + twine upload dist/* \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index daffc504..a2c3e598 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,30 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v1.2.0 + +ADDED: + +- Added new_uuid method to orchestration clients allowing generation of replay-safe UUIDs. +- Added ProtoTaskHubSidecarServiceStub class to allow passing self-generated stubs to worker +- Added support for new event types needed for specific durable backend setups: + - orchestratorCompleted + - eventSent + - eventRaised modified to support entity events + +CHANGED: + +- Added py.typed marker file to durabletask module +- Updated type hinting on EntityInstanceId.parse() to reflect behavior +- Entity operations now use UUIDs generated with new_uuid + +FIXED: + +- Mismatched parameter names in call_entity/signal_entity from interface + ## v1.1.0 -ADDED: +ADDED: - Allow retrieving entity metadata from the client, with or without state diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index efc31e0f..8d88678e 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v1.2.0 + +- Updates base dependency to durabletask v1.2.0 + - See durabletask changelog for more details + ## v1.1.0 CHANGED: diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index 5c502461..f013a565 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask.azuremanaged" -version = "1.1.0" +version = "1.2.0" description = "Durable Task Python SDK provider implementation for the Azure Durable Task Scheduler" keywords = [ "durable", @@ -26,7 +26,7 @@ requires-python = ">=3.10" license = {file = "LICENSE"} readme = "README.md" dependencies = [ - "durabletask>=1.1.0", + "durabletask>=1.2.0", "azure-identity>=1.19.0" ] diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py index c3b76c13..02a2595a 100644 --- a/durabletask/entities/entity_instance_id.py +++ b/durabletask/entities/entity_instance_id.py @@ -1,6 +1,3 @@ -from typing import Optional - - class EntityInstanceId: def __init__(self, entity: str, key: str): self.entity = entity @@ -20,7 +17,7 @@ def __lt__(self, other): return str(self) < str(other) @staticmethod - def parse(entity_id: str) -> Optional["EntityInstanceId"]: + def parse(entity_id: str) -> "EntityInstanceId": """Parse a string representation of an entity ID into an EntityInstanceId object. Parameters @@ -30,8 +27,13 @@ def parse(entity_id: str) -> Optional["EntityInstanceId"]: Returns ------- - Optional[EntityInstanceId] - The parsed EntityInstanceId object, or None if the input is None. + EntityInstanceId + The parsed EntityInstanceId object. + + Raises + ------ + ValueError + If the input string is not in the correct format. """ try: _, entity, key = entity_id.split("@", 2) diff --git a/durabletask/entities/entity_metadata.py b/durabletask/entities/entity_metadata.py index 68009392..3e04206d 100644 --- a/durabletask/entities/entity_metadata.py +++ b/durabletask/entities/entity_metadata.py @@ -44,8 +44,9 @@ def __init__(self, @staticmethod def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool): - entity_id = EntityInstanceId.parse(entity_response.entity.instanceId) - if not entity_id: + try: + entity_id = EntityInstanceId.parse(entity_response.entity.instanceId) + except ValueError: raise ValueError("Invalid entity instance ID in entity response.") entity_state = None if includes_state: diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index ccd8558b..0b1f655d 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -20,6 +20,11 @@ def new_orchestrator_started_event(timestamp: Optional[datetime] = None) -> pb.H return pb.HistoryEvent(eventId=-1, timestamp=ts, orchestratorStarted=pb.OrchestratorStartedEvent()) +def new_orchestrator_completed_event() -> pb.HistoryEvent: + return pb.HistoryEvent(eventId=-1, timestamp=timestamp_pb2.Timestamp(), + orchestratorCompleted=pb.OrchestratorCompletedEvent()) + + def new_execution_started_event(name: str, instance_id: str, encoded_input: Optional[str] = None, tags: Optional[dict[str, str]] = None) -> pb.HistoryEvent: return pb.HistoryEvent( @@ -119,6 +124,18 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails: ) +def new_event_sent_event(event_id: int, instance_id: str, input: str): + return pb.HistoryEvent( + eventId=event_id, + timestamp=timestamp_pb2.Timestamp(), + eventSent=pb.EventSentEvent( + name="", + input=get_string_value(input), + instanceId=instance_id + ) + ) + + def new_event_raised_event(name: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, @@ -196,9 +213,14 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], )) -def new_call_entity_action(id: int, parent_instance_id: str, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): +def new_call_entity_action(id: int, + parent_instance_id: str, + entity_id: EntityInstanceId, + operation: str, + encoded_input: Optional[str], + request_id: str) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent( - requestId=f"{parent_instance_id}:{id}", + requestId=request_id, operation=operation, scheduledTime=None, input=get_string_value(encoded_input), @@ -208,9 +230,13 @@ def new_call_entity_action(id: int, parent_instance_id: str, entity_id: EntityIn ))) -def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): +def new_signal_entity_action(id: int, + entity_id: EntityInstanceId, + operation: str, + encoded_input: Optional[str], + request_id: str) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( - requestId=f"{entity_id}:{id}", + requestId=request_id, operation=operation, scheduledTime=None, input=get_string_value(encoded_input), diff --git a/durabletask/internal/proto_task_hub_sidecar_service_stub.py b/durabletask/internal/proto_task_hub_sidecar_service_stub.py new file mode 100644 index 00000000..8f51123b --- /dev/null +++ b/durabletask/internal/proto_task_hub_sidecar_service_stub.py @@ -0,0 +1,33 @@ +from typing import Any, Callable, Protocol + + +class ProtoTaskHubSidecarServiceStub(Protocol): + """A stub class matching the TaskHubSidecarServiceStub generated from the .proto file. + Allows the use of TaskHubGrpcWorker methods when a real sidecar stub is not available. + """ + Hello: Callable[..., Any] + StartInstance: Callable[..., Any] + GetInstance: Callable[..., Any] + RewindInstance: Callable[..., Any] + WaitForInstanceStart: Callable[..., Any] + WaitForInstanceCompletion: Callable[..., Any] + RaiseEvent: Callable[..., Any] + TerminateInstance: Callable[..., Any] + SuspendInstance: Callable[..., Any] + ResumeInstance: Callable[..., Any] + QueryInstances: Callable[..., Any] + PurgeInstances: Callable[..., Any] + GetWorkItems: Callable[..., Any] + CompleteActivityTask: Callable[..., Any] + CompleteOrchestratorTask: Callable[..., Any] + CompleteEntityTask: Callable[..., Any] + StreamInstanceHistory: Callable[..., Any] + CreateTaskHub: Callable[..., Any] + DeleteTaskHub: Callable[..., Any] + SignalEntity: Callable[..., Any] + GetEntity: Callable[..., Any] + QueryEntities: Callable[..., Any] + CleanEntityStorage: Callable[..., Any] + AbandonTaskActivityWorkItem: Callable[..., Any] + AbandonTaskOrchestratorWorkItem: Callable[..., Any] + AbandonTaskEntityWorkItem: Callable[..., Any] diff --git a/durabletask/py.typed b/durabletask/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/durabletask/task.py b/durabletask/task.py index 35708388..1ae9f494 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -139,7 +139,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, pass @abstractmethod - def call_entity(self, entity: EntityInstanceId, + def call_entity(self, + entity: EntityInstanceId, operation: str, input: Optional[TInput] = None) -> Task: """Schedule entity function for execution. @@ -258,6 +259,22 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: """ pass + @abstractmethod + def new_uuid(self) -> str: + """Create a new UUID that is safe for replay within an orchestration or operation. + + The default implementation of this method creates a name-based UUID + using the algorithm from RFC 4122 ยง4.3. The name input used to generate + this value is a combination of the orchestration instance ID, the current UTC datetime, + and an internally managed counter. + + Returns + ------- + str + New UUID that is safe for replay within an orchestration or operation. + """ + pass + @abstractmethod def _exit_critical_section(self) -> None: pass diff --git a/durabletask/worker.py b/durabletask/worker.py index fae345c0..48c2e442 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -12,7 +12,8 @@ from threading import Event, Thread from types import GeneratorType from enum import Enum -from typing import Any, Generator, Optional, Sequence, TypeVar, Union +from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union +import uuid from packaging.version import InvalidVersion, parse import grpc @@ -23,6 +24,7 @@ from durabletask.internal.helpers import new_timestamp from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext +from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe import durabletask.internal.orchestrator_service_pb2 as pb @@ -33,6 +35,7 @@ TInput = TypeVar("TInput") TOutput = TypeVar("TOutput") +DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' class ConcurrencyOptions: @@ -629,7 +632,7 @@ def stop(self): def _execute_orchestrator( self, req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): try: @@ -677,7 +680,7 @@ def _execute_orchestrator( def _cancel_orchestrator( self, req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): stub.AbandonTaskOrchestratorWorkItem( @@ -690,7 +693,7 @@ def _cancel_orchestrator( def _execute_activity( self, req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): instance_id = req.orchestrationInstance.instanceId @@ -723,7 +726,7 @@ def _execute_activity( def _cancel_activity( self, req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): stub.AbandonTaskActivityWorkItem( @@ -736,7 +739,7 @@ def _cancel_activity( def _execute_entity_batch( self, req: Union[pb.EntityBatchRequest, pb.EntityRequest], - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): if isinstance(req, pb.EntityRequest): @@ -750,9 +753,10 @@ def _execute_entity_batch( for operation in req.operations: start_time = datetime.now(timezone.utc) executor = _EntityExecutor(self._registry, self._logger) - entity_instance_id = EntityInstanceId.parse(instance_id) - if not entity_instance_id: - raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.") + try: + entity_instance_id = EntityInstanceId.parse(instance_id) + except ValueError: + raise RuntimeError(f"Invalid entity instance ID '{instance_id}' in entity operation request.") operation_result = None @@ -804,7 +808,7 @@ def _execute_entity_batch( def _cancel_entity_batch( self, req: Union[pb.EntityBatchRequest, pb.EntityRequest], - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): stub.AbandonTaskEntityWorkItem( @@ -828,9 +832,11 @@ def __init__(self, instance_id: str, registry: _Registry): self._pending_tasks: dict[int, task.CompletableTask] = {} # Maps entity ID to task ID self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} + self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} # Maps criticalSectionId to task ID self._entity_lock_id_map: dict[str, int] = {} self._sequence_number = 0 + self._new_uuid_counter = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id self._registry = registry @@ -1038,14 +1044,14 @@ def call_activity( def call_entity( self, - entity_id: EntityInstanceId, + entity: EntityInstanceId, operation: str, input: Optional[TInput] = None, ) -> task.Task: id = self.next_sequence_number() self.call_entity_function_helper( - id, entity_id, operation, input=input + id, entity, operation, input=input ) return self._pending_tasks.get(id, task.CompletableTask()) @@ -1053,13 +1059,13 @@ def call_entity( def signal_entity( self, entity_id: EntityInstanceId, - operation: str, + operation_name: str, input: Optional[TInput] = None ) -> None: id = self.next_sequence_number() self.signal_entity_function_helper( - id, entity_id, operation, input + id, entity_id, operation_name, input ) def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]: @@ -1165,7 +1171,7 @@ def call_entity_function_helper( raise RuntimeError(error_message) encoded_input = shared.to_json(input) if input is not None else None - action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input) + action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input, self.new_uuid()) self._pending_actions[id] = action fn_task = task.CompletableTask() @@ -1188,7 +1194,7 @@ def signal_entity_function_helper( encoded_input = shared.to_json(input) if input is not None else None - action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input) + action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid()) self._pending_actions[id] = action def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None: @@ -1199,7 +1205,7 @@ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId if not transition_valid: raise RuntimeError(error_message) - critical_section_id = f"{self.instance_id}:{id:04x}" + critical_section_id = self.new_uuid() request, target = self._entity_context.emit_acquire_message(critical_section_id, entities) @@ -1251,6 +1257,17 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None: self.set_continued_as_new(new_input, save_events) + def new_uuid(self) -> str: + NAMESPACE_UUID: str = "9e952958-5e33-4daf-827f-2fa12937b875" + + uuid_name_value = \ + f"{self._instance_id}" \ + f"_{self.current_utc_datetime.strftime(DATETIME_STRING_FORMAT)}" \ + f"_{self._new_uuid_counter}" + self._new_uuid_counter += 1 + namespace_uuid = uuid.uuid5(uuid.NAMESPACE_OID, NAMESPACE_UUID) + return str(uuid.uuid5(namespace_uuid, uuid_name_value)) + class ExecutionResults: actions: list[pb.OrchestratorAction] @@ -1590,33 +1607,40 @@ def process_event( else: raise TypeError("Unexpected sub-orchestration task type") elif event.HasField("eventRaised"): - # event names are case-insensitive - event_name = event.eventRaised.name.casefold() - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id} Event raised: {event_name}") - task_list = ctx._pending_events.get(event_name, None) - decoded_result: Optional[Any] = None - if task_list: - event_task = task_list.pop(0) - if not ph.is_empty(event.eventRaised.input): - decoded_result = shared.from_json(event.eventRaised.input.value) - event_task.complete(decoded_result) - if not task_list: - del ctx._pending_events[event_name] - ctx.resume() + if event.eventRaised.name in ctx._entity_task_id_map: + entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None)) + self._handle_entity_event_raised(ctx, event, entity_id, task_id, False) + elif event.eventRaised.name in ctx._entity_lock_task_id_map: + entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None)) + self._handle_entity_event_raised(ctx, event, entity_id, task_id, True) else: - # buffer the event - event_list = ctx._received_events.get(event_name, None) - if not event_list: - event_list = [] - ctx._received_events[event_name] = event_list - if not ph.is_empty(event.eventRaised.input): - decoded_result = shared.from_json(event.eventRaised.input.value) - event_list.append(decoded_result) + # event names are case-insensitive + event_name = event.eventRaised.name.casefold() if not ctx.is_replaying: - self._logger.info( - f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." - ) + self._logger.info(f"{ctx.instance_id} Event raised: {event_name}") + task_list = ctx._pending_events.get(event_name, None) + decoded_result: Optional[Any] = None + if task_list: + event_task = task_list.pop(0) + if not ph.is_empty(event.eventRaised.input): + decoded_result = shared.from_json(event.eventRaised.input.value) + event_task.complete(decoded_result) + if not task_list: + del ctx._pending_events[event_name] + ctx.resume() + else: + # buffer the event + event_list = ctx._received_events.get(event_name, None) + if not event_list: + event_list = [] + ctx._received_events[event_name] = event_list + if not ph.is_empty(event.eventRaised.input): + decoded_result = shared.from_json(event.eventRaised.input.value) + event_list.append(decoded_result) + if not ctx.is_replaying: + self._logger.info( + f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." + ) elif event.HasField("executionSuspended"): if not self._is_suspended and not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution suspended.") @@ -1656,8 +1680,9 @@ def process_event( raise _get_wrong_action_type_error( entity_call_id, expected_method_name, action ) - entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) - if not entity_id: + try: + entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) + except ValueError: raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'") ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id) elif event.HasField("entityOperationSignaled"): @@ -1743,6 +1768,21 @@ def process_event( self._logger.info(f"{ctx.instance_id}: Entity operation failed.") self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}") pass + elif event.HasField("orchestratorCompleted"): + # Added in Functions only (for some reason) and does not affect orchestrator flow + pass + elif event.HasField("eventSent"): + # Check if this eventSent corresponds to an entity operation call after being translated to the old + # entity protocol by the Durable WebJobs extension. If so, treat this message similarly to + # entityOperationCalled and remove the pending action. Also store the entity id and event id for later + action = ctx._pending_actions.pop(event.eventId, None) + if action and action.HasField("sendEntityMessage"): + if action.sendEntityMessage.HasField("entityOperationCalled"): + entity_id, event_id = self._parse_entity_event_sent_input(event) + ctx._entity_task_id_map[event_id] = (entity_id, event.eventId) + elif action.sendEntityMessage.HasField("entityLockRequested"): + entity_id, event_id = self._parse_entity_event_sent_input(event) + ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId) else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError( @@ -1752,6 +1792,44 @@ def process_event( # The orchestrator generator function completed ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED) + def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId, str]: + try: + entity_id = EntityInstanceId.parse(event.eventSent.instanceId) + except ValueError: + raise RuntimeError(f"Could not parse entity ID from instanceId '{event.eventSent.instanceId}'") + try: + event_id = json.loads(event.eventSent.input.value)["id"] + except (json.JSONDecodeError, KeyError, TypeError) as ex: + raise RuntimeError(f"Could not parse event ID from eventSent input '{event.eventSent.input.value}'") from ex + return entity_id, event_id + + def _handle_entity_event_raised(self, + ctx: _RuntimeOrchestrationContext, + event: pb.HistoryEvent, + entity_id: Optional[EntityInstanceId], + task_id: Optional[int], + is_lock_event: bool): + # This eventRaised represents the result of an entity operation after being translated to the old + # entity protocol by the Durable WebJobs extension + if entity_id is None: + raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") + if task_id is None: + raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") + result = None + if not ph.is_empty(event.eventRaised.input): + # TODO: Investigate why the event result is wrapped in a dict with "result" key + result = shared.from_json(event.eventRaised.input.value)["result"] + if is_lock_event: + ctx._entity_context.complete_acquire(event.eventRaised.name) + entity_task.complete(EntityLock(ctx)) + else: + ctx._entity_context.recover_lock_after_call(entity_id) + entity_task.complete(result) + ctx.resume() + def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]: if versioning is None: return None diff --git a/pyproject.toml b/pyproject.toml index 111693c7..d9700894 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask" -version = "1.1.0" +version = "1.2.0" description = "A Durable Task Client SDK for Python" keywords = [ "durable", diff --git a/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py b/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py index 4a963fc7..7a7232e9 100644 --- a/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py +++ b/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py @@ -5,6 +5,7 @@ import os import threading from datetime import timedelta +import uuid import pytest @@ -532,3 +533,39 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status == "\"foobaz\"" + + +def test_new_uuid(): + def noop(_: task.ActivityContext, _1): + pass + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + # Assert that two new_uuid calls return different values + results = [ctx.new_uuid(), ctx.new_uuid()] + yield ctx.call_activity("noop") + # Assert that new_uuid still returns a unique value after replay + results.append(ctx.new_uuid()) + return results + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_activity(noop) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + results = json.loads(state.serialized_output or "\"\"") + assert isinstance(results, list) and len(results) == 3 + assert uuid.UUID(results[0]) != uuid.UUID(results[1]) + assert uuid.UUID(results[0]) != uuid.UUID(results[2]) + assert uuid.UUID(results[1]) != uuid.UUID(results[2]) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 3db608dc..997bc504 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -5,6 +5,7 @@ import threading import time from datetime import timedelta +import uuid import pytest @@ -499,3 +500,37 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status == "\"foobaz\"" + + +def test_new_uuid(): + def noop(_: task.ActivityContext, _1): + pass + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + # Assert that two new_uuid calls return different values + results = [ctx.new_uuid(), ctx.new_uuid()] + yield ctx.call_activity("noop") + # Assert that new_uuid still returns a unique value after replay + results.append(ctx.new_uuid()) + return results + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(empty_orchestrator) + w.add_activity(noop) + w.start() + + c = client.TaskHubGrpcClient() + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + results = json.loads(state.serialized_output or "\"\"") + assert isinstance(results, list) and len(results) == 3 + assert uuid.UUID(results[0]) != uuid.UUID(results[1]) + assert uuid.UUID(results[0]) != uuid.UUID(results[2]) + assert uuid.UUID(results[1]) != uuid.UUID(results[2]) diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 5646f07b..8c728124 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -9,7 +9,7 @@ import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb -from durabletask import task, worker +from durabletask import task, worker, entities logging.basicConfig( format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', @@ -1183,6 +1183,77 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert str(ex) in complete_action.failureDetails.errorMessage +def test_orchestrator_completed_no_effect(): + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): + yield ctx.call_activity(dummy_activity, input=orchestrator_input) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + encoded_input = json.dumps(42) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input), + helpers.new_orchestrator_completed_event()] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result.actions + + assert len(actions) == 1 + assert type(actions[0]) is pb.OrchestratorAction + assert actions[0].id == 1 + assert actions[0].HasField("scheduleTask") + assert actions[0].scheduleTask.name == task.get_name(dummy_activity) + assert actions[0].scheduleTask.input.value == encoded_input + + +def test_entity_lock_created_as_event(): + test_entity_id = entities.EntityInstanceId("Counter", "myCounter") + + def orchestrator(ctx: task.OrchestrationContext, _): + entity_id = test_entity_id + with (yield ctx.lock_entities([entity_id])): + return (yield ctx.call_entity(entity_id, "set", 1)) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, None), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result1 = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result1.actions + assert len(actions) == 1 + assert type(actions[0]) is pb.OrchestratorAction + assert actions[0].id == 1 + assert actions[0].HasField("sendEntityMessage") + assert actions[0].sendEntityMessage.HasField("entityLockRequested") + + old_events = new_events + event_sent_input = { + "id": actions[0].sendEntityMessage.entityLockRequested.criticalSectionId, + } + new_events = [ + helpers.new_event_sent_event(1, str(test_entity_id), json.dumps(event_sent_input)), + helpers.new_event_raised_event(event_sent_input["id"], None), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + + assert len(actions) == 1 + assert type(actions[0]) is pb.OrchestratorAction + assert actions[0].id == 2 + assert actions[0].HasField("sendEntityMessage") + assert actions[0].sendEntityMessage.HasField("entityOperationCalled") + assert actions[0].sendEntityMessage.entityOperationCalled.targetInstanceId.value == str(test_entity_id) + + def get_and_validate_complete_orchestration_action_list(expected_action_count: int, actions: list[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: assert len(actions) == expected_action_count assert type(actions[-1]) is pb.OrchestratorAction diff --git a/tests/durabletask/test_proto_task_hub_shim.py b/tests/durabletask/test_proto_task_hub_shim.py new file mode 100644 index 00000000..8bd3a659 --- /dev/null +++ b/tests/durabletask/test_proto_task_hub_shim.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from typing import get_type_hints + +from durabletask.internal.orchestrator_service_pb2_grpc import TaskHubSidecarServiceStub +from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub + + +def test_proto_task_hub_shim_is_compatible(): + """Test that ProtoTaskHubSidecarServiceStub is compatible with TaskHubSidecarServiceStub.""" + protocol_attrs = set(get_type_hints(ProtoTaskHubSidecarServiceStub).keys()) + + # Instantiate TaskHubSidecarServiceStub with a dummy channel to get its attributes + class TestChannel(): + def unary_unary(self, *args, **kwargs): + pass + + def unary_stream(self, *args, **kwargs): + pass + impl_attrs = TaskHubSidecarServiceStub(TestChannel()).__dict__.keys() + + # Check missing + assert protocol_attrs == impl_attrs