diff --git a/src/dstack/_internal/proxy/gateway/const.py b/src/dstack/_internal/proxy/gateway/const.py index cb172f73f..7b958030a 100644 --- a/src/dstack/_internal/proxy/gateway/const.py +++ b/src/dstack/_internal/proxy/gateway/const.py @@ -5,3 +5,4 @@ DSTACK_DIR_ON_GATEWAY = Path("/home/ubuntu/dstack") SERVER_CONNECTIONS_DIR_ON_GATEWAY = DSTACK_DIR_ON_GATEWAY / "server-connections" PROXY_PORT_ON_GATEWAY = 8000 +SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE = "Service {ref} is already registered" diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index dc6407d24..adebe6f41 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -8,6 +8,7 @@ from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway import models as gateway_models +from dstack._internal.proxy.gateway.const import SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.services.nginx import ( LimitReqConfig, @@ -63,7 +64,7 @@ async def register_service( async with lock: if await repo.get_service(project_name, run_name) is not None: - raise ProxyError(f"Service {service.fmt()} is already registered") + raise ProxyError(SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format(ref=service.fmt())) old_project = await repo.get_project(project_name) new_project = models.Project(name=project_name, ssh_private_key=ssh_private_key) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index b701b822b..8dba43ea8 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -5,6 +5,7 @@ import json import uuid from datetime import datetime +from functools import partial from typing import Optional import httpx @@ -33,6 +34,7 @@ ) from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.core.models.services import OpenAIChatModel +from dstack._internal.proxy.gateway.const import SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services import events @@ -177,7 +179,8 @@ async def _register_service_in_gateway( try: logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) async with conn.client() as client: - await client.register_service( + do_register = partial( + client.register_service, project=run_model.project.name, run_name=run_model.run_name, domain=domain, @@ -190,6 +193,26 @@ async def _register_service_in_gateway( ssh_private_key=run_model.project.ssh_private_key, router=router, ) + try: + await do_register() + except GatewayError as e: + if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format( + ref=f"{run_model.project.name}/{run_model.run_name}" + ): + # Happens if there was a communication issue with the gateway when last unregistering + logger.warning( + "Service %s/%s is dangling on gateway %s, unregistering and re-registering", + run_model.project.name, + run_model.run_name, + gateway.name, + ) + await client.unregister_service( + project=run_model.project.name, + run_name=run_model.run_name, + ) + await do_register() + else: + raise except SSHError: raise ServerClientError("Gateway tunnel is not working") except httpx.RequestError as e: diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 25cbbead3..1f6b1ebf3 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal import settings +from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.configurations import ( @@ -2299,13 +2300,13 @@ async def test_returns_400_if_runs_active( @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestSubmitService: @pytest.fixture(autouse=True) - def mock_gateway_connections(self) -> Generator[None, None, None]: + def mock_gateway_connection(self) -> Generator[AsyncMock, None, None]: with patch( "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" ) as get_conn_mock: get_conn_mock.return_value.client = Mock() get_conn_mock.return_value.client.return_value = AsyncMock() - yield + yield get_conn_mock @pytest.mark.asyncio @pytest.mark.parametrize( @@ -2481,3 +2482,54 @@ async def test_return_error_if_specified_gateway_is_true_and_no_gateway_exists( } ] } + + @pytest.mark.asyncio + async def test_unregister_dangling_service( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + mock_gateway_connection: AsyncMock, + ) -> None: + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user, name="test-project") + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + wildcard_domain="example.com", + ) + project.default_gateway_id = gateway.id + await session.commit() + + client_mock = ( + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value + ) + client_mock.register_service.side_effect = [ + GatewayError("Service test-project/test-service is already registered"), + None, # Second call succeeds + ] + + response = await client.post( + "/api/project/test-project/runs/submit", + headers=get_auth_headers(user.token), + json={"run_spec": get_service_run_spec(repo_id=repo.name, run_name="test-service")}, + ) + + assert response.status_code == 200 + assert response.json()["service"]["url"] == "https://test-service.example.com" + # Verify that unregister_service was called to clean up the dangling service + client_mock.unregister_service.assert_called_once_with( + project=project.name, + run_name="test-service", + ) + # Verify that register_service was called twice (first failed, then succeeded) + assert client_mock.register_service.call_count == 2