diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index 0f89b770f..280ab14f1 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -17,6 +17,8 @@ class VolumeStatus(str, Enum): SUBMITTED = "submitted" + # PROVISIONING is currently not used since on all backends supporting volumes, + # volumes become ACTIVE (ready to be used) almost immediately after provisioning. PROVISIONING = "provisioning" ACTIVE = "active" FAILED = "failed" diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 01feb958d..d9f67680c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -6,6 +6,7 @@ from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) +from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -17,6 +18,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), GatewayPipeline(), PlacementGroupPipeline(), + VolumePipeline(), ] self._hinter = PipelineHinter(self._pipelines) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 938c6013c..33e839b8b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -27,6 +27,7 @@ from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group from dstack._internal.server.services.instances import emit_instance_status_change_event from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -107,6 +108,7 @@ def __init__( queue_check_delay=queue_check_delay, ) + @sentry_utils.instrument_named_task("pipeline_tasks.ComputeGroupFetcher.fetch") async def fetch(self, limit: int) -> list[PipelineItem]: compute_group_lock, _ = get_locker(get_db().dialect_name).get_lockset( ComputeGroupModel.__tablename__ @@ -172,6 +174,7 @@ def __init__( heartbeater=heartbeater, ) + @sentry_utils.instrument_named_task("pipeline_tasks.ComputeGroupWorker.process") async def process(self, item: PipelineItem): async with get_session_ctx() as session: res = await session.execute( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index c64cd719a..cdd0904e1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -35,6 +35,7 @@ from dstack._internal.server.services.gateways.pool import gateway_connections_pool from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt +from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -118,6 +119,7 @@ def __init__( queue_check_delay=queue_check_delay, ) + @sentry_utils.instrument_named_task("pipeline_tasks.GatewayFetcher.fetch") async def fetch(self, limit: int) -> list[GatewayPipelineItem]: gateway_lock, _ = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) async with gateway_lock: @@ -193,6 +195,7 @@ def __init__( heartbeater=heartbeater, ) + @sentry_utils.instrument_named_task("pipeline_tasks.GatewayWorker.process") async def process(self, item: GatewayPipelineItem): if item.to_be_deleted: await _process_to_be_deleted_item(item) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index a184379c3..193358ec0 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -26,6 +26,7 @@ from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.placement import placement_group_model_to_placement_group +from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -103,6 +104,7 @@ def __init__( queue_check_delay=queue_check_delay, ) + @sentry_utils.instrument_named_task("pipeline_tasks.PlacementGroupFetcher.fetch") async def fetch(self, limit: int) -> list[PipelineItem]: placement_group_lock, _ = get_locker(get_db().dialect_name).get_lockset( PlacementGroupModel.__tablename__ @@ -170,6 +172,7 @@ def __init__( heartbeater=heartbeater, ) + @sentry_utils.instrument_named_task("pipeline_tasks.PlacementGroupWorker.process") async def process(self, item: PipelineItem): async with get_session_ctx() as session: res = await session.execute( @@ -230,6 +233,7 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> backend_type=placement_group.provisioning_data.backend, ) if backend is None: + # TODO: Retry deletion logger.error( "Failed to delete placement group %s. Backend not available. Please delete it manually.", placement_group.name, @@ -245,6 +249,7 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> ) return {} except Exception: + # TODO: Retry deletion logger.exception( "Got exception when deleting placement group %s. Please delete it manually.", placement_group.name, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py new file mode 100644 index 000000000..578fe8423 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -0,0 +1,448 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Sequence + +from sqlalchemy import or_, select, update +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.errors import BackendError, BackendNotAvailable +from dstack._internal.core.models.volumes import VolumeStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + FleetModel, + InstanceModel, + ProjectModel, + UserModel, + VolumeAttachmentModel, + VolumeModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.volumes import ( + emit_volume_status_change_event, + volume_model_to_volume, +) +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class VolumePipelineItem(PipelineItem): + status: VolumeStatus + to_be_deleted: bool + + +class VolumePipeline(Pipeline[VolumePipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[VolumePipelineItem]( + model_type=VolumeModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = VolumeFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + VolumeWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return VolumeModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[VolumePipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[VolumePipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["VolumeWorker"]: + return self.__workers + + +class VolumeFetcher(Fetcher[VolumePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[VolumePipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[VolumePipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.VolumeFetcher.fetch") + async def fetch(self, limit: int) -> list[VolumePipelineItem]: + volume_lock, _ = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__) + async with volume_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(VolumeModel) + .where( + or_( + VolumeModel.status == VolumeStatus.SUBMITTED, + VolumeModel.to_be_deleted == True, + ), + VolumeModel.deleted == False, + or_( + VolumeModel.last_processed_at <= now - self._min_processing_interval, + VolumeModel.last_processed_at == VolumeModel.created_at, + ), + or_( + VolumeModel.lock_expires_at.is_(None), + VolumeModel.lock_expires_at < now, + ), + or_( + VolumeModel.lock_owner.is_(None), + VolumeModel.lock_owner == VolumePipeline.__name__, + ), + ) + .order_by(VolumeModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + VolumeModel.id, + VolumeModel.lock_token, + VolumeModel.lock_expires_at, + VolumeModel.status, + VolumeModel.to_be_deleted, + ) + ) + ) + volume_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for volume_model in volume_models: + prev_lock_expired = volume_model.lock_expires_at is not None + volume_model.lock_expires_at = lock_expires_at + volume_model.lock_token = lock_token + volume_model.lock_owner = VolumePipeline.__name__ + items.append( + VolumePipelineItem( + __tablename__=VolumeModel.__tablename__, + id=volume_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=volume_model.status, + to_be_deleted=volume_model.to_be_deleted, + ) + ) + await session.commit() + return items + + +class VolumeWorker(Worker[VolumePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[VolumePipelineItem], + heartbeater: Heartbeater[VolumePipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.VolumeWorker.process") + async def process(self, item: VolumePipelineItem): + if item.to_be_deleted: + await _process_to_be_deleted_item(item) + elif item.status == VolumeStatus.SUBMITTED: + await _process_submitted_item(item) + elif item.status == VolumeStatus.ACTIVE: + pass + + +async def _process_submitted_item(item: VolumePipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.id == item.id, + VolumeModel.lock_token == item.lock_token, + ) + .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(VolumeModel.user)) + .options( + joinedload(VolumeModel.attachments) + .joinedload(VolumeAttachmentModel.instance) + .joinedload(InstanceModel.fleet) + .load_only(FleetModel.name) + ) + ) + volume_model = res.unique().scalar_one_or_none() + if volume_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + result = await _process_submitted_volume(volume_model) + update_map = result.update_map | get_processed_update_map() | get_unlock_update_map() + async with get_session_ctx() as session: + res = await session.execute( + update(VolumeModel) + .where( + VolumeModel.id == volume_model.id, + VolumeModel.lock_token == volume_model.lock_token, + ) + .values(**update_map) + .returning(VolumeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + # TODO: Clean up volume. + return + emit_volume_status_change_event( + session=session, + volume_model=volume_model, + old_status=volume_model.status, + new_status=update_map.get("status", volume_model.status), + status_message=update_map.get("status_message", volume_model.status_message), + ) + + +@dataclass +class _SubmittedResult: + update_map: UpdateMap = field(default_factory=dict) + + +async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResult: + volume = volume_model_to_volume(volume_model) + try: + backend = await backends_services.get_project_backend_by_type_or_error( + project=volume_model.project, + backend_type=volume.configuration.backend, + overrides=True, + ) + except BackendNotAvailable: + logger.error( + "Failed to process volume %s. Backend %s not available.", + volume.name, + volume.configuration.backend.value, + ) + return _SubmittedResult( + update_map={ + "status": VolumeStatus.FAILED, + "status_message": "Backend not available", + } + ) + + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + try: + if volume.configuration.volume_id is not None: + logger.info("Registering external volume %s", volume_model.name) + vpd = await run_async( + compute.register_volume, + volume=volume, + ) + else: + logger.info("Provisioning new volume %s", volume_model.name) + vpd = await run_async( + compute.create_volume, + volume=volume, + ) + except BackendError as e: + logger.info("Failed to create volume %s: %s", volume_model.name, repr(e)) + status_message = f"Backend error: {repr(e)}" + if len(e.args) > 0: + status_message = str(e.args[0]) + return _SubmittedResult( + update_map={ + "status": VolumeStatus.FAILED, + "status_message": status_message, + } + ) + except Exception as e: + logger.exception("Got exception when creating volume %s", volume_model.name) + return _SubmittedResult( + update_map={ + "status": VolumeStatus.FAILED, + "status_message": f"Unexpected error: {repr(e)}", + } + ) + + logger.info("Added new volume %s", volume_model.name) + # Provisioned volumes marked as active since they become available almost immediately in AWS + # TODO: Consider checking volume state + return _SubmittedResult( + update_map={ + "status": VolumeStatus.ACTIVE, + "volume_provisioning_data": vpd.json(), + } + ) + + +async def _process_to_be_deleted_item(item: VolumePipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.id == item.id, + VolumeModel.lock_token == item.lock_token, + ) + .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(VolumeModel.user).load_only(UserModel.name)) + .options( + joinedload(VolumeModel.attachments) + .joinedload(VolumeAttachmentModel.instance) + .joinedload(InstanceModel.fleet) + .load_only(FleetModel.name) + ) + ) + volume_model = res.unique().scalar_one_or_none() + if volume_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + result = await _process_to_be_deleted_volume(volume_model) + update_map = result.update_map | get_unlock_update_map() + async with get_session_ctx() as session: + res = await session.execute( + update(VolumeModel) + .where( + VolumeModel.id == volume_model.id, + VolumeModel.lock_token == volume_model.lock_token, + ) + .values(**update_map) + .returning(VolumeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + events.emit( + session, + "Volume deleted", + actor=events.SystemActor(), + targets=[events.Target.from_model(volume_model)], + ) + + +@dataclass +class _DeletedResult: + update_map: UpdateMap = field(default_factory=dict) + + +async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedResult: + volume = volume_model_to_volume(volume_model) + if volume.external: + return _get_deleted_result() + if volume.provisioning_data is None: + # The volume wasn't provisioned so there is nothing to delete + return _get_deleted_result() + if volume.provisioning_data.backend is None: + logger.error( + f"Failed to delete volume {volume_model.name}. volume.provisioning_data.backend is None." + ) + return _get_deleted_result() + try: + backend = await backends_services.get_project_backend_by_type_or_error( + project=volume_model.project, + backend_type=volume.provisioning_data.backend, + ) + except BackendNotAvailable: + # TODO: Retry deletion + logger.error( + f"Failed to delete volume {volume_model.name}. Backend {volume.configuration.backend} not available." + " Please terminate it manually to avoid unexpected charges.", + ) + return _get_deleted_result() + + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + try: + await run_async( + compute.delete_volume, + volume=volume, + ) + except Exception: + # TODO: Retry deletion + logger.exception( + "Got exception when deleting volume %s. Please terminate it manually to avoid unexpected charges.", + volume.name, + ) + return _get_deleted_result() + + +def _get_deleted_result() -> _DeletedResult: + now = get_current_datetime() + return _DeletedResult( + update_map={ + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + } + ) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 6067d9d4d..45ae8ec7f 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -99,9 +99,6 @@ def start_scheduled_tasks() -> AsyncIOScheduler: ) _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job( - process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 - ) _scheduler.add_job( process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) @@ -116,6 +113,9 @@ def start_scheduled_tasks() -> AsyncIOScheduler: process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) + _scheduler.add_job( + process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 + ) for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): # Add multiple copies of tasks if requested. # max_instances=1 for additional copies to avoid running too many tasks. diff --git a/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py index 6b449efab..feb1cc507 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py @@ -39,7 +39,7 @@ async def process_compute_groups(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def _process_next_compute_group(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(ComputeGroupModel.__tablename__) async with get_session_ctx() as session: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/events.py b/src/dstack/_internal/server/background/scheduled_tasks/events.py index 22df5bcf3..1fbf60217 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/events.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/events.py @@ -9,7 +9,7 @@ from dstack._internal.utils.common import get_current_datetime -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def delete_events(): cutoff = get_current_datetime() - timedelta(seconds=settings.SERVER_EVENTS_TTL_SECONDS) stmt = delete(EventModel).where(EventModel.recorded_at < cutoff) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py index 50c3dcfe2..a758f86ad 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py @@ -39,7 +39,7 @@ MIN_PROCESSING_INTERVAL = timedelta(seconds=30) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def process_fleets(): fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset( FleetModel.__tablename__ diff --git a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py index 3b6bee012..fc12e8e3b 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py @@ -35,7 +35,7 @@ async def process_gateways_connections(): await _process_active_connections() -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def process_gateways(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) async with get_session_ctx() as session: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py index cd5b66bc7..c77039013 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py @@ -19,6 +19,7 @@ volume_model_to_volume, ) from dstack._internal.server.utils import sentry_utils +from dstack._internal.settings import FeatureFlags from dstack._internal.utils import common from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger @@ -26,7 +27,7 @@ logger = get_logger(__name__) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def process_idle_volumes(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__) async with get_session_ctx() as session: @@ -35,7 +36,9 @@ async def process_idle_volumes(): select(VolumeModel.id) .where( VolumeModel.status == VolumeStatus.ACTIVE, + VolumeModel.auto_cleanup_enabled.is_not(False), VolumeModel.deleted == False, + VolumeModel.lock_expires_at.is_(None), VolumeModel.id.not_in(lockset), ) .order_by(VolumeModel.last_processed_at.asc()) @@ -90,23 +93,31 @@ def _get_idle_time(volume: VolumeModel) -> datetime.timedelta: async def _delete_idle_volumes(session: AsyncSession, volumes: List[VolumeModel]): - # Note: Multiple volumes are deleted in the same transaction, - # so long deletion of one volume may block processing other volumes. for volume_model in volumes: logger.info("Deleting idle volume %s", volume_model.name) - try: - await _delete_idle_volume(session, volume_model) - except Exception: - logger.exception("Error when deleting idle volume %s", volume_model.name) - - volume_model.deleted = True - volume_model.deleted_at = get_current_datetime() - events.emit( - session=session, - message="Volume deleted due to exceeding auto_cleanup_duration", - actor=events.SystemActor(), - targets=[events.Target.from_model(volume_model)], - ) + if FeatureFlags.PIPELINE_PROCESSING_ENABLED: + volume_model.to_be_deleted = True + events.emit( + session=session, + message="Volume marked for deletion due to exceeding auto_cleanup_duration", + actor=events.SystemActor(), + targets=[events.Target.from_model(volume_model)], + ) + else: + try: + # Note: Multiple volumes are deleted in the same transaction, + # so long deletion of one volume may block processing other volumes. + await _delete_idle_volume(session, volume_model) + except Exception: + logger.exception("Error when deleting idle volume %s", volume_model.name) + volume_model.deleted = True + volume_model.deleted_at = get_current_datetime() + events.emit( + session=session, + message="Volume deleted due to exceeding auto_cleanup_duration", + actor=events.SystemActor(), + targets=[events.Target.from_model(volume_model)], + ) await session.commit() diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index 196f347c4..e5ecba527 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -152,7 +152,7 @@ async def process_instances(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def delete_instance_health_checks(): now = get_current_datetime() cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS) @@ -163,7 +163,7 @@ async def delete_instance_health_checks(): await session.commit() -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def _process_next_instance(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) async with get_session_ctx() as session: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py index ca2d25fe5..f75c5f3ea 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py @@ -27,7 +27,7 @@ MIN_COLLECT_INTERVAL_SECONDS = 9 -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def collect_metrics(): async with get_session_ctx() as session: res = await session.execute( @@ -47,7 +47,7 @@ async def collect_metrics(): await _collect_jobs_metrics(batch) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def delete_metrics(): now_timestamp_micro = int(get_current_datetime().timestamp() * 1_000_000) running_timestamp_micro_cutoff = ( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py index 1f6130001..71ab51b07 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py @@ -19,7 +19,7 @@ logger = get_logger(__name__) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def process_placement_groups(): lock, lockset = get_locker(get_db().dialect_name).get_lockset( PlacementGroupModel.__tablename__ diff --git a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py index 2f8bf7214..5b039fe2e 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py @@ -35,7 +35,7 @@ METRICS_TTL_SECONDS = 600 -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def collect_prometheus_metrics(): now = get_current_datetime() cutoff = now - timedelta(seconds=MIN_COLLECT_INTERVAL_SECONDS) @@ -63,7 +63,7 @@ async def collect_prometheus_metrics(): await _collect_jobs_metrics(batch, now) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def delete_prometheus_metrics(): now = get_current_datetime() cutoff = now - timedelta(seconds=METRICS_TTL_SECONDS) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index f413edf44..9d3bd04c3 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -103,7 +103,7 @@ async def process_running_jobs(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def _process_next_running_job(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) async with get_session_ctx() as session: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index e9421e5cb..e0c6793ce 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -76,7 +76,7 @@ async def process_runs(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def _process_next_run(): run_lock, run_lockset = get_locker(get_db().dialect_name).get_lockset(RunModel.__tablename__) job_lock, job_lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 79746e933..5d1b2e1a7 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -164,7 +164,7 @@ def _get_effective_batch_size(batch_size: int) -> int: return batch_size -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def _process_next_submitted_job(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) async with get_session_ctx() as session: @@ -1042,10 +1042,15 @@ async def _attach_volumes( ) job_runtime_data.volume_names.append(volume.name) break # attach next mount point - except (ServerClientError, BackendError) as e: - logger.warning("%s: failed to attached volume: %s", fmt(job_model), repr(e)) + except ServerClientError as e: + logger.info("%s: failed to attach volume: %s", fmt(job_model), repr(e)) job_model.termination_reason = JobTerminationReason.VOLUME_ERROR - job_model.termination_reason_message = "Failed to attach volume" + job_model.termination_reason_message = f"Failed to attach volume: {e.msg}" + switch_job_status(session, job_model, JobStatus.TERMINATING) + except BackendError as e: + logger.warning("%s: failed to attach volume: %s", fmt(job_model), repr(e)) + job_model.termination_reason = JobTerminationReason.VOLUME_ERROR + job_model.termination_reason_message = f"Failed to attach volume: {str(e)}" switch_job_status(session, job_model, JobStatus.TERMINATING) except Exception: logger.exception( @@ -1053,7 +1058,7 @@ async def _attach_volumes( fmt(job_model), ) job_model.termination_reason = JobTerminationReason.VOLUME_ERROR - job_model.termination_reason_message = "Failed to attach volume" + job_model.termination_reason_message = "Failed to attach volume: unexpected error" switch_job_status(session, job_model, JobStatus.TERMINATING) finally: job_model.job_runtime_data = job_runtime_data.json() @@ -1069,10 +1074,14 @@ async def _attach_volume( compute = backend.compute() assert isinstance(compute, ComputeWithVolumeSupport) volume = volume_model_to_volume(volume_model) - # Refresh only to check if the volume wasn't deleted before the lock + # Refresh only to check if the volume wasn't deleted or marked for deletion before the lock await session.refresh(volume_model) if volume_model.deleted: raise ServerClientError("Cannot attach a deleted volume") + if volume_model.to_be_deleted: + raise ServerClientError("Cannot attach a volume marked for deletion") + if volume_model.lock_expires_at is not None: + raise ServerClientError("Cannot attach a volume locked for processing") attachment_data = await common_utils.run_async( compute.attach_volume, volume=volume, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 6a358dcd6..3749076c1 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -35,7 +35,7 @@ async def process_terminating_jobs(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def _process_next_terminating_job(): job_lock, job_lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py index 66124619a..a61f79694 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/volumes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py @@ -24,7 +24,7 @@ logger = get_logger(__name__) -@sentry_utils.instrument_background_task +@sentry_utils.instrument_scheduled_task async def process_submitted_volumes(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__) async with get_session_ctx() as session: @@ -33,6 +33,7 @@ async def process_submitted_volumes(): select(VolumeModel) .where( VolumeModel.status == VolumeStatus.SUBMITTED, + VolumeModel.deleted == False, VolumeModel.id.not_in(lockset), ) .order_by(VolumeModel.last_processed_at.asc()) diff --git a/src/dstack/_internal/server/migrations/versions/2026/02_23_1134_ccfac6ac7924_add_volumemodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/02_23_1134_ccfac6ac7924_add_volumemodel_pipeline_columns.py new file mode 100644 index 000000000..4034f227e --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/02_23_1134_ccfac6ac7924_add_volumemodel_pipeline_columns.py @@ -0,0 +1,53 @@ +"""Add VolumeModel pipeline columns + +Revision ID: ccfac6ac7924 +Revises: 140331002ece +Create Date: 2026-02-23 11:34:24.731339+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "ccfac6ac7924" +down_revision = "140331002ece" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.add_column( + sa.Column("to_be_deleted", sa.Boolean(), server_default=sa.false(), nullable=False) + ) + batch_op.add_column(sa.Column("auto_cleanup_enabled", sa.Boolean(), nullable=True)) + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + batch_op.drop_column("auto_cleanup_enabled") + batch_op.drop_column("to_be_deleted") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/02_24_0945_9a363c3cbe04_add_ix_volumes_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/02_24_0945_9a363c3cbe04_add_ix_volumes_pipeline_fetch_q_index.py new file mode 100644 index 000000000..1d729dbbd --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/02_24_0945_9a363c3cbe04_add_ix_volumes_pipeline_fetch_q_index.py @@ -0,0 +1,50 @@ +"""Add ix_volumes_pipeline_fetch_q index + +Revision ID: 9a363c3cbe04 +Revises: ccfac6ac7924 +Create Date: 2026-02-24 09:45:54.068288+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9a363c3cbe04" +down_revision = "ccfac6ac7924" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_volumes_pipeline_fetch_q", + table_name="volumes", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_volumes_pipeline_fetch_q", + "volumes", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("deleted = 0"), + postgresql_where=sa.text("deleted IS FALSE"), + postgresql_concurrently=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_volumes_pipeline_fetch_q", + table_name="volumes", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index df9cf8607..a7a8ec0bd 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -732,7 +732,7 @@ class InstanceHealthCheckModel(BaseModel): response: Mapped[str] = mapped_column(Text) -class VolumeModel(BaseModel): +class VolumeModel(PipelineModelMixin, BaseModel): __tablename__ = "volumes" id: Mapped[uuid.UUID] = mapped_column( @@ -753,6 +753,7 @@ class VolumeModel(BaseModel): last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + to_be_deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) # NOTE: `status` must be changed only via `switch_volume_status()` status: Mapped[VolumeStatus] = mapped_column(EnumAsString(VolumeStatus, 100), index=True) @@ -760,12 +761,23 @@ class VolumeModel(BaseModel): configuration: Mapped[str] = mapped_column(Text) volume_provisioning_data: Mapped[Optional[str]] = mapped_column(Text) + # auto_cleanup_enabled is set for all new models but old models may not have it. + auto_cleanup_enabled: Mapped[Optional[bool]] = mapped_column(Boolean) attachments: Mapped[List["VolumeAttachmentModel"]] = relationship(back_populates="volume") # Deprecated in favor of VolumeAttachmentModel.attachment_data volume_attachment_data: Mapped[Optional[str]] = mapped_column(Text) + __table_args__ = ( + Index( + "ix_volumes_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + class VolumeAttachmentModel(BaseModel): __tablename__ = "volumes_attachments" diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py index ead5465c4..fccb2fd47 100644 --- a/src/dstack/_internal/server/routers/volumes.py +++ b/src/dstack/_internal/server/routers/volumes.py @@ -15,6 +15,7 @@ ListVolumesRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember +from dstack._internal.server.services.pipelines import PipelineHinterProtocol, get_pipeline_hinter from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -92,6 +93,7 @@ async def create_volume( body: CreateVolumeRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), ): """ Creates a volume given a volume configuration. @@ -103,6 +105,7 @@ async def create_volume( project=project, user=user, configuration=body.configuration, + pipeline_hinter=pipeline_hinter, ) ) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 2ddadbfb1..eb10bda5c 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -720,6 +720,10 @@ async def get_job_configured_volume_models( ) if volume_model is None: raise ResourceNotExistsError(f"Volume {mount_point.name} not found") + if volume_model.to_be_deleted: + raise ServerClientError( + f"Volume {mount_point.name} is marked for deletion and cannot be attached" + ) mount_point_volume_models.append(volume_model) volume_models.append(mount_point_volume_models) return volume_models @@ -729,7 +733,7 @@ def check_can_attach_job_volumes(volumes: List[List[Volume]]): """ Performs basic checks if volumes can be attached. This is useful to show error ASAP (when user submits the run). - If the attachment is to fail anyway, the error will be handled when proccessing submitted jobs. + If the attachment is to fail anyway, the error will be handled when processing submitted jobs. """ if len(volumes) == 0: return diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 49a3d7959..f0d2fc703 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -13,6 +13,7 @@ ResourceExistsError, ServerClientError, ) +from dstack._internal.core.models.profiles import parse_duration from dstack._internal.core.models.volumes import ( Volume, VolumeAttachment, @@ -39,8 +40,10 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import list_user_project_models +from dstack._internal.settings import FeatureFlags from dstack._internal.utils import common, random_names from dstack._internal.utils.logging import get_logger @@ -58,13 +61,45 @@ def switch_volume_status( return volume_model.status = new_status + emit_volume_status_change_event( + session=session, + volume_model=volume_model, + old_status=old_status, + new_status=new_status, + status_message=volume_model.status_message, + actor=actor, + ) - msg = f"Volume status changed {old_status.upper()} -> {new_status.upper()}" - if volume_model.status_message is not None: - msg += f" ({volume_model.status_message})" + +def emit_volume_status_change_event( + session: AsyncSession, + volume_model: VolumeModel, + old_status: VolumeStatus, + new_status: VolumeStatus, + status_message: Optional[str], + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + msg = get_volume_status_change_message( + old_status=old_status, + new_status=new_status, + status_message=status_message, + ) events.emit(session, msg, actor=actor, targets=[events.Target.from_model(volume_model)]) +def get_volume_status_change_message( + old_status: VolumeStatus, + new_status: VolumeStatus, + status_message: Optional[str], +) -> str: + msg = f"Volume status changed {old_status.upper()} -> {new_status.upper()}" + if status_message is not None: + msg += f" ({status_message})" + return msg + + async def list_volumes( session: AsyncSession, user: UserModel, @@ -223,6 +258,7 @@ async def create_volume( project: ProjectModel, user: UserModel, configuration: VolumeConfiguration, + pipeline_hinter: PipelineHinterProtocol, ) -> Volume: spec = await apply_plugin_policies( user=user.name, @@ -254,6 +290,7 @@ async def create_volume( else: configuration.name = await generate_volume_name(session=session, project=project) + now = common.get_current_datetime() volume_model = VolumeModel( id=uuid.uuid4(), name=configuration.name, @@ -261,7 +298,10 @@ async def create_volume( project=project, status=VolumeStatus.SUBMITTED, configuration=configuration.json(), + auto_cleanup_enabled=_get_autocleanup_enabled(configuration), attachments=[], + created_at=now, + last_processed_at=now, ) session.add(volume_model) events.emit( @@ -271,11 +311,88 @@ async def create_volume( targets=[events.Target.from_model(volume_model)], ) await session.commit() + pipeline_hinter.hint_fetch(VolumeModel.__name__) return volume_model_to_volume(volume_model) async def delete_volumes( session: AsyncSession, project: ProjectModel, names: List[str], user: UserModel +): + # Keep both delete code paths while pipeline processing is behind a feature flag: + # - pipeline path marks volumes for async deletion by VolumePipeline + # - sync path deletes volume inline for non-pipeline processing + # TODO: Drop sync path after pipeline processing is enabled by default. + if FeatureFlags.PIPELINE_PROCESSING_ENABLED: + await _delete_volumes_pipeline( + session=session, + project=project, + names=names, + user=user, + ) + else: + await _delete_volumes_sync( + session=session, + project=project, + names=names, + user=user, + ) + + +async def _delete_volumes_pipeline( + session: AsyncSession, project: ProjectModel, names: List[str], user: UserModel +): + res = await session.execute( + select(VolumeModel).where( + VolumeModel.project_id == project.id, + VolumeModel.name.in_(names), + VolumeModel.deleted == False, + ) + ) + volume_models = res.scalars().all() + volumes_ids = sorted([v.id for v in volume_models]) + await session.commit() + logger.info("Deleting volumes: %s", [v.name for v in volume_models]) + async with get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids): + # Refetch after lock + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.project_id == project.id, + VolumeModel.id.in_(volumes_ids), + VolumeModel.deleted == False, + VolumeModel.lock_expires_at.is_(None), + ) + .options(selectinload(VolumeModel.attachments)) + .execution_options(populate_existing=True) + .order_by(VolumeModel.id) # take locks in order + .with_for_update(key_share=True, of=VolumeModel) + ) + volume_models = res.scalars().unique().all() + if len(volume_models) != len(volumes_ids): + # TODO: Make the delete endpoint fully async so we don't need to lock and error: + # put the request in queue and process in the background. + raise ServerClientError( + "Failed to delete volumes: volumes are being processed currently. Try again later." + ) + for volume_model in volume_models: + if len(volume_model.attachments) > 0: + raise ServerClientError( + f"Failed to delete volume {volume_model.name}. Volume is in use." + ) + for volume_model in volume_models: + if not volume_model.to_be_deleted: + volume_model.to_be_deleted = True + events.emit( + session, + message="Volume marked for deletion", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(volume_model)], + ) + await session.commit() + + +async def _delete_volumes_sync( + session: AsyncSession, project: ProjectModel, names: List[str], user: UserModel ): res = await session.execute( select(VolumeModel).where( @@ -494,3 +611,8 @@ def _get_volume_cost(volume: Volume) -> float: * volume.provisioning_data.price / _VOLUME_PRICING_PERIOD.total_seconds() ) + + +def _get_autocleanup_enabled(configuration: VolumeConfiguration) -> bool: + auto_cleanup_duration = parse_duration(configuration.auto_cleanup_duration) + return auto_cleanup_duration is not None and auto_cleanup_duration > 0 diff --git a/src/dstack/_internal/server/utils/sentry_utils.py b/src/dstack/_internal/server/utils/sentry_utils.py index 8dd7326b7..a99173b08 100644 --- a/src/dstack/_internal/server/utils/sentry_utils.py +++ b/src/dstack/_internal/server/utils/sentry_utils.py @@ -6,15 +6,29 @@ from sentry_sdk.types import Event, Hint -def instrument_background_task(f): +def instrument_scheduled_task(f): @functools.wraps(f) async def wrapper(*args, **kwargs): - with sentry_sdk.start_transaction(name=f"background.{f.__name__}"): - return await f(*args, **kwargs) + with sentry_sdk.isolation_scope(): + with sentry_sdk.start_transaction(name=f"scheduled_tasks.{f.__name__}"): + return await f(*args, **kwargs) return wrapper +def instrument_named_task(name: str): + def decorator(f): + @functools.wraps(f) + async def wrapper(*args, **kwargs): + with sentry_sdk.isolation_scope(): + with sentry_sdk.start_transaction(name=name): + return await f(*args, **kwargs) + + return wrapper + + return decorator + + class AsyncioCancelledErrorFilterEventProcessor: # See https://docs.sentry.io/platforms/python/configuration/filtering/#filtering-error-events def __call__(self, event: Event, hint: Hint) -> Optional[Event]: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py new file mode 100644 index 000000000..4d22c59b9 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py @@ -0,0 +1,336 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError, BackendNotAvailable +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus +from dstack._internal.server.background.pipeline_tasks.volumes import ( + VolumePipelineItem, + VolumeWorker, +) +from dstack._internal.server.models import VolumeModel +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_project, + create_user, + create_volume, + get_volume_configuration, + get_volume_provisioning_data, + list_events, +) + + +@pytest.fixture +def worker() -> VolumeWorker: + return VolumeWorker(queue=Mock(), heartbeater=Mock()) + + +def _volume_to_pipeline_item(volume_model: VolumeModel) -> VolumePipelineItem: + assert volume_model.lock_token is not None + assert volume_model.lock_expires_at is not None + return VolumePipelineItem( + __tablename__=volume_model.__tablename__, + id=volume_model.id, + lock_token=volume_model.lock_token, + lock_expires_at=volume_model.lock_expires_at, + prev_lock_expired=False, + status=volume_model.status, + to_be_deleted=volume_model.to_be_deleted, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestVolumeWorkerSubmitted: + async def test_submitted_to_active(self, test_db, session: AsyncSession, worker: VolumeWorker): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.create_volume.return_value = VolumeProvisioningData( + backend=BackendType.AWS, + volume_id="vol-1234", + size_gb=100, + ) + get_backend_mock.return_value = backend_mock + + await worker.process(_volume_to_pipeline_item(volume)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.create_volume.assert_called_once() + backend_mock.compute.return_value.register_volume.assert_not_called() + + await session.refresh(volume) + assert volume.status == VolumeStatus.ACTIVE + assert volume.volume_provisioning_data is not None + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume status changed SUBMITTED -> ACTIVE" + + async def test_registers_external_volume( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + configuration=get_volume_configuration(volume_id="vol-external-123"), + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.register_volume.return_value = ( + VolumeProvisioningData( + backend=BackendType.AWS, + volume_id="vol-external-123", + size_gb=100, + ) + ) + get_backend_mock.return_value = backend_mock + + await worker.process(_volume_to_pipeline_item(volume)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.register_volume.assert_called_once() + backend_mock.compute.return_value.create_volume.assert_not_called() + + await session.refresh(volume) + assert volume.status == VolumeStatus.ACTIVE + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume status changed SUBMITTED -> ACTIVE" + + async def test_marks_volume_failed_if_backend_not_available( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + get_backend_mock.side_effect = BackendNotAvailable() + await worker.process(_volume_to_pipeline_item(volume)) + get_backend_mock.assert_called_once() + + await session.refresh(volume) + assert volume.status == VolumeStatus.FAILED + assert volume.status_message == "Backend not available" + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message + == "Volume status changed SUBMITTED -> FAILED (Backend not available)" + ) + + async def test_marks_volume_failed_if_backend_returns_error( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.create_volume.side_effect = BackendError( + "Some error" + ) + get_backend_mock.return_value = backend_mock + + await worker.process(_volume_to_pipeline_item(volume)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.create_volume.assert_called_once() + + await session.refresh(volume) + assert volume.status == VolumeStatus.FAILED + assert volume.status_message == "Some error" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume status changed SUBMITTED -> FAILED (Some error)" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestVolumeWorkerDeleted: + async def test_marks_volume_deleted( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + volume_provisioning_data=get_volume_provisioning_data(backend=BackendType.AWS), + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + volume.to_be_deleted = True + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + get_backend_mock.return_value = backend_mock + + await worker.process(_volume_to_pipeline_item(volume)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.delete_volume.assert_called_once() + + await session.refresh(volume) + assert volume.deleted is True + assert volume.deleted_at is not None + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume deleted" + + async def test_marks_external_volume_deleted_without_backend_call( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + configuration=get_volume_configuration(volume_id="vol-external-123"), + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + volume.to_be_deleted = True + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + await worker.process(_volume_to_pipeline_item(volume)) + get_backend_mock.assert_not_called() + + await session.refresh(volume) + assert volume.deleted is True + assert volume.deleted_at is not None + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume deleted" + + async def test_marks_volume_deleted_if_backend_not_available( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + volume_provisioning_data=get_volume_provisioning_data(backend=BackendType.AWS), + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + volume.to_be_deleted = True + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + get_backend_mock.side_effect = BackendNotAvailable() + await worker.process(_volume_to_pipeline_item(volume)) + get_backend_mock.assert_called_once() + + await session.refresh(volume) + assert volume.deleted is True + assert volume.deleted_at is not None + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume deleted" + + async def test_marks_volume_deleted_if_backend_delete_errors( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + volume_provisioning_data=get_volume_provisioning_data(backend=BackendType.AWS), + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + volume.to_be_deleted = True + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes.backends_services.get_project_backend_by_type_or_error" + ) as get_backend_mock: + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.delete_volume.side_effect = BackendError( + "Delete failed" + ) + get_backend_mock.return_value = backend_mock + + await worker.process(_volume_to_pipeline_item(volume)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.delete_volume.assert_called_once() + + await session.refresh(volume) + assert volume.deleted is True + assert volume.deleted_at is not None + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Volume deleted" diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py index 6a7acf0c4..5bf844fee 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py @@ -22,12 +22,25 @@ get_volume_provisioning_data, list_events, ) +from dstack._internal.settings import FeatureFlags from dstack._internal.utils.common import get_current_datetime +@pytest.fixture +def patch_pipeline_processing_flag(monkeypatch: pytest.MonkeyPatch): + def _apply(enabled: bool): + monkeypatch.setattr(FeatureFlags, "PIPELINE_PROCESSING_ENABLED", enabled) + + return _apply + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -class TestProcessIdleVolumes: +class TestProcessIdleVolumesScheduledTask: + @pytest.fixture(autouse=True) + def _patch_feature_flag(self, patch_pipeline_processing_flag): + patch_pipeline_processing_flag(False) + async def test_deletes_idle_volumes(self, test_db, session: AsyncSession): project = await create_project(session=session) user = await create_user(session=session) @@ -71,17 +84,169 @@ async def test_deletes_idle_volumes(self, test_db, session: AsyncSession): m.return_value = aws_mock aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) await process_idle_volumes() + m.assert_called_once() await session.refresh(volume1) await session.refresh(volume2) events = await list_events(session) + assert not volume1.to_be_deleted assert volume1.deleted assert volume1.deleted_at is not None + assert not volume2.to_be_deleted assert not volume2.deleted assert volume2.deleted_at is None assert len(events) == 1 assert events[0].message == "Volume deleted due to exceeding auto_cleanup_duration" + async def test_deletes_idle_volume_with_null_auto_cleanup_enabled( + self, test_db, session: AsyncSession + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=get_volume_configuration( + name="test-volume", + auto_cleanup_duration="1h", + ), + volume_provisioning_data=get_volume_provisioning_data(), + last_job_processed_at=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=2), + ) + volume.auto_cleanup_enabled = None + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + await process_idle_volumes() + m.assert_called_once() + + await session.refresh(volume) + events = await list_events(session) + assert not volume.to_be_deleted + assert volume.deleted + assert volume.deleted_at is not None + assert len(events) == 1 + assert events[0].message == "Volume deleted due to exceeding auto_cleanup_duration" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestProcessIdleVolumesPipelineTask: + @pytest.fixture(autouse=True) + def _patch_feature_flag(self, patch_pipeline_processing_flag): + patch_pipeline_processing_flag(True) + + async def test_deletes_idle_volumes(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + config1 = get_volume_configuration( + name="test-volume", + auto_cleanup_duration="1h", + ) + config2 = get_volume_configuration( + name="test-volume", + auto_cleanup_duration="3h", + ) + volume1 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config1, + volume_provisioning_data=get_volume_provisioning_data(), + last_job_processed_at=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=2), + ) + volume2 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config2, + volume_provisioning_data=get_volume_provisioning_data(), + last_job_processed_at=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=2), + ) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + await process_idle_volumes() + m.assert_not_called() + + await session.refresh(volume1) + await session.refresh(volume2) + events = await list_events(session) + assert volume1.to_be_deleted + assert not volume1.deleted + assert volume1.deleted_at is None + assert not volume2.to_be_deleted + assert not volume2.deleted + assert volume2.deleted_at is None + assert len(events) == 1 + assert ( + events[0].message + == "Volume marked for deletion due to exceeding auto_cleanup_duration" + ) + + async def test_deletes_idle_volume_with_null_auto_cleanup_enabled( + self, test_db, session: AsyncSession + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=get_volume_configuration( + name="test-volume", + auto_cleanup_duration="1h", + ), + volume_provisioning_data=get_volume_provisioning_data(), + last_job_processed_at=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=2), + ) + volume.auto_cleanup_enabled = None + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + await process_idle_volumes() + m.assert_not_called() + + await session.refresh(volume) + events = await list_events(session) + assert volume.to_be_deleted + assert not volume.deleted + assert volume.deleted_at is None + assert len(events) == 1 + assert ( + events[0].message + == "Volume marked for deletion due to exceeding auto_cleanup_duration" + ) + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index b06eb50ec..f33f608c7 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -492,6 +492,79 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn ) assert job.instance.volume_attachments[0].volume == volume + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_fails_job_when_attaching_volume_marked_for_deletion( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + volume_provisioning_data=get_volume_provisioning_data(), + backend=BackendType.AWS, + region="us-east-1", + ) + volume.to_be_deleted = True + await session.commit() + fleet = await create_fleet(session=session, project=project) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + backend=BackendType.AWS, + region="us-east-1", + ) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.volumes = [VolumeMountPoint(name=volume.name, path="/volume")] + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.attach_volume.return_value = VolumeAttachmentData() + # Submitted jobs processing happens in two steps + await process_submitted_jobs() + await process_submitted_jobs() + backend_mock.compute.return_value.attach_volume.assert_not_called() + + await session.refresh(job) + res = await session.execute( + select(JobModel).options( + joinedload(JobModel.instance) + .joinedload(InstanceModel.volume_attachments) + .joinedload(VolumeAttachmentModel.volume) + ) + ) + job = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.VOLUME_ERROR + assert job.termination_reason_message is not None + assert "marked for deletion and cannot be attached" in job.termination_reason_message + assert job.instance is None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSession): diff --git a/src/tests/_internal/server/services/test_volumes.py b/src/tests/_internal/server/services/test_volumes.py index 4de9c3f05..6bfb9bae6 100644 --- a/src/tests/_internal/server/services/test_volumes.py +++ b/src/tests/_internal/server/services/test_volumes.py @@ -10,7 +10,10 @@ _get_volume_cost, _validate_volume_configuration, ) -from dstack._internal.server.testing.common import get_volume, get_volume_provisioning_data +from dstack._internal.server.testing.common import ( + get_volume, + get_volume_provisioning_data, +) class TestValidateVolumeConfiguration: