Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DSTACK_DIR_ON_GATEWAY = Path("/home/ubuntu/dstack")
SERVER_CONNECTIONS_DIR_ON_GATEWAY = DSTACK_DIR_ON_GATEWAY / "server-connections"
PROXY_PORT_ON_GATEWAY = 8000
SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE = "Service {ref} is already registered"
3 changes: 2 additions & 1 deletion src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType
from dstack._internal.proxy.gateway import models as gateway_models
from dstack._internal.proxy.gateway.const import SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.services.nginx import (
LimitReqConfig,
Expand Down Expand Up @@ -63,7 +64,7 @@ async def register_service(

async with lock:
if await repo.get_service(project_name, run_name) is not None:
raise ProxyError(f"Service {service.fmt()} is already registered")
raise ProxyError(SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format(ref=service.fmt()))

old_project = await repo.get_project(project_name)
new_project = models.Project(name=project_name, ssh_private_key=ssh_private_key)
Expand Down
25 changes: 24 additions & 1 deletion src/dstack/_internal/server/services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import uuid
from datetime import datetime
from functools import partial
from typing import Optional

import httpx
Expand Down Expand Up @@ -33,6 +34,7 @@
)
from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec
from dstack._internal.core.models.services import OpenAIChatModel
from dstack._internal.proxy.gateway.const import SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE
from dstack._internal.server import settings
from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services import events
Expand Down Expand Up @@ -177,7 +179,8 @@ async def _register_service_in_gateway(
try:
logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url)
async with conn.client() as client:
await client.register_service(
do_register = partial(
client.register_service,
project=run_model.project.name,
run_name=run_model.run_name,
domain=domain,
Expand All @@ -190,6 +193,26 @@ async def _register_service_in_gateway(
ssh_private_key=run_model.project.ssh_private_key,
router=router,
)
try:
await do_register()
except GatewayError as e:
if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format(
ref=f"{run_model.project.name}/{run_model.run_name}"
):
# Happens if there was a communication issue with the gateway when last unregistering
logger.warning(
"Service %s/%s is dangling on gateway %s, unregistering and re-registering",
run_model.project.name,
run_model.run_name,
gateway.name,
)
await client.unregister_service(
project=run_model.project.name,
run_name=run_model.run_name,
)
await do_register()
else:
raise
except SSHError:
raise ServerClientError("Gateway tunnel is not working")
except httpx.RequestError as e:
Expand Down
56 changes: 54 additions & 2 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal import settings
from dstack._internal.core.errors import GatewayError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import ApplyAction
from dstack._internal.core.models.configurations import (
Expand Down Expand Up @@ -2299,13 +2300,13 @@ async def test_returns_400_if_runs_active(
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
class TestSubmitService:
@pytest.fixture(autouse=True)
def mock_gateway_connections(self) -> Generator[None, None, None]:
def mock_gateway_connection(self) -> Generator[AsyncMock, None, None]:
with patch(
"dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add"
) as get_conn_mock:
get_conn_mock.return_value.client = Mock()
get_conn_mock.return_value.client.return_value = AsyncMock()
yield
yield get_conn_mock

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2481,3 +2482,54 @@ async def test_return_error_if_specified_gateway_is_true_and_no_gateway_exists(
}
]
}

@pytest.mark.asyncio
async def test_unregister_dangling_service(
self,
test_db,
session: AsyncSession,
client: AsyncClient,
mock_gateway_connection: AsyncMock,
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user, name="test-project")
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
wildcard_domain="example.com",
)
project.default_gateway_id = gateway.id
await session.commit()

client_mock = (
mock_gateway_connection.return_value.client.return_value.__aenter__.return_value
)
client_mock.register_service.side_effect = [
GatewayError("Service test-project/test-service is already registered"),
None, # Second call succeeds
]

response = await client.post(
"/api/project/test-project/runs/submit",
headers=get_auth_headers(user.token),
json={"run_spec": get_service_run_spec(repo_id=repo.name, run_name="test-service")},
)

assert response.status_code == 200
assert response.json()["service"]["url"] == "https://test-service.example.com"
# Verify that unregister_service was called to clean up the dangling service
client_mock.unregister_service.assert_called_once_with(
project=project.name,
run_name="test-service",
)
# Verify that register_service was called twice (first failed, then succeeded)
assert client_mock.register_service.call_count == 2