diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index de62304..3976671 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -1,41 +1,28 @@ -"""Main Queue class and @task decorator. +"""Main Queue class. -The Queue class is composed from multiple mixins for organization: +The Queue class is composed from multiple mixins for organization. Methods +shared across the inspection, operations, lock, decorator, resource, event, +and lifecycle surfaces live in the corresponding mixin under +``taskito.mixins.*``. Async wrappers come from ``taskito.async_support.mixins``. -- ``QueueInspectionMixin`` (mixins.py) — read-only inspection, stats, queries -- ``QueueOperationsMixin`` (mixins.py) — write operations (cancel, purge, archive, etc.) -- ``QueueLockMixin`` (mixins.py) — distributed locking -- ``AsyncQueueMixin`` (async_support/mixins.py) — ``a*`` async wrappers for all sync methods - -The Queue class itself (this file) handles: +The Queue class itself (this file) handles only: - Constructor and storage backend initialization -- @task() and @periodic() decorators with task registration - enqueue() / enqueue_many() job submission -- run_worker() worker lifecycle (signals, heartbeat, resources) -- Event bus and webhook management +- _wrap_task() task body wrapping with hooks, middleware, proxies, resources +- Internal helpers (``_get_serializer``, ``_deserialize_payload``, + ``_get_middleware_chain``) """ from __future__ import annotations -import asyncio -import contextlib import functools -import json import logging import os -import signal -import sys -import threading -import urllib.parse -import uuid -from collections.abc import Callable, Sequence +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any -if TYPE_CHECKING: - from taskito.testing import TestMode - -from taskito._taskito import PyQueue, PyTaskConfig +from taskito._taskito import PyQueue from taskito.async_support.helpers import run_maybe_async from taskito.async_support.mixins import AsyncQueueMixin from taskito.events import EventBus, EventType @@ -43,20 +30,27 @@ from taskito.interception.built_in import build_default_registry from taskito.middleware import TaskMiddleware from taskito.mixins import ( + QueueDecoratorMixin, + QueueEventsMixin, QueueInspectionMixin, + QueueLifecycleMixin, QueueLockMixin, QueueOperationsMixin, + QueueResourceMixin, ) from taskito.proxies import ProxyRegistry, cleanup_proxies, reconstruct_proxies from taskito.proxies.built_in import register_builtin_handlers from taskito.proxies.metrics import ProxyMetrics -from taskito.resources.definition import ResourceDefinition, ResourceScope -from taskito.resources.runtime import ResourceRuntime from taskito.result import JobResult from taskito.serializers import CloudpickleSerializer, Serializer -from taskito.task import TaskWrapper from taskito.webhooks import WebhookManager +if TYPE_CHECKING: + from taskito._taskito import PyTaskConfig + from taskito.interception.metrics import InterceptionMetrics + from taskito.resources.definition import ResourceDefinition + from taskito.resources.runtime import ResourceRuntime + try: from taskito.workflows.mixins import QueueWorkflowMixin from taskito.workflows.tracker import WorkflowTracker @@ -73,24 +67,11 @@ class QueueWorkflowMixin: # type: ignore[no-redef] logger = logging.getLogger("taskito") -def _resolve_module_name(module_name: str) -> str: - """Resolve __main__ to the actual module name.""" - if module_name != "__main__": - return module_name - import sys - - main = sys.modules.get("__main__") - if main is not None: - spec = getattr(main, "__spec__", None) - if spec and spec.name: - return str(spec.name) - f = getattr(main, "__file__", None) - if f: - return str(os.path.splitext(os.path.basename(f))[0]) - return module_name - - class Queue( + QueueDecoratorMixin, + QueueResourceMixin, + QueueEventsMixin, + QueueLifecycleMixin, QueueInspectionMixin, QueueOperationsMixin, QueueLockMixin, @@ -245,7 +226,7 @@ def __init__( self._max_reconstruction_timeout = max_reconstruction_timeout # Argument interception - self._interception_metrics = None + self._interception_metrics: InterceptionMetrics | None = None if interception != "off": from taskito.interception.metrics import InterceptionMetrics @@ -278,456 +259,6 @@ def __init__( if _WORKFLOWS_AVAILABLE and hasattr(self._inner, "submit_workflow"): self._workflow_tracker = WorkflowTracker(self) - def task( - self, - name: str | None = None, - max_retries: int = 3, - retry_backoff: float = 1.0, - timeout: int = 300, - priority: int = 0, - rate_limit: str | None = None, - queue: str = "default", - circuit_breaker: dict | None = None, - retry_on: list[type[Exception]] | None = None, - dont_retry_on: list[type[Exception]] | None = None, - soft_timeout: float | None = None, - middleware: list[TaskMiddleware] | None = None, - retry_delays: list[float] | None = None, - inject: list[str] | None = None, - serializer: Serializer | None = None, - max_retry_delay: int | None = None, - max_concurrent: int | None = None, - ) -> Callable[[Callable[..., Any]], TaskWrapper]: - """Decorator to register a function as a background task. - - Args: - name: Explicit task name. Defaults to ``module.qualname``. - max_retries: Max retry attempts on failure before moving to DLQ. - retry_backoff: Base delay in seconds for exponential backoff between retries. - timeout: Max execution time in seconds before the task is killed. - priority: Priority level (higher = more urgent). - rate_limit: Rate limit string, e.g. ``"100/m"``, ``"10/s"``, ``"3600/h"``. - queue: Named queue to submit to. - circuit_breaker: Optional dict with ``threshold``, ``window`` (seconds), - and ``cooldown`` (seconds) keys. - retry_on: List of exception classes that should trigger retries. - If set, only these exceptions are retried. - dont_retry_on: List of exception classes that should never be retried. - soft_timeout: Soft timeout in seconds. Checked via ``current_job.check_timeout()``. - middleware: Per-task middleware instances (in addition to global middleware). - inject: List of resource names to inject as keyword arguments. - serializer: Per-task serializer. Falls back to the queue-level serializer. - max_retry_delay: Maximum backoff delay in seconds. Defaults to 300 - (5 minutes) if not set. - max_concurrent: Maximum number of concurrent running instances of - this task. ``None`` means no limit. - """ - - def decorator(fn: Callable) -> TaskWrapper: - task_name = name or f"{_resolve_module_name(fn.__module__)}.{fn.__qualname__}" - - # Detect Inject["name"] annotations (Phase E) - from taskito.inject import _InjectAlias - - annotation_injects: list[str] = [] - try: - import typing - - hints: dict[str, Any] = {} - with contextlib.suppress(Exception): - # get_type_hints evaluates string annotations - ns = getattr(fn, "__globals__", {}) - ns = {**ns, "Inject": __import__("taskito.inject", fromlist=["Inject"]).Inject} - hints = typing.get_type_hints(fn, globalns=ns, include_extras=True) - # Fallback: check raw annotations if get_type_hints failed - if not hints: - with contextlib.suppress(Exception): - hints = getattr(fn, "__annotations__", {}) - for _param_name, hint in hints.items(): - if isinstance(hint, _InjectAlias): - annotation_injects.append(hint.resource_name) - except Exception: - pass - - # Merge explicit inject= with annotation-detected injects - final_inject = list(inject or []) - for res_name in annotation_injects: - if res_name not in final_inject: - final_inject.append(res_name) - - # Store retry filters - if retry_on or dont_retry_on: - self._task_retry_filters[task_name] = { - "retry_on": retry_on or [], - "dont_retry_on": dont_retry_on or [], - } - - # Store per-task middleware - if middleware: - self._task_middleware[task_name] = middleware - - # Store per-task serializer - if serializer is not None: - self._task_serializers[task_name] = serializer - - # Store inject map for resource injection - if final_inject: - self._task_inject_map[task_name] = final_inject - - # Wrap the function with hooks, middleware, and context - wrapped = self._wrap_task(fn, task_name, soft_timeout) - self._task_registry[task_name] = wrapped - - cb_threshold = None - cb_window = None - cb_cooldown = None - cb_half_open_probes = None - cb_half_open_success_rate = None - if circuit_breaker: - cb_threshold = circuit_breaker.get("threshold", 5) - cb_window = circuit_breaker.get("window", 60) - cb_cooldown = circuit_breaker.get("cooldown", 300) - cb_half_open_probes = circuit_breaker.get("half_open_probes") - cb_half_open_success_rate = circuit_breaker.get("half_open_success_rate") - - # Store config for worker startup - config = PyTaskConfig( - name=task_name, - max_retries=max_retries, - retry_backoff=retry_backoff, - timeout=timeout, - priority=priority, - rate_limit=rate_limit, - queue=queue, - circuit_breaker_threshold=cb_threshold, - circuit_breaker_window=cb_window, - circuit_breaker_cooldown=cb_cooldown, - retry_delays=retry_delays, - max_retry_delay=max_retry_delay, - max_concurrent=max_concurrent, - circuit_breaker_half_open_probes=cb_half_open_probes, - circuit_breaker_half_open_success_rate=cb_half_open_success_rate, - ) - self._task_configs.append(config) - - # Return a TaskWrapper that has .delay() and .apply_async() - wrapper = TaskWrapper( - fn=fn, - queue_ref=self, - task_name=task_name, - default_priority=priority, - default_queue=queue, - default_max_retries=max_retries, - default_timeout=timeout, - inject=final_inject or None, - ) - - # Preserve function metadata - functools.update_wrapper(wrapper, fn) - - # Mark async status for native async dispatch - is_async = asyncio.iscoroutinefunction(fn) - wrapper._taskito_is_async = is_async - if is_async: - wrapper._taskito_async_fn = fn - - return wrapper - - return decorator - - def periodic( - self, - cron: str, - name: str | None = None, - args: tuple = (), - kwargs: dict | None = None, - queue: str = "default", - timezone: str | None = None, - ) -> Callable[[Callable[..., Any]], TaskWrapper]: - """Decorator to register a periodic (cron-scheduled) task. - - Args: - cron: Cron expression (6-field with seconds), e.g. ``"0 */5 * * * *"`` - for every 5 minutes. - name: Explicit task name. Defaults to ``module.qualname``. - args: Positional arguments to pass to the task on each run. - kwargs: Keyword arguments to pass to the task on each run. - queue: Named queue to submit to. - """ - - def decorator(fn: Callable) -> TaskWrapper: - # If fn is a WorkflowProxy (from @queue.workflow()), create a - # launcher task that submits the workflow on each cron trigger. - if getattr(fn, "_is_workflow_proxy", False): - proxy: Any = fn - launcher_name = f"_wf_launcher_{proxy._name}" - - @self.task(name=launcher_name, queue=queue) - def _wf_launcher() -> str: - run = proxy.submit() - return f"submitted workflow run {run.id}" - - payload = self._get_serializer(launcher_name).dumps(((), {})) - self._periodic_configs.append( - { - "name": launcher_name, - "task_name": launcher_name, - "cron_expr": cron, - "payload": payload, - "queue": queue, - "timezone": timezone, - } - ) - return fn # type: ignore[return-value] - - # Register as a normal task first - wrapper = self.task(name=name, queue=queue)(fn) - - # Store periodic config for registration at worker startup - payload = self._get_serializer(wrapper.name).dumps((args, kwargs or {})) - self._periodic_configs.append( - { - "name": name or f"{_resolve_module_name(fn.__module__)}.{fn.__qualname__}", - "task_name": wrapper.name, - "cron_expr": cron, - "payload": payload, - "queue": queue, - "timezone": timezone, - } - ) - - return wrapper - - return decorator - - # -- Hooks / middleware -- - - def before_task(self, fn: Callable) -> Callable: - """Register a hook called before each task executes. - - Args: - fn: Callback with signature ``fn(task_name, args, kwargs)``. - """ - self._hooks["before_task"].append(fn) - return fn - - def after_task(self, fn: Callable) -> Callable: - """Register a hook called after each task completes or fails. - - Args: - fn: Callback with signature - ``fn(task_name, args, kwargs, result, error)``. - """ - self._hooks["after_task"].append(fn) - return fn - - def on_success(self, fn: Callable) -> Callable: - """Register a hook called when a task completes successfully. - - Args: - fn: Callback with signature - ``fn(task_name, args, kwargs, result)``. - """ - self._hooks["on_success"].append(fn) - return fn - - def on_failure(self, fn: Callable) -> Callable: - """Register a hook called when a task raises an exception. - - Args: - fn: Callback with signature - ``fn(task_name, args, kwargs, error)``. - """ - self._hooks["on_failure"].append(fn) - return fn - - # -- Worker Resources -- - - def worker_resource( - self, - name: str, - depends_on: list[str] | None = None, - teardown: Callable | None = None, - health_check: Callable | None = None, - health_check_interval: float = 0.0, - max_recreation_attempts: int = 3, - scope: str = "worker", - pool_size: int | None = None, - pool_min: int = 0, - acquire_timeout: float = 10.0, - max_lifetime: float = 3600.0, - idle_timeout: float = 300.0, - reloadable: bool = False, - frozen: bool = False, - ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - """Decorator to register a resource factory. - - Args: - name: Resource name used in ``inject=["name"]``. - depends_on: Names of resources this one depends on. - teardown: Optional callable to clean up the resource on shutdown. - health_check: Optional callable that returns truthy if healthy. - health_check_interval: Seconds between health checks (0 = disabled). - max_recreation_attempts: Max times to recreate on health failure. - scope: Resource scope — ``"worker"``, ``"task"``, ``"thread"``, - or ``"request"``. - pool_size: Pool size for task-scoped resources. - pool_min: Minimum pre-warmed instances (task scope). - acquire_timeout: Max seconds to wait for pool instance. - max_lifetime: Max seconds a pooled instance lives. - idle_timeout: Max idle seconds before eviction. - reloadable: Whether the resource can be hot-reloaded via SIGHUP. - frozen: Wrap the resource in a read-only proxy. - """ - from taskito.resources.graph import detect_cycle - - def decorator(factory: Callable[..., Any]) -> Callable[..., Any]: - self.register_resource( - ResourceDefinition( - name=name, - factory=factory, - depends_on=depends_on or [], - teardown=teardown, - health_check=health_check, - health_check_interval=health_check_interval, - max_recreation_attempts=max_recreation_attempts, - scope=ResourceScope(scope), - pool_size=pool_size, - pool_min=pool_min, - acquire_timeout=acquire_timeout, - max_lifetime=max_lifetime, - idle_timeout=idle_timeout, - reloadable=reloadable, - frozen=frozen, - ) - ) - # Validate no cycles eagerly - cycle = detect_cycle(self._resource_definitions) - if cycle is not None: - from taskito.exceptions import CircularDependencyError - - # Roll back the registration - del self._resource_definitions[name] - raise CircularDependencyError( - f"Circular dependency detected: {' -> '.join(cycle)}" - ) - return factory - - return decorator - - def register_resource(self, definition: ResourceDefinition) -> None: - """Programmatically register a resource definition. - - Args: - definition: A :class:`~taskito.resources.ResourceDefinition`. - """ - self._resource_definitions[definition.name] = definition - - def health_check(self, name: str) -> bool: - """Run a resource's health check immediately. - - Args: - name: The registered resource name. - - Returns: - True if healthy, False otherwise. - """ - runtime = self._resource_runtime - if runtime is None: - return False - defn = self._resource_definitions.get(name) - if defn is None or defn.health_check is None: - return False - try: - instance = runtime.resolve(name) - return bool(defn.health_check(instance)) - except Exception: - return False - - def load_resources(self, toml_path: str) -> None: - """Load resource definitions from a TOML file. - - Must be called before ``run_worker()``. - - Args: - toml_path: Path to the TOML configuration file. - """ - from taskito.resources.toml_config import load_resources_from_toml - - for defn in load_resources_from_toml(toml_path): - self.register_resource(defn) - - def proxy_stats(self) -> list[dict[str, Any]]: - """Return per-handler proxy reconstruction metrics.""" - return self._proxy_metrics.to_list() - - def interception_stats(self) -> dict[str, Any]: - """Return interception performance metrics.""" - if self._interception_metrics is not None: - return self._interception_metrics.to_dict() - return {} - - def register_type( - self, - python_type: type, - strategy: str, - *, - resource: str | None = None, - message: str | None = None, - converter: Callable | None = None, - type_key: str | None = None, - proxy_handler: str | None = None, - ) -> None: - """Register a custom type with the interception system. - - Args: - python_type: The type to register. - strategy: One of ``"pass"``, ``"convert"``, ``"redirect"``, - ``"reject"``, or ``"proxy"``. - resource: Resource name for ``"redirect"`` strategy. - message: Rejection reason for ``"reject"`` strategy. - converter: Converter callable for ``"convert"`` strategy. - type_key: Key for the converter reconstructor dispatch. - proxy_handler: Handler name for ``"proxy"`` strategy. - """ - if self._interceptor is None: - raise RuntimeError( - "Interception is disabled; set interception='strict' or " - "'lenient' to use register_type()" - ) - from taskito.interception.strategy import Strategy as S - - strat = S(strategy) - self._interceptor._registry.register( - python_type, - strat, - priority=15, - redirect_resource=resource, - reject_reason=message, - converter=converter, - type_key=type_key, - proxy_handler=proxy_handler, - ) - - def set_queue_rate_limit(self, queue_name: str, rate_limit: str) -> None: - """Set a rate limit for an entire queue. - - Args: - queue_name: Queue name (e.g. ``"default"``). - rate_limit: Rate limit string, e.g. ``"100/m"``, ``"10/s"``. - """ - self._queue_configs.setdefault(queue_name, {})["rate_limit"] = rate_limit - - def set_queue_concurrency(self, queue_name: str, max_concurrent: int) -> None: - """Set a maximum number of concurrent jobs for a queue. - - Args: - queue_name: Queue name (e.g. ``"default"``). - max_concurrent: Maximum number of jobs running simultaneously - from this queue. - """ - self._queue_configs.setdefault(queue_name, {})["max_concurrent"] = max_concurrent - def _get_serializer(self, task_name: str) -> Serializer: """Get the serializer for a task (per-task or queue-level fallback).""" return self._task_serializers.get(task_name, self._serializer) @@ -1057,417 +588,3 @@ def enqueue_many( pass return results - - # -- Events & Webhooks -- - - def _emit_event(self, event_type: EventType, payload: dict[str, Any]) -> None: - """Emit an event to the event bus and webhook manager.""" - self._event_bus.emit(event_type, payload) - self._webhook_manager.notify(event_type, payload) - - def on_event(self, event_type: EventType, callback: Callable[..., Any]) -> None: - """Register a callback for a job lifecycle event. - - Args: - event_type: The event type to listen for. - callback: Called with ``(event_type, payload_dict)``. - """ - self._event_bus.on(event_type, callback) - - def add_webhook( - self, - url: str, - events: list[EventType] | None = None, - headers: dict[str, str] | None = None, - secret: str | None = None, - max_retries: int = 3, - timeout: float = 10.0, - retry_backoff: float = 2.0, - ) -> None: - """Register a webhook endpoint for job events. - - Args: - url: URL to POST event payloads to. - events: Event types to subscribe to (None = all). - headers: Extra HTTP headers. - secret: HMAC-SHA256 signing secret. - max_retries: Maximum delivery attempts (default 3). - timeout: HTTP request timeout in seconds (default 10.0). - retry_backoff: Base for exponential backoff between retries (default 2.0). - """ - self._webhook_manager.add_webhook( - url, - events, - headers, - secret, - max_retries=max_retries, - timeout=timeout, - retry_backoff=retry_backoff, - ) - - # -- Worker startup -- - - def _print_banner(self, queues: list[str]) -> None: - """Print ASCII startup banner.""" - from taskito import __version__ - - banner = rf""" - _ _ _ _ -| |_ __ _ ___| | _(_) |_ ___ -| __/ _` / __| |/ / | __/ _ \ -| || (_| \__ \ <| | || (_) | - \__\__,_|___/_|\_\_|\__\___/ v{__version__} -""" - lines = [banner] - lines.append(f"> Backend: {self._backend}") - if self._backend == "sqlite": - lines.append(f"> DB: {self._db_path}") - else: - # Mask password in connection URL for display - url = self._db_url or "" - parsed_url = urllib.parse.urlparse(url) - if parsed_url.password: - masked = parsed_url._replace( - netloc=f"{parsed_url.username}:****@{parsed_url.hostname}" - + (f":{parsed_url.port}" if parsed_url.port else "") - ) - url = urllib.parse.urlunparse(masked) - lines.append(f"> DB: {url}") - lines.append(f"> Schema: {self._schema}") - lines.append(f"> Concurrency: {self._workers} (threads)") - lines.append(f"> Queues: {', '.join(queues)}") - lines.append("") - - task_names = sorted(self._task_registry.keys()) - if task_names: - lines.append("[tasks]") - for name in task_names: - lines.append(f" . {name}") - lines.append("") - - if self._periodic_configs: - lines.append("[periodic]") - for pc in self._periodic_configs: - lines.append(f" . {pc['name']} ({pc['cron_expr']})") - lines.append("") - - if self._resource_definitions: - lines.append("[resources]") - for rname, rdef in sorted(self._resource_definitions.items()): - deps = f" (depends: {', '.join(rdef.depends_on)})" if rdef.depends_on else "" - lines.append(f" . {rname}{deps}") - lines.append("") - - print("\n".join(lines)) - - def run_worker( - self, - queues: Sequence[str] | None = None, - tags: list[str] | None = None, - pool: str = "thread", - app: str | None = None, - ) -> None: - """Start the worker loop. Blocks until interrupted. - - Args: - queues: List of queue names to consume from. ``None`` consumes - from all queues. - tags: Optional tags for worker specialization / routing. - pool: Worker pool type — ``"thread"`` (default) or ``"prefork"``. - Prefork spawns child processes with independent GILs for - true parallelism on CPU-bound tasks. - app: Import path to the Queue instance (e.g. ``"myapp:queue"``). - Required when ``pool="prefork"``. - """ - if pool == "prefork": - if sys.platform == "win32": - raise NotImplementedError( - "pool='prefork' is not supported on Windows. " - "Use pool='thread' (default) or run on Linux/macOS." - ) - if not app: - raise ValueError("app= is required when pool='prefork' (e.g. app='myapp:queue')") - queue_list = list(queues) if queues else None - - # Make queue accessible from job context (for current_job.update_progress()) - from taskito.context import _set_queue_ref - - _set_queue_ref(self) - - # Register periodic tasks with Rust scheduler - for pc in self._periodic_configs: - self._inner.register_periodic( - name=pc["name"], - task_name=pc["task_name"], - cron_expr=pc["cron_expr"], - args=pc["payload"], - queue=pc["queue"], - timezone=pc.get("timezone"), - ) - - if not logging.root.handlers: - logging.basicConfig( - level=logging.INFO, - format="[%(asctime)s] %(levelname)s %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - worker_queues = queue_list or ["default"] - self._print_banner(worker_queues) - - # Initialize worker resources (before Rust dispatches tasks) - health_checker = None - if self._resource_definitions: - from taskito.resources.health import HealthChecker - - self._resource_runtime = ResourceRuntime(self._resource_definitions) - self._resource_runtime.initialize() - logger.info( - "Initialized %d resource(s): %s", - len(self._resource_definitions), - ", ".join(self._resource_runtime._init_order), - ) - health_checker = HealthChecker(self._resource_runtime) - health_checker.start() - - # Set up signal handlers for graceful shutdown (only in main thread) - is_main = threading.current_thread() is threading.main_thread() - original_sigint = None - original_sigterm = None - - if is_main: - original_sigint = signal.getsignal(signal.SIGINT) - original_sigterm = signal.getsignal(signal.SIGTERM) - - def shutdown_handler(signum: int, frame: Any) -> None: - logger.info("Warm shutdown (waiting for running tasks to finish)...") - with contextlib.suppress(Exception): - self._inner.set_worker_status(worker_id, "draining") - self._inner.request_shutdown() - # Restore original handlers so a second signal force-kills - signal.signal(signal.SIGINT, original_sigint) - signal.signal(signal.SIGTERM, original_sigterm) - - signal.signal(signal.SIGINT, shutdown_handler) - signal.signal(signal.SIGTERM, shutdown_handler) - - # SIGHUP handler for hot-reloading resources (Unix only) - if hasattr(signal, "SIGHUP"): - - def sighup_handler(signum: int, frame: Any) -> None: - logger.info("SIGHUP received — reloading reloadable resources") - if self._resource_runtime is not None: - results = self._resource_runtime.reload() - for rname, success in results.items(): - logger.info( - "Reload %s: %s", - rname, - "OK" if success else "FAILED", - ) - - signal.signal(signal.SIGHUP, sighup_handler) - - # Serialize resource names for worker advertisement - resources_json: str | None = None - if self._resource_definitions: - resources_json = json.dumps(sorted(self._resource_definitions.keys())) - - # Generate worker ID and start Python-side heartbeat thread - worker_id = str(uuid.uuid4()) - stop_heartbeat = threading.Event() - heartbeat_thread = threading.Thread( - target=self._run_heartbeat, - args=(worker_id, stop_heartbeat), - daemon=True, - name="taskito-heartbeat", - ) - heartbeat_thread.start() - - self._emit_event( - EventType.WORKER_STARTED, - {"worker_id": worker_id, "queues": worker_queues}, - ) - self._emit_event( - EventType.WORKER_ONLINE, - {"worker_id": worker_id, "queues": worker_queues, "pool": pool}, - ) - - try: - queue_configs_json = json.dumps(self._queue_configs) if self._queue_configs else None - self._inner.run_worker( - task_registry=self._task_registry, - task_configs=self._task_configs, - queues=queue_list, - drain_timeout_secs=self._drain_timeout, - tags=",".join(tags) if tags else None, - worker_id=worker_id, - resources=resources_json, - threads=self._workers, - async_concurrency=self._async_concurrency, - queue_configs=queue_configs_json, - pool=pool if pool != "thread" else None, - app_path=app, - ) - except KeyboardInterrupt: - logger.info("Cold shutdown (terminating immediately)") - finally: - self._emit_event( - EventType.WORKER_STOPPED, - {"worker_id": worker_id}, - ) - stop_heartbeat.set() - heartbeat_thread.join(timeout=6) - # Tear down resources before stopping async loop - if health_checker is not None: - health_checker.stop() - if self._resource_runtime is not None: - self._resource_runtime.teardown() - self._resource_runtime = None - logger.info("Worker stopped.") - if is_main: - if original_sigint is not None: - signal.signal(signal.SIGINT, original_sigint) - if original_sigterm is not None: - signal.signal(signal.SIGTERM, original_sigterm) - - def _build_resource_health_json(self) -> str | None: - """Snapshot current resource health as JSON for heartbeat.""" - if not self._resource_definitions: - return None - runtime = self._resource_runtime - health: dict[str, str] = {} - for name in self._resource_definitions: - if runtime is not None and name in runtime._unhealthy: - health[name] = "unhealthy" - else: - health[name] = "healthy" - return json.dumps(health) - - def _run_heartbeat( - self, - worker_id: str, - stop_event: threading.Event, - ) -> None: - """Send periodic heartbeats to storage with current resource health.""" - prev_unhealthy: set[str] = set() - while not stop_event.is_set(): - resource_health = self._build_resource_health_json() - try: - reaped_ids = self._inner.worker_heartbeat(worker_id, resource_health) - # Emit WORKER_OFFLINE events for reaped dead workers - for rid in reaped_ids: - self._emit_event(EventType.WORKER_OFFLINE, {"worker_id": rid}) - except Exception: - logger.debug("Heartbeat failed", exc_info=True) - - # Detect health transitions → emit WORKER_UNHEALTHY - runtime = self._resource_runtime - if runtime is not None: - current_unhealthy = set(runtime._unhealthy) - new_unhealthy = current_unhealthy - prev_unhealthy - if new_unhealthy: - self._emit_event( - EventType.WORKER_UNHEALTHY, - { - "worker_id": worker_id, - "resources": sorted(new_unhealthy), - }, - ) - prev_unhealthy = current_unhealthy - - stop_event.wait(timeout=5.0) - - # -- Resource Status -- - - def resource_status(self) -> list[dict[str, Any]]: - """Return per-resource status info. - - Each entry contains: name, scope, health, init_duration_ms, - recreations, depends_on. If this process is running the worker, the - live in-process runtime is authoritative. Otherwise (e.g. the - dashboard is a separate process), health is reconstructed from the - latest heartbeat each worker pushed via ``worker_heartbeat``. - Returns an empty list when nothing is registered and no worker has - reported yet. - """ - if self._resource_runtime is not None: - return self._resource_runtime.status() - return self._resource_status_from_heartbeats() - - def _resource_status_from_heartbeats(self) -> list[dict[str, Any]]: - """Fallback path when the runtime isn't in this process. - - Aggregates each worker's ``resource_health`` JSON snapshot into a - status list shaped like ``ResourceRuntime.status()``. Uses the - rule: any ``unhealthy`` wins; mixed healthy/unhealthy is - ``degraded``; all ``healthy`` → ``healthy``; no workers reporting - a given resource → ``not_initialized``. - """ - observed: dict[str, list[str]] = {} - try: - workers = self._inner.list_workers() - except Exception: - logger.warning("resource_status: failed to list workers", exc_info=True) - workers = [] - - for worker in workers: - raw = worker.get("resource_health") - if not raw: - continue - try: - report = json.loads(raw) - except (TypeError, ValueError): - continue - if not isinstance(report, dict): - continue - for name, health in report.items(): - observed.setdefault(str(name), []).append(str(health).lower()) - - # Build an entry for every registered definition, joined with any - # resource a live worker reports (covers the case where the - # dashboard process has no definitions registered at all). - names = set(self._resource_definitions.keys()) | set(observed.keys()) - result: list[dict[str, Any]] = [] - for name in sorted(names): - defn = self._resource_definitions.get(name) - healths = observed.get(name, []) - if not healths: - health = "not_initialized" - elif any(h == "unhealthy" for h in healths): - health = "unhealthy" - elif all(h == "healthy" for h in healths): - health = "healthy" - else: - health = "degraded" - result.append( - { - "name": name, - "scope": defn.scope.value if defn is not None else "unknown", - "health": health, - "init_duration_ms": 0, - "recreations": 0, - "depends_on": defn.depends_on if defn is not None else [], - } - ) - return result - - # -- Test Mode -- - - def test_mode( - self, - propagate_errors: bool = False, - resources: dict[str, Any] | None = None, - ) -> TestMode: - """Return a context manager that runs tasks synchronously (no worker needed). - - Args: - propagate_errors: If True, re-raise task exceptions immediately. - resources: Dict of resource name → mock instance for injection - during test mode. - - Returns: - A :class:`~taskito.testing.TestMode` context manager. - """ - from taskito.testing import TestMode - - return TestMode(self, propagate_errors=propagate_errors, resources=resources) diff --git a/py_src/taskito/mixins/__init__.py b/py_src/taskito/mixins/__init__.py new file mode 100644 index 0000000..0b0828a --- /dev/null +++ b/py_src/taskito/mixins/__init__.py @@ -0,0 +1,19 @@ +"""Mixin classes that compose into the main Queue class.""" + +from taskito.mixins.decorators import QueueDecoratorMixin +from taskito.mixins.events import QueueEventsMixin +from taskito.mixins.inspection import QueueInspectionMixin +from taskito.mixins.lifecycle import QueueLifecycleMixin +from taskito.mixins.locks import QueueLockMixin +from taskito.mixins.operations import QueueOperationsMixin +from taskito.mixins.resources import QueueResourceMixin + +__all__ = [ + "QueueDecoratorMixin", + "QueueEventsMixin", + "QueueInspectionMixin", + "QueueLifecycleMixin", + "QueueLockMixin", + "QueueOperationsMixin", + "QueueResourceMixin", +] diff --git a/py_src/taskito/mixins/_shared.py b/py_src/taskito/mixins/_shared.py new file mode 100644 index 0000000..e450700 --- /dev/null +++ b/py_src/taskito/mixins/_shared.py @@ -0,0 +1,6 @@ +"""Shared sentinels and helpers for Queue mixins.""" + +from __future__ import annotations + +_UNSET = object() +"""Sentinel to distinguish 'not passed' from explicit None in mixin defaults.""" diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py new file mode 100644 index 0000000..1596246 --- /dev/null +++ b/py_src/taskito/mixins/decorators.py @@ -0,0 +1,375 @@ +"""Task and periodic decorators, lifecycle hooks, and registration.""" + +from __future__ import annotations + +import asyncio +import contextlib +import functools +import os +import sys +import typing +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from taskito._taskito import PyTaskConfig +from taskito.inject import Inject, _InjectAlias +from taskito.interception.strategy import Strategy as S +from taskito.task import TaskWrapper + +if TYPE_CHECKING: + from taskito.interception import ArgumentInterceptor + from taskito.middleware import TaskMiddleware + from taskito.serializers import Serializer + + +def _resolve_module_name(module_name: str) -> str: + """Resolve __main__ to the actual module name.""" + if module_name != "__main__": + return module_name + + main = sys.modules.get("__main__") + if main is not None: + spec = getattr(main, "__spec__", None) + if spec and spec.name: + return str(spec.name) + f = getattr(main, "__file__", None) + if f: + return str(os.path.splitext(os.path.basename(f))[0]) + return module_name + + +class QueueDecoratorMixin: + """Task/periodic decorators, hooks, type registration, queue-level config.""" + + _task_registry: dict[str, Callable] + _task_configs: list[PyTaskConfig] + _periodic_configs: list[dict[str, Any]] + _hooks: dict[str, list[Callable]] + _task_serializers: dict[str, Serializer] + _task_middleware: dict[str, list[TaskMiddleware]] + _task_retry_filters: dict[str, dict[str, list[type[Exception]]]] + _task_inject_map: dict[str, list[str]] + _interceptor: ArgumentInterceptor | None + _queue_configs: dict[str, dict[str, Any]] + + def task( + self, + name: str | None = None, + max_retries: int = 3, + retry_backoff: float = 1.0, + timeout: int = 300, + priority: int = 0, + rate_limit: str | None = None, + queue: str = "default", + circuit_breaker: dict | None = None, + retry_on: list[type[Exception]] | None = None, + dont_retry_on: list[type[Exception]] | None = None, + soft_timeout: float | None = None, + middleware: list[TaskMiddleware] | None = None, + retry_delays: list[float] | None = None, + inject: list[str] | None = None, + serializer: Serializer | None = None, + max_retry_delay: int | None = None, + max_concurrent: int | None = None, + ) -> Callable[[Callable[..., Any]], TaskWrapper]: + """Decorator to register a function as a background task. + + Args: + name: Explicit task name. Defaults to ``module.qualname``. + max_retries: Max retry attempts on failure before moving to DLQ. + retry_backoff: Base delay in seconds for exponential backoff between retries. + timeout: Max execution time in seconds before the task is killed. + priority: Priority level (higher = more urgent). + rate_limit: Rate limit string, e.g. ``"100/m"``, ``"10/s"``, ``"3600/h"``. + queue: Named queue to submit to. + circuit_breaker: Optional dict with ``threshold``, ``window`` (seconds), + and ``cooldown`` (seconds) keys. + retry_on: List of exception classes that should trigger retries. + If set, only these exceptions are retried. + dont_retry_on: List of exception classes that should never be retried. + soft_timeout: Soft timeout in seconds. Checked via ``current_job.check_timeout()``. + middleware: Per-task middleware instances (in addition to global middleware). + inject: List of resource names to inject as keyword arguments. + serializer: Per-task serializer. Falls back to the queue-level serializer. + max_retry_delay: Maximum backoff delay in seconds. Defaults to 300 + (5 minutes) if not set. + max_concurrent: Maximum number of concurrent running instances of + this task. ``None`` means no limit. + """ + + def decorator(fn: Callable) -> TaskWrapper: + task_name = name or f"{_resolve_module_name(fn.__module__)}.{fn.__qualname__}" + + # Detect Inject["name"] annotations (Phase E) + annotation_injects: list[str] = [] + try: + hints: dict[str, Any] = {} + with contextlib.suppress(Exception): + # get_type_hints evaluates string annotations + ns = getattr(fn, "__globals__", {}) + ns = {**ns, "Inject": Inject} + hints = typing.get_type_hints(fn, globalns=ns, include_extras=True) + # Fallback: check raw annotations if get_type_hints failed + if not hints: + with contextlib.suppress(Exception): + hints = getattr(fn, "__annotations__", {}) + for _param_name, hint in hints.items(): + if isinstance(hint, _InjectAlias): + annotation_injects.append(hint.resource_name) + except Exception: + pass + + # Merge explicit inject= with annotation-detected injects + final_inject = list(inject or []) + for res_name in annotation_injects: + if res_name not in final_inject: + final_inject.append(res_name) + + # Store retry filters + if retry_on or dont_retry_on: + self._task_retry_filters[task_name] = { + "retry_on": retry_on or [], + "dont_retry_on": dont_retry_on or [], + } + + # Store per-task middleware + if middleware: + self._task_middleware[task_name] = middleware + + # Store per-task serializer + if serializer is not None: + self._task_serializers[task_name] = serializer + + # Store inject map for resource injection + if final_inject: + self._task_inject_map[task_name] = final_inject + + # Wrap the function with hooks, middleware, and context + wrapped = self._wrap_task(fn, task_name, soft_timeout) # type: ignore[attr-defined] + self._task_registry[task_name] = wrapped + + cb_threshold = None + cb_window = None + cb_cooldown = None + cb_half_open_probes = None + cb_half_open_success_rate = None + if circuit_breaker: + cb_threshold = circuit_breaker.get("threshold", 5) + cb_window = circuit_breaker.get("window", 60) + cb_cooldown = circuit_breaker.get("cooldown", 300) + cb_half_open_probes = circuit_breaker.get("half_open_probes") + cb_half_open_success_rate = circuit_breaker.get("half_open_success_rate") + + # Store config for worker startup + config = PyTaskConfig( + name=task_name, + max_retries=max_retries, + retry_backoff=retry_backoff, + timeout=timeout, + priority=priority, + rate_limit=rate_limit, + queue=queue, + circuit_breaker_threshold=cb_threshold, + circuit_breaker_window=cb_window, + circuit_breaker_cooldown=cb_cooldown, + retry_delays=retry_delays, + max_retry_delay=max_retry_delay, + max_concurrent=max_concurrent, + circuit_breaker_half_open_probes=cb_half_open_probes, + circuit_breaker_half_open_success_rate=cb_half_open_success_rate, + ) + self._task_configs.append(config) + + # Return a TaskWrapper that has .delay() and .apply_async() + wrapper = TaskWrapper( + fn=fn, + queue_ref=self, # type: ignore[arg-type] + task_name=task_name, + default_priority=priority, + default_queue=queue, + default_max_retries=max_retries, + default_timeout=timeout, + inject=final_inject or None, + ) + + # Preserve function metadata + functools.update_wrapper(wrapper, fn) + + # Mark async status for native async dispatch + is_async = asyncio.iscoroutinefunction(fn) + wrapper._taskito_is_async = is_async + if is_async: + wrapper._taskito_async_fn = fn + + return wrapper + + return decorator + + def periodic( + self, + cron: str, + name: str | None = None, + args: tuple = (), + kwargs: dict | None = None, + queue: str = "default", + timezone: str | None = None, + ) -> Callable[[Callable[..., Any]], TaskWrapper]: + """Decorator to register a periodic (cron-scheduled) task. + + Args: + cron: Cron expression (6-field with seconds), e.g. ``"0 */5 * * * *"`` + for every 5 minutes. + name: Explicit task name. Defaults to ``module.qualname``. + args: Positional arguments to pass to the task on each run. + kwargs: Keyword arguments to pass to the task on each run. + queue: Named queue to submit to. + """ + + def decorator(fn: Callable) -> TaskWrapper: + # If fn is a WorkflowProxy (from @queue.workflow()), create a + # launcher task that submits the workflow on each cron trigger. + if getattr(fn, "_is_workflow_proxy", False): + proxy: Any = fn + launcher_name = f"_wf_launcher_{proxy._name}" + + @self.task(name=launcher_name, queue=queue) + def _wf_launcher() -> str: + run = proxy.submit() + return f"submitted workflow run {run.id}" + + payload = self._get_serializer(launcher_name).dumps(((), {})) # type: ignore[attr-defined] + self._periodic_configs.append( + { + "name": launcher_name, + "task_name": launcher_name, + "cron_expr": cron, + "payload": payload, + "queue": queue, + "timezone": timezone, + } + ) + return fn # type: ignore[return-value] + + # Register as a normal task first + wrapper = self.task(name=name, queue=queue)(fn) + + # Store periodic config for registration at worker startup + payload = self._get_serializer(wrapper.name).dumps((args, kwargs or {})) # type: ignore[attr-defined] + self._periodic_configs.append( + { + "name": name or f"{_resolve_module_name(fn.__module__)}.{fn.__qualname__}", + "task_name": wrapper.name, + "cron_expr": cron, + "payload": payload, + "queue": queue, + "timezone": timezone, + } + ) + + return wrapper + + return decorator + + # -- Hooks / middleware -- + + def before_task(self, fn: Callable) -> Callable: + """Register a hook called before each task executes. + + Args: + fn: Callback with signature ``fn(task_name, args, kwargs)``. + """ + self._hooks["before_task"].append(fn) + return fn + + def after_task(self, fn: Callable) -> Callable: + """Register a hook called after each task completes or fails. + + Args: + fn: Callback with signature + ``fn(task_name, args, kwargs, result, error)``. + """ + self._hooks["after_task"].append(fn) + return fn + + def on_success(self, fn: Callable) -> Callable: + """Register a hook called when a task completes successfully. + + Args: + fn: Callback with signature + ``fn(task_name, args, kwargs, result)``. + """ + self._hooks["on_success"].append(fn) + return fn + + def on_failure(self, fn: Callable) -> Callable: + """Register a hook called when a task raises an exception. + + Args: + fn: Callback with signature + ``fn(task_name, args, kwargs, error)``. + """ + self._hooks["on_failure"].append(fn) + return fn + + # -- Type registration -- + + def register_type( + self, + python_type: type, + strategy: str, + *, + resource: str | None = None, + message: str | None = None, + converter: Callable | None = None, + type_key: str | None = None, + proxy_handler: str | None = None, + ) -> None: + """Register a custom type with the interception system. + + Args: + python_type: The type to register. + strategy: One of ``"pass"``, ``"convert"``, ``"redirect"``, + ``"reject"``, or ``"proxy"``. + resource: Resource name for ``"redirect"`` strategy. + message: Rejection reason for ``"reject"`` strategy. + converter: Converter callable for ``"convert"`` strategy. + type_key: Key for the converter reconstructor dispatch. + proxy_handler: Handler name for ``"proxy"`` strategy. + """ + if self._interceptor is None: + raise RuntimeError( + "Interception is disabled; set interception='strict' or " + "'lenient' to use register_type()" + ) + strat = S(strategy) + self._interceptor._registry.register( + python_type, + strat, + priority=15, + redirect_resource=resource, + reject_reason=message, + converter=converter, + type_key=type_key, + proxy_handler=proxy_handler, + ) + + # -- Queue-level config -- + + def set_queue_rate_limit(self, queue_name: str, rate_limit: str) -> None: + """Set a rate limit for an entire queue. + + Args: + queue_name: Queue name (e.g. ``"default"``). + rate_limit: Rate limit string, e.g. ``"100/m"``, ``"10/s"``. + """ + self._queue_configs.setdefault(queue_name, {})["rate_limit"] = rate_limit + + def set_queue_concurrency(self, queue_name: str, max_concurrent: int) -> None: + """Set a maximum number of concurrent jobs for a queue. + + Args: + queue_name: Queue name (e.g. ``"default"``). + max_concurrent: Maximum number of jobs running simultaneously + from this queue. + """ + self._queue_configs.setdefault(queue_name, {})["max_concurrent"] = max_concurrent diff --git a/py_src/taskito/mixins/events.py b/py_src/taskito/mixins/events.py new file mode 100644 index 0000000..936c4c1 --- /dev/null +++ b/py_src/taskito/mixins/events.py @@ -0,0 +1,62 @@ +"""Event bus and webhook registration.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from taskito.events import EventBus, EventType + from taskito.webhooks import WebhookManager + + +class QueueEventsMixin: + """Event emission, lifecycle event listeners, and webhook registration.""" + + _event_bus: EventBus + _webhook_manager: WebhookManager + + def _emit_event(self, event_type: EventType, payload: dict[str, Any]) -> None: + """Emit an event to the event bus and webhook manager.""" + self._event_bus.emit(event_type, payload) + self._webhook_manager.notify(event_type, payload) + + def on_event(self, event_type: EventType, callback: Callable[..., Any]) -> None: + """Register a callback for a job lifecycle event. + + Args: + event_type: The event type to listen for. + callback: Called with ``(event_type, payload_dict)``. + """ + self._event_bus.on(event_type, callback) + + def add_webhook( + self, + url: str, + events: list[EventType] | None = None, + headers: dict[str, str] | None = None, + secret: str | None = None, + max_retries: int = 3, + timeout: float = 10.0, + retry_backoff: float = 2.0, + ) -> None: + """Register a webhook endpoint for job events. + + Args: + url: URL to POST event payloads to. + events: Event types to subscribe to (None = all). + headers: Extra HTTP headers. + secret: HMAC-SHA256 signing secret. + max_retries: Maximum delivery attempts (default 3). + timeout: HTTP request timeout in seconds (default 10.0). + retry_backoff: Base for exponential backoff between retries (default 2.0). + """ + self._webhook_manager.add_webhook( + url, + events, + headers, + secret, + max_retries=max_retries, + timeout=timeout, + retry_backoff=retry_backoff, + ) diff --git a/py_src/taskito/mixins.py b/py_src/taskito/mixins/inspection.py similarity index 57% rename from py_src/taskito/mixins.py rename to py_src/taskito/mixins/inspection.py index 91fc2a5..c4fe2f7 100644 --- a/py_src/taskito/mixins.py +++ b/py_src/taskito/mixins/inspection.py @@ -1,14 +1,12 @@ -"""Mixin classes that compose into the main Queue class.""" +"""Read-only inspection, stats, and query methods for the Queue.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +from collections import defaultdict +from typing import Any -if TYPE_CHECKING: - from taskito.locks import DistributedLock - from taskito.result import JobResult - -_UNSET = object() # sentinel to distinguish "not passed" from explicit None +from taskito.mixins._shared import _UNSET +from taskito.result import JobResult class QueueInspectionMixin: @@ -19,8 +17,6 @@ class QueueInspectionMixin: def get_job(self, job_id: str) -> JobResult | None: """Retrieve a job by its unique ID.""" - from taskito.result import JobResult - py_job = self._inner.get_job(job_id) if py_job is None: return None @@ -40,8 +36,6 @@ def list_jobs( By default, scoped to this queue's namespace. Pass ``namespace=None`` explicitly to see jobs across all namespaces. """ - from taskito.result import JobResult - ns = self._namespace if namespace is _UNSET else namespace py_jobs = self._inner.list_jobs( status=status, @@ -71,8 +65,6 @@ def list_jobs_filtered( By default, scoped to this queue's namespace. Pass ``namespace=None`` explicitly to see jobs across all namespaces. """ - from taskito.result import JobResult - ns = self._namespace if namespace is _UNSET else namespace py_jobs = self._inner.list_jobs_filtered( status=status, @@ -211,153 +203,8 @@ def purge_completed(self, older_than: int = 86400) -> int: return self._inner.purge_completed(older_than) # type: ignore[no-any-return] -class QueueOperationsMixin: - """Dead letters, replay, circuit breakers, logs, workers, queue management.""" - - _inner: Any - - # -- Dead Letters -- - - def dead_letters(self, limit: int = 10, offset: int = 0) -> list[dict]: - """List dead letter queue entries.""" - return self._inner.dead_letters(limit=limit, offset=offset) # type: ignore[no-any-return] - - def retry_dead(self, dead_id: str) -> str: - """Re-enqueue a dead letter job. Returns new job ID.""" - return self._inner.retry_dead(dead_id) # type: ignore[no-any-return] - - def purge_dead(self, older_than: int = 86400) -> int: - """Delete dead letter entries older than a given age.""" - return self._inner.purge_dead(older_than) # type: ignore[no-any-return] - - # -- Replay -- - - def replay(self, job_id: str) -> JobResult: - """Re-enqueue a completed or failed job with the exact same payload.""" - new_id = self._inner.replay_job(job_id) - return self.get_job(new_id) # type: ignore[attr-defined, no-any-return] - - def replay_history(self, job_id: str) -> list[dict]: - """Get replay history for a job.""" - return self._inner.get_replay_history(job_id) # type: ignore[no-any-return] - - # -- Circuit Breakers -- - - def circuit_breakers(self) -> list[dict]: - """List all circuit breaker states.""" - return self._inner.list_circuit_breakers() # type: ignore[no-any-return] - - # -- Logs -- - - def task_logs(self, job_id: str) -> list[dict]: - """Get structured logs for a specific job.""" - return self._inner.get_task_logs(job_id) # type: ignore[no-any-return] - - def query_logs( - self, - task_name: str | None = None, - level: str | None = None, - since: int = 3600, - limit: int = 100, - ) -> list[dict]: - """Query structured task logs with filters.""" - return self._inner.query_task_logs( # type: ignore[no-any-return] - task_name=task_name, level=level, since_seconds=since, limit=limit - ) - - # -- Workers -- - - def workers(self) -> list[dict]: - """List all registered workers and their heartbeat status.""" - return self._inner.list_workers() # type: ignore[no-any-return] - - # -- Queue Pause/Resume -- - - def pause(self, queue_name: str = "default") -> None: - """Pause a queue so no new jobs are dispatched from it.""" - self._inner.pause_queue(queue_name) - if hasattr(self, "_emit_event"): - from taskito.events import EventType - - self._emit_event(EventType.QUEUE_PAUSED, {"queue": queue_name}) - - def resume(self, queue_name: str = "default") -> None: - """Resume a paused queue.""" - self._inner.resume_queue(queue_name) - if hasattr(self, "_emit_event"): - from taskito.events import EventType - - self._emit_event(EventType.QUEUE_RESUMED, {"queue": queue_name}) - - def paused_queues(self) -> list[str]: - """List currently paused queues.""" - return self._inner.list_paused_queues() # type: ignore[no-any-return] - - # -- Job Revocation -- - - def purge(self, queue_name: str) -> int: - """Cancel all pending jobs in a queue. Returns count cancelled.""" - return self._inner.purge_queue(queue_name) # type: ignore[no-any-return] - - def revoke_task(self, task_name: str) -> int: - """Cancel all pending jobs for a task name. Returns count cancelled.""" - return self._inner.revoke_task(task_name) # type: ignore[no-any-return] - - # -- Job Archival -- - - def archive(self, older_than: int = 86400) -> int: - """Archive completed/dead/cancelled jobs older than the given age (seconds).""" - return self._inner.archive_old_jobs(older_than) # type: ignore[no-any-return] - - def list_archived(self, limit: int = 50, offset: int = 0) -> list[JobResult]: - """List archived jobs with pagination.""" - from taskito.result import JobResult - - py_jobs = self._inner.list_archived(limit=limit, offset=offset) - return [JobResult(py_job=pj, queue=self) for pj in py_jobs] # type: ignore[arg-type] - - -class QueueLockMixin: - """Distributed locking methods for the Queue.""" - - _inner: Any - - def lock( - self, - name: str, - ttl: float = 30.0, - auto_extend: bool = True, - owner_id: str | None = None, - timeout: float | None = None, - retry_interval: float = 0.1, - ) -> DistributedLock: - """Return a sync distributed lock context manager. - - Args: - name: Lock name (unique across the cluster). - ttl: Lock TTL in seconds. Auto-extended at ttl/3 intervals. - auto_extend: Whether to auto-extend the lock in a background thread. - owner_id: Unique owner identifier. Auto-generated if not provided. - timeout: Max seconds to wait for acquisition. None = fail immediately. - retry_interval: Seconds between retries when timeout is set. - """ - from taskito.locks import DistributedLock - - return DistributedLock( - inner=self._inner, - name=name, - ttl=ttl, - owner_id=owner_id, - auto_extend=auto_extend, - timeout=timeout, - retry_interval=retry_interval, - ) - - def _aggregate_metrics(raw: list[dict]) -> dict[str, Any]: """Aggregate raw metric rows into per-task statistics.""" - from collections import defaultdict - by_task: dict[str, list[dict]] = defaultdict(list) for r in raw: by_task[r["task_name"]].append(r) diff --git a/py_src/taskito/mixins/lifecycle.py b/py_src/taskito/mixins/lifecycle.py new file mode 100644 index 0000000..298953f --- /dev/null +++ b/py_src/taskito/mixins/lifecycle.py @@ -0,0 +1,406 @@ +"""Worker startup banner, run_worker loop, heartbeat, resource status, test mode.""" + +from __future__ import annotations + +import contextlib +import json +import logging +import signal +import sys +import threading +import urllib.parse +import uuid +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import taskito +from taskito._taskito import PyQueue, PyTaskConfig +from taskito.context import _set_queue_ref +from taskito.events import EventType +from taskito.resources.health import HealthChecker +from taskito.resources.runtime import ResourceRuntime +from taskito.testing import TestMode + +if TYPE_CHECKING: + from collections.abc import Callable + + from taskito.resources.definition import ResourceDefinition + + +logger = logging.getLogger("taskito") + + +class QueueLifecycleMixin: + """Worker startup, heartbeat, resource status aggregation, and test mode.""" + + _inner: PyQueue + _backend: str + _db_path: str + _db_url: str | None + _schema: str + _workers: int + _drain_timeout: int + _async_concurrency: int + _periodic_configs: list[dict[str, Any]] + _resource_definitions: dict[str, ResourceDefinition] + _resource_runtime: ResourceRuntime | None + _task_registry: dict[str, Callable] + _task_configs: list[PyTaskConfig] + _queue_configs: dict[str, dict[str, Any]] + + def _print_banner(self, queues: list[str]) -> None: + """Print ASCII startup banner.""" + banner = rf""" + _ _ _ _ +| |_ __ _ ___| | _(_) |_ ___ +| __/ _` / __| |/ / | __/ _ \ +| || (_| \__ \ <| | || (_) | + \__\__,_|___/_|\_\_|\__\___/ v{taskito.__version__} +""" + lines = [banner] + lines.append(f"> Backend: {self._backend}") + if self._backend == "sqlite": + lines.append(f"> DB: {self._db_path}") + else: + # Mask password in connection URL for display + url = self._db_url or "" + parsed_url = urllib.parse.urlparse(url) + if parsed_url.password: + masked = parsed_url._replace( + netloc=f"{parsed_url.username}:****@{parsed_url.hostname}" + + (f":{parsed_url.port}" if parsed_url.port else "") + ) + url = urllib.parse.urlunparse(masked) + lines.append(f"> DB: {url}") + lines.append(f"> Schema: {self._schema}") + lines.append(f"> Concurrency: {self._workers} (threads)") + lines.append(f"> Queues: {', '.join(queues)}") + lines.append("") + + task_names = sorted(self._task_registry.keys()) + if task_names: + lines.append("[tasks]") + for name in task_names: + lines.append(f" . {name}") + lines.append("") + + if self._periodic_configs: + lines.append("[periodic]") + for pc in self._periodic_configs: + lines.append(f" . {pc['name']} ({pc['cron_expr']})") + lines.append("") + + if self._resource_definitions: + lines.append("[resources]") + for rname, rdef in sorted(self._resource_definitions.items()): + deps = f" (depends: {', '.join(rdef.depends_on)})" if rdef.depends_on else "" + lines.append(f" . {rname}{deps}") + lines.append("") + + print("\n".join(lines)) + + def run_worker( + self, + queues: Sequence[str] | None = None, + tags: list[str] | None = None, + pool: str = "thread", + app: str | None = None, + ) -> None: + """Start the worker loop. Blocks until interrupted. + + Args: + queues: List of queue names to consume from. ``None`` consumes + from all queues. + tags: Optional tags for worker specialization / routing. + pool: Worker pool type — ``"thread"`` (default) or ``"prefork"``. + Prefork spawns child processes with independent GILs for + true parallelism on CPU-bound tasks. + app: Import path to the Queue instance (e.g. ``"myapp:queue"``). + Required when ``pool="prefork"``. + """ + if pool == "prefork": + if sys.platform == "win32": + raise NotImplementedError( + "pool='prefork' is not supported on Windows. " + "Use pool='thread' (default) or run on Linux/macOS." + ) + if not app: + raise ValueError("app= is required when pool='prefork' (e.g. app='myapp:queue')") + queue_list = list(queues) if queues else None + + # Make queue accessible from job context (for current_job.update_progress()) + _set_queue_ref(self) + + # Register periodic tasks with Rust scheduler + for pc in self._periodic_configs: + self._inner.register_periodic( + name=pc["name"], + task_name=pc["task_name"], + cron_expr=pc["cron_expr"], + args=pc["payload"], + queue=pc["queue"], + timezone=pc.get("timezone"), + ) + + if not logging.root.handlers: + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + worker_queues = queue_list or ["default"] + self._print_banner(worker_queues) + + # Initialize worker resources (before Rust dispatches tasks) + health_checker = None + if self._resource_definitions: + self._resource_runtime = ResourceRuntime(self._resource_definitions) + self._resource_runtime.initialize() + logger.info( + "Initialized %d resource(s): %s", + len(self._resource_definitions), + ", ".join(self._resource_runtime._init_order), + ) + health_checker = HealthChecker(self._resource_runtime) + health_checker.start() + + # Set up signal handlers for graceful shutdown (only in main thread) + is_main = threading.current_thread() is threading.main_thread() + original_sigint = None + original_sigterm = None + + if is_main: + original_sigint = signal.getsignal(signal.SIGINT) + original_sigterm = signal.getsignal(signal.SIGTERM) + + def shutdown_handler(signum: int, frame: Any) -> None: + logger.info("Warm shutdown (waiting for running tasks to finish)...") + with contextlib.suppress(Exception): + self._inner.set_worker_status(worker_id, "draining") + self._inner.request_shutdown() + # Restore original handlers so a second signal force-kills + signal.signal(signal.SIGINT, original_sigint) + signal.signal(signal.SIGTERM, original_sigterm) + + signal.signal(signal.SIGINT, shutdown_handler) + signal.signal(signal.SIGTERM, shutdown_handler) + + # SIGHUP handler for hot-reloading resources (Unix only) + if hasattr(signal, "SIGHUP"): + + def sighup_handler(signum: int, frame: Any) -> None: + logger.info("SIGHUP received — reloading reloadable resources") + if self._resource_runtime is not None: + results = self._resource_runtime.reload() + for rname, success in results.items(): + logger.info( + "Reload %s: %s", + rname, + "OK" if success else "FAILED", + ) + + signal.signal(signal.SIGHUP, sighup_handler) + + # Serialize resource names for worker advertisement + resources_json: str | None = None + if self._resource_definitions: + resources_json = json.dumps(sorted(self._resource_definitions.keys())) + + # Generate worker ID and start Python-side heartbeat thread + worker_id = str(uuid.uuid4()) + stop_heartbeat = threading.Event() + heartbeat_thread = threading.Thread( + target=self._run_heartbeat, + args=(worker_id, stop_heartbeat), + daemon=True, + name="taskito-heartbeat", + ) + heartbeat_thread.start() + + self._emit_event( # type: ignore[attr-defined] + EventType.WORKER_STARTED, + {"worker_id": worker_id, "queues": worker_queues}, + ) + self._emit_event( # type: ignore[attr-defined] + EventType.WORKER_ONLINE, + {"worker_id": worker_id, "queues": worker_queues, "pool": pool}, + ) + + try: + queue_configs_json = json.dumps(self._queue_configs) if self._queue_configs else None + self._inner.run_worker( + task_registry=self._task_registry, + task_configs=self._task_configs, + queues=queue_list, + drain_timeout_secs=self._drain_timeout, + tags=",".join(tags) if tags else None, + worker_id=worker_id, + resources=resources_json, + threads=self._workers, + async_concurrency=self._async_concurrency, + queue_configs=queue_configs_json, + pool=pool if pool != "thread" else None, + app_path=app, + ) + except KeyboardInterrupt: + logger.info("Cold shutdown (terminating immediately)") + finally: + self._emit_event( # type: ignore[attr-defined] + EventType.WORKER_STOPPED, + {"worker_id": worker_id}, + ) + stop_heartbeat.set() + heartbeat_thread.join(timeout=6) + # Tear down resources before stopping async loop + if health_checker is not None: + health_checker.stop() + if self._resource_runtime is not None: + self._resource_runtime.teardown() + self._resource_runtime = None + logger.info("Worker stopped.") + if is_main: + if original_sigint is not None: + signal.signal(signal.SIGINT, original_sigint) + if original_sigterm is not None: + signal.signal(signal.SIGTERM, original_sigterm) + + def _build_resource_health_json(self) -> str | None: + """Snapshot current resource health as JSON for heartbeat.""" + if not self._resource_definitions: + return None + runtime = self._resource_runtime + health: dict[str, str] = {} + for name in self._resource_definitions: + if runtime is not None and name in runtime._unhealthy: + health[name] = "unhealthy" + else: + health[name] = "healthy" + return json.dumps(health) + + def _run_heartbeat( + self, + worker_id: str, + stop_event: threading.Event, + ) -> None: + """Send periodic heartbeats to storage with current resource health.""" + prev_unhealthy: set[str] = set() + while not stop_event.is_set(): + resource_health = self._build_resource_health_json() + try: + reaped_ids = self._inner.worker_heartbeat(worker_id, resource_health) + # Emit WORKER_OFFLINE events for reaped dead workers + for rid in reaped_ids: + self._emit_event(EventType.WORKER_OFFLINE, {"worker_id": rid}) # type: ignore[attr-defined] + except Exception: + logger.debug("Heartbeat failed", exc_info=True) + + # Detect health transitions → emit WORKER_UNHEALTHY + runtime = self._resource_runtime + if runtime is not None: + current_unhealthy = set(runtime._unhealthy) + new_unhealthy = current_unhealthy - prev_unhealthy + if new_unhealthy: + self._emit_event( # type: ignore[attr-defined] + EventType.WORKER_UNHEALTHY, + { + "worker_id": worker_id, + "resources": sorted(new_unhealthy), + }, + ) + prev_unhealthy = current_unhealthy + + stop_event.wait(timeout=5.0) + + # -- Resource Status -- + + def resource_status(self) -> list[dict[str, Any]]: + """Return per-resource status info. + + Each entry contains: name, scope, health, init_duration_ms, + recreations, depends_on. If this process is running the worker, the + live in-process runtime is authoritative. Otherwise (e.g. the + dashboard is a separate process), health is reconstructed from the + latest heartbeat each worker pushed via ``worker_heartbeat``. + Returns an empty list when nothing is registered and no worker has + reported yet. + """ + if self._resource_runtime is not None: + return self._resource_runtime.status() + return self._resource_status_from_heartbeats() + + def _resource_status_from_heartbeats(self) -> list[dict[str, Any]]: + """Fallback path when the runtime isn't in this process. + + Aggregates each worker's ``resource_health`` JSON snapshot into a + status list shaped like ``ResourceRuntime.status()``. Uses the + rule: any ``unhealthy`` wins; mixed healthy/unhealthy is + ``degraded``; all ``healthy`` → ``healthy``; no workers reporting + a given resource → ``not_initialized``. + """ + observed: dict[str, list[str]] = {} + try: + workers = self._inner.list_workers() + except Exception: + logger.warning("resource_status: failed to list workers", exc_info=True) + workers = [] + + for worker in workers: + raw = worker.get("resource_health") + if not raw: + continue + try: + report = json.loads(raw) + except (TypeError, ValueError): + continue + if not isinstance(report, dict): + continue + for name, health in report.items(): + observed.setdefault(str(name), []).append(str(health).lower()) + + # Build an entry for every registered definition, joined with any + # resource a live worker reports (covers the case where the + # dashboard process has no definitions registered at all). + names = set(self._resource_definitions.keys()) | set(observed.keys()) + result: list[dict[str, Any]] = [] + for name in sorted(names): + defn = self._resource_definitions.get(name) + healths = observed.get(name, []) + if not healths: + health = "not_initialized" + elif any(h == "unhealthy" for h in healths): + health = "unhealthy" + elif all(h == "healthy" for h in healths): + health = "healthy" + else: + health = "degraded" + result.append( + { + "name": name, + "scope": defn.scope.value if defn is not None else "unknown", + "health": health, + "init_duration_ms": 0, + "recreations": 0, + "depends_on": defn.depends_on if defn is not None else [], + } + ) + return result + + # -- Test Mode -- + + def test_mode( + self, + propagate_errors: bool = False, + resources: dict[str, Any] | None = None, + ) -> TestMode: + """Return a context manager that runs tasks synchronously (no worker needed). + + Args: + propagate_errors: If True, re-raise task exceptions immediately. + resources: Dict of resource name → mock instance for injection + during test mode. + + Returns: + A :class:`~taskito.testing.TestMode` context manager. + """ + return TestMode(self, propagate_errors=propagate_errors, resources=resources) # type: ignore[arg-type] diff --git a/py_src/taskito/mixins/locks.py b/py_src/taskito/mixins/locks.py new file mode 100644 index 0000000..c1b5ef8 --- /dev/null +++ b/py_src/taskito/mixins/locks.py @@ -0,0 +1,42 @@ +"""Distributed locking methods for the Queue.""" + +from __future__ import annotations + +from typing import Any + +from taskito.locks import DistributedLock + + +class QueueLockMixin: + """Distributed locking methods for the Queue.""" + + _inner: Any + + def lock( + self, + name: str, + ttl: float = 30.0, + auto_extend: bool = True, + owner_id: str | None = None, + timeout: float | None = None, + retry_interval: float = 0.1, + ) -> DistributedLock: + """Return a sync distributed lock context manager. + + Args: + name: Lock name (unique across the cluster). + ttl: Lock TTL in seconds. Auto-extended at ttl/3 intervals. + auto_extend: Whether to auto-extend the lock in a background thread. + owner_id: Unique owner identifier. Auto-generated if not provided. + timeout: Max seconds to wait for acquisition. None = fail immediately. + retry_interval: Seconds between retries when timeout is set. + """ + return DistributedLock( + inner=self._inner, + name=name, + ttl=ttl, + owner_id=owner_id, + auto_extend=auto_extend, + timeout=timeout, + retry_interval=retry_interval, + ) diff --git a/py_src/taskito/mixins/operations.py b/py_src/taskito/mixins/operations.py new file mode 100644 index 0000000..5a252dc --- /dev/null +++ b/py_src/taskito/mixins/operations.py @@ -0,0 +1,108 @@ +"""Dead letters, replay, circuit breakers, logs, workers, queue management.""" + +from __future__ import annotations + +from typing import Any + +from taskito.events import EventType +from taskito.result import JobResult + + +class QueueOperationsMixin: + """Dead letters, replay, circuit breakers, logs, workers, queue management.""" + + _inner: Any + + # -- Dead Letters -- + + def dead_letters(self, limit: int = 10, offset: int = 0) -> list[dict]: + """List dead letter queue entries.""" + return self._inner.dead_letters(limit=limit, offset=offset) # type: ignore[no-any-return] + + def retry_dead(self, dead_id: str) -> str: + """Re-enqueue a dead letter job. Returns new job ID.""" + return self._inner.retry_dead(dead_id) # type: ignore[no-any-return] + + def purge_dead(self, older_than: int = 86400) -> int: + """Delete dead letter entries older than a given age.""" + return self._inner.purge_dead(older_than) # type: ignore[no-any-return] + + # -- Replay -- + + def replay(self, job_id: str) -> JobResult: + """Re-enqueue a completed or failed job with the exact same payload.""" + new_id = self._inner.replay_job(job_id) + return self.get_job(new_id) # type: ignore[attr-defined, no-any-return] + + def replay_history(self, job_id: str) -> list[dict]: + """Get replay history for a job.""" + return self._inner.get_replay_history(job_id) # type: ignore[no-any-return] + + # -- Circuit Breakers -- + + def circuit_breakers(self) -> list[dict]: + """List all circuit breaker states.""" + return self._inner.list_circuit_breakers() # type: ignore[no-any-return] + + # -- Logs -- + + def task_logs(self, job_id: str) -> list[dict]: + """Get structured logs for a specific job.""" + return self._inner.get_task_logs(job_id) # type: ignore[no-any-return] + + def query_logs( + self, + task_name: str | None = None, + level: str | None = None, + since: int = 3600, + limit: int = 100, + ) -> list[dict]: + """Query structured task logs with filters.""" + return self._inner.query_task_logs( # type: ignore[no-any-return] + task_name=task_name, level=level, since_seconds=since, limit=limit + ) + + # -- Workers -- + + def workers(self) -> list[dict]: + """List all registered workers and their heartbeat status.""" + return self._inner.list_workers() # type: ignore[no-any-return] + + # -- Queue Pause/Resume -- + + def pause(self, queue_name: str = "default") -> None: + """Pause a queue so no new jobs are dispatched from it.""" + self._inner.pause_queue(queue_name) + if hasattr(self, "_emit_event"): + self._emit_event(EventType.QUEUE_PAUSED, {"queue": queue_name}) + + def resume(self, queue_name: str = "default") -> None: + """Resume a paused queue.""" + self._inner.resume_queue(queue_name) + if hasattr(self, "_emit_event"): + self._emit_event(EventType.QUEUE_RESUMED, {"queue": queue_name}) + + def paused_queues(self) -> list[str]: + """List currently paused queues.""" + return self._inner.list_paused_queues() # type: ignore[no-any-return] + + # -- Job Revocation -- + + def purge(self, queue_name: str) -> int: + """Cancel all pending jobs in a queue. Returns count cancelled.""" + return self._inner.purge_queue(queue_name) # type: ignore[no-any-return] + + def revoke_task(self, task_name: str) -> int: + """Cancel all pending jobs for a task name. Returns count cancelled.""" + return self._inner.revoke_task(task_name) # type: ignore[no-any-return] + + # -- Job Archival -- + + def archive(self, older_than: int = 86400) -> int: + """Archive completed/dead/cancelled jobs older than the given age (seconds).""" + return self._inner.archive_old_jobs(older_than) # type: ignore[no-any-return] + + def list_archived(self, limit: int = 50, offset: int = 0) -> list[JobResult]: + """List archived jobs with pagination.""" + py_jobs = self._inner.list_archived(limit=limit, offset=offset) + return [JobResult(py_job=pj, queue=self) for pj in py_jobs] # type: ignore[arg-type] diff --git a/py_src/taskito/mixins/resources.py b/py_src/taskito/mixins/resources.py new file mode 100644 index 0000000..5ec810d --- /dev/null +++ b/py_src/taskito/mixins/resources.py @@ -0,0 +1,144 @@ +"""Worker resource registration, health checks, and metrics.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from taskito.exceptions import CircularDependencyError +from taskito.resources.definition import ResourceDefinition, ResourceScope +from taskito.resources.graph import detect_cycle +from taskito.resources.toml_config import load_resources_from_toml + +if TYPE_CHECKING: + from taskito.interception.metrics import InterceptionMetrics + from taskito.proxies.metrics import ProxyMetrics + from taskito.resources.runtime import ResourceRuntime + + +class QueueResourceMixin: + """Worker resource registration, health checks, and proxy/interception stats.""" + + _resource_definitions: dict[str, ResourceDefinition] + _resource_runtime: ResourceRuntime | None + _proxy_metrics: ProxyMetrics + _interception_metrics: InterceptionMetrics | None + + def worker_resource( + self, + name: str, + depends_on: list[str] | None = None, + teardown: Callable | None = None, + health_check: Callable | None = None, + health_check_interval: float = 0.0, + max_recreation_attempts: int = 3, + scope: str = "worker", + pool_size: int | None = None, + pool_min: int = 0, + acquire_timeout: float = 10.0, + max_lifetime: float = 3600.0, + idle_timeout: float = 300.0, + reloadable: bool = False, + frozen: bool = False, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to register a resource factory. + + Args: + name: Resource name used in ``inject=["name"]``. + depends_on: Names of resources this one depends on. + teardown: Optional callable to clean up the resource on shutdown. + health_check: Optional callable that returns truthy if healthy. + health_check_interval: Seconds between health checks (0 = disabled). + max_recreation_attempts: Max times to recreate on health failure. + scope: Resource scope — ``"worker"``, ``"task"``, ``"thread"``, + or ``"request"``. + pool_size: Pool size for task-scoped resources. + pool_min: Minimum pre-warmed instances (task scope). + acquire_timeout: Max seconds to wait for pool instance. + max_lifetime: Max seconds a pooled instance lives. + idle_timeout: Max idle seconds before eviction. + reloadable: Whether the resource can be hot-reloaded via SIGHUP. + frozen: Wrap the resource in a read-only proxy. + """ + + def decorator(factory: Callable[..., Any]) -> Callable[..., Any]: + self.register_resource( + ResourceDefinition( + name=name, + factory=factory, + depends_on=depends_on or [], + teardown=teardown, + health_check=health_check, + health_check_interval=health_check_interval, + max_recreation_attempts=max_recreation_attempts, + scope=ResourceScope(scope), + pool_size=pool_size, + pool_min=pool_min, + acquire_timeout=acquire_timeout, + max_lifetime=max_lifetime, + idle_timeout=idle_timeout, + reloadable=reloadable, + frozen=frozen, + ) + ) + # Validate no cycles eagerly + cycle = detect_cycle(self._resource_definitions) + if cycle is not None: + # Roll back the registration + del self._resource_definitions[name] + raise CircularDependencyError( + f"Circular dependency detected: {' -> '.join(cycle)}" + ) + return factory + + return decorator + + def register_resource(self, definition: ResourceDefinition) -> None: + """Programmatically register a resource definition. + + Args: + definition: A :class:`~taskito.resources.ResourceDefinition`. + """ + self._resource_definitions[definition.name] = definition + + def health_check(self, name: str) -> bool: + """Run a resource's health check immediately. + + Args: + name: The registered resource name. + + Returns: + True if healthy, False otherwise. + """ + runtime = self._resource_runtime + if runtime is None: + return False + defn = self._resource_definitions.get(name) + if defn is None or defn.health_check is None: + return False + try: + instance = runtime.resolve(name) + return bool(defn.health_check(instance)) + except Exception: + return False + + def load_resources(self, toml_path: str) -> None: + """Load resource definitions from a TOML file. + + Must be called before ``run_worker()``. + + Args: + toml_path: Path to the TOML configuration file. + """ + for defn in load_resources_from_toml(toml_path): + self.register_resource(defn) + + def proxy_stats(self) -> list[dict[str, Any]]: + """Return per-handler proxy reconstruction metrics.""" + return self._proxy_metrics.to_list() + + def interception_stats(self) -> dict[str, Any]: + """Return interception performance metrics.""" + if self._interception_metrics is not None: + return self._interception_metrics.to_dict() + return {} diff --git a/py_src/taskito/result.py b/py_src/taskito/result.py index 737b0ef..78e4122 100644 --- a/py_src/taskito/result.py +++ b/py_src/taskito/result.py @@ -81,12 +81,12 @@ def errors(self) -> list[dict]: @property def dependencies(self) -> list[str]: """IDs of jobs this job depends on.""" - return self._queue._inner.get_dependencies(self.id) # type: ignore[no-any-return] + return self._queue._inner.get_dependencies(self.id) @property def dependents(self) -> list[str]: """IDs of jobs that depend on this job.""" - return self._queue._inner.get_dependents(self.id) # type: ignore[no-any-return] + return self._queue._inner.get_dependents(self.id) def _poll_once(self) -> tuple[str, Any]: """Refresh and return (status, deserialized result or None)."""