From 5555dfe7679df98245c93c9f76e85326d5df8ad5 Mon Sep 17 00:00:00 2001 From: Tanmay patil Date: Wed, 4 Jun 2025 15:11:14 +0530 Subject: [PATCH 1/3] first draft --- modelq/app/backends/rabbitmq/__init__.py | 0 modelq/app/backends/redis/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 modelq/app/backends/rabbitmq/__init__.py create mode 100644 modelq/app/backends/redis/__init__.py diff --git a/modelq/app/backends/rabbitmq/__init__.py b/modelq/app/backends/rabbitmq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modelq/app/backends/redis/__init__.py b/modelq/app/backends/redis/__init__.py new file mode 100644 index 0000000..e69de29 From a5cf631e9948d20cbcbe7ca9b89934a1bcdd2b71 Mon Sep 17 00:00:00 2001 From: Tanmay patil Date: Sat, 7 Jun 2025 01:32:22 +0530 Subject: [PATCH 2/3] some basic restructure --- modelq/app/backends/base.py | 119 ++++ modelq/app/backends/redis/__init__.py | 3 + modelq/app/backends/redis/backend.py | 138 ++++ modelq/app/base.py | 935 ++++++-------------------- modelq/app/tasks/base.py | 254 ++++--- test.py | 63 ++ 6 files changed, 645 insertions(+), 867 deletions(-) create mode 100644 modelq/app/backends/base.py create mode 100644 modelq/app/backends/redis/backend.py create mode 100644 test.py diff --git a/modelq/app/backends/base.py b/modelq/app/backends/base.py new file mode 100644 index 0000000..597651f --- /dev/null +++ b/modelq/app/backends/base.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any + + +class QueueBackend(ABC): + """Abstract interface for ModelQ queue backends.""" + + # ─── Task Enqueue/Dequeue ────────────────────────────────────────────── + + @abstractmethod + def enqueue_task(self, task_data: dict) -> None: + """Push a new task to the task queue.""" + pass + + @abstractmethod + def dequeue_task(self, timeout: Optional[int] = None) -> Optional[dict]: + """Pop the next task from the queue (blocking or timed).""" + pass + + @abstractmethod + def requeue_task(self, task_data: dict) -> None: + """Re-queue an existing task (e.g., after failure or rejection).""" + pass + + @abstractmethod + def enqueue_delayed_task(self, task_data: dict, delay_seconds: int) -> None: + """Push task to delayed queue (sorted by run timestamp).""" + pass + + @abstractmethod + def dequeue_ready_delayed_tasks(self) -> list: + """Get all delayed tasks ready to run now (score <= time.time()).""" + pass + + @abstractmethod + def flush_queue(self) -> None: + """Empty all tasks from the main task queue (for tests/dev reset).""" + pass + + # ─── Task Status Management ──────────────────────────────────────────── + + @abstractmethod + def save_task_state(self, task_id: str, task_data: dict, result: bool) -> None: + """Save or update the final state of a task (completed/failed/etc).""" + pass + + @abstractmethod + def load_task_state(self, task_id: str) -> Optional[dict]: + """Fetch a task's full state from storage.""" + pass + + @abstractmethod + def remove_task_from_queue(self, task_id: str) -> bool: + """Remove task from queue if still queued.""" + pass + + @abstractmethod + def mark_processing(self, task_id: str) -> bool: + """Add task to 'processing' set; return False if already processing.""" + pass + + @abstractmethod + def unmark_processing(self, task_id: str) -> None: + """Remove task from processing set.""" + pass + + @abstractmethod + def get_all_processing_tasks(self) -> list: + """Return list of currently 'processing' task IDs.""" + pass + + @abstractmethod + def get_all_queued_tasks(self) -> list: + """Return list of all tasks in the main queue.""" + pass + + # ─── Server Registry ─────────────────────────────────────────────────── + + @abstractmethod + def register_server(self, server_id: str, task_names: list) -> None: + """Register a worker with allowed task names and heartbeat.""" + pass + + @abstractmethod + def update_server_status(self, server_id: str, status: str) -> None: + """Update current server status and heartbeat time.""" + pass + + @abstractmethod + def get_all_server_ids(self) -> list: + """Return all currently registered server IDs.""" + pass + + @abstractmethod + def get_server_data(self, server_id: str) -> Optional[dict]: + """Get full data object for a server.""" + pass + + @abstractmethod + def prune_dead_servers(self, timeout: int) -> list: + """Remove any servers whose heartbeat is older than `timeout` seconds.""" + pass + + # ─── Metrics + Maintenance ───────────────────────────────────────────── + + @abstractmethod + def prune_old_results(self, older_than_seconds: int) -> int: + """Delete old task results beyond TTL.""" + pass + + @abstractmethod + def queue_length(self) -> int: + """Return the length of the main task queue.""" + pass + + @abstractmethod + def cleanup_dlq(self) -> None: + """Clear all items from dead letter queue.""" + pass diff --git a/modelq/app/backends/redis/__init__.py b/modelq/app/backends/redis/__init__.py index e69de29..50d7bb7 100644 --- a/modelq/app/backends/redis/__init__.py +++ b/modelq/app/backends/redis/__init__.py @@ -0,0 +1,3 @@ +from modelq.app.backends.redis.backend import RedisQueueBackend + +__all__ = ["RedisQueueBackend"] \ No newline at end of file diff --git a/modelq/app/backends/redis/backend.py b/modelq/app/backends/redis/backend.py new file mode 100644 index 0000000..c691620 --- /dev/null +++ b/modelq/app/backends/redis/backend.py @@ -0,0 +1,138 @@ +import time +import json +import redis +from typing import Optional +from modelq.app.backends.base import QueueBackend + + +class RedisQueueBackend(QueueBackend): + def __init__(self, redis_client: redis.Redis): + self.redis = redis_client + + # ─────────────────────── Task Queue ─────────────────────────────── + + def enqueue_task(self, task_data: dict) -> None: + task_data["status"] = "queued" + self.redis.rpush("ml_tasks", json.dumps(task_data)) + self.redis.zadd("queued_requests", {task_data["task_id"]: task_data["queued_at"]}) + + def dequeue_task(self, timeout: Optional[int] = None) -> Optional[dict]: + data = self.redis.blpop("ml_tasks", timeout=timeout or 5) + if data: + _, task_json = data + return json.loads(task_json) + return None + + def requeue_task(self, task_data: dict) -> None: + self.redis.rpush("ml_tasks", json.dumps(task_data)) + + def enqueue_delayed_task(self, task_data: dict, delay_seconds: int) -> None: + run_at = time.time() + delay_seconds + self.redis.zadd("delayed_tasks", {json.dumps(task_data): run_at}) + + def dequeue_ready_delayed_tasks(self) -> list: + now = time.time() + tasks = self.redis.zrangebyscore("delayed_tasks", 0, now) + for task_json in tasks: + self.redis.zrem("delayed_tasks", task_json) + self.redis.lpush("ml_tasks", task_json) + return [json.loads(t) for t in tasks] + + def flush_queue(self) -> None: + self.redis.ltrim("ml_tasks", 1, 0) + + # ─────────────────────── Task State ─────────────────────────────── + + def save_task_state(self, task_id: str, task_data: dict, result: bool) -> None: + task_data["finished_at"] = time.time() + self.redis.set(f"task_result:{task_id}", json.dumps(task_data), ex=3600) + self.redis.set(f"task:{task_id}", json.dumps(task_data), ex=86400) + + def load_task_state(self, task_id: str) -> Optional[dict]: + data = self.redis.get(f"task:{task_id}") + return json.loads(data) if data else None + + def remove_task_from_queue(self, task_id: str) -> bool: + tasks = self.redis.lrange("ml_tasks", 0, -1) + for task_json in tasks: + task_dict = json.loads(task_json) + if task_dict.get("task_id") == task_id: + self.redis.lrem("ml_tasks", 1, task_json) + self.redis.zrem("queued_requests", task_id) + return True + return False + + def mark_processing(self, task_id: str) -> bool: + return self.redis.sadd("processing_tasks", task_id) == 1 + + def unmark_processing(self, task_id: str) -> None: + self.redis.srem("processing_tasks", task_id) + + def get_all_processing_tasks(self) -> list: + return [pid.decode() for pid in self.redis.smembers("processing_tasks")] + + def get_all_queued_tasks(self) -> list: + raw = self.redis.lrange("ml_tasks", 0, -1) + return [json.loads(task) for task in raw if json.loads(task).get("status") == "queued"] + + # ─────────────────────── Server State ─────────────────────────────── + + def register_server(self, server_id: str, task_names: list) -> None: + self.redis.hset("servers", server_id, json.dumps({ + "allowed_tasks": task_names, + "status": "idle", + "last_heartbeat": time.time() + })) + + def update_server_status(self, server_id: str, status: str) -> None: + raw = self.redis.hget("servers", server_id) + if raw: + data = json.loads(raw) + data["status"] = status + data["last_heartbeat"] = time.time() + self.redis.hset("servers", server_id, json.dumps(data)) + + def get_all_server_ids(self) -> list: + return [k.decode("utf-8") for k in self.redis.hkeys("servers")] + + def get_server_data(self, server_id: str) -> Optional[dict]: + raw = self.redis.hget("servers", server_id) + return json.loads(raw) if raw else None + + def prune_dead_servers(self, timeout: int) -> list: + now = time.time() + pruned = [] + for sid, raw in self.redis.hgetall("servers").items(): + try: + sid_str = sid.decode() + data = json.loads(raw) + if now - data.get("last_heartbeat", 0) > timeout: + self.redis.hdel("servers", sid_str) + pruned.append(sid_str) + except: + continue + return pruned + + # ─────────────────────── Miscellaneous ───────────────────────────── + + def prune_old_results(self, older_than_seconds: int) -> int: + now = time.time() + deleted = 0 + for key in self.redis.scan_iter("task_result:*"): + raw = self.redis.get(key) + if not raw: + continue + data = json.loads(raw) + timestamp = data.get("finished_at") or data.get("started_at") + if timestamp and now - timestamp > older_than_seconds: + task_id = key.decode().split(":")[-1] + self.redis.delete(key) + self.redis.delete(f"task:{task_id}") + deleted += 1 + return deleted + + def queue_length(self) -> int: + return self.redis.llen("ml_tasks") + + def cleanup_dlq(self) -> None: + self.redis.delete("dlq") diff --git a/modelq/app/base.py b/modelq/app/base.py index 21c6e93..16b8f53 100644 --- a/modelq/app/base.py +++ b/modelq/app/base.py @@ -1,805 +1,286 @@ -import redis -import json -import functools import threading import time -import uuid -import logging -import traceback -from typing import Optional, Dict, Any +import json import socket +import logging +from typing import Optional, Type, Any, Generator -import requests # For sending error payloads to a webhook +import redis +from pydantic import BaseModel, ValidationError from modelq.app.tasks import Task -from modelq.exceptions import TaskProcessingError, TaskTimeoutError,RetryTaskException from modelq.app.middleware import Middleware +from modelq.exceptions import ( + TaskProcessingError, + TaskTimeoutError, + RetryTaskException, +) +from modelq.app.backends.base import QueueBackend +from modelq.app.backends.redis import RedisQueueBackend -from pydantic import BaseModel, ValidationError -from typing import Optional, Dict, Any, Type -import os - +# ─────────────────────── Logger Setup ──────────────────────────────── +logger = logging.getLogger("modelq") logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", ) -logger = logging.getLogger(__name__) class ModelQ: - # Constants for heartbeat and pruning intervals - HEARTBEAT_INTERVAL = 30 # seconds: how often this server updates its heartbeat - PRUNE_TIMEOUT = 300 # seconds: how long before a server is considered stale - PRUNE_CHECK_INTERVAL = 60 # seconds: how often to check for stale servers - TASK_RESULT_RETENTION = 86400 + """Core orchestrator class with pluggable queue backend (Redis by default).""" + + HEARTBEAT_INTERVAL = 30 + PRUNE_TIMEOUT = 300 + PRUNE_CHECK_INTERVAL = 60 + TASK_RESULT_RETENTION = 86_400 # 24 h + # ─────────────────────── Init / Backend ────────────────────────── def __init__( self, + backend: Optional[QueueBackend] = None, host: str = "localhost", - server_id: Optional[str] = None, - username: str = None, port: int = 6379, db: int = 0, - password: str = None, + password: str | None = None, + username: str | None = None, ssl: bool = False, ssl_cert_reqs: Any = None, - redis_client: Any = None, - max_connections: int = 50, # Limit max connections to avoid "too many clients" - webhook_url: Optional[str] = None, # Optional webhook for error logging - requeue_threshold : Optional[int] = None , + max_connections: int = 50, + server_id: str | None = None, + webhook_url: str | None = None, + requeue_threshold: int | None = None, delay_seconds: int = 30, - **kwargs, ): - if redis_client: - self.redis_client = redis_client - else: - self.redis_client = self._connect_to_redis( + if backend is None: + pool = redis.ConnectionPool( host=host, port=port, db=db, password=password, username=username, - ssl=ssl, - ssl_cert_reqs=ssl_cert_reqs, max_connections=max_connections, - **kwargs, ) + backend = RedisQueueBackend(redis.Redis(connection_pool=pool)) + logger.info("Using default Redis backend @ %s:%s", host, port) + else: + logger.info("Using custom backend: %s", backend.__class__.__name__) + self.backend = backend - self.worker_threads = [] - if server_id is None: - # Attempt to load the server_id from a local file: - server_id = self._get_or_create_server_id_file() - self.server_id = server_id - self.allowed_tasks = set() - self.middleware: Middleware = None + self.server_id: str = server_id or socket.gethostname() + self.allowed_tasks: set[str] = set() + self.middleware: Middleware | None = None self.webhook_url = webhook_url self.requeue_threshold = requeue_threshold self.delay_seconds = delay_seconds + self.worker_threads: list[threading.Thread] = [] - # Register this server in Redis (with an initial heartbeat) - self.register_server() - - def _connect_to_redis( - self, - host: str, - port: int, - db: int, - password: str, - ssl: bool, - ssl_cert_reqs: Any, - username: str, - max_connections: int = 50, - **kwargs, - ) -> redis.Redis: - pool = redis.ConnectionPool( - host=host, - port=port, - db=db, - password=password, - username=username, - # Enable TLS/SSL if needed: - # ssl=ssl, - # ssl_cert_reqs=ssl_cert_reqs, - max_connections=max_connections, - ) - return redis.Redis(connection_pool=pool) - - def _get_or_create_server_id_file(self) -> str: - return str(socket.gethostname()) - - def register_server(self): - """ - Registers this server in the 'servers' hash, including allowed tasks, - current status, and last_heartbeat timestamp. - """ - server_data = { - "allowed_tasks": list(self.allowed_tasks), - "status": "idle", - "last_heartbeat": time.time(), - } - self.redis_client.hset("servers", self.server_id, json.dumps(server_data)) - - def requeue_stuck_processing_tasks(self, threshold: float = 180.0): - """ - Re-queues any tasks that have been in 'processing' for more than 'threshold' seconds. - """ - - if self.requeue_threshold : - threshold = self.requeue_threshold - - processing_task_ids = self.redis_client.smembers("processing_tasks") - now = time.time() - - for pid in processing_task_ids: - task_id = pid.decode("utf-8") - task_data = self.redis_client.get(f"task:{task_id}") - if not task_data: - # If there's no data in Redis for that task, remove it from processing set. - self.redis_client.srem("processing_tasks", task_id) - logger.warning( - f"No record found for in-progress task {task_id}. Removing from 'processing_tasks'." - ) - continue - - task_dict = json.loads(task_data) - started_at = task_dict.get("started_at", 0) - if started_at: - if now - started_at > threshold: - logger.info( - f"Re-queuing stuck task {task_id} which has been 'processing' for {now - started_at:.2f} seconds." - ) - # Update status, queued_at, etc. - task_dict["status"] = "queued" - task_dict["queued_at"] = now - - # Store the updated dict back in Redis - self.redis_client.set(f"task:{task_id}", json.dumps(task_dict),ex=86400) - - # Push it back into ml_tasks - self.redis_client.rpush("ml_tasks", json.dumps(task_dict)) - - # Remove from processing set - self.redis_client.srem("processing_tasks", task_id) - - def prune_old_task_results(self, older_than_seconds: int = None): - """ - Deletes task result keys (stored with the prefix 'task_result:') whose - finished_at (or started_at if finished_at is not available) timestamp is older - than `older_than_seconds`. In addition, it also removes the corresponding - task key (stored with the prefix 'task:'). - """ - if older_than_seconds is None: - older_than_seconds = self.TASK_RESULT_RETENTION - - now = time.time() - keys_deleted = 0 - - # Use scan_iter to avoid blocking Redis - for key in self.redis_client.scan_iter("task_result:*"): - try: - task_json = self.redis_client.get(key) - if not task_json: - continue - task_data = json.loads(task_json) - # Use finished_at if available; otherwise fallback to started_at - timestamp = task_data.get("finished_at") or task_data.get("started_at") - if timestamp and (now - timestamp > older_than_seconds): - # Delete the task_result key - self.redis_client.delete(key) - # Extract the task id from the key and delete the corresponding task key. - key_str = key.decode("utf-8") if isinstance(key, bytes) else key - task_id = key_str.split("task_result:")[-1] - task_key = f"task:{task_id}" - self.redis_client.delete(task_key) - keys_deleted += 1 - logger.info(f"Deleted old keys: {key_str} and {task_key}") - except Exception as e: - key_str = key.decode("utf-8") if isinstance(key, bytes) else key - logger.error(f"Error processing key {key_str}: {e}") - - if keys_deleted: - logger.info(f"Pruned {keys_deleted} task(s) older than {older_than_seconds} seconds.") - - def update_server_status(self, status: str): - """ - Updates the server's status in Redis. - """ - raw_data = self.redis_client.hget("servers", self.server_id) - if not raw_data: - self.register_server() - return - server_data = json.loads(raw_data) - server_data["status"] = status - server_data["last_heartbeat"] = time.time() - self.redis_client.hset("servers", self.server_id, json.dumps(server_data)) - - def get_registered_server_ids(self) -> list: - """ - Returns a list of server_ids that are currently registered in Redis under the 'servers' hash. - """ - keys = self.redis_client.hkeys("servers") # returns raw bytes for each key - return [k.decode("utf-8") for k in keys] - - def heartbeat(self): - """ - Periodically update this server's 'last_heartbeat' in Redis. - """ - raw_data = self.redis_client.hget("servers", self.server_id) - if not raw_data: - self.register_server() - return - - data = json.loads(raw_data) - data["last_heartbeat"] = time.time() - self.redis_client.hset("servers", self.server_id, json.dumps(data)) - - def prune_inactive_servers(self, timeout_seconds: int = None): - """ - Removes servers from the 'servers' hash if they haven't sent - a heartbeat within 'timeout_seconds' seconds. - """ - if timeout_seconds is None: - timeout_seconds = self.PRUNE_TIMEOUT - - all_servers = self.redis_client.hgetall("servers") - now = time.time() - removed_count = 0 - - for server_id_bytes, data_bytes in all_servers.items(): - server_id_str = server_id_bytes.decode("utf-8") - try: - data = json.loads(data_bytes.decode("utf-8")) - last_heartbeat = data.get("last_heartbeat", 0) - if (now - last_heartbeat) > timeout_seconds: - self.redis_client.hdel("servers", server_id_str) - removed_count += 1 - logger.info(f"[Prune] Removed stale server: {server_id_str}") - except Exception as e: - logger.warning(f"[Prune] Could not parse server data for {server_id_str}: {e}") - - if removed_count > 0: - logger.info(f"[Prune] Total {removed_count} inactive servers pruned.") - - def enqueue_task(self, task_data: dict, payload: dict): - """ - Pushes a task into the 'ml_tasks' list with status=queued. - We assume 'task_data' may already have 'created_at' set. - Here, we optionally set 'queued_at' if not present. - """ - # Ensure status is 'queued' - task_data["status"] = "queued" - self.check_middleware("before_enqueue") - # If the decorator didn’t set queued_at, set it now - if "queued_at" not in task_data: - task_data["queued_at"] = time.time() - - self.redis_client.rpush("ml_tasks", json.dumps(task_data)) - self.redis_client.zadd("queued_requests", {task_data["task_id"]: task_data["queued_at"]}) - self.check_middleware("after_enqueue") - - def delete_queue(self): - self.redis_client.ltrim("ml_tasks", 1, 0) - - def enqueue_delayed_task(self, task_dict: dict, delay_seconds: int): - """ - Enqueues a task into a Redis sorted set ('delayed_tasks') to be processed later. - """ - run_at = time.time() + delay_seconds - task_json = json.dumps(task_dict) - self.redis_client.zadd("delayed_tasks", {task_json: run_at}) - logger.info(f"Delayed task {task_dict.get('task_id')} by {delay_seconds} seconds.") - - def requeue_delayed_tasks(self): - """ - Thread that periodically checks 'delayed_tasks' for tasks whose run_at time has passed, - then moves them into 'ml_tasks' for immediate processing. - """ - while True: - now = time.time() - ready_tasks = self.redis_client.zrangebyscore("delayed_tasks", 0, now) - for task_json in ready_tasks: - self.redis_client.zrem("delayed_tasks", task_json) - self.redis_client.lpush("ml_tasks", task_json) - time.sleep(1) - - def requeue_inprogress_tasks(self): - """ - On server startup, re-queue tasks that were marked 'processing' but never finished. - """ - logger.info("Checking for in-progress tasks to re-queue on startup...") - processing_task_ids = self.redis_client.smembers("processing_tasks") - for pid in processing_task_ids: - task_id = pid.decode("utf-8") - task_data = self.redis_client.get(f"task:{task_id}") - if not task_data: - self.redis_client.srem("processing_tasks", task_id) - logger.warning(f"No record found for in-progress task {task_id}. Removing it.") - continue - - task_dict = json.loads(task_data) - if task_dict.get("status") == "processing": - logger.info(f"Re-queuing task {task_id} which was in progress.") - task_dict["payload"] = task_dict.original_payload - self.redis_client.rpush("ml_tasks", json.dumps(task_dict)) - self.redis_client.srem("processing_tasks", task_id) - - def get_all_queued_tasks(self) -> list: - """ - Returns a list of tasks currently in the 'ml_tasks' list with a status of 'queued'. - """ - queued_tasks = [] - tasks_in_list = self.redis_client.lrange("ml_tasks", 0, -1) - - for t_json in tasks_in_list: - try: - t_dict = json.loads(t_json) - if t_dict.get("status") == "queued": - queued_tasks.append(t_dict) - except Exception as e: - logger.error(f"Error deserializing task from ml_tasks: {e}") - - return queued_tasks + self.backend.register_server(self.server_id, list(self.allowed_tasks)) + logger.info("Server registered with id=%s", self.server_id) + # ─────────────────────── Decorator API ─────────────────────────── def task( self, - task_class=Task, - timeout: Optional[int] = None, + task_class: Type[Task] = Task, + timeout: int | None = None, stream: bool = False, retries: int = 0, - schema: Optional[Type[BaseModel]] = None, # ▶ pydantic - returns: Optional[Type[BaseModel]] = None, # ▶ pydantic + schema: Type[BaseModel] | None = None, + returns: Type[BaseModel] | None = None, ): def decorator(func): - # make the schema classes discoverable at run time - func._mq_schema = schema # ▶ pydantic - func._mq_returns = returns # ▶ pydantic + func._mq_schema = schema + func._mq_returns = returns - @functools.wraps(func) def wrapper(*args, **kwargs): - # --------------------------- PRODUCER-SIDE VALIDATION - if schema is not None: # ▶ pydantic + # ── Input validation ── + if schema is not None: try: - # allow either a ready-made model instance - # or raw kwargs/args that build one - if len(args) == 1 and isinstance(args[0], schema): - validated = args[0] - else: - validated = schema(*args, **kwargs) - except ValidationError as ve: - raise TaskProcessingError( - func.__name__, f"Input validation failed – {ve}" + instance = ( + args[0] if len(args) == 1 and isinstance(args[0], schema) else schema(*args, **kwargs) ) - payload_data = validated.model_dump(mode="json") # zero-copy - args, kwargs = (), {} # we’ll carry payload in kwargs only + payload_data = instance.model_dump(mode="json") + args, kwargs = (), {} + except ValidationError as ve: + logger.error("Validation failed for task %s: %s", func.__name__, ve) + raise TaskProcessingError(func.__name__, f"Input validation failed – {ve}") else: payload_data = {"args": args, "kwargs": kwargs} payload = { - "data": payload_data, # ▶ pydantic – typed or raw + "data": payload_data, "timeout": timeout, "stream": stream, "retries": retries, } - task = task_class(task_name=func.__name__, payload=payload) - if stream: - task.stream = True - - task_dict = task.to_dict() - now_ts = time.time() - task_dict["created_at"] = now_ts - task_dict["queued_at"] = now_ts - - self.enqueue_task(task_dict, payload=payload) - self.redis_client.set(f"task:{task.task_id}", - json.dumps(task_dict), - ex=86400) + now = time.time() + task_dict = task.to_dict() | {"created_at": now, "queued_at": now, "status": "queued"} + + self.backend.enqueue_task(task_dict) + logger.info("Enqueued task %s (id=%s)", task.task_name, task.task_id) + + # attach helpers for convenient retrieval + task.backend = self.backend + task._modelq_ref = self + task.result_blocking = lambda timeout=None, returns=None: task.get_result( + self.backend, + timeout=timeout, + returns=returns, + modelq_ref=self, + ) return task + setattr(self, func.__name__, func) self.allowed_tasks.add(func.__name__) - self.register_server() + self.backend.register_server(self.server_id, list(self.allowed_tasks)) + logger.debug("Task registered: %s", func.__name__) return wrapper + return decorator + # ─────────────────────── Worker Management ─────────────────────── def start_workers(self, no_of_workers: int = 1): - """ - Starts worker threads to pop tasks from 'ml_tasks' and process them. - Also starts: - - a thread for re-queuing delayed tasks - - a heartbeat thread - - a pruning thread - """ - # Avoid restarting if workers are already running - if any(thread.is_alive() for thread in self.worker_threads): + if any(t.is_alive() for t in self.worker_threads): + logger.warning("Workers already running — skipping start") return - self.check_middleware("before_worker_boot") - - # 1) Delayed re-queue thread - requeue_thread = threading.Thread(target=self.requeue_delayed_tasks, daemon=True) - requeue_thread.start() - self.worker_threads.append(requeue_thread) - - # 2) Heartbeat thread - heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True) - heartbeat_thread.start() - self.worker_threads.append(heartbeat_thread) - - # 3) Pruning thread - pruning_thread = threading.Thread(target=self._pruning_loop, daemon=True) - pruning_thread.start() - self.worker_threads.append(pruning_thread) - - # 4) Worker threads - def worker_loop(worker_id): - self.check_middleware("after_worker_boot") - while True: - try: - self.update_server_status(f"worker_{worker_id}: idle") - task_data = self.redis_client.blpop("ml_tasks") # blocks until a task is available - if not task_data: - continue - - self.update_server_status(f"worker_{worker_id}: busy") - _, task_json = task_data - task_dict = json.loads(task_json) - task = Task.from_dict(task_dict) - - # Mark task as 'processing' - added = self.redis_client.sadd("processing_tasks", task.task_id) - if added == 0: - logger.warning( - f"Task {task.task_id} is already being processed. Skipping duplicate." - ) - continue - task.status = "processing" - - # Set started_at - task_dict["started_at"] = time.time() - - # Update in Redis - self.redis_client.set(f"task:{task.task_id}", json.dumps(task_dict),ex=86400) - - if task.task_name in self.allowed_tasks: - try: - logger.info(f"Worker {worker_id} started processing: {task.task_name}") - start_time = time.time() - self.process_task(task) - end_time = time.time() - logger.info( - f"Worker {worker_id} finished {task.task_name} " - f"in {end_time - start_time:.2f} seconds" - ) - - except TaskProcessingError as e: - logger.error( - f"Worker {worker_id} encountered a TaskProcessingError: {e}" - ) - if task.payload.get("retries", 0) > 0: - new_task_dict = task.to_dict() - new_task_dict["payload"] = task.original_payload - new_task_dict["payload"]["retries"] -= 1 - self.enqueue_delayed_task(new_task_dict, delay_seconds=self.delay_seconds) - - except Exception as e: - logger.error( - f"Worker {worker_id} encountered an unexpected error: {e}" - ) - if task.payload.get("retries", 0) > 0: - new_task_dict = task.to_dict() - new_task_dict["payload"] = task.original_payload - new_task_dict["payload"]["retries"] -= 1 - self.enqueue_delayed_task(new_task_dict, delay_seconds=self.delay_seconds) - else: - # If task is not allowed on this server, re-queue it - logger.warning( - f"Worker {worker_id} cannot process task {task.task_name}, re-queueing..." - ) - self.redis_client.rpush("ml_tasks", task_json) - - except Exception as e: - logger.error( - f"Worker {worker_id} crashed with error: {e}. Restarting worker..." - ) - finally: - self.check_middleware("before_worker_shutdown") - self.check_middleware("after_worker_shutdown") - - for i in range(no_of_workers): - worker_thread = threading.Thread(target=worker_loop, args=(i,), daemon=True) - worker_thread.start() - self.worker_threads.append(worker_thread) - - task_names = ", ".join(self.allowed_tasks) if self.allowed_tasks else "No tasks registered" - logger.info( - f"ModelQ workers started with {no_of_workers} worker(s). " - f"Connected to Redis at {self.redis_client.connection_pool.connection_kwargs['host']}:" - f"{self.redis_client.connection_pool.connection_kwargs['port']}. " - f"Registered tasks: {task_names}" - ) + logger.info("Booting %d worker(s)…", no_of_workers) + + # supportive threads + t_delay = threading.Thread(target=self._delay_loop, daemon=True, name="mq-delay") + t_delay.start(); self.worker_threads.append(t_delay) + t_hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="mq-heartbeat") + t_hb.start(); self.worker_threads.append(t_hb) + t_prune = threading.Thread(target=self._prune_loop, daemon=True, name="mq-prune") + t_prune.start(); self.worker_threads.append(t_prune) + + # worker threads + for wid in range(no_of_workers): + th = threading.Thread(target=self._worker_loop, args=(wid,), daemon=True, name=f"mq-worker-{wid}") + th.start(); self.worker_threads.append(th) + + logger.info("ModelQ online — workers=%d server_id=%s", no_of_workers, self.server_id) + + # ─────────────────────── Internal Loops ────────────────────────── + def _worker_loop(self, wid: int): + logger.debug("Worker %d spawned", wid) + while True: + task_dict = self.backend.dequeue_task() + if not task_dict: + continue + task = Task.from_dict(task_dict,backend=self.backend) + if not self.backend.mark_processing(task.task_id): + continue # duplicate + logger.info("[W%d] Started %s (id=%s)", wid, task.task_name, task.task_id) + try: + self._execute_task(task) + logger.info("[W%d] Completed %s (id=%s)", wid, task.task_name, task.task_id) + except Exception as exc: + logger.error("[W%d] Failed %s (id=%s): %s", wid, task.task_name, task.task_id, exc) + finally: + self.backend.unmark_processing(task.task_id) + + def _delay_loop(self): + while True: + released = self.backend.dequeue_ready_delayed_tasks() + if released: + logger.debug("Moved %d delayed task(s) to queue", len(released)) + time.sleep(1) def _heartbeat_loop(self): - """ - Continuously updates the heartbeat for this server. - """ while True: - self.heartbeat() + self.backend.update_server_status(self.server_id, "alive") + logger.debug("Heartbeat sent for %s", self.server_id) time.sleep(self.HEARTBEAT_INTERVAL) - def _pruning_loop(self): - """ - Continuously prunes servers that have not updated their heartbeat in a while. - """ + def _prune_loop(self): while True: - self.prune_inactive_servers(timeout_seconds=self.PRUNE_TIMEOUT) - self.requeue_stuck_processing_tasks(threshold=180) - self.prune_old_task_results(older_than_seconds=self.TASK_RESULT_RETENTION) + pruned = self.backend.prune_dead_servers(self.PRUNE_TIMEOUT) + if pruned: + logger.warning("Pruned stale servers: %s", pruned) + removed = self.backend.prune_old_results(self.TASK_RESULT_RETENTION) + if removed: + logger.info("Pruned %d old task result(s)", removed) time.sleep(self.PRUNE_CHECK_INTERVAL) - def check_middleware(self, middleware_event: str,task: Optional[Task] = None, error: Optional[Exception] = None): - """ - Hooks into the Middleware lifecycle if a Middleware instance is attached. - """ - if self.middleware: - self.middleware.execute(event=middleware_event,task=task, error=error) - - def process_task(self, task: Task) -> None: - """ - Processes the task by invoking the registered function. Handles timeouts, streaming, - error logging, and error reporting to a webhook. - We'll set finished_at on success or fail. - """ - try: - if task.task_name not in self.allowed_tasks: - task.status = "failed" - task.result = "Task not allowed on this server." - self._store_final_task_state(task) - logger.error(f"Task {task.task_name} is not allowed on this server.") - raise TaskProcessingError(task.task_name, "Task not allowed") - - task_function = getattr(self, task.task_name, None) - if not task_function: - task.status = "failed" - task.result = "Task function not found" - self._store_final_task_state(task) - logger.error(f"Task {task.task_name} failed - function not found.") - raise TaskProcessingError(task.task_name, "Task function not found") - - # ---- New: Check for Pydantic schema - schema_cls = getattr(task_function, "_mq_schema", None) - return_cls = getattr(task_function, "_mq_returns", None) - - # ---- Prepare args/kwargs based on schema - if schema_cls is not None: - try: - # Accept either dict or JSON-serialized dict - payload_data = task.payload["data"] - if isinstance(payload_data, str): - import json - payload_data = json.loads(payload_data) - validated_in = schema_cls(**payload_data) - except Exception as ve: - task.status = "failed" - task.result = f"Input validation failed – {ve}" - self._store_final_task_state(task, success=False) - logger.error(f"[ModelQ] Input validation failed: {ve}") - raise TaskProcessingError(task.task_name, f"Input validation failed: {ve}") - call_args = (validated_in,) - call_kwargs = {} - else: - # Legacy: no schema - call_args = tuple(task.payload['data'].get("args", ())) - call_kwargs = dict(task.payload['data'].get("kwargs", {})) - timeout = task.payload.get("timeout", None) - stream = task.payload.get("stream", False) - - logger.info( - f"Processing task: {task.task_name} " - f"with args: {call_args}, kwargs: {call_kwargs}" - ) - - if stream: - # Stream results - for result in task_function(*call_args, **call_kwargs): - import json - task.status = "in_progress" - self.redis_client.xadd( - f"task_stream:{task.task_id}", - {"result": json.dumps(result, default=str)} - ) - # Once streaming is done - task.status = "completed" - self.redis_client.expire(f"task_stream:{task.task_id}", 3600) - self._store_final_task_state(task, success=True) - else: - # Standard execution with optional timeout - if timeout: - result = self._run_with_timeout( - task_function, timeout, - *call_args, **call_kwargs - ) - else: - result = task_function( - *call_args, **call_kwargs - ) - - # ---- New: Output validation for standard result - if return_cls is not None: - try: - if not isinstance(result, return_cls): - result = return_cls(**(result if isinstance(result, dict) else result.__dict__)) - except Exception as ve: - task.status = "failed" - task.result = f"Output validation failed – {ve}" - self._store_final_task_state(task, success=False) - logger.error(f"[ModelQ] Output validation failed: {ve}") - raise TaskProcessingError(task.task_name, f"Output validation failed: {ve}") - - # When you set `task.result` (in process_task), use this logic: - if isinstance(result, BaseModel): - # Pydantic object: store as dict, not string! - task.result = result.model_dump(mode="json") - elif isinstance(result, (dict, list, int, float, bool)): - task.result = result - # only images as base64 string - else: - task.result = str(result) - - task.status = "completed" - self._store_final_task_state(task, success=True) - - logger.info(f"Task {task.task_name} completed successfully.") - - except RetryTaskException as e: - logger.warning(f"Task {task.task_name} requested retry: {e}") - new_task_dict = task.to_dict() - new_task_dict["payload"] = task.original_payload - self.enqueue_delayed_task(new_task_dict, delay_seconds=self.delay_seconds) - except Exception as e: - # Mark as failed - task.status = "failed" - task.result = str(e) - self._store_final_task_state(task, success=False) - - # 1) Log to file - self.log_task_error_to_file(task, e) - self.check_middleware("on_error", task=task, error=e) - - # 2) Webhook (if configured) - self.post_error_to_webhook(task, e) - logger.error(f"Task {task.task_name} failed with error: {e}") - raise TaskProcessingError(task.task_name, str(e)) - - finally: - self.redis_client.srem("processing_tasks", task.task_id) - - - def _store_final_task_state(self, task: Task, success: bool): - """ - Persists the final status/result of the task in Redis, adding finished_at. - """ - task_dict = task.to_dict() - - # Mark finished_at - task_dict["finished_at"] = time.time() - - self.redis_client.set( - f"task_result:{task.task_id}", - json.dumps(task_dict), - ex=3600, - ) - self.redis_client.set( - f"task:{task.task_id}", - json.dumps(task_dict), - ex=86400 - ) - - - def _run_with_timeout(self, func, timeout, *args, **kwargs): - """ - Runs the given function with a threading-based timeout. - If still alive after `timeout` seconds, raises TaskTimeoutError. - """ - result = [None] - exception = [None] + # ─────────────────────── Task Execution ────────────────────────── + def _execute_task(self, task: Task): + func = getattr(self, task.task_name, None) + if not func: + raise TaskProcessingError(task.task_name, "Task function not found") + + schema_cls = getattr(func, "_mq_schema", None) + return_cls = getattr(func, "_mq_returns", None) + + # prepare args / kwargs + if schema_cls: + payload_data = task.payload["data"] + if isinstance(payload_data, str): + payload_data = json.loads(payload_data) + call_args = (schema_cls(**payload_data),) + call_kwargs = {} + else: + call_args = tuple(task.payload["data"].get("args", ())) + call_kwargs = dict(task.payload["data"].get("kwargs", {})) - def target(): - try: - result[0] = func(*args, **kwargs) - except Exception as ex: - exception[0] = ex - - thread = threading.Thread(target=target) - thread.start() - thread.join(timeout) - - if thread.is_alive(): - logger.error(f"Task exceeded timeout of {timeout} seconds.") - raise TaskTimeoutError(f"Task exceeded timeout of {timeout} seconds") - if exception[0]: - raise exception[0] - return result[0] - - def get_task_status(self, task_id: str) -> Optional[str]: - """ - Returns the stored status of a given task_id. - """ - task_data = self.redis_client.get(f"task:{task_id}") - if task_data: - return json.loads(task_data).get("status") - return None - - def log_task_error_to_file(self, task: Task, exc: Exception, file_path="modelq_errors.log"): - """ - Logs detailed error info to a specified file, with dashes before and after. - """ - error_trace = traceback.format_exc() - log_data = { - "task_id": task.task_id, - "task_name": task.task_name, - "payload": task.payload, - "error_message": str(exc), - "traceback": error_trace, - "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), - } - with open(file_path, "a", encoding="utf-8") as f: - f.write("----\n") - f.write(json.dumps(log_data, indent=2)) - f.write("\n----\n") - - def post_error_to_webhook(self, task: Task, exc: Exception): - """ - Non-blocking method to POST a detailed error message to the configured webhook. - """ - if not self.webhook_url: + # handle streaming tasks + if task.payload.get("stream"): + self._run_streaming_task(task, func, call_args, call_kwargs) return - full_tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) - payload_str = json.dumps(task.payload, indent=2) - - content_str = ( - f"**Task Name**: {task.task_name}\n" - f"**Task ID**: {task.task_id}\n" - f"**Payload**:\n```json\n{payload_str}\n```\n" - f"**Error Message**: {exc}\n" - f"**Traceback**:\n```{full_tb}```" - ) - - t = threading.Thread( - target=self._post_error_to_webhook_sync, - args=(content_str,), - daemon=True - ) - t.start() - - def _post_error_to_webhook_sync(self, content_str: str): - payload = {"content": content_str} + # run with optional timeout try: - resp = requests.post(self.webhook_url, json=payload, timeout=10) - if resp.status_code >= 400: - logger.error( - f"Failed to POST error to webhook. " - f"Status code: {resp.status_code}, Response: {resp.text}" - ) - except Exception as e2: - logger.error(f"Exception while sending error to webhook: {e2}") - - def remove_task_from_queue(self, task_id: str) -> bool: - """ - Removes a task from the 'ml_tasks' queue using its task_id. - Returns True if the task was found and removed, False otherwise. - """ - tasks = self.redis_client.lrange("ml_tasks", 0, -1) - removed = False - for task_json in tasks: + if t := task.payload.get("timeout"): + result = self._run_with_timeout(func, t, *call_args, **call_kwargs) + else: + result = func(*call_args, **call_kwargs) + + if return_cls and not isinstance(result, return_cls): + result = return_cls(**(result if isinstance(result, dict) else result.__dict__)) + + serialized_result = result.model_dump(mode="json") if isinstance(result, BaseModel) else result + task.status = "completed"; task.result = serialized_result + self.backend.save_task_state(task.task_id, task.to_dict(), result=True) + except RetryTaskException: + logger.warning("Task %s requested retry", task.task_id) + task.payload["retries"] -= 1 + self.backend.enqueue_delayed_task(task.to_dict(), self.delay_seconds) + except Exception as exc: + task.status = "failed"; task.result = str(exc) + self.backend.save_task_state(task.task_id, task.to_dict(), result=False) + logger.exception("Task %s failed: %s", task.task_id, exc) + raise + + # ─────────────────────── Streaming Helper ──────────────────────── + def _run_streaming_task(self, task: Task, func, call_args, call_kwargs): + if not isinstance(self.backend, RedisQueueBackend): + raise NotImplementedError("Streaming tasks currently require Redis backend") + redis_client = self.backend.redis + stream_key = f"task_stream:{task.task_id}" + try: + for chunk in func(*call_args, **call_kwargs): + redis_client.xadd(stream_key, {"result": json.dumps(chunk, default=str)}) + task.status = "completed"; task.result = ""; + redis_client.expire(stream_key, 3600) + self.backend.save_task_state(task.task_id, task.to_dict(), result=True) + except Exception as exc: + task.status = "failed"; task.result = str(exc) + self.backend.save_task_state(task.task_id, task.to_dict(), result=False) + logger.exception("Streaming task %s failed: %s", task.task_id, exc) + raise + + # ─────────────────────── Helpers ───────────────────────────────── + @staticmethod + def _run_with_timeout(func, timeout, *args, **kwargs): + res = [None]; exc = [None] + + def target(): try: - task_dict = json.loads(task_json) - if task_dict.get("task_id") == task_id: - self.redis_client.lrem("ml_tasks", 1, task_json) - self.redis_client.zrem("queued_requests", task_id) - removed = True - logger.info(f"Removed task {task_id} from queue.") - break + res[0] = func(*args, **kwargs) except Exception as e: - logger.error(f"Failed to process task while trying to remove: {e}") - return removed - \ No newline at end of file + exc[0] = e + + th = threading.Thread(target=target) + th.start(); th.join(timeout) + if th.is_alive(): + raise TaskTimeoutError(func.__name__, timeout) \ No newline at end of file diff --git a/modelq/app/tasks/base.py b/modelq/app/tasks/base.py index b4cc9eb..cf19881 100644 --- a/modelq/app/tasks/base.py +++ b/modelq/app/tasks/base.py @@ -1,34 +1,52 @@ import uuid import time import json -import redis import base64 -from typing import Any, Optional, Generator -from modelq.exceptions import TaskTimeoutError, TaskProcessingError -from PIL import Image, PngImagePlugin import io import copy -from typing import Type +import logging +from typing import Any, Optional, Generator, Type + +from PIL import Image, PngImagePlugin + +from modelq.exceptions import TaskTimeoutError, TaskProcessingError +from modelq.app.backends.redis import RedisQueueBackend # for stream helper + +logger = logging.getLogger("modelq.task") + class Task: - def __init__(self, task_name: str, payload: dict, timeout: int = 15): - self.task_id = str(uuid.uuid4()) - self.task_name = task_name - self.payload = payload + """Light‑weight DTO representing a single ModelQ job.""" + + def __init__( + self, + task_name: str, + payload: dict, + timeout: int = 15, + *, + backend: Optional[Any] = None, + ): + self.task_id: str = str(uuid.uuid4()) + self.task_name: str = task_name + self.payload: dict = payload self.original_payload = copy.deepcopy(payload) - self.status = "queued" - self.result = None - - # New timestamps: - self.created_at = time.time() # When Task object is instantiated - self.queued_at = None # When task is enqueued in Redis - self.started_at = None # When worker actually starts it - self.finished_at = None # When task finishes (success or fail) - - self.timeout = timeout - self.stream = False - self.combined_result = "" + # runtime / bookkeeping + self.status: str = "queued" + self.result: Any = None + self.created_at: float = time.time() + self.queued_at: Optional[float] = None + self.started_at: Optional[float] = None + self.finished_at: Optional[float] = None + + self.timeout: int = timeout + self.stream: bool = False + self.combined_result: str = "" + + # store queue backend reference for later use (get_result / get_stream) + self.backend: Optional[Any] = backend + + # ─────────────────────── Serialization ────────────────────────── def to_dict(self): return { "task_id": self.task_id, @@ -44,142 +62,98 @@ def to_dict(self): } @staticmethod - def from_dict(data: dict) -> "Task": - task = Task(task_name=data["task_name"], payload=data["payload"]) - task.task_id = data["task_id"] - task.status = data["status"] - task.result = data.get("result") - - # Load timestamps if present - task.created_at = data.get("created_at") - task.queued_at = data.get("queued_at") - task.started_at = data.get("started_at") - task.finished_at = data.get("finished_at") + def from_dict(data: dict, *, backend: Optional[Any] = None) -> "Task": + t = Task(data["task_name"], data["payload"], backend=backend) + t.task_id = data["task_id"] + t.status = data.get("status", "queued") + t.result = data.get("result") + t.created_at = data.get("created_at") + t.queued_at = data.get("queued_at") + t.started_at = data.get("started_at") + t.finished_at = data.get("finished_at") + t.stream = data.get("stream", False) + return t + + # ─────────────────────── Streaming (Redis only for now) ────────── + def get_stream(self, backend: Any | None = None) -> Generator[Any, None, None]: + """Yield incremental results for streaming tasks (Redis backend only).""" + backend = backend or self.backend + if backend is None: + raise ValueError("Backend reference required for streaming") + if not isinstance(backend, RedisQueueBackend): + raise NotImplementedError("Streaming supported only on Redis backend right now") + redis_client = backend.redis - task.stream = data.get("stream", False) - return task - - def _convert_to_string(self, data: Any) -> str: - """ - Converts data to a string representation. If the data is a PIL image, - encode it as a base64 PNG. - """ - try: - if isinstance(data, (dict, list, int, float, bool)): - return json.dumps(data) - elif isinstance(data, (Image.Image, PngImagePlugin.PngImageFile)): - buffered = io.BytesIO() - data.save(buffered, format="PNG") - return "data:image/png;base64," + base64.b64encode( - buffered.getvalue() - ).decode("utf-8") - return str(data) - except TypeError: - return str(data) - - def get_stream(self, redis_client: redis.Redis) -> Generator[Any, None, None]: - """ - Generator to yield results from a streaming task. - Continuously reads from a Redis stream and stops when - the task is completed or failed. - """ stream_key = f"task_stream:{self.task_id}" last_id = "0" - completed = False - - while not completed: - # block=1000 => block for up to 1s, count=10 => max 10 messages + while True: results = redis_client.xread({stream_key: last_id}, block=1000, count=10) - if results: - for _, messages in results: - for message_id, message_data in messages: - # print(message_data) - result = json.loads(message_data[b"result"].decode("utf-8")) - yield result - last_id = message_id - # Append to combined_result - self.combined_result += result - - # Check if the task is finished or failed - task_json = redis_client.get(f"task_result:{self.task_id}") - if task_json: - task_data = json.loads(task_json) - if task_data.get("status") == "completed": - completed = True - # Update local fields - self.status = "completed" - self.result = self.combined_result - elif task_data.get("status") == "failed": - error_message = task_data.get("result", "Task failed without an error message") - raise TaskProcessingError( - task_data.get("task_name", self.task_name), - error_message - ) - - return - + for _, msgs in results or []: + for msg_id, data in msgs: + payload = json.loads(data[b"result"].decode()) + self.combined_result += payload if isinstance(payload, str) else json.dumps(payload) + yield payload + last_id = msg_id + state = backend.load_task_state(self.task_id) + if state and state.get("status") in {"completed", "failed"}: + if state["status"] == "failed": + raise TaskProcessingError(self.task_name, state.get("result")) + self.status = "completed"; self.result = self.combined_result + break + + # ─────────────────────── Result Retrieval ─────────────────────── def get_result( self, - redis_client: redis.Redis, - timeout: int = None, + backend: Any | None = None, + timeout: Optional[int] = None, returns: Optional[Type[Any]] = None, modelq_ref: Any = None, ) -> Any: - """ - Waits for the result of the task until the timeout. - Raises TaskProcessingError if the task failed, - or TaskTimeoutError if it never completes within the timeout. - Optionally validates/deserializes the result using a Pydantic model. - """ - if not timeout: - timeout = self.timeout - - start_time = time.time() - while time.time() - start_time < timeout: - task_json = redis_client.get(f"task_result:{self.task_id}") - if task_json: - task_data = json.loads(task_json) - self.result = task_data.get("result") - self.status = task_data.get("status") - - if self.status == "failed": - error_message = self.result or "Task failed without an error message" - raise TaskProcessingError( - task_data.get("task_name", self.task_name), - error_message - ) - elif self.status == "completed": - raw_result = self.result - - # Auto-detect returns schema if not given + backend = backend or self.backend + if backend is None: + raise ValueError("Backend reference required to fetch result") + timeout = timeout or self.timeout + start = time.time() + while time.time() - start < timeout: + state = backend.load_task_state(self.task_id) + if state: + status = state.get("status"); result = state.get("result") + if status == "failed": + raise TaskProcessingError(self.task_name, result or "Task failed") + if status == "completed": + # optional pydantic coercion if returns is None and modelq_ref is not None: - task_function = getattr(modelq_ref, self.task_name, None) - returns = getattr(task_function, "_mq_returns", None) - - if returns is not None: + func = getattr(modelq_ref, self.task_name, None) + returns = getattr(func, "_mq_returns", None) + if returns: try: - if isinstance(raw_result, str): + if isinstance(result, str): try: - result_data = json.loads(raw_result) + result_data = json.loads(result) except Exception: - result_data = raw_result + result_data = result else: - result_data = raw_result - + result_data = result if isinstance(result_data, dict): return returns(**result_data) - elif isinstance(result_data, returns): + if isinstance(result_data, returns): return result_data - else: - return returns.parse_obj(result_data) + return returns.parse_obj(result_data) except Exception as ve: - raise TaskProcessingError( - self.task_name, - f"Result validation failed: {ve}" - ) - else: - return raw_result - + raise TaskProcessingError(self.task_name, f"Result validation failed: {ve}") + return result time.sleep(1) - raise TaskTimeoutError(self.task_id) + + # ─────────────────────── Helpers ───────────────────────────────── + @staticmethod + def _encode_image(img: Image.Image) -> str: + buf = io.BytesIO(); img.save(buf, format="PNG") + return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() + + def _convert_to_string(self, data: Any) -> str: + if isinstance(data, (dict, list, int, float, bool)): + return json.dumps(data) + if isinstance(data, (Image.Image, PngImagePlugin.PngImageFile)): + return self._encode_image(data) + return str(data) diff --git a/test.py b/test.py new file mode 100644 index 0000000..30f035c --- /dev/null +++ b/test.py @@ -0,0 +1,63 @@ +from pydantic import BaseModel, Field +from modelq import ModelQ +from redis import Redis +from modelq.app.backends.redis import RedisQueueBackend + +class AddIn(BaseModel): + a: int = Field(ge=0) + b: int = Field(ge=0) + +class AddOut(BaseModel): + total: int + +redis_client = Redis(host="localhost", port=6379, db=0) + +backend = RedisQueueBackend(redis_client) +mq = ModelQ(backend=backend) + + +@mq.task(schema=AddIn, returns=AddOut) +def add(payload: AddIn) -> AddOut: + print(f"Processing addition: {payload.a} + {payload.b}") + time.sleep(10) # Simulate some processing time + return AddOut(total=payload.a + payload.b) + +@mq.task() +def sub(a: int, b: int): + print(f"Processing subtraction: {a} - {b}") + return a - b + +@mq.task() +def image_task(params: dict): + print(f"Processing image task with params: {params}") + # Simulate image processing + return "Image processed successfully" + +job = add(a=3, b=4) # ✨ validated on the spot + +job2 = sub(a=10, b=5) # ✨ no schema validation, just a simple task + +task = image_task({"image": "example.png"}) # ✨ no schema validation, just a simple task +task2 = image_task(params={"image": "example.png"}) +import time + +if __name__ == "__main__": + mq.start_workers() + + # Keep the worker running indefinitely + try: + while True: + output = job.get_result(returns=AddOut) + + print(f"Result of addition: {output}") + print(type(output)) + print(f"Result of addition (total): {output.total}") + + output2 = job2.get_result() + print(f"Result of subtraction: {output2}") + + output3 = task.get_result(mq.redis_client) + print(f"Result of image task: {output3}") + time.sleep(1) + except KeyboardInterrupt: + print("\nGracefully shutting down...") \ No newline at end of file From e4431b920419f6bfb5470e98c1aa31de57ce4e5a Mon Sep 17 00:00:00 2001 From: Tanmay patil Date: Sat, 7 Jun 2025 15:02:43 +0530 Subject: [PATCH 3/3] some optimisation to redis backend --- modelq/app/backends/rabbitmq/backend.py | 0 modelq/app/backends/redis/backend.py | 94 ++++++++++++++++++------- modelq/app/base.py | 5 +- test.py | 4 +- 4 files changed, 76 insertions(+), 27 deletions(-) create mode 100644 modelq/app/backends/rabbitmq/backend.py diff --git a/modelq/app/backends/rabbitmq/backend.py b/modelq/app/backends/rabbitmq/backend.py new file mode 100644 index 0000000..e69de29 diff --git a/modelq/app/backends/redis/backend.py b/modelq/app/backends/redis/backend.py index c691620..918cec3 100644 --- a/modelq/app/backends/redis/backend.py +++ b/modelq/app/backends/redis/backend.py @@ -6,23 +6,71 @@ class RedisQueueBackend(QueueBackend): + def __init__(self, redis_client: redis.Redis): self.redis = redis_client + self._register_scripts() # ─────────────────────── Task Queue ─────────────────────────────── + def _register_scripts(self) -> None: + # enqueue_task (list + sorted-set in one shot) + self._enqueue_sha = self.redis.script_load(""" + -- KEYS[1] = ml_tasks, KEYS[2] = queued_requests + -- ARGV[1] = full task JSON, ARGV[2] = task_id, ARGV[3] = queued_at + redis.call('RPUSH', KEYS[1], ARGV[1]) + redis.call('ZADD', KEYS[2], ARGV[3], ARGV[2]) + return 1 + """) + + # dequeue_ready_delayed_tasks (atomically move due jobs) + self._promote_delayed_sha = self.redis.script_load(""" + -- KEYS[1] = delayed_tasks, KEYS[2] = ml_tasks, ARGV[1] = now + local ready = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1]) + if #ready == 0 then return {} end + redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, ARGV[1]) + for i=1,#ready do + redis.call('LPUSH', KEYS[2], ready[i]) + end + return ready -- array of JSON strings + """) + + # remove_task_from_queue (search & delete server-side) + self._remove_sha = self.redis.script_load(""" + -- KEYS[1] = ml_tasks, KEYS[2] = queued_requests, ARGV[1] = task_id + local len = redis.call('LLEN', KEYS[1]) + for i=0, len-1 do + local item = redis.call('LINDEX', KEYS[1], i) + if item then + local ok, obj = pcall(cjson.decode, item) + if ok and obj['task_id'] == ARGV[1] then + redis.call('LSET', KEYS[1], i, '__DEL__') + redis.call('LREM', KEYS[1], 0, '__DEL__') + redis.call('ZREM', KEYS[2], ARGV[1]) + return 1 + end + end + end + return 0 + """) def enqueue_task(self, task_data: dict) -> None: task_data["status"] = "queued" - self.redis.rpush("ml_tasks", json.dumps(task_data)) - self.redis.zadd("queued_requests", {task_data["task_id"]: task_data["queued_at"]}) - + self.redis.evalsha( + self._enqueue_sha, + 2, # number of KEYS + "ml_tasks", "queued_requests", + json.dumps(task_data), # ARGV[1] + task_data["task_id"], # ARGV[2] + task_data["queued_at"], # ARGV[3] + ) + def dequeue_task(self, timeout: Optional[int] = None) -> Optional[dict]: - data = self.redis.blpop("ml_tasks", timeout=timeout or 5) - if data: - _, task_json = data - return json.loads(task_json) + rv = self.redis.blpop("ml_tasks", timeout or 5) + if rv: + _, raw = rv + return json.loads(raw) return None - + def requeue_task(self, task_data: dict) -> None: self.redis.rpush("ml_tasks", json.dumps(task_data)) @@ -31,12 +79,12 @@ def enqueue_delayed_task(self, task_data: dict, delay_seconds: int) -> None: self.redis.zadd("delayed_tasks", {json.dumps(task_data): run_at}) def dequeue_ready_delayed_tasks(self) -> list: - now = time.time() - tasks = self.redis.zrangebyscore("delayed_tasks", 0, now) - for task_json in tasks: - self.redis.zrem("delayed_tasks", task_json) - self.redis.lpush("ml_tasks", task_json) - return [json.loads(t) for t in tasks] + # Single RTT instead of ZRANGEBYSCORE + loop in Python + ready = self.redis.evalsha( + self._promote_delayed_sha, + 2, "delayed_tasks", "ml_tasks", time.time() + ) + return [json.loads(j) for j in ready] def flush_queue(self) -> None: self.redis.ltrim("ml_tasks", 1, 0) @@ -45,22 +93,20 @@ def flush_queue(self) -> None: def save_task_state(self, task_id: str, task_data: dict, result: bool) -> None: task_data["finished_at"] = time.time() - self.redis.set(f"task_result:{task_id}", json.dumps(task_data), ex=3600) - self.redis.set(f"task:{task_id}", json.dumps(task_data), ex=86400) + with self.redis.pipeline() as pipe: # tiny but measurable + pipe.set(f"task_result:{task_id}", json.dumps(task_data), ex=3600) + pipe.set(f"task:{task_id}", json.dumps(task_data), ex=86400) + pipe.execute() def load_task_state(self, task_id: str) -> Optional[dict]: data = self.redis.get(f"task:{task_id}") return json.loads(data) if data else None def remove_task_from_queue(self, task_id: str) -> bool: - tasks = self.redis.lrange("ml_tasks", 0, -1) - for task_json in tasks: - task_dict = json.loads(task_json) - if task_dict.get("task_id") == task_id: - self.redis.lrem("ml_tasks", 1, task_json) - self.redis.zrem("queued_requests", task_id) - return True - return False + return bool(self.redis.evalsha( + self._remove_sha, + 2, "ml_tasks", "queued_requests", task_id + )) def mark_processing(self, task_id: str) -> bool: return self.redis.sadd("processing_tasks", task_id) == 1 diff --git a/modelq/app/base.py b/modelq/app/base.py index 16b8f53..a3bc618 100644 --- a/modelq/app/base.py +++ b/modelq/app/base.py @@ -283,4 +283,7 @@ def target(): th = threading.Thread(target=target) th.start(); th.join(timeout) if th.is_alive(): - raise TaskTimeoutError(func.__name__, timeout) \ No newline at end of file + raise TaskTimeoutError(func.__name__, timeout) + if exc[0]: + raise exc[0] + return res[0] \ No newline at end of file diff --git a/test.py b/test.py index 30f035c..2911cd1 100644 --- a/test.py +++ b/test.py @@ -56,8 +56,8 @@ def image_task(params: dict): output2 = job2.get_result() print(f"Result of subtraction: {output2}") - output3 = task.get_result(mq.redis_client) - print(f"Result of image task: {output3}") + # output3 = task.get_result(mq.redis_client) + # print(f"Result of image task: {output3}") time.sleep(1) except KeyboardInterrupt: print("\nGracefully shutting down...") \ No newline at end of file