diff --git a/docs/features.md b/docs/features.md index d3fcc56..cd28b2c 100644 --- a/docs/features.md +++ b/docs/features.md @@ -48,6 +48,88 @@ Orchestrations can schedule durable timers using the `create_timer` API. These t Orchestrations can start child orchestrations using the `call_sub_orchestrator` API. Child orchestrations are useful for encapsulating complex logic and for breaking up large orchestrations into smaller, more manageable pieces. Sub-orchestrations can also be versioned in a similar manner to their parent orchestrations, however, they do not inherit the parent orchestrator's version. Instead, they will use the default_version defined in the current worker's VersioningOptions unless otherwise specified during `call_sub_orchestrator`. +### Entities + +#### Concepts + +Durable Entities provide a way to model small, stateful objects within your orchestration workflows. Each entity has a unique identity and maintains its own state, which is persisted durably. Entities can be interacted with by sending them operations (messages) that mutate or query their state. These operations are processed sequentially, ensuring consistency. Examples of uses for durable entities include counters, accumulators, or any other operation which requires state to persist across orchestrations. + +Entities can be invoked from durable clients directly, or from durable orchestrators. They support features like automatic state persistence, concurrency control, and can be locked for exclusive access during critical operations. + +Entities are accessed by a unique ID, implemented here as EntityInstanceId. This ID is comprised of two parts, an entity name referring to the function or class that defines the behavior of the entity, and a key which is any string defined in your code. Each entity instance, represented by a distinct EntityInstanceId, has its own state. + +#### Syntax + +##### Defining Entities + +Entities can be defined using either function-based or class-based syntax. + +```python +# Funtion-based entity +def counter(ctx: entities.EntityContext, input: int): + state = ctx.get_state(int, 0) + if ctx.operation == "add": + state += input + ctx.set_state(state) + elif operation == "get": + return state + +# Class-based entity +class Counter(entities.DurableEntity): + def __init__(self): + self.set_state(0) + + def add(self, amount: int): + self.set_state(self.get_state(int, 0) + amount) + + def get(self): + return self.get_state(int, 0) +``` + +> Note that the object properties of class-based entities may not be preserved across invocations. Use the derived get_state and set_state methods to access the persisted entity data. + +##### Invoking entities + +Entities are invoked using the `signal_entity` or `call_entity` APIs. The Durable Client only allows `signal_entity`: + +```python +c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) +entity_id = entities.EntityInstanceId("my_entity_function", "myEntityId") +c.signal_entity(entity_id, "do_nothing") +``` + +Whereas orchestrators can choose to use `signal_entity` or `call_entity`: + +```python +# Signal an entity (fire-and-forget) +entity_id = entities.EntityInstanceId("my_entity_function", "myEntityId") +ctx.signal_entity(entity_id, operation_name="add", input=5) + +# Call an entity (wait for result) +entity_id = entities.EntityInstanceId("my_entity_function", "myEntityId") +result = yield ctx.call_entity(entity_id, operation_name="get") +``` + +##### Entity actions + +Entities can perform actions such signaling other entities or starting new orchestrations + +- `ctx.signal_entity(entity_id, operation, input)` +- `ctx.schedule_new_orchestration(orchestrator_name, input)` + +##### Locking and concurrency + +Because entites can be accessed from multiple running orchestrations at the same time, entities may also be locked by a single orchestrator ensuring exclusive access during the duration of the lock (also known as a critical section). Think semaphores: + +```python +with (yield ctx.lock_entities([entity_id_1, entity_id_2]): + # Perform entity call operations that require exclusive access + ... +``` + +Note that locked entities may not be signalled, and every call to a locked entity must return a result before another call to the same entity may be made from within the critical section. For more details and advanced usage, see the examples and API documentation. + ### External events Orchestrations can wait for external events using the `wait_for_external_event` API. External events are useful for implementing human interaction patterns, such as waiting for a user to approve an order before continuing. diff --git a/durabletask/client.py b/durabletask/client.py index bc3abed..c150822 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -4,13 +4,14 @@ import logging import uuid from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Any, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 +from durabletask.entities import EntityInstanceId import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs @@ -227,3 +228,16 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True): req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") self._stub.PurgeInstances(req) + + def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None): + req = pb.SignalEntityRequest( + instanceId=str(entity_instance_id), + name=operation_name, + input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None, + requestId=str(uuid.uuid4()), + scheduledTime=None, + parentTraceContext=None, + requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) + ) + self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") + self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? diff --git a/durabletask/entities/__init__.py b/durabletask/entities/__init__.py new file mode 100644 index 0000000..4ab03c0 --- /dev/null +++ b/durabletask/entities/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Durable Task SDK for Python entities component""" + +from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.entities.durable_entity import DurableEntity +from durabletask.entities.entity_lock import EntityLock +from durabletask.entities.entity_context import EntityContext + +__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext"] + +PACKAGE_NAME = "durabletask.entities" diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py new file mode 100644 index 0000000..31e3488 --- /dev/null +++ b/durabletask/entities/durable_entity.py @@ -0,0 +1,35 @@ +from typing import Any, Optional, Type, TypeVar, overload + +from durabletask.entities.entity_context import EntityContext +from durabletask.entities.entity_instance_id import EntityInstanceId + +TState = TypeVar("TState") + + +class DurableEntity: + def _initialize_entity_context(self, context: EntityContext): + self.entity_context = context + + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... + + @overload + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... + + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + return self.entity_context.get_state(intended_type, default) + + def set_state(self, state: Any): + self.entity_context.set_state(state) + + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + self.entity_context.signal_entity(entity_instance_id, operation, input) + + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: + return self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id) diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py new file mode 100644 index 0000000..b861030 --- /dev/null +++ b/durabletask/entities/entity_context.py @@ -0,0 +1,106 @@ + +from typing import Any, Optional, Type, TypeVar, overload +import uuid +from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.internal import helpers, shared +from durabletask.internal.entity_state_shim import StateShim +import durabletask.internal.orchestrator_service_pb2 as pb + +TState = TypeVar("TState") + + +class EntityContext: + def __init__(self, orchestration_id: str, operation: str, state: StateShim, entity_id: EntityInstanceId): + self._orchestration_id = orchestration_id + self._operation = operation + self._state = state + self._entity_id = entity_id + + @property + def orchestration_id(self) -> str: + """Get the ID of the orchestration instance that scheduled this entity. + + Returns + ------- + str + The ID of the current orchestration instance. + """ + return self._orchestration_id + + @property + def operation(self) -> str: + """Get the operation associated with this entity invocation. + + The operation is a string that identifies the specific action being + performed on the entity. It can be used to distinguish between + multiple operations that are part of the same entity invocation. + + Returns + ------- + str + The operation associated with this entity invocation. + """ + return self._operation + + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... + + @overload + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... + + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + return self._state.get_state(intended_type, default) + + def set_state(self, new_state: Any): + self._state.set_state(new_state) + + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + encoded_input = shared.to_json(input) if input is not None else None + self._state.add_operation_action( + pb.OperationAction( + sendSignal=pb.SendSignalAction( + instanceId=str(entity_instance_id), + name=operation, + input=helpers.get_string_value(encoded_input), + scheduledTime=None, + requestTime=None, + parentTraceContext=None, + ) + ) + ) + + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: + encoded_input = shared.to_json(input) if input is not None else None + if not instance_id: + instance_id = uuid.uuid4().hex + self._state.add_operation_action( + pb.OperationAction( + startNewOrchestration=pb.StartNewOrchestrationAction( + instanceId=instance_id, + name=orchestration_name, + input=helpers.get_string_value(encoded_input), + version=None, + scheduledTime=None, + requestTime=None, + parentTraceContext=None + ) + ) + ) + return instance_id + + @property + def entity_id(self) -> EntityInstanceId: + """Get the ID of the entity instance. + + Returns + ------- + str + The ID of the current entity instance. + """ + return self._entity_id diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py new file mode 100644 index 0000000..1fee44f --- /dev/null +++ b/durabletask/entities/entity_instance_id.py @@ -0,0 +1,28 @@ +from typing import Optional + + +class EntityInstanceId: + def __init__(self, entity: str, key: str): + self.entity = entity + self.key = key + + def __str__(self) -> str: + return f"@{self.entity}@{self.key}" + + def __eq__(self, other): + if not isinstance(other, EntityInstanceId): + return False + return self.entity == other.entity and self.key == other.key + + def __lt__(self, other): + if not isinstance(other, EntityInstanceId): + return self < other + return str(self) < str(other) + + @staticmethod + def parse(entity_id: str) -> Optional["EntityInstanceId"]: + try: + _, entity, key = entity_id.split("@", 2) + return EntityInstanceId(entity=entity, key=key) + except ValueError as ex: + raise ValueError("Invalid entity ID format", ex) diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py new file mode 100644 index 0000000..5cbf7ea --- /dev/null +++ b/durabletask/entities/entity_lock.py @@ -0,0 +1,16 @@ +import durabletask.internal.orchestrator_service_pb2 as pb + + +class EntityLock: + def __init__(self, context): + self._context = context + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? + print(f"Unlocking entities: {self._context._entity_context.critical_section_locks}") + for entity_unlock_message in self._context._entity_context.emit_lock_release_messages(): + task_id = self._context.next_sequence_number() + action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) + self._context._pending_actions[task_id] = action diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py new file mode 100644 index 0000000..f27edc5 --- /dev/null +++ b/durabletask/internal/entity_state_shim.py @@ -0,0 +1,66 @@ +from typing import Any, TypeVar +from typing import Optional, Type, overload + +import durabletask.internal.orchestrator_service_pb2 as pb + +TState = TypeVar("TState") + + +class StateShim: + def __init__(self, start_state): + self._current_state: Any = start_state + self._checkpoint_state: Any = start_state + self._operation_actions: list[pb.OperationAction] = [] + self._actions_checkpoint_state: int = 0 + + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... + + @overload + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... + + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + if self._current_state is None and default is not None: + return default + + if intended_type is None: + return self._current_state + + if isinstance(self._current_state, intended_type): + return self._current_state + + try: + return intended_type(self._current_state) # type: ignore[call-arg] + except Exception as ex: + raise TypeError( + f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" + ) from ex + + def set_state(self, state): + self._current_state = state + + def add_operation_action(self, action: pb.OperationAction): + self._operation_actions.append(action) + + def get_operation_actions(self) -> list[pb.OperationAction]: + return self._operation_actions[:self._actions_checkpoint_state] + + def commit(self): + self._checkpoint_state = self._current_state + self._actions_checkpoint_state = len(self._operation_actions) + + def rollback(self): + self._current_state = self._checkpoint_state + self._operation_actions = self._operation_actions[:self._actions_checkpoint_state] + + def reset(self): + self._current_state = None + self._checkpoint_state = None + self._operation_actions = [] + self._actions_checkpoint_state = 0 diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 6140dec..ccd8558 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -7,6 +7,7 @@ from google.protobuf import timestamp_pb2, wrappers_pb2 +from durabletask.entities import EntityInstanceId import durabletask.internal.orchestrator_service_pb2 as pb # TODO: The new_xxx_event methods are only used by test code and should be moved elsewhere @@ -159,6 +160,12 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]: return wrappers_pb2.StringValue(value=val) +def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue: + if val is None: + return wrappers_pb2.StringValue(value="") + return wrappers_pb2.StringValue(value=val) + + def new_complete_orchestration_action( id: int, status: pb.OrchestrationStatus, @@ -189,6 +196,57 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], )) +def new_call_entity_action(id: int, parent_instance_id: str, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): + return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent( + requestId=f"{parent_instance_id}:{id}", + operation=operation, + scheduledTime=None, + input=get_string_value(encoded_input), + parentInstanceId=get_string_value(parent_instance_id), + parentExecutionId=None, + targetInstanceId=get_string_value(str(entity_id)), + ))) + + +def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): + return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( + requestId=f"{entity_id}:{id}", + operation=operation, + scheduledTime=None, + input=get_string_value(encoded_input), + targetInstanceId=get_string_value(str(entity_id)), + ))) + + +def new_lock_entities_action(id: int, entity_message: pb.SendEntityMessageAction): + return pb.OrchestratorAction(id=id, sendEntityMessage=entity_message) + + +def convert_to_entity_batch_request(req: pb.EntityRequest) -> tuple[pb.EntityBatchRequest, list[pb.OperationInfo]]: + batch_request = pb.EntityBatchRequest(entityState=req.entityState, instanceId=req.instanceId, operations=[]) + + operation_infos: list[pb.OperationInfo] = [] + + for op in req.operationRequests: + if op.HasField("entityOperationSignaled"): + batch_request.operations.append(pb.OperationRequest(requestId=op.entityOperationSignaled.requestId, + operation=op.entityOperationSignaled.operation, + input=op.entityOperationSignaled.input)) + operation_infos.append(pb.OperationInfo(requestId=op.entityOperationSignaled.requestId, + responseDestination=None)) + elif op.HasField("entityOperationCalled"): + batch_request.operations.append(pb.OperationRequest(requestId=op.entityOperationCalled.requestId, + operation=op.entityOperationCalled.operation, + input=op.entityOperationCalled.input)) + operation_infos.append(pb.OperationInfo(requestId=op.entityOperationCalled.requestId, + responseDestination=pb.OrchestrationInstance( + instanceId=op.entityOperationCalled.parentInstanceId.value, + executionId=op.entityOperationCalled.parentExecutionId + ))) + + return batch_request, operation_infos + + def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: ts = timestamp_pb2.Timestamp() ts.FromDatetime(dt) diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py new file mode 100644 index 0000000..1cb4619 --- /dev/null +++ b/durabletask/internal/orchestration_entity_context.py @@ -0,0 +1,117 @@ +from datetime import datetime +from typing import Generator, List, Optional, Tuple, Union + +from durabletask.internal.helpers import get_string_value +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.entities import EntityInstanceId + + +class OrchestrationEntityContext: + def __init__(self, instance_id: str): + self.instance_id = instance_id + + self.lock_acquisition_pending = False + + self.critical_section_id = None + self.critical_section_locks: list[EntityInstanceId] = [] + self.available_locks: list[EntityInstanceId] = [] + + @property + def is_inside_critical_section(self) -> bool: + return self.critical_section_id is not None + + def get_available_entities(self) -> Generator[EntityInstanceId, None, None]: + if self.is_inside_critical_section: + for available_lock in self.available_locks: + yield available_lock + + def validate_suborchestration_transition(self) -> Tuple[bool, str]: + if self.is_inside_critical_section: + return False, "While holding locks, cannot call suborchestrators." + return True, "" + + def validate_operation_transition(self, target_instance_id: EntityInstanceId, one_way: bool) -> Tuple[bool, str]: + if self.is_inside_critical_section: + lock_to_use = target_instance_id + if one_way: + if target_instance_id in self.critical_section_locks: + return False, "Must not signal a locked entity from a critical section." + else: + try: + self.available_locks.remove(lock_to_use) + except ValueError: + if self.lock_acquisition_pending: + return False, "Must await the completion of the lock request prior to calling any entity." + if lock_to_use in self.critical_section_locks: + return False, "Must not call an entity from a critical section while a prior call to the same entity is still pending." + else: + return False, "Must not call an entity from a critical section if it is not one of the locked entities." + return True, "" + + def validate_acquire_transition(self) -> Tuple[bool, str]: + if self.is_inside_critical_section: + return False, "Must not enter another critical section from within a critical section." + return True, "" + + def recover_lock_after_call(self, target_instance_id: EntityInstanceId): + if self.is_inside_critical_section: + self.available_locks.append(target_instance_id) + + def emit_lock_release_messages(self): + if self.is_inside_critical_section: + for entity_id in self.critical_section_locks: + unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent( + criticalSectionId=self.critical_section_id, + targetInstanceId=get_string_value(str(entity_id)), + parentInstanceId=get_string_value(self.instance_id) + )) + yield unlock_event + + # TODO: Emit the actual release messages (?) + self.critical_section_locks = [] + self.available_locks = [] + self.critical_section_id = None + + def emit_request_message(self, target, operation_name: str, one_way: bool, operation_id: str, + scheduled_time_utc: datetime, input: Optional[str], + request_time: Optional[datetime] = None, create_trace: bool = False): + raise NotImplementedError() + + def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None], Tuple[pb.SendEntityMessageAction, pb.OrchestrationInstance]]: + if not entities: + return None, None + + # Acquire the locks in a globally fixed order to avoid deadlocks + # Also remove duplicates - this can be optimized for perf if necessary + entity_ids = sorted(entities) + entity_ids_dedup = [] + for i, entity_id in enumerate(entity_ids): + if entity_id != entity_ids[i - 1] if i > 0 else True: + entity_ids_dedup.append(entity_id) + + target = pb.OrchestrationInstance(instanceId=str(entity_ids_dedup[0])) + request = pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent( + criticalSectionId=critical_section_id, + parentInstanceId=get_string_value(self.instance_id), + lockSet=[str(eid) for eid in entity_ids_dedup], + position=0, + )) + + self.critical_section_id = critical_section_id + self.critical_section_locks = entity_ids_dedup + self.lock_acquisition_pending = True + + return request, target + + def complete_acquire(self, critical_section_id): + # TODO: HashSet or equivalent + if self.critical_section_id != critical_section_id: + raise RuntimeError(f"Unexpected lock acquire for critical section ID '{critical_section_id}' (expected '{self.critical_section_id}')") + self.available_locks = self.critical_section_locks + self.lock_acquisition_pending = False + + def adjust_outgoing_message(self, instance_id: str, request_message, capped_time: datetime) -> str: + raise NotImplementedError() + + def deserialize_entity_response_event(self, event_content: str): + raise NotImplementedError() diff --git a/durabletask/task.py b/durabletask/task.py index 14f5fac..645354e 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union +from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -137,6 +138,64 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, """ pass + @abstractmethod + def call_entity(self, entity: EntityInstanceId, + operation: str, + input: Optional[TInput] = None) -> Task: + """Schedule entity function for execution. + + Parameters + ---------- + entity: EntityInstanceId + The ID of the entity instance to call. + operation: str + The name of the operation to invoke on the entity. + input: Optional[TInput] + The optional JSON-serializable input to pass to the entity function. + + Returns + ------- + Task + A Durable Task that completes when the called entity function completes or fails. + """ + pass + + @abstractmethod + def signal_entity( + self, + entity_id: EntityInstanceId, + operation_name: str, + input: Optional[TInput] = None + ) -> None: + """Signal an entity function for execution. + + Parameters + ---------- + entity_id: EntityInstanceId + The ID of the entity instance to signal. + operation_name: str + The name of the operation to invoke on the entity. + input: Optional[TInput] + The optional JSON-serializable input to pass to the entity function. + """ + pass + + @abstractmethod + def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock: + """Lock the specified entity instances for the duration of the orchestration. + + Parameters + ---------- + entities: list[EntityInstanceId] + The list of entity instance IDs to lock. + + Returns + ------- + EntityLock + A disposable object that acquires and releases the locks when initialized or disposed. + """ + pass + @abstractmethod def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, input: Optional[TInput] = None, @@ -458,6 +517,8 @@ def task_id(self) -> int: # Activities are simple functions that can be scheduled by orchestrators Activity = Callable[[ActivityContext, TInput], TOutput] +Entity = Union[Callable[[EntityContext, TInput], TOutput], type[DurableEntity]] + class RetryPolicy: """Represents the retry policy for an orchestration or activity function.""" diff --git a/durabletask/worker.py b/durabletask/worker.py index ba5f0ba..7d4c8d6 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -3,11 +3,12 @@ import asyncio import inspect +import json import logging import os import random from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from threading import Event, Thread from types import GeneratorType from enum import Enum @@ -17,6 +18,11 @@ import grpc from google.protobuf import empty_pb2 +from durabletask.internal import helpers +from durabletask.internal.entity_state_shim import StateShim +from durabletask.internal.helpers import new_timestamp +from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext +from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe import durabletask.internal.orchestrator_service_pb2 as pb @@ -40,6 +46,7 @@ def __init__( self, maximum_concurrent_activity_work_items: Optional[int] = None, maximum_concurrent_orchestration_work_items: Optional[int] = None, + maximum_concurrent_entity_work_items: Optional[int] = None, maximum_thread_pool_workers: Optional[int] = None, ): """Initialize concurrency options. @@ -68,6 +75,12 @@ def __init__( else default_concurrency ) + self.maximum_concurrent_entity_work_items = ( + maximum_concurrent_entity_work_items + if maximum_concurrent_entity_work_items is not None + else default_concurrency + ) + self.maximum_thread_pool_workers = ( maximum_thread_pool_workers if maximum_thread_pool_workers is not None @@ -124,11 +137,15 @@ def __init__(self, version: Optional[str] = None, class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] + entities: dict[str, task.Entity] + entity_instances: dict[str, DurableEntity] versioning: Optional[VersioningOptions] = None def __init__(self): self.orchestrators = {} self.activities = {} + self.entities = {} + self.entity_instances = {} def add_orchestrator(self, fn: task.Orchestrator) -> str: if fn is None: @@ -168,6 +185,29 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None: def get_activity(self, name: str) -> Optional[task.Activity]: return self.activities.get(name) + def add_entity(self, fn: task.Entity) -> str: + if fn is None: + raise ValueError("An entity function argument is required.") + + if isinstance(fn, type) and issubclass(fn, DurableEntity): + name = fn.__name__ + self.add_named_entity(name, fn) + else: + name = task.get_name(fn) + self.add_named_entity(name, fn) + return name + + def add_named_entity(self, name: str, fn: task.Entity) -> None: + if not name: + raise ValueError("A non-empty entity name is required.") + if name in self.entities: + raise ValueError(f"A '{name}' entity already exists.") + + self.entities[name] = fn + + def get_entity(self, name: str) -> Optional[task.Entity]: + return self.entities.get(name) + class OrchestratorNotRegisteredError(ValueError): """Raised when attempting to start an orchestration that is not registered""" @@ -181,6 +221,12 @@ class ActivityNotRegisteredError(ValueError): pass +class EntityNotRegisteredError(ValueError): + """Raised when attempting to call an entity that is not registered""" + + pass + + class TaskHubGrpcWorker: """A gRPC-based worker for processing durable task orchestrations and activities. @@ -329,6 +375,14 @@ def add_activity(self, fn: task.Activity) -> str: ) return self._registry.add_activity(fn) + def add_entity(self, fn: task.Entity) -> str: + """Registers an entity function with the worker.""" + if self._is_running: + raise RuntimeError( + "Entities cannot be added while the worker is running." + ) + return self._registry.add_entity(fn) + def use_versioning(self, version: VersioningOptions) -> None: """Initializes versioning options for sub-orchestrators and activities.""" if self._is_running: @@ -490,6 +544,20 @@ def stream_reader(): stub, work_item.completionToken, ) + elif work_item.HasField("entityRequest"): + self._async_worker_manager.submit_entity_batch( + self._execute_entity_batch, + work_item.entityRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("entityRequestV2"): + self._async_worker_manager.submit_entity_batch( + self._execute_entity_batch, + work_item.entityRequestV2, + stub, + work_item.completionToken + ) elif work_item.HasField("healthPing"): pass else: @@ -635,22 +703,95 @@ def _execute_activity( f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" ) + def _execute_entity_batch( + self, + req: Union[pb.EntityBatchRequest, pb.EntityRequest], + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + if isinstance(req, pb.EntityRequest): + req, operation_infos = helpers.convert_to_entity_batch_request(req) + + entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None) + + instance_id = req.instanceId + + results: list[pb.OperationResult] = [] + for operation in req.operations: + start_time = datetime.now(timezone.utc) + executor = _EntityExecutor(self._registry, self._logger) + entity_instance_id = EntityInstanceId.parse(instance_id) + if not entity_instance_id: + raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.") + + operation_result = None + + try: + entity_result = executor.execute( + instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value + ) + + entity_result = ph.get_string_value_or_empty(entity_result) + operation_result = pb.OperationResult(success=pb.OperationResultSuccess( + result=entity_result, + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.commit() + except Exception as ex: + self._logger.exception(ex) + operation_result = pb.OperationResult(failure=pb.OperationResultFailure( + failureDetails=ph.new_failure_details(ex), + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.rollback() + + batch_result = pb.EntityBatchResult( + results=results, + actions=entity_state.get_operation_actions(), + entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, + failureDetails=None, + completionToken=completionToken, + operationInfos=operation_infos, + ) + + try: + stub.CompleteEntityTask(batch_result) + except Exception as ex: + self._logger.exception( + f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) + + # TODO: Reset context + + return batch_result + class _RuntimeOrchestrationContext(task.OrchestrationContext): _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] - def __init__(self, instance_id: str, registry: _Registry): + def __init__(self, instance_id: str, registry: _Registry, entity_context: OrchestrationEntityContext): self._generator = None self._is_replaying = True self._is_complete = False self._result = None self._pending_actions: dict[int, pb.OrchestratorAction] = {} self._pending_tasks: dict[int, task.CompletableTask] = {} + # Maps entity ID to task ID + self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} + # Maps criticalSectionId to task ID + self._entity_lock_id_map: dict[str, int] = {} self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id self._registry = registry + self._entity_context = entity_context self._version: Optional[str] = None self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: dict[str, list[Any]] = {} @@ -833,6 +974,40 @@ def call_activity( ) return self._pending_tasks.get(id, task.CompletableTask()) + def call_entity( + self, + entity_id: EntityInstanceId, + operation: str, + input: Optional[TInput] = None, + ) -> task.Task: + id = self.next_sequence_number() + + self.call_entity_function_helper( + id, entity_id, operation, input=input + ) + + return self._pending_tasks.get(id, task.CompletableTask()) + + def signal_entity( + self, + entity_id: EntityInstanceId, + operation: str, + input: Optional[TInput] = None + ) -> None: + id = self.next_sequence_number() + + self.signal_entity_function_helper( + id, entity_id, operation, input + ) + + def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]: + id = self.next_sequence_number() + + self.lock_entities_function_helper( + id, entities + ) + return self._pending_tasks.get(id, task.CompletableTask()) + def call_sub_orchestrator( self, orchestrator: task.Orchestrator[TInput, TOutput], @@ -909,6 +1084,69 @@ def call_activity_function_helper( ) self._pending_tasks[id] = fn_task + def call_entity_function_helper( + self, + id: Optional[int], + entity_id: EntityInstanceId, + operation: str, + *, + input: Optional[TInput] = None, + ): + if id is None: + id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False) + if not transition_valid: + raise RuntimeError(error_message) + + encoded_input = shared.to_json(input) if input is not None else None + action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input) + self._pending_actions[id] = action + + fn_task = task.CompletableTask() + self._pending_tasks[id] = fn_task + + def signal_entity_function_helper( + self, + id: Optional[int], + entity_id: EntityInstanceId, + operation: str, + input: Optional[TInput] + ) -> None: + if id is None: + id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, True) + + if not transition_valid: + raise RuntimeError(error_message) + + encoded_input = shared.to_json(input) if input is not None else None + + action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input) + self._pending_actions[id] = action + + def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None: + if id is None: + id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_acquire_transition() + if not transition_valid: + raise RuntimeError(error_message) + + critical_section_id = f"{self.instance_id}:{id:04x}" + + request, target = self._entity_context.emit_acquire_message(critical_section_id, entities) + + if not request or not target: + raise RuntimeError("Failed to create entity lock request.") + + action = ph.new_lock_entities_action(id, request) + self._pending_actions[id] = action + + fn_task = task.CompletableTask[EntityLock]() + self._pending_tasks[id] = fn_task + def wait_for_external_event(self, name: str) -> task.Task: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an @@ -957,6 +1195,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._logger = logger self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] + self._entity_state: Optional[OrchestrationEntityContext] = None def execute( self, @@ -964,12 +1203,14 @@ def execute( old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent], ) -> ExecutionResults: + self._entity_state = OrchestrationEntityContext(instance_id) + if not new_events: raise task.OrchestrationStateError( "The new history event list must have at least one event in it." ) - ctx = _RuntimeOrchestrationContext(instance_id, self._registry) + ctx = _RuntimeOrchestrationContext(instance_id, self._registry, self._entity_state) try: # Rebuild local state by replaying old history into the orchestrator function self._logger.debug( @@ -1316,6 +1557,108 @@ def process_event( pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True, ) + elif event.HasField("entityOperationCalled"): + # This history event confirms that the entity operation was successfully scheduled. + # Remove the entityOperationCalled event from the pending action list so we don't schedule it again + entity_call_id = event.eventId + action = ctx._pending_actions.pop(entity_call_id, None) + entity_task = ctx._pending_tasks.get(entity_call_id, None) + if not action: + raise _get_non_determinism_error( + entity_call_id, task.get_name(ctx.call_entity) + ) + elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationCalled"): + expected_method_name = task.get_name(ctx.call_entity) + raise _get_wrong_action_type_error( + entity_call_id, expected_method_name, action + ) + entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) + if not entity_id: + raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'") + ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id) + elif event.HasField("entityOperationSignaled"): + # This history event confirms that the entity signal was successfully scheduled. + # Remove the entityOperationSignaled event from the pending action list so we don't schedule it + entity_signal_id = event.eventId + action = ctx._pending_actions.pop(entity_signal_id, None) + if not action: + raise _get_non_determinism_error( + entity_signal_id, task.get_name(ctx.signal_entity) + ) + elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationSignaled"): + expected_method_name = task.get_name(ctx.signal_entity) + raise _get_wrong_action_type_error( + entity_signal_id, expected_method_name, action + ) + elif event.HasField("entityLockRequested"): + section_id = event.entityLockRequested.criticalSectionId + task_id = event.eventId + action = ctx._pending_actions.pop(task_id, None) + entity_task = ctx._pending_tasks.get(task_id, None) + if not action: + raise _get_non_determinism_error( + task_id, task.get_name(ctx.lock_entities) + ) + elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityLockRequested"): + expected_method_name = task.get_name(ctx.lock_entities) + raise _get_wrong_action_type_error( + task_id, expected_method_name, action + ) + ctx._entity_lock_id_map[section_id] = task_id + elif event.HasField("entityUnlockSent"): + # Remove the unlock tasks as they have already been processed + tasks_to_remove = [] + for task_id, action in ctx._pending_actions.items(): + if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"): + if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId: + tasks_to_remove.append(task_id) + for task_to_remove in tasks_to_remove: + ctx._pending_actions.pop(task_to_remove, None) + elif event.HasField("entityLockGranted"): + section_id = event.entityLockGranted.criticalSectionId + task_id = ctx._entity_lock_id_map.pop(section_id, None) + if not task_id: + # TODO: Should this be an error? When would it ever happen? + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." + ) + return + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." + ) + return + ctx._entity_context.complete_acquire(section_id) + entity_task.complete(EntityLock(ctx)) + ctx.resume() + elif event.HasField("entityOperationCompleted"): + request_id = event.entityOperationCompleted.requestId + entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None)) + if not entity_id: + raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'") + if not task_id: + raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}." + ) + return + result = None + if not ph.is_empty(event.entityOperationCompleted.output): + result = shared.from_json(event.entityOperationCompleted.output.value) + ctx._entity_context.recover_lock_after_call(entity_id) + entity_task.complete(result) + ctx.resume() + elif event.HasField("entityOperationFailed"): + if not ctx.is_replaying: + self._logger.info(f"{ctx.instance_id}: Entity operation failed.") + self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}") + pass else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError( @@ -1406,6 +1749,60 @@ def execute( return encoded_output +class _EntityExecutor: + def __init__(self, registry: _Registry, logger: logging.Logger): + self._registry = registry + self._logger = logger + + def execute( + self, + orchestration_id: str, + entity_id: EntityInstanceId, + operation: str, + state: StateShim, + encoded_input: Optional[str], + ) -> Optional[str]: + """Executes an entity function and returns the serialized result, if any.""" + self._logger.debug( + f"{orchestration_id}: Executing entity '{entity_id}'..." + ) + fn = self._registry.get_entity(entity_id.entity) + if not fn: + raise EntityNotRegisteredError( + f"Entity function named '{entity_id.entity}' was not registered!" + ) + + entity_input = shared.from_json(encoded_input) if encoded_input else None + ctx = EntityContext(orchestration_id, operation, state, entity_id) + + if isinstance(fn, type) and issubclass(fn, DurableEntity): + if self._registry.entity_instances.get(str(entity_id), None): + entity_instance = self._registry.entity_instances[str(entity_id)] + else: + entity_instance = fn() + self._registry.entity_instances[str(entity_id)] = entity_instance + if not hasattr(entity_instance, operation): + raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'") + method = getattr(entity_instance, operation) + if not callable(method): + raise TypeError(f"Entity operation '{operation}' is not callable") + # Execute the entity method + entity_instance._initialize_entity_context(ctx) + entity_output = method(entity_input) + else: + # Execute the entity function + entity_output = fn(ctx, entity_input) + + encoded_output = ( + shared.to_json(entity_output) if entity_output is not None else None + ) + chars = len(encoded_output) if encoded_output else 0 + self._logger.debug( + f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output." + ) + return encoded_output + + def _get_non_determinism_error( task_id: int, action_name: str ) -> task.NonDeterminismError: @@ -1497,13 +1894,16 @@ def __init__(self, concurrency_options: ConcurrencyOptions): self.concurrency_options = concurrency_options self.activity_semaphore = None self.orchestration_semaphore = None + self.entity_semaphore = None # Don't create queues here - defer until we have an event loop self.activity_queue: Optional[asyncio.Queue] = None self.orchestration_queue: Optional[asyncio.Queue] = None + self.entity_batch_queue: Optional[asyncio.Queue] = None self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None # Store work items when no event loop is available self._pending_activity_work: list = [] self._pending_orchestration_work: list = [] + self._pending_entity_batch_work: list = [] self.thread_pool = ThreadPoolExecutor( max_workers=concurrency_options.maximum_thread_pool_workers, thread_name_prefix="DurableTask", @@ -1520,7 +1920,7 @@ def _ensure_queues_for_current_loop(self): # Check if queues are already properly set up for current loop if self._queue_event_loop is current_loop: - if self.activity_queue is not None and self.orchestration_queue is not None: + if self.activity_queue is not None and self.orchestration_queue is not None and self.entity_batch_queue is not None: # Queues are already bound to the current loop and exist return @@ -1528,6 +1928,7 @@ def _ensure_queues_for_current_loop(self): # First, preserve any existing work items existing_activity_items = [] existing_orchestration_items = [] + existing_entity_batch_items = [] if self.activity_queue is not None: try: @@ -1545,9 +1946,19 @@ def _ensure_queues_for_current_loop(self): except Exception: pass + if self.entity_batch_queue is not None: + try: + while not self.entity_batch_queue.empty(): + existing_entity_batch_items.append( + self.entity_batch_queue.get_nowait() + ) + except Exception: + pass + # Create fresh queues for the current event loop self.activity_queue = asyncio.Queue() self.orchestration_queue = asyncio.Queue() + self.entity_batch_queue = asyncio.Queue() self._queue_event_loop = current_loop # Restore the work items to the new queues @@ -1555,16 +1966,21 @@ def _ensure_queues_for_current_loop(self): self.activity_queue.put_nowait(item) for item in existing_orchestration_items: self.orchestration_queue.put_nowait(item) + for item in existing_entity_batch_items: + self.entity_batch_queue.put_nowait(item) # Move pending work items to the queues for item in self._pending_activity_work: self.activity_queue.put_nowait(item) for item in self._pending_orchestration_work: self.orchestration_queue.put_nowait(item) + for item in self._pending_entity_batch_work: + self.entity_batch_queue.put_nowait(item) # Clear the pending work lists self._pending_activity_work.clear() self._pending_orchestration_work.clear() + self._pending_entity_batch_work.clear() async def run(self): # Reset shutdown flag in case this manager is being reused @@ -1580,14 +1996,21 @@ async def run(self): self.orchestration_semaphore = asyncio.Semaphore( self.concurrency_options.maximum_concurrent_orchestration_work_items ) + self.entity_semaphore = asyncio.Semaphore( + self.concurrency_options.maximum_concurrent_entity_work_items + ) # Start background consumers for each work type - if self.activity_queue is not None and self.orchestration_queue is not None: + if self.activity_queue is not None and self.orchestration_queue is not None \ + and self.entity_batch_queue is not None: await asyncio.gather( self._consume_queue(self.activity_queue, self.activity_semaphore), self._consume_queue( self.orchestration_queue, self.orchestration_semaphore ), + self._consume_queue( + self.entity_batch_queue, self.entity_semaphore + ) ) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): @@ -1657,6 +2080,15 @@ def submit_orchestration(self, func, *args, **kwargs): # No event loop running, store in pending list self._pending_orchestration_work.append(work_item) + def submit_entity_batch(self, func, *args, **kwargs): + work_item = (func, args, kwargs) + self._ensure_queues_for_current_loop() + if self.entity_batch_queue is not None: + self.entity_batch_queue.put_nowait(work_item) + else: + # No event loop running, store in pending list + self._pending_entity_batch_work.append(work_item) + def shutdown(self): self._shutdown = True self.thread_pool.shutdown(wait=True) diff --git a/examples/entities/class_based_entity.py b/examples/entities/class_based_entity.py new file mode 100644 index 0000000..f211b65 --- /dev/null +++ b/examples/entities/class_based_entity.py @@ -0,0 +1,65 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +class Counter(entities.DurableEntity): + def set(self, input: int): + self.set_state(input) + + def add(self, input: int): + current_state = self.get_state(int, 0) + new_state = current_state + (input or 1) + self.set_state(new_state) + return new_state + + def get(self): + return self.get_state(int, 0) + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = task.EntityInstanceId("Counter", "myCounter") + + # Initialize the entity with state 0 + ctx.signal_entity(entity_id, "set", 0) + # Increment the counter by 1 + yield ctx.call_entity(entity_id, "add", 1) + # Return the entity's current value (should be 1) + return (yield ctx.call_entity(entity_id, "get")) + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_entity(Counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/class_based_entity_actions.py b/examples/entities/class_based_entity_actions.py new file mode 100644 index 0000000..8a38218 --- /dev/null +++ b/examples/entities/class_based_entity_actions.py @@ -0,0 +1,85 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +class Counter(entities.DurableEntity): + def set(self, input: int): + self.set_state(input) + + def add(self, input: int): + current_state = self.get_state(int, 0) + new_state = current_state + (input or 1) + self.set_state(new_state) + return new_state + + def update_parent(self): + parent_entity_id = entities.EntityInstanceId("Counter", "parentCounter") + if self.entity_context.entity_id == parent_entity_id: + return # Prevent self-update + self.signal_entity(parent_entity_id, "set", self.get_state(int, 0)) + + def start_hello(self): + self.schedule_new_orchestration("hello_orchestrator") + + def get(self): + return self.get_state(int, 0) + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = task.EntityInstanceId("Counter", "myCounter") + parent_entity_id = task.EntityInstanceId("Counter", "parentCounter") + + # Use Counter to demonstrate starting an orchestration from an entity + ctx.signal_entity(entity_id, "start_hello") + + # User Counter to demonstrate signaling an entity from another entity + # Initialize myCounter with state 0, increment it by 1, and set the state of parentCounter using + # update_parent on myCounter. Retrieve and return the state of parentCounter (should be 1). + ctx.signal_entity(entity_id, "set", 0) + yield ctx.call_entity(entity_id, "add", 1) + yield ctx.call_entity(entity_id, "update_parent") + + return (yield ctx.call_entity(parent_entity_id, "get")) + + +def hello_orchestrator(ctx: task.OrchestrationContext, _): + return "Hello world!" + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_orchestrator(hello_orchestrator) + w.add_entity(Counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/entity_locking.py b/examples/entities/entity_locking.py new file mode 100644 index 0000000..cdc25ab --- /dev/null +++ b/examples/entities/entity_locking.py @@ -0,0 +1,67 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +class Counter(entities.DurableEntity): + def set(self, input: int): + self.set_state(input) + + def add(self, input: int): + current_state = self.get_state(int, 0) + new_state = current_state + (input or 1) + self.set_state(new_state) + return new_state + + def get(self): + return self.get_state(int, 0) + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = entities.EntityInstanceId("Counter", "myCounter") + + # Initialize the entity with state 0, increment the counter by 1, and get the entity state using + # entity locking to ensure no other orchestrator can modify the entity state between the calls to call_entity + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "set", 0) + yield ctx.call_entity(entity_id, "add", 1) + result = yield ctx.call_entity(entity_id, "get") + # Return the entity's current value (will be 1) + return result + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_entity(Counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/function_based_entity.py b/examples/entities/function_based_entity.py new file mode 100644 index 0000000..a43b86d --- /dev/null +++ b/examples/entities/function_based_entity.py @@ -0,0 +1,66 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os +from typing import Optional + +from azure.identity import DefaultAzureCredential + +from durabletask import client, entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def counter(ctx: entities.EntityContext, input: int) -> Optional[int]: + if ctx.operation == "set": + ctx.set_state(input) + if ctx.operation == "add": + current_state = ctx.get_state(int, 0) + new_state = current_state + (input or 1) + ctx.set_state(new_state) + return new_state + elif ctx.operation == "get": + return ctx.get_state(int, 0) + else: + raise ValueError(f"Unknown operation '{ctx.operation}'") + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = entities.EntityInstanceId("counter", "myCounter") + + # Initialize the entity with state 0 + ctx.signal_entity(entity_id, "set", 0) + # Increment the counter by 1 + yield ctx.call_entity(entity_id, "add", 1) + # Return the entity's current value (should be 1) + return (yield ctx.call_entity(entity_id, "get")) + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_entity(counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/function_based_entity_actions.py b/examples/entities/function_based_entity_actions.py new file mode 100644 index 0000000..129eb6c --- /dev/null +++ b/examples/entities/function_based_entity_actions.py @@ -0,0 +1,79 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os +from typing import Optional + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def counter(ctx: entities.EntityContext, input: int) -> Optional[int]: + if ctx.operation == "set": + ctx.set_state(input) + elif ctx.operation == "get": + return ctx.get_state(int, 0) + elif ctx.operation == "update_parent": + parent_entity_id = entities.EntityInstanceId("counter", "parentCounter") + if ctx.entity_id == parent_entity_id: + return # Prevent self-update + ctx.signal_entity(parent_entity_id, "set", ctx.get_state(int, 0)) + elif ctx.operation == "start_hello": + ctx.schedule_new_orchestration("hello_orchestrator") + else: + raise ValueError(f"Unknown operation '{ctx.operation}'") + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = task.EntityInstanceId("counter", "myCounter") + parent_entity_id = task.EntityInstanceId("counter", "parentCounter") + + # Use counter to demonstrate starting an orchestration from an entity + ctx.signal_entity(entity_id, "start_hello") + + # User counter to demonstrate signaling an entity from another entity + # Initialize myCounter with state 0, increment it by 1, and set the state of parentCounter using + # update_parent on myCounter. Retrieve and return the state of parentCounter (should be 1). + ctx.signal_entity(entity_id, "set", 0) + yield ctx.call_entity(entity_id, "add", 1) + yield ctx.call_entity(entity_id, "update_parent") + + return (yield ctx.call_entity(parent_entity_id, "get")) + + +def hello_orchestrator(ctx: task.OrchestrationContext, _): + return "Hello world!" + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_orchestrator(hello_orchestrator) + w.add_entity(counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py new file mode 100644 index 0000000..19e8e5b --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py @@ -0,0 +1,110 @@ +import os +import time + +import pytest + +from durabletask import client, entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def test_client_signal_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + + assert invoked + + +def test_orchestration_signal_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + ctx.signal_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed - signals cannot be awaited from inside the orchestrator + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py new file mode 100644 index 0000000..e5018a5 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -0,0 +1,258 @@ +import os +import time + +import pytest + +from durabletask import client, entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def test_client_signal_entity(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoked # don't do this in a real app! + if ctx.operation == "do_nothing": + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + + assert invoked + + +def test_orchestration_signal_entity(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + ctx.signal_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed - signals cannot be awaited from inside the orchestrator + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_entity(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_entity_with_lock(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + # Call this a second time to ensure the entity is still responsive after being locked and unlocked + id_2 = c.schedule_new_orchestration(empty_orchestrator) + state_2 = c.wait_for_orchestration_completion(id_2, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + assert state_2 is not None + assert state_2.name == task.get_name(empty_orchestrator) + assert state_2.instance_id == id_2 + assert state_2.failure_details is None + assert state_2.runtime_status == client.OrchestrationStatus.COMPLETED + assert state_2.serialized_input is None + assert state_2.serialized_output is None + assert state_2.serialized_custom_status is None + + +def test_orchestration_entity_signals_entity(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + elif ctx.operation == "signal_other": + entity_id = entities.EntityInstanceId("empty_entity", "otherEntity") + ctx.signal_entity(entity_id, "do_nothing") + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + yield ctx.call_entity(entity_id, "signal_other") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_entity_starts_orchestration(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "start_orchestration": + ctx.schedule_new_orchestration("empty_orchestrator") + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + c.signal_entity(entities.EntityInstanceId("empty_entity", "testEntity"), "start_orchestration") + time.sleep(2) # wait for the signal and orchestration to be processed + + assert invoked + + +def test_entity_locking_behavior(): + def empty_entity(ctx: entities.EntityContext, _): + pass + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + with (yield ctx.lock_entities([entity_id])): + # Cannot signal entities that have been locked + assert pytest.raises(Exception, ctx.signal_entity, entity_id, "do_nothing") + ctx.call_entity(entity_id, "do_nothing") + # Cannot call entities that have been locked and already called, but not yet returned a result + assert pytest.raises(Exception, ctx.call_entity, entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None