diff --git a/src/dstack/_internal/core/models/compute_groups.py b/src/dstack/_internal/core/models/compute_groups.py index 66e1292eff..3fa967494d 100644 --- a/src/dstack/_internal/core/models/compute_groups.py +++ b/src/dstack/_internal/core/models/compute_groups.py @@ -12,6 +12,13 @@ class ComputeGroupStatus(str, enum.Enum): RUNNING = "running" TERMINATED = "terminated" + @classmethod + def finished_statuses(cls) -> List["ComputeGroupStatus"]: + return [cls.TERMINATED] + + def is_finished(self): + return self in self.finished_statuses() + class ComputeGroupProvisioningData(CoreModel): compute_group_id: str diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index dbea6f777b..209679f0ef 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -23,8 +23,9 @@ from dstack._internal.proxy.lib.deps import get_injector_from_app from dstack._internal.proxy.lib.routers import model_proxy from dstack._internal.server import settings -from dstack._internal.server.background import start_background_tasks -from dstack._internal.server.background.tasks.process_probes import PROBES_SCHEDULER +from dstack._internal.server.background.pipeline_tasks import start_pipeline_tasks +from dstack._internal.server.background.scheduled_tasks import start_scheduled_tasks +from dstack._internal.server.background.scheduled_tasks.probes import PROBES_SCHEDULER from dstack._internal.server.db import get_db, get_session_ctx, migrate from dstack._internal.server.routers import ( auth, @@ -163,8 +164,11 @@ async def lifespan(app: FastAPI): if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() scheduler = None + pipeline_manager = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: - scheduler = start_background_tasks() + scheduler = start_scheduled_tasks() + pipeline_manager = start_pipeline_tasks() + app.state.pipeline_manager = pipeline_manager else: logger.info("Background processing is disabled") PROBES_SCHEDULER.start() @@ -189,9 +193,15 @@ async def lifespan(app: FastAPI): for func in _ON_STARTUP_HOOKS: await func(app) yield + PROBES_SCHEDULER.shutdown(wait=False) + if pipeline_manager is not None: + pipeline_manager.shutdown() if scheduler is not None: + # Note: Scheduler does not cancel currently running jobs, so scheduled tasks cannot do cleanup. + # TODO: Track and cancel scheduled tasks. scheduler.shutdown() - PROBES_SCHEDULER.shutdown(wait=False) + if pipeline_manager is not None: + await pipeline_manager.drain() await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 8577cce6f1..e69de29bb2 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -1,142 +0,0 @@ -from apscheduler.schedulers.asyncio import AsyncIOScheduler -from apscheduler.triggers.interval import IntervalTrigger - -from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_compute_groups import process_compute_groups -from dstack._internal.server.background.tasks.process_events import delete_events -from dstack._internal.server.background.tasks.process_fleets import process_fleets -from dstack._internal.server.background.tasks.process_gateways import ( - process_gateways, - process_gateways_connections, -) -from dstack._internal.server.background.tasks.process_idle_volumes import process_idle_volumes -from dstack._internal.server.background.tasks.process_instances import ( - delete_instance_health_checks, - process_instances, -) -from dstack._internal.server.background.tasks.process_metrics import ( - collect_metrics, - delete_metrics, -) -from dstack._internal.server.background.tasks.process_placement_groups import ( - process_placement_groups, -) -from dstack._internal.server.background.tasks.process_probes import process_probes -from dstack._internal.server.background.tasks.process_prometheus_metrics import ( - collect_prometheus_metrics, - delete_prometheus_metrics, -) -from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs -from dstack._internal.server.background.tasks.process_runs import process_runs -from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs -from dstack._internal.server.background.tasks.process_terminating_jobs import ( - process_terminating_jobs, -) -from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes - -_scheduler = AsyncIOScheduler() - - -def get_scheduler() -> AsyncIOScheduler: - return _scheduler - - -def start_background_tasks() -> AsyncIOScheduler: - # Background processing is implemented via in-memory locks on SQLite - # and SELECT FOR UPDATE on Postgres. Locks may be held for a long time. - # This is currently the main bottleneck for scaling dstack processing - # as processing more resources requires more DB connections. - # TODO: Make background processing efficient by committing locks to DB - # and processing outside of DB transactions. - # - # Now we just try to process as many resources as possible without exhausting DB connections. - # - # Quick tasks can process multiple resources per transaction. - # Potentially long tasks process one resource per transaction - # to avoid holding locks for all the resources if one is slow to process. - # Still, the next batch won't be processed unless all resources are processed, - # so larger batches do not increase processing rate linearly. - # - # The interval, batch_size, and max_instances determine background tasks processing rates. - # By default, one server replica can handle: - # - # * 150 active jobs with 2 minutes processing latency - # * 150 active runs with 2 minutes processing latency - # * 150 active instances with 2 minutes processing latency - # - # These latency numbers do not account for provisioning time, - # so it may be slower if a backend is slow to provision. - # - # Users can set SERVER_BACKGROUND_PROCESSING_FACTOR to process more resources per replica. - # They also need to increase max db connections on the client side and db side. - # - # In-memory locking via locksets does not guarantee - # that the first waiting for the lock will acquire it. - # The jitter is needed to give all tasks a chance to acquire locks. - - _scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1)) - _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) - _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) - if settings.ENABLE_PROMETHEUS_METRICS: - _scheduler.add_job( - collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 - ) - _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) - _scheduler.add_job( - process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 - ) - _scheduler.add_job( - process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 - ) - _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) - _scheduler.add_job( - process_fleets, - IntervalTrigger(seconds=10, jitter=2), - max_instances=1, - ) - _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) - for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): - # Add multiple copies of tasks if requested. - # max_instances=1 for additional copies to avoid running too many tasks. - # Move other tasks here when they need per-replica scaling. - _scheduler.add_job( - process_submitted_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=4 if replica == 0 else 1, - ) - _scheduler.add_job( - process_running_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_terminating_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_runs, - IntervalTrigger(seconds=2, jitter=1), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_instances, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_compute_groups, - IntervalTrigger(seconds=15, jitter=2), - kwargs={"batch_size": 1}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.start() - return _scheduler diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py new file mode 100644 index 0000000000..355e042476 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -0,0 +1,69 @@ +import asyncio + +from dstack._internal.server.background.pipeline_tasks.base import Pipeline +from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline +from dstack._internal.server.background.pipeline_tasks.placement_groups import ( + PlacementGroupPipeline, +) +from dstack._internal.settings import FeatureFlags +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class PipelineManager: + def __init__(self) -> None: + self._pipelines: list[Pipeline] = [] + if FeatureFlags.PIPELINE_PROCESSING_ENABLED: + self._pipelines += [ + ComputeGroupPipeline(), + PlacementGroupPipeline(), + ] + self._hinter = PipelineHinter(self._pipelines) + + def start(self): + for pipeline in self._pipelines: + pipeline.start() + + def shutdown(self): + for pipeline in self._pipelines: + pipeline.shutdown() + + async def drain(self): + results = await asyncio.gather( + *[p.drain() for p in self._pipelines], return_exceptions=True + ) + for pipeline, result in zip(self._pipelines, results): + if isinstance(result, BaseException): + logger.error( + "Unexpected exception when draining pipeline %r", + pipeline, + exc_info=(type(result), result, result.__traceback__), + ) + + @property + def hinter(self): + return self._hinter + + +class PipelineHinter: + def __init__(self, pipelines: list[Pipeline]) -> None: + self._pipelines = pipelines + self._hint_fetch_map = {p.hint_fetch_model_name: p for p in self._pipelines} + + def hint_fetch(self, model_name: str): + pipeline = self._hint_fetch_map.get(model_name) + if pipeline is None: + logger.warning("Model %s not registered for fetch hints", model_name) + return + pipeline.hint_fetch() + + +def start_pipeline_tasks() -> PipelineManager: + """ + Start tasks processed by fetch-workers pipelines based on db + in-memory queues. + Suitable for tasks that run frequently and need to lock rows for a long time. + """ + pipeline_manager = PipelineManager() + pipeline_manager.start() + return pipeline_manager diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py new file mode 100644 index 0000000000..30be480bf9 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -0,0 +1,344 @@ +import asyncio +import math +import random +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar + +from sqlalchemy import and_, or_, update +from sqlalchemy.orm import Mapped + +from dstack._internal.server.db import get_session_ctx +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class PipelineItem: + __tablename__: str + id: uuid.UUID + lock_expires_at: datetime + lock_token: uuid.UUID + prev_lock_expired: bool + + +class PipelineModel(Protocol): + __tablename__: str + __mapper__: ClassVar[Any] + __table__: ClassVar[Any] + id: Mapped[uuid.UUID] + lock_expires_at: Mapped[Optional[datetime]] + lock_token: Mapped[Optional[uuid.UUID]] + + +class PipelineError(Exception): + pass + + +class Pipeline(ABC): + def __init__( + self, + workers_num: int, + queue_lower_limit_factor: float, + queue_upper_limit_factor: float, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeat_trigger: timedelta, + ) -> None: + self._workers_num = workers_num + self._queue_lower_limit_factor = queue_lower_limit_factor + self._queue_upper_limit_factor = queue_upper_limit_factor + self._queue_desired_minsize = math.ceil(workers_num * queue_lower_limit_factor) + self._queue_maxsize = math.ceil(workers_num * queue_upper_limit_factor) + self._min_processing_interval = min_processing_interval + self._lock_timeout = lock_timeout + self._heartbeat_trigger = heartbeat_trigger + self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize) + self._tasks: list[asyncio.Task] = [] + self._running = False + self._shutdown = False + + def start(self): + """ + Starts all pipeline tasks. + """ + if self._running: + return + if self._shutdown: + raise PipelineError("Cannot start pipeline after shutdown.") + self._running = True + self._tasks.append(asyncio.create_task(self._heartbeater.start())) + for worker in self._workers: + self._tasks.append(asyncio.create_task(worker.start())) + self._tasks.append(asyncio.create_task(self._fetcher.start())) + + def shutdown(self): + """ + Stops the pipeline from processing new items and signals running tasks to cancel. + """ + if self._shutdown: + return + self._shutdown = True + self._running = False + self._fetcher.stop() + for worker in self._workers: + worker.stop() + self._heartbeater.stop() + for task in self._tasks: + if not task.done(): + task.cancel() + + async def drain(self): + """ + Waits for all pipeline tasks to finish cleanup after shutdown. + """ + if not self._shutdown: + raise PipelineError("Cannot drain running pipeline. Call `shutdown()` first.") + results = await asyncio.gather(*self._tasks, return_exceptions=True) + for task, result in zip(self._tasks, results): + if isinstance(result, BaseException) and not isinstance( + result, asyncio.CancelledError + ): + logger.error( + "Unexpected exception when draining pipeline task %r", + task, + exc_info=(type(result), result, result.__traceback__), + ) + + def hint_fetch(self): + self._fetcher.hint() + + @property + @abstractmethod + def hint_fetch_model_name(self) -> str: + pass + + @property + @abstractmethod + def _heartbeater(self) -> "Heartbeater": + pass + + @property + @abstractmethod + def _fetcher(self) -> "Fetcher": + pass + + @property + @abstractmethod + def _workers(self) -> Sequence["Worker"]: + pass + + +ModelT = TypeVar("ModelT", bound=PipelineModel) + + +class Heartbeater(Generic[ModelT]): + def __init__( + self, + model_type: type[ModelT], + lock_timeout: timedelta, + heartbeat_trigger: timedelta, + heartbeat_delay: float = 1.0, + ) -> None: + self._model_type = model_type + self._lock_timeout = lock_timeout + self._hearbeat_margin = heartbeat_trigger + self._items: dict[uuid.UUID, PipelineItem] = {} + self._untrack_lock = asyncio.Lock() + self._heartbeat_delay = heartbeat_delay + self._running = False + + async def start(self): + self._running = True + while self._running: + try: + await self.heartbeat() + except Exception: + logger.exception("Unexpected exception when running heartbeat") + await asyncio.sleep(self._heartbeat_delay) + + def stop(self): + self._running = False + + async def track(self, item: PipelineItem): + self._items[item.id] = item + + async def untrack(self, item: PipelineItem): + async with self._untrack_lock: + tracked = self._items.get(item.id) + # Prevent expired fetch iteration to unlock item processed by new iteration. + if tracked is not None and tracked.lock_token == item.lock_token: + del self._items[item.id] + + async def heartbeat(self): + items_to_update: list[PipelineItem] = [] + now = get_current_datetime() + items = list(self._items.values()) + failed_to_heartbeat_count = 0 + for item in items: + if item.lock_expires_at < now: + failed_to_heartbeat_count += 1 + await self.untrack(item) + elif item.lock_expires_at < now + self._hearbeat_margin: + items_to_update.append(item) + if failed_to_heartbeat_count > 0: + logger.warning( + "Failed to heartbeat %d %s items in time." + " The items are expected to be processed on another fetch iteration.", + failed_to_heartbeat_count, + self._model_type.__tablename__, + ) + if len(items_to_update) == 0: + return + logger.debug( + "Updating lock_expires_at for items: %s", [str(r.id) for r in items_to_update] + ) + async with get_session_ctx() as session: + per_item_filters = [ + and_( + self._model_type.id == item.id, self._model_type.lock_token == item.lock_token + ) + for item in items_to_update + ] + res = await session.execute( + update(self._model_type) + .where(or_(*per_item_filters)) + .values(lock_expires_at=now + self._lock_timeout) + .returning(self._model_type.id) + ) + updated_ids = set(res.scalars().all()) + failed_to_update_count = 0 + for item in items_to_update: + if item.id in updated_ids: + item.lock_expires_at = now + self._lock_timeout + else: + failed_to_update_count += 1 + await self.untrack(item) + if failed_to_update_count > 0: + logger.warning( + "Failed to update %s lock_expires_at of %d items: lock_token changed." + " The items are expected to be processed and updated on another fetch iteration.", + self._model_type.__tablename__, + failed_to_update_count, + ) + + +class Fetcher(ABC): + _DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5] + + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater, + queue_check_delay: float = 1.0, + fetch_delays: Optional[list[float]] = None, + ) -> None: + self._queue = queue + self._queue_desired_minsize = queue_desired_minsize + self._min_processing_interval = min_processing_interval + self._lock_timeout = lock_timeout + self._heartbeater = heartbeater + self._queue_check_delay = queue_check_delay + if fetch_delays is None: + fetch_delays = self._DEFAULT_FETCH_DELAYS + self._fetch_delays = fetch_delays + self._running = False + self._fetch_event = asyncio.Event() + + async def start(self): + self._running = True + empty_fetch_count = 0 + while self._running: + if self._queue.qsize() >= self._queue_desired_minsize: + await asyncio.sleep(self._queue_check_delay) + continue + fetch_limit = self._queue.maxsize - self._queue.qsize() + try: + items = await self.fetch(limit=fetch_limit) + except Exception: + logger.exception("Unexpected exception when fetching new items") + items = [] + if len(items) == 0: + try: + await asyncio.wait_for( + self._fetch_event.wait(), + timeout=self._next_fetch_delay(empty_fetch_count), + ) + except TimeoutError: + pass + empty_fetch_count += 1 + self._fetch_event.clear() + continue + else: + empty_fetch_count = 0 + for item in items: + self._queue.put_nowait(item) # should never raise + await self._heartbeater.track(item) + + def stop(self): + self._running = False + + def hint(self): + self._fetch_event.set() + + @abstractmethod + async def fetch(self, limit: int) -> list[PipelineItem]: + pass + + def _next_fetch_delay(self, empty_fetch_count: int) -> float: + next_delay = self._fetch_delays[min(empty_fetch_count, len(self._fetch_delays) - 1)] + jitter = random.random() * 0.4 - 0.2 + return next_delay * (1 + jitter) + + +class Worker(ABC): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater, + ) -> None: + self._queue = queue + self._heartbeater = heartbeater + self._running = False + + async def start(self): + self._running = True + while self._running: + item = await self._queue.get() + logger.debug("Processing %s item %s", item.__tablename__, item.id) + try: + await self.process(item) + except Exception: + logger.exception("Unexpected exception when processing item") + finally: + await self._heartbeater.untrack(item) + logger.debug("Processed %s item %s", item.__tablename__, item.id) + + def stop(self): + self._running = False + + @abstractmethod + async def process(self, item: PipelineItem): + pass + + +UpdateMap = dict[str, Any] + + +def get_unlock_update_map() -> UpdateMap: + return { + "lock_expires_at": None, + "lock_token": None, + "lock_owner": None, + } + + +def get_processed_update_map() -> UpdateMap: + return {"last_processed_at": get_current_datetime()} diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py new file mode 100644 index 0000000000..685c5205a8 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -0,0 +1,335 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Sequence + +from sqlalchemy import or_, select, update +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithGroupProvisioningSupport +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.compute_groups import ComputeGroupStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group +from dstack._internal.server.services.instances import switch_instance_status +from dstack._internal.server.services.locking import get_locker +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + +TERMINATION_RETRY_TIMEOUT = timedelta(seconds=60) +TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) + + +class ComputeGroupPipeline(Pipeline): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[ComputeGroupModel]( + model_type=ComputeGroupModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = ComputeGroupFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + ComputeGroupWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return ComputeGroupModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher: + return self.__fetcher + + @property + def _workers(self) -> Sequence["ComputeGroupWorker"]: + return self.__workers + + +class ComputeGroupFetcher(Fetcher): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[ComputeGroupModel], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + async def fetch(self, limit: int) -> list[PipelineItem]: + compute_group_lock, _ = get_locker(get_db().dialect_name).get_lockset( + ComputeGroupModel.__tablename__ + ) + async with compute_group_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(ComputeGroupModel) + .where( + ComputeGroupModel.status.not_in(ComputeGroupStatus.finished_statuses()), + ComputeGroupModel.last_processed_at <= now - self._min_processing_interval, + or_( + ComputeGroupModel.lock_expires_at.is_(None), + ComputeGroupModel.lock_expires_at < now, + ), + or_( + ComputeGroupModel.lock_owner.is_(None), + ComputeGroupModel.lock_owner == ComputeGroupPipeline.__name__, + ), + ) + .order_by(ComputeGroupModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=ComputeGroupModel) + .options( + load_only( + ComputeGroupModel.id, + ComputeGroupModel.lock_token, + ComputeGroupModel.lock_expires_at, + ) + ) + ) + compute_group_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for compute_group_model in compute_group_models: + prev_lock_expired = compute_group_model.lock_expires_at is not None + compute_group_model.lock_expires_at = lock_expires_at + compute_group_model.lock_token = lock_token + compute_group_model.lock_owner = ComputeGroupPipeline.__name__ + items.append( + PipelineItem( + __tablename__=ComputeGroupModel.__tablename__, + id=compute_group_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + ) + ) + await session.commit() + return items + + +class ComputeGroupWorker(Worker): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater[ComputeGroupModel], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + async def process(self, item: PipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(ComputeGroupModel) + .where( + ComputeGroupModel.id == item.id, + ComputeGroupModel.lock_token == item.lock_token, + ) + # Terminating instances belonging to a compute group are locked implicitly by locking the compute group. + .options( + joinedload(ComputeGroupModel.instances), + joinedload(ComputeGroupModel.project).joinedload(ProjectModel.backends), + ) + ) + compute_group_model = res.unique().scalar_one_or_none() + if compute_group_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + terminate_result = _TerminateResult() + # TODO: Fetch only compute groups with all instances terminating. + if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances): + terminate_result = await _terminate_compute_group(compute_group_model) + if terminate_result.compute_group_update_map: + logger.info("Terminated compute group %s", compute_group_model.id) + else: + terminate_result.compute_group_update_map = get_processed_update_map() + + terminate_result.compute_group_update_map |= get_unlock_update_map() + + async with get_session_ctx() as session: + res = await session.execute( + update(ComputeGroupModel) + .where( + ComputeGroupModel.id == compute_group_model.id, + ComputeGroupModel.lock_token == compute_group_model.lock_token, + ) + .values(**terminate_result.compute_group_update_map) + .returning(ComputeGroupModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + if not terminate_result.instances_update_map: + return + instances_ids = [i.id for i in compute_group_model.instances] + res = await session.execute( + update(InstanceModel) + .where(InstanceModel.id.in_(instances_ids)) + .values(**terminate_result.instances_update_map) + ) + for instance_model in compute_group_model.instances: + switch_instance_status(session, instance_model, InstanceStatus.TERMINATED) + + +@dataclass +class _TerminateResult: + compute_group_update_map: UpdateMap = field(default_factory=dict) + instances_update_map: UpdateMap = field(default_factory=dict) + + +async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _TerminateResult: + result = _TerminateResult() + if ( + compute_group_model.last_termination_retry_at is not None + and _next_termination_retry_at(compute_group_model.last_termination_retry_at) + > get_current_datetime() + ): + return result + compute_group = compute_group_model_to_compute_group(compute_group_model) + cgpd = compute_group.provisioning_data + backend = await backends_services.get_project_backend_by_type( + project=compute_group_model.project, + backend_type=cgpd.backend, + ) + if backend is None: + logger.error( + "Failed to terminate compute group %s. Backend %s not available." + " Please terminate it manually to avoid unexpected charges.", + compute_group.name, + cgpd.backend, + ) + return _get_terminated_result() + logger.debug("Terminating compute group %s", compute_group.name) + compute = backend.compute() + assert isinstance(compute, ComputeWithGroupProvisioningSupport) + try: + await run_async( + compute.terminate_compute_group, + compute_group, + ) + except Exception as e: + if compute_group_model.first_termination_retry_at is None: + result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime() + result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime() + if _next_termination_retry_at( + result.compute_group_update_map["last_termination_retry_at"] + ) < _get_termination_deadline( + result.compute_group_update_map.get( + "first_termination_retry_at", compute_group_model.first_termination_retry_at + ) + ): + logger.warning( + "Failed to terminate compute group %s. Will retry. Error: %r", + compute_group.name, + e, + exc_info=not isinstance(e, BackendError), + ) + return result + logger.error( + "Failed all attempts to terminate compute group %s." + " Please terminate it manually to avoid unexpected charges." + " Error: %r", + compute_group.name, + e, + exc_info=not isinstance(e, BackendError), + ) + terminated_result = _get_terminated_result() + return _TerminateResult( + compute_group_update_map=result.compute_group_update_map + | terminated_result.compute_group_update_map, + instances_update_map=result.instances_update_map | terminated_result.instances_update_map, + ) + + +def _next_termination_retry_at(last_termination_retry_at: datetime) -> datetime: + return last_termination_retry_at + TERMINATION_RETRY_TIMEOUT + + +def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime: + return first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION + + +def _get_terminated_result() -> _TerminateResult: + now = get_current_datetime() + return _TerminateResult( + compute_group_update_map={ + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + "status": ComputeGroupStatus.TERMINATED, + }, + instances_update_map={ + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + "finished_at": now, + "status": InstanceStatus.TERMINATED, + }, + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py new file mode 100644 index 0000000000..9fac5665a5 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -0,0 +1,263 @@ +import asyncio +import uuid +from datetime import timedelta +from typing import Sequence + +from sqlalchemy import or_, select, update +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport +from dstack._internal.core.errors import PlacementGroupInUseError +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + PlacementGroupModel, + ProjectModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.placement import placement_group_model_to_placement_group +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class PlacementGroupPipeline(Pipeline): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[PlacementGroupModel]( + model_type=PlacementGroupModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = PlacementGroupFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + PlacementGroupWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return PlacementGroupModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher: + return self.__fetcher + + @property + def _workers(self) -> Sequence["PlacementGroupWorker"]: + return self.__workers + + +class PlacementGroupFetcher(Fetcher): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[PlacementGroupModel], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + async def fetch(self, limit: int) -> list[PipelineItem]: + placement_group_lock, _ = get_locker(get_db().dialect_name).get_lockset( + PlacementGroupModel.__tablename__ + ) + async with placement_group_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.fleet_deleted == True, + PlacementGroupModel.deleted == False, + PlacementGroupModel.last_processed_at + <= now - self._min_processing_interval, + or_( + PlacementGroupModel.lock_expires_at.is_(None), + PlacementGroupModel.lock_expires_at < now, + ), + or_( + PlacementGroupModel.lock_owner.is_(None), + PlacementGroupModel.lock_owner == PlacementGroupPipeline.__name__, + ), + ) + .order_by(PlacementGroupModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + PlacementGroupModel.id, + PlacementGroupModel.lock_token, + PlacementGroupModel.lock_expires_at, + ) + ) + ) + placement_group_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for placement_group_model in placement_group_models: + prev_lock_expired = placement_group_model.lock_expires_at is not None + placement_group_model.lock_expires_at = lock_expires_at + placement_group_model.lock_token = lock_token + placement_group_model.lock_owner = PlacementGroupPipeline.__name__ + items.append( + PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, + id=placement_group_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + ) + ) + await session.commit() + return items + + +class PlacementGroupWorker(Worker): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater[PlacementGroupModel], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + async def process(self, item: PipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.id == item.id, + PlacementGroupModel.lock_token == item.lock_token, + ) + .options(joinedload(PlacementGroupModel.project).joinedload(ProjectModel.backends)) + ) + placement_group_model = res.unique().scalar_one_or_none() + if placement_group_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + update_map = await _delete_placement_group(placement_group_model) + if update_map: + logger.info("Deleted placement group %s", placement_group_model.name) + else: + update_map = get_processed_update_map() + + update_map |= get_unlock_update_map() + + async with get_session_ctx() as session: + res = await session.execute( + update(PlacementGroupModel) + .where( + PlacementGroupModel.id == placement_group_model.id, + PlacementGroupModel.lock_token == placement_group_model.lock_token, + ) + .values(**update_map) + .returning(PlacementGroupModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + + +async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> UpdateMap: + placement_group = placement_group_model_to_placement_group(placement_group_model) + if placement_group.provisioning_data is None: + logger.error( + "Failed to delete placement group %s. provisioning_data is None.", placement_group.name + ) + return _get_deleted_update_map() + backend = await backends_services.get_project_backend_by_type( + project=placement_group_model.project, + backend_type=placement_group.provisioning_data.backend, + ) + if backend is None: + logger.error( + "Failed to delete placement group %s. Backend not available. Please delete it manually.", + placement_group.name, + ) + return _get_deleted_update_map() + compute = backend.compute() + assert isinstance(compute, ComputeWithPlacementGroupSupport) + try: + await run_async(compute.delete_placement_group, placement_group) + except PlacementGroupInUseError: + logger.info( + "Placement group %s is still in use. Skipping deletion for now.", placement_group.name + ) + return {} + except Exception: + logger.exception( + "Got exception when deleting placement group %s. Please delete it manually.", + placement_group.name, + ) + return _get_deleted_update_map() + + return _get_deleted_update_map() + + +def _get_deleted_update_map() -> UpdateMap: + now = get_current_datetime() + return { + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + } diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py new file mode 100644 index 0000000000..c4baf96c58 --- /dev/null +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -0,0 +1,159 @@ +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.interval import IntervalTrigger + +from dstack._internal.server import settings +from dstack._internal.server.background.scheduled_tasks.compute_groups import ( + process_compute_groups, +) +from dstack._internal.server.background.scheduled_tasks.events import delete_events +from dstack._internal.server.background.scheduled_tasks.fleets import process_fleets +from dstack._internal.server.background.scheduled_tasks.gateways import ( + process_gateways, + process_gateways_connections, +) +from dstack._internal.server.background.scheduled_tasks.idle_volumes import ( + process_idle_volumes, +) +from dstack._internal.server.background.scheduled_tasks.instances import ( + delete_instance_health_checks, + process_instances, +) +from dstack._internal.server.background.scheduled_tasks.metrics import ( + collect_metrics, + delete_metrics, +) +from dstack._internal.server.background.scheduled_tasks.placement_groups import ( + process_placement_groups, +) +from dstack._internal.server.background.scheduled_tasks.probes import process_probes +from dstack._internal.server.background.scheduled_tasks.prometheus_metrics import ( + collect_prometheus_metrics, + delete_prometheus_metrics, +) +from dstack._internal.server.background.scheduled_tasks.running_jobs import ( + process_running_jobs, +) +from dstack._internal.server.background.scheduled_tasks.runs import process_runs +from dstack._internal.server.background.scheduled_tasks.submitted_jobs import ( + process_submitted_jobs, +) +from dstack._internal.server.background.scheduled_tasks.terminating_jobs import ( + process_terminating_jobs, +) +from dstack._internal.server.background.scheduled_tasks.volumes import ( + process_submitted_volumes, +) +from dstack._internal.settings import FeatureFlags + +_scheduler = AsyncIOScheduler() + + +def get_scheduler() -> AsyncIOScheduler: + return _scheduler + + +def start_scheduled_tasks() -> AsyncIOScheduler: + """ + Start periodic tasks triggered by `apscheduler` at specific times/intervals. + Suitable for tasks that run infrequently and don't need to lock rows for a long time. + """ + # Background processing is implemented via in-memory locks on SQLite + # and SELECT FOR UPDATE on Postgres. Locks may be held for a long time. + # This is currently the main bottleneck for scaling dstack processing + # as processing more resources requires more DB connections. + # TODO: Make background processing efficient by committing locks to DB + # and processing outside of DB transactions. + # + # Now we just try to process as many resources as possible without exhausting DB connections. + # + # Quick tasks can process multiple resources per transaction. + # Potentially long tasks process one resource per transaction + # to avoid holding locks for all the resources if one is slow to process. + # Still, the next batch won't be processed unless all resources are processed, + # so larger batches do not increase processing rate linearly. + # + # The interval, batch_size, and max_instances determine background tasks processing rates. + # By default, one server replica can handle: + # + # * 150 active jobs with 2 minutes processing latency + # * 150 active runs with 2 minutes processing latency + # * 150 active instances with 2 minutes processing latency + # + # These latency numbers do not account for provisioning time, + # so it may be slower if a backend is slow to provision. + # + # Users can set SERVER_BACKGROUND_PROCESSING_FACTOR to process more resources per replica. + # They also need to increase max db connections on the client side and db side. + # + # In-memory locking via locksets does not guarantee + # that the first waiting for the lock will acquire it. + # The jitter is needed to give all tasks a chance to acquire locks. + + _scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1)) + _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) + _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) + _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) + if settings.ENABLE_PROMETHEUS_METRICS: + _scheduler.add_job( + collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 + ) + _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) + _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) + _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) + _scheduler.add_job( + process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 + ) + _scheduler.add_job( + process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 + ) + if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) + _scheduler.add_job( + process_fleets, + IntervalTrigger(seconds=10, jitter=2), + max_instances=1, + ) + _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) + for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): + # Add multiple copies of tasks if requested. + # max_instances=1 for additional copies to avoid running too many tasks. + # Move other tasks here when they need per-replica scaling. + _scheduler.add_job( + process_submitted_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=4 if replica == 0 else 1, + ) + _scheduler.add_job( + process_running_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_terminating_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_runs, + IntervalTrigger(seconds=2, jitter=1), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_instances, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_compute_groups, + IntervalTrigger(seconds=15, jitter=2), + kwargs={"batch_size": 1}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.start() + return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/common.py b/src/dstack/_internal/server/background/scheduled_tasks/common.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/common.py rename to src/dstack/_internal/server/background/scheduled_tasks/common.py diff --git a/src/dstack/_internal/server/background/tasks/process_compute_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_compute_groups.py rename to src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py diff --git a/src/dstack/_internal/server/background/tasks/process_events.py b/src/dstack/_internal/server/background/scheduled_tasks/events.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_events.py rename to src/dstack/_internal/server/background/scheduled_tasks/events.py diff --git a/src/dstack/_internal/server/background/tasks/process_fleets.py b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_fleets.py rename to src/dstack/_internal/server/background/scheduled_tasks/fleets.py diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_gateways.py rename to src/dstack/_internal/server/background/scheduled_tasks/gateways.py diff --git a/src/dstack/_internal/server/background/tasks/process_idle_volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_idle_volumes.py rename to src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py similarity index 99% rename from src/dstack/_internal/server/background/tasks/process_instances.py rename to src/dstack/_internal/server/background/scheduled_tasks/instances.py index da47cf16ed..196f347c4f 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -59,7 +59,7 @@ JobProvisioningData, ) from dstack._internal.server import settings as server_settings -from dstack._internal.server.background.tasks.common import get_provisioning_timeout +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, diff --git a/src/dstack/_internal/server/background/tasks/process_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_metrics.py rename to src/dstack/_internal/server/background/scheduled_tasks/metrics.py diff --git a/src/dstack/_internal/server/background/tasks/process_placement_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_placement_groups.py rename to src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py diff --git a/src/dstack/_internal/server/background/tasks/process_probes.py b/src/dstack/_internal/server/background/scheduled_tasks/probes.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_probes.py rename to src/dstack/_internal/server/background/scheduled_tasks/probes.py diff --git a/src/dstack/_internal/server/background/tasks/process_prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_prometheus_metrics.py rename to src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py similarity index 99% rename from src/dstack/_internal/server/background/tasks/process_running_jobs.py rename to src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 7275106ceb..f413edf44b 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -37,7 +37,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint -from dstack._internal.server.background.tasks.common import get_provisioning_timeout +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_runs.py rename to src/dstack/_internal/server/background/scheduled_tasks/runs.py diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py similarity index 99% rename from src/dstack/_internal/server/background/tasks/process_submitted_jobs.py rename to src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index a021096613..79746e9338 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -57,7 +57,9 @@ from dstack._internal.core.models.volumes import Volume from dstack._internal.core.services.profiles import get_termination from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_compute_groups import ComputeGroupStatus +from dstack._internal.server.background.scheduled_tasks.compute_groups import ( + ComputeGroupStatus, +) from dstack._internal.server.db import ( get_db, get_session_ctx, diff --git a/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_terminating_jobs.py rename to src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py diff --git a/src/dstack/_internal/server/background/tasks/process_volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_volumes.py rename to src/dstack/_internal/server/background/scheduled_tasks/volumes.py diff --git a/src/dstack/_internal/server/migrations/versions/57cff3ec86ce_add_computegroupmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/57cff3ec86ce_add_computegroupmodel_pipeline_columns.py new file mode 100644 index 0000000000..e341b3b4a4 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/57cff3ec86ce_add_computegroupmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add ComputeGroupModel pipeline columns + +Revision ID: 57cff3ec86ce +Revises: 706e0acc3a7d +Create Date: 2026-02-18 11:07:48.686185 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "57cff3ec86ce" +down_revision = "706e0acc3a7d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("compute_groups", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("compute_groups", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py new file mode 100644 index 0000000000..56297fde36 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add PlacementGroupModel pipeline columns + +Revision ID: 9c2a227b0154 +Revises: 57cff3ec86ce +Create Date: 2026-02-18 11:08:57.860277 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "9c2a227b0154" +down_revision = "57cff3ec86ce" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("placement_groups", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("placement_groups", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py b/src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py new file mode 100644 index 0000000000..ad35a23d06 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py @@ -0,0 +1,57 @@ +"""Add pipeline indexes for compute and placement groups + +Revision ID: a8ed24fd7f90 +Revises: 9c2a227b0154 +Create Date: 2026-02-18 11:22:25.972000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a8ed24fd7f90" +down_revision = "9c2a227b0154" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.create_index( + "ix_compute_groups_pipeline_fetch_q", + "compute_groups", + [sa.literal_column("last_processed_at ASC")], + unique=False, + postgresql_where=sa.text("(status NOT IN ('TERMINATED'))"), + sqlite_where=sa.text("(status NOT IN ('TERMINATED'))"), + postgresql_concurrently=True, + ) + op.create_index( + "ix_placement_groups_pipeline_fetch_q", + "placement_groups", + [sa.literal_column("last_processed_at ASC")], + unique=False, + postgresql_where=sa.text("deleted IS FALSE"), + sqlite_where=sa.text("deleted = 0"), + postgresql_concurrently=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_placement_groups_pipeline_fetch_q", + "placement_groups", + postgresql_concurrently=True, + ) + op.drop_index( + "ix_compute_groups_pipeline_fetch_q", + "compute_groups", + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 7e9db282d1..a837137a10 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -196,6 +196,12 @@ class BaseModel(DeclarativeBase): metadata = MetaData(naming_convention=constraint_naming_convention) +class PipelineModelMixin: + lock_expires_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + lock_token: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) + lock_owner: Mapped[Optional[str]] = mapped_column(String(100)) + + class UserModel(BaseModel): __tablename__ = "users" @@ -768,7 +774,7 @@ class VolumeAttachmentModel(BaseModel): attachment_data: Mapped[Optional[str]] = mapped_column(Text) -class PlacementGroupModel(BaseModel): +class PlacementGroupModel(PipelineModelMixin, BaseModel): __tablename__ = "placement_groups" id: Mapped[uuid.UUID] = mapped_column( @@ -794,8 +800,17 @@ class PlacementGroupModel(BaseModel): configuration: Mapped[str] = mapped_column(Text) provisioning_data: Mapped[Optional[str]] = mapped_column(Text) + __table_args__ = ( + Index( + "ix_placement_groups_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + -class ComputeGroupModel(BaseModel): +class ComputeGroupModel(PipelineModelMixin, BaseModel): __tablename__ = "compute_groups" id: Mapped[uuid.UUID] = mapped_column( @@ -823,6 +838,15 @@ class ComputeGroupModel(BaseModel): instances: Mapped[List["InstanceModel"]] = relationship(back_populates="compute_group") + __table_args__ = ( + Index( + "ix_compute_groups_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=status.not_in(ComputeGroupStatus.finished_statuses()), + sqlite_where=status.not_in(ComputeGroupStatus.finished_statuses()), + ), + ) + class JobMetricsPoint(BaseModel): __tablename__ = "job_metrics_points" diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index a9496ad348..3310cda996 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -77,7 +77,7 @@ def get_default_python_verison() -> str: def get_default_image(nvcc: bool = False) -> str: """ Note: May be overridden by dstack (e.g., EFA-enabled version for AWS EFA-capable instances). - See `dstack._internal.server.background.tasks.process_running_jobs._patch_base_image_for_aws_efa` for details. + See `dstack._internal.server.background.scheduled_tasks.running_jobs._patch_base_image_for_aws_efa` for details. Args: nvcc: If True, returns 'devel' variant, otherwise 'base'. diff --git a/src/dstack/_internal/server/services/pipelines.py b/src/dstack/_internal/server/services/pipelines.py new file mode 100644 index 0000000000..19f4df902d --- /dev/null +++ b/src/dstack/_internal/server/services/pipelines.py @@ -0,0 +1,12 @@ +from typing import Protocol + +from fastapi import Request + + +class PipelineHinterProtocol(Protocol): + def hint_fetch(self, model_name: str) -> None: + pass + + +def get_pipeline_hinter(request: Request) -> PipelineHinterProtocol: + return request.app.state.pipeline_manager.hinter diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index 6089e37c07..d94bb56547 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -47,3 +47,6 @@ class FeatureFlags: # DSTACK_FF_AUTOCREATED_FLEETS_ENABLED enables legacy autocreated fleets: # If there are no fleet suitable for the run, a new fleet is created automatically instead of an error. AUTOCREATED_FLEETS_ENABLED = os.getenv("DSTACK_FF_AUTOCREATED_FLEETS_ENABLED") is not None + # DSTACK_FF_PIPELINE_PROCESSING_ENABLED enables new pipeline-based processing tasks (background/pipeline_tasks/) + # instead of scheduler-based processing tasks (background/scheduled_tasks/) for tasks that implement pipelines. + PIPELINE_PROCESSING_ENABLED = os.getenv("DSTACK_FF_PIPELINE_PROCESSING_ENABLED") is not None diff --git a/src/dstack/_internal/server/background/tasks/__init__.py b/src/tests/_internal/server/background/pipeline_tasks/__init__.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/__init__.py rename to src/tests/_internal/server/background/pipeline_tasks/__init__.py diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_base.py b/src/tests/_internal/server/background/pipeline_tasks/test_base.py new file mode 100644 index 0000000000..7e84d9f80d --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_base.py @@ -0,0 +1,183 @@ +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.pipeline_tasks.base import Heartbeater, PipelineItem +from dstack._internal.server.models import PlacementGroupModel +from dstack._internal.server.testing.common import ( + create_fleet, + create_placement_group, + create_project, +) + + +@pytest.fixture +def now() -> datetime: + return datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + + +@pytest.fixture +def heartbeater() -> Heartbeater[PlacementGroupModel]: + return Heartbeater( + model_type=PlacementGroupModel, + lock_timeout=timedelta(seconds=30), + heartbeat_trigger=timedelta(seconds=5), + ) + + +async def _create_locked_placement_group( + session: AsyncSession, + now: datetime, + lock_expires_in: timedelta, +) -> PlacementGroupModel: + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + placement_group = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test-pg", + ) + placement_group.lock_token = uuid.uuid4() + placement_group.lock_expires_at = now + lock_expires_in + await session.commit() + return placement_group + + +def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> PipelineItem: + assert placement_group.lock_token is not None + assert placement_group.lock_expires_at is not None + return PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, + id=placement_group.id, + lock_token=placement_group.lock_token, + lock_expires_at=placement_group.lock_expires_at, + prev_lock_expired=False, + ) + + +class TestHeartbeater: + @pytest.mark.asyncio + async def test_untrack_preserves_item_when_lock_token_mismatches( + self, heartbeater: Heartbeater[PlacementGroupModel], now: datetime + ): + item = PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, + id=uuid.uuid4(), + lock_token=uuid.uuid4(), + lock_expires_at=now + timedelta(seconds=10), + prev_lock_expired=True, + ) + await heartbeater.track(item) + + stale_item = PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, + id=item.id, + lock_token=uuid.uuid4(), + lock_expires_at=item.lock_expires_at, + prev_lock_expired=False, + ) + await heartbeater.untrack(stale_item) + + assert item.id in heartbeater._items + await heartbeater.untrack(item) + assert item.id not in heartbeater._items + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_heartbeat_extends_locks_close_to_expiration( + self, + test_db, + session: AsyncSession, + heartbeater: Heartbeater[PlacementGroupModel], + now: datetime, + ): + placement_group = await _create_locked_placement_group( + session=session, + now=now, + lock_expires_in=timedelta(seconds=2), + ) + await heartbeater.track(_placement_group_to_pipeline_item(placement_group)) + + with patch( + "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", + return_value=now, + ): + await heartbeater.heartbeat() + + expected_lock_expires_at = now + timedelta(seconds=30) + tracked_item = heartbeater._items[placement_group.id] + assert tracked_item.lock_expires_at == expected_lock_expires_at + + await session.refresh(placement_group) + assert placement_group.lock_expires_at == expected_lock_expires_at + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_heartbeat_untracks_expired_items_without_db_update( + self, + test_db, + session: AsyncSession, + heartbeater: Heartbeater[PlacementGroupModel], + now: datetime, + ): + original_lock_expires_at = now - timedelta(seconds=1) + placement_group = await _create_locked_placement_group( + session=session, + now=now, + lock_expires_in=timedelta(seconds=-1), + ) + await heartbeater.track(_placement_group_to_pipeline_item(placement_group)) + + with patch( + "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", + return_value=now, + ): + await heartbeater.heartbeat() + + assert placement_group.id not in heartbeater._items + + await session.refresh(placement_group) + assert placement_group.lock_expires_at == original_lock_expires_at + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_heartbeat_untracks_item_when_lock_token_changed_in_db( + self, + test_db, + session: AsyncSession, + heartbeater: Heartbeater[PlacementGroupModel], + now: datetime, + ): + original_lock_expires_at = now + timedelta(seconds=2) + placement_group = await _create_locked_placement_group( + session=session, + now=now, + lock_expires_in=timedelta(seconds=2), + ) + await heartbeater.track(_placement_group_to_pipeline_item(placement_group)) + + new_lock_token = uuid.uuid4() + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.id == placement_group.id) + .values(lock_token=new_lock_token) + .execution_options(synchronize_session=False) + ) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", + return_value=now, + ): + await heartbeater.heartbeat() + + assert placement_group.id not in heartbeater._items + + await session.refresh(placement_group) + assert placement_group.lock_token == new_lock_token + assert placement_group.lock_expires_at == original_lock_expires_at diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py new file mode 100644 index 0000000000..6d24669f7c --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py @@ -0,0 +1,113 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.compute_groups import ComputeGroupStatus +from dstack._internal.server.background.pipeline_tasks.base import PipelineItem +from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupWorker +from dstack._internal.server.models import ComputeGroupModel +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_compute_group, + create_fleet, + create_project, +) + + +@pytest.fixture +def worker() -> ComputeGroupWorker: + return ComputeGroupWorker(queue=Mock(), heartbeater=Mock()) + + +def _compute_group_to_pipeline_item(compute_group: ComputeGroupModel) -> PipelineItem: + assert compute_group.lock_token is not None + assert compute_group.lock_expires_at is not None + return PipelineItem( + __tablename__=compute_group.__tablename__, + id=compute_group.id, + lock_token=compute_group.lock_token, + lock_expires_at=compute_group.lock_expires_at, + prev_lock_expired=False, + ) + + +class TestComputeGroupWorker: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_terminates_compute_group( + self, test_db, session: AsyncSession, worker: ComputeGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group( + session=session, + project=project, + fleet=fleet, + ) + compute_group.lock_token = uuid.uuid4() + compute_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + await worker.process(_compute_group_to_pipeline_item(compute_group)) + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status == ComputeGroupStatus.TERMINATED + assert compute_group.deleted + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retries_compute_group_termination( + self, test_db, session: AsyncSession, worker: ComputeGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), + ) + compute_group.lock_token = uuid.uuid4() + compute_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.terminate_compute_group.side_effect = BackendError() + await worker.process(_compute_group_to_pipeline_item(compute_group)) + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status != ComputeGroupStatus.TERMINATED + assert compute_group.first_termination_retry_at is not None + assert compute_group.last_termination_retry_at is not None + # Simulate termination deadline exceeded + compute_group.first_termination_retry_at = datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc) + compute_group.last_termination_retry_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + compute_group.last_processed_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + compute_group.lock_token = uuid.uuid4() + compute_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.terminate_compute_group.side_effect = BackendError() + await worker.process(_compute_group_to_pipeline_item(compute_group)) + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status == ComputeGroupStatus.TERMINATED diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py new file mode 100644 index 0000000000..87cab83e12 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -0,0 +1,63 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.pipeline_tasks.base import PipelineItem +from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_fleet, + create_placement_group, + create_project, +) + + +@pytest.fixture +def worker() -> PlacementGroupWorker: + return PlacementGroupWorker(queue=Mock(), heartbeater=Mock()) + + +def _placement_group_to_pipeline_item(placement_group) -> PipelineItem: + assert placement_group.lock_token is not None + assert placement_group.lock_expires_at is not None + return PipelineItem( + __tablename__=placement_group.__tablename__, + id=placement_group.id, + lock_token=placement_group.lock_token, + lock_expires_at=placement_group.lock_expires_at, + prev_lock_expired=False, + ) + + +class TestPlacementGroupWorker: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_placement_group( + self, test_db, session: AsyncSession, worker: PlacementGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + placement_group = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test1-pg", + fleet_deleted=True, + ) + placement_group.lock_token = uuid.uuid4() + placement_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + await worker.process(_placement_group_to_pipeline_item(placement_group)) + aws_mock.compute.return_value.delete_placement_group.assert_called_once() + await session.refresh(placement_group) + assert placement_group.deleted diff --git a/src/tests/_internal/server/background/tasks/__init__.py b/src/tests/_internal/server/background/scheduled_tasks/__init__.py similarity index 100% rename from src/tests/_internal/server/background/tasks/__init__.py rename to src/tests/_internal/server/background/scheduled_tasks/__init__.py diff --git a/src/tests/_internal/server/background/tasks/test_process_compute_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_compute_groups.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_compute_groups.py rename to src/tests/_internal/server/background/scheduled_tasks/test_compute_groups.py index 11ce734606..b2b1920199 100644 --- a/src/tests/_internal/server/background/tasks/test_process_compute_groups.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_compute_groups.py @@ -6,8 +6,8 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.server.background.tasks.process_compute_groups import ( - ComputeGroupStatus, +from dstack._internal.core.models.compute_groups import ComputeGroupStatus +from dstack._internal.server.background.scheduled_tasks.compute_groups import ( process_compute_groups, ) from dstack._internal.server.testing.common import ( diff --git a/src/tests/_internal/server/background/tasks/test_process_events.py b/src/tests/_internal/server/background/scheduled_tasks/test_events.py similarity index 94% rename from src/tests/_internal/server/background/tasks/test_process_events.py rename to src/tests/_internal/server/background/scheduled_tasks/test_events.py index 21043e0bae..91eb066f58 100644 --- a/src/tests/_internal/server/background/tasks/test_process_events.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_events.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_events import delete_events +from dstack._internal.server.background.scheduled_tasks.events import delete_events from dstack._internal.server.services import events from dstack._internal.server.testing.common import create_user, list_events diff --git a/src/tests/_internal/server/background/tasks/test_process_fleets.py b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_fleets.py rename to src/tests/_internal/server/background/scheduled_tasks/test_fleets.py index ae7155c3ca..2ef1b27ab9 100644 --- a/src/tests/_internal/server/background/tasks/test_process_fleets.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py @@ -6,7 +6,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole -from dstack._internal.server.background.tasks.process_fleets import process_fleets +from dstack._internal.server.background.scheduled_tasks.fleets import process_fleets from dstack._internal.server.models import InstanceModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_gateways.py rename to src/tests/_internal/server/background/scheduled_tasks/test_gateways.py index b280b8948d..5f19d2cfcd 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py @@ -5,7 +5,7 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus -from dstack._internal.server.background.tasks.process_gateways import process_gateways +from dstack._internal.server.background.scheduled_tasks.gateways import process_gateways from dstack._internal.server.testing.common import ( AsyncContextManager, ComputeMockSpec, diff --git a/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_idle_volumes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py index 9d73afbb78..6a7acf0c43 100644 --- a/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py @@ -6,7 +6,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeStatus -from dstack._internal.server.background.tasks.process_idle_volumes import ( +from dstack._internal.server.background.scheduled_tasks.idle_volumes import ( _get_idle_time, _should_delete_volume, process_idle_volumes, diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_instances.py rename to src/tests/_internal/server/background/scheduled_tasks/test_instances.py index 8d94ee059b..1b9789953e 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py @@ -39,7 +39,7 @@ JobProvisioningData, JobStatus, ) -from dstack._internal.server.background.tasks.process_instances import ( +from dstack._internal.server.background.scheduled_tasks.instances import ( delete_instance_health_checks, process_instances, ) @@ -101,7 +101,7 @@ async def test_check_shim_transitions_provisioning_on_ready( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -130,7 +130,7 @@ async def test_check_shim_transitions_provisioning_on_terminating( health_reason = "Shim problem" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_reason) await process_instances() @@ -177,7 +177,7 @@ async def test_check_shim_transitions_provisioning_on_busy( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -202,7 +202,7 @@ async def test_check_shim_start_termination_deadline(self, test_db, session: Asy ) health_status = "SSH connection fail" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() @@ -232,7 +232,7 @@ async def test_check_shim_does_not_start_termination_deadline_with_ssh_instance( ) health_status = "SSH connection fail" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() @@ -257,7 +257,7 @@ async def test_check_shim_stop_termination_deadline(self, test_db, session: Asyn await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -283,7 +283,7 @@ async def test_check_shim_terminate_instance_by_deadline(self, test_db, session: health_status = "Not ok" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() @@ -347,7 +347,7 @@ async def test_check_shim_process_ureachable_state( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -378,7 +378,7 @@ async def test_check_shim_switch_to_unreachable_state( ) with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False) await process_instances() @@ -412,7 +412,7 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes ) with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck( reachable=True, health_response=health_response @@ -440,7 +440,7 @@ class TestRemoveDanglingTasks: @pytest.fixture def disable_maybe_install_components(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._maybe_install_components", + "dstack._internal.server.background.scheduled_tasks.instances._maybe_install_components", Mock(return_value=None), ) @@ -607,7 +607,7 @@ def mock_terminate_in_backend(error: Optional[Exception] = None): if error is not None: terminate_instance.side_effect = error with patch( - "dstack._internal.server.background.tasks.process_instances.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.scheduled_tasks.instances.backends_services.get_project_backend_by_type" ) as get_backend: get_backend.return_value = backend yield terminate_instance @@ -1153,7 +1153,7 @@ def host_info(self) -> dict: def deploy_instance_mock(self, monkeypatch: pytest.MonkeyPatch, host_info: dict): mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, GoArchType.AMD64)) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._deploy_instance", mock + "dstack._internal.server.background.scheduled_tasks.instances._deploy_instance", mock ) return mock @@ -1262,7 +1262,7 @@ def component_list(self) -> ComponentList: def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: caplog.set_level( level=logging.DEBUG, - logger="dstack._internal.server.background.tasks.process_instances", + logger="dstack._internal.server.background.scheduled_tasks.instances", ) return caplog @@ -1308,7 +1308,7 @@ def component_list(self) -> ComponentList: def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=self.EXPECTED_VERSION) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_version", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_runner_version", mock, ) return mock @@ -1317,7 +1317,7 @@ def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Moc def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value="https://example.com/runner") monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_download_url", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_runner_download_url", mock, ) return mock @@ -1424,7 +1424,7 @@ def component_list(self) -> ComponentList: def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=self.EXPECTED_VERSION) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_shim_version", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_shim_version", mock, ) return mock @@ -1433,7 +1433,7 @@ def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value="https://example.com/shim") monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_shim_download_url", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_shim_download_url", mock, ) return mock @@ -1547,7 +1547,7 @@ def component_list(self) -> ComponentList: def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=False) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._maybe_install_runner", + "dstack._internal.server.background.scheduled_tasks.instances._maybe_install_runner", mock, ) return mock @@ -1556,7 +1556,7 @@ def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=False) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._maybe_install_shim", + "dstack._internal.server.background.scheduled_tasks.instances._maybe_install_shim", mock, ) return mock diff --git a/src/tests/_internal/server/background/tasks/test_process_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_metrics.py rename to src/tests/_internal/server/background/scheduled_tasks/test_metrics.py index 0be650a223..df52dd88e2 100644 --- a/src/tests/_internal/server/background/tasks/test_process_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py @@ -10,7 +10,7 @@ from dstack._internal.core.models.runs import JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_metrics import ( +from dstack._internal.server.background.scheduled_tasks.metrics import ( collect_metrics, delete_metrics, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_placement_groups.py similarity index 94% rename from src/tests/_internal/server/background/tasks/test_process_placement_groups.py rename to src/tests/_internal/server/background/scheduled_tasks/test_placement_groups.py index a45051a48e..14b9d2189d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_placement_groups.py @@ -3,7 +3,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.server.background.tasks.process_placement_groups import ( +from dstack._internal.server.background.scheduled_tasks.placement_groups import ( process_placement_groups, ) from dstack._internal.server.testing.common import ( diff --git a/src/tests/_internal/server/background/tasks/test_process_probes.py b/src/tests/_internal/server/background/scheduled_tasks/test_probes.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_probes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_probes.py index 928709dd7f..bfd569ab1b 100644 --- a/src/tests/_internal/server/background/tasks/test_process_probes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_probes.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.configurations import ProbeConfig, ServiceConfiguration from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus -from dstack._internal.server.background.tasks.process_probes import ( +from dstack._internal.server.background.scheduled_tasks.probes import ( PROCESSING_OVERHEAD_TIMEOUT, SSH_CONNECT_TIMEOUT, process_probes, @@ -140,7 +140,7 @@ async def test_schedules_probe_execution(self, test_db, session: AsyncSession) - processing_time = datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc) with freeze_time(processing_time): with patch( - "dstack._internal.server.background.tasks.process_probes.PROBES_SCHEDULER" + "dstack._internal.server.background.scheduled_tasks.probes.PROBES_SCHEDULER" ) as scheduler_mock: await process_probes() assert scheduler_mock.add_job.call_count == 2 @@ -210,7 +210,7 @@ async def test_deactivates_probe_when_until_ready_and_ready_after_reached( probe_regular = await create_probe(session, job, probe_num=1, success_streak=3) with patch( - "dstack._internal.server.background.tasks.process_probes.PROBES_SCHEDULER" + "dstack._internal.server.background.scheduled_tasks.probes.PROBES_SCHEDULER" ) as scheduler_mock: await process_probes() diff --git a/src/tests/_internal/server/background/tasks/test_process_prometheus_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_prometheus_metrics.py rename to src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py index 7c59b6dd1f..80961d5c10 100644 --- a/src/tests/_internal/server/background/tasks/test_process_prometheus_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py @@ -11,7 +11,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole -from dstack._internal.server.background.tasks.process_prometheus_metrics import ( +from dstack._internal.server.background.scheduled_tasks.prometheus_metrics import ( collect_prometheus_metrics, delete_prometheus_metrics, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py similarity index 99% rename from src/tests/_internal/server/background/tasks/test_process_running_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 12edeec208..0d748f4e91 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -37,7 +37,7 @@ VolumeStatus, ) from dstack._internal.server import settings as server_settings -from dstack._internal.server.background.tasks.process_running_jobs import ( +from dstack._internal.server.background.scheduled_tasks.running_jobs import ( _patch_base_image_for_aws_efa, process_running_jobs, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_runs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_runs.py index b9420d8e9a..ffb63de358 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py @@ -8,7 +8,7 @@ from pydantic import parse_obj_as from sqlalchemy.ext.asyncio import AsyncSession -import dstack._internal.server.background.tasks.process_runs as process_runs +import dstack._internal.server.background.scheduled_tasks.runs as process_runs from dstack._internal.core.models.configurations import ( ProbeConfig, ServiceConfiguration, @@ -100,7 +100,7 @@ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): expected_duration = (current_time - run.submitted_at).total_seconds() with patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics: await process_runs.process_runs() @@ -131,7 +131,7 @@ async def test_keep_provisioning(self, test_db, session: AsyncSession): await create_job(session=session, run=run, status=JobStatus.PULLING) with patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics: await process_runs.process_runs() @@ -198,7 +198,7 @@ async def test_retry_running_to_pending(self, test_db, session: AsyncSession): with ( patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics, ): datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3) @@ -297,7 +297,7 @@ async def test_submitted_to_provisioning_if_any(self, test_db, session: AsyncSes expected_duration = (current_time - run.submitted_at).total_seconds() with patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics: await process_runs.process_runs() @@ -351,7 +351,7 @@ async def test_all_no_capacity_to_pending(self, test_db, session: AsyncSession): with ( patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics, ): datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py similarity index 99% rename from src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index 8a3a4b1d57..b06eb50ec2 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -27,7 +27,7 @@ VolumeMountPoint, VolumeStatus, ) -from dstack._internal.server.background.tasks.process_submitted_jobs import ( +from dstack._internal.server.background.scheduled_tasks.submitted_jobs import ( _prepare_job_runtime_data, process_submitted_jobs, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_volumes.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_submitted_volumes.py index dfeef1e42e..8c9a6bf3cf 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_volumes.py @@ -5,7 +5,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus -from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes +from dstack._internal.server.background.scheduled_tasks.volumes import process_submitted_volumes from dstack._internal.server.testing.common import ( ComputeMockSpec, create_project, diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py similarity index 99% rename from src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py index 1d1c143d4f..d2b4d2d318 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py @@ -10,7 +10,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus, JobTerminationReason from dstack._internal.core.models.volumes import VolumeStatus -from dstack._internal.server.background.tasks.process_terminating_jobs import ( +from dstack._internal.server.background.scheduled_tasks.terminating_jobs import ( process_terminating_jobs, ) from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 0c5ca338df..25cbbead36 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -12,6 +12,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal import settings from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.configurations import ( @@ -71,7 +72,6 @@ list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str -from tests._internal.server.background.tasks.test_process_running_jobs import settings pytestmark = pytest.mark.usefixtures("image_config_mock")