From fe8903135643b44aeea3042f78dc4faf16b326f4 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Tue, 31 Mar 2026 01:26:52 +0200 Subject: [PATCH 1/3] fix: enable DNS rebinding protection by default via Ktor route-scoped plugin --- .../sdk/conformance/ConformanceServer.kt | 2 - kotlin-sdk-server/api/kotlin-sdk-server.api | 22 +- .../kotlin/sdk/server/HostValidation.kt | 119 +++++++++ .../kotlin/sdk/server/KtorServer.kt | 95 +++++-- .../server/StreamableHttpServerTransport.kt | 17 +- .../sdk/server/DnsRebindingProtectionTest.kt | 242 ++++++++++++++++++ .../server/KtorApplicationExtensionsTest.kt | 2 +- .../sdk/server/KtorRouteExtensionsTest.kt | 6 +- 8 files changed, 472 insertions(+), 33 deletions(-) create mode 100644 kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt create mode 100644 kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt index b42077529..11946749a 100644 --- a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt @@ -19,8 +19,6 @@ fun main() { json(McpJson) } mcpStreamableHttp( - enableDnsRebindingProtection = true, - allowedHosts = listOf("localhost", "127.0.0.1", "localhost:$port", "127.0.0.1:$port"), eventStore = InMemoryEventStore(), ) { createConformanceServer() diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index fef34b261..a4f5c9ab0 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -28,16 +28,32 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/ClientConnection$De public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/server/ClientConnection;Lio/modelcontextprotocol/kotlin/sdk/types/PingRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionConfig { + public fun ()V + public final fun getAllowedHosts ()Ljava/util/List; + public final fun getAllowedOrigins ()Ljava/util/List; + public final fun setAllowedHosts (Ljava/util/List;)V + public final fun setAllowedOrigins (Ljava/util/List;)V +} + public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventStore { public abstract fun getStreamIdForEventId (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun replayEventsAfter (Ljava/lang/String;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun storeEvent (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/HostValidationKt { + public static final fun getDnsRebindingProtection ()Lio/ktor/server/application/RouteScopedPlugin; + public static final fun getLOCALHOST_ALLOWED_HOSTS ()Ljava/util/List; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { - public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V - public static final fun mcp (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V - public static final fun mcp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;)V + public static final fun mcp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V + public static final fun mcp (Lio/ktor/server/routing/Route;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V + public static final fun mcp (Lio/ktor/server/routing/Route;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun mcp$default (Lio/ktor/server/routing/Route;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun mcp$default (Lio/ktor/server/routing/Route;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt new file mode 100644 index 000000000..23df79fb9 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt @@ -0,0 +1,119 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.URLBuilder +import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.RouteScopedPlugin +import io.ktor.server.application.createRouteScopedPlugin +import io.ktor.server.request.header +import io.ktor.server.response.respondText +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.RPCError + +/** + * Default list of hostnames allowed for localhost DNS rebinding protection. + * Matches the TypeScript SDK's `localhostAllowedHostnames()`. + */ +public val LOCALHOST_ALLOWED_HOSTS: List = listOf("localhost", "127.0.0.1", "[::1]") + +/** + * Extracts the hostname from a Host header value, stripping port and normalizing IPv6. + * + * Examples: + * - `"localhost:3000"` → `"localhost"` + * - `"127.0.0.1:8080"` → `"127.0.0.1"` + * - `"[::1]:3000"` → `"[::1]"` + * - `"example.com"` → `"example.com"` + * + * @return the hostname, or `null` if parsing fails. + */ +internal fun extractHostname(hostHeader: String): String? { + if (hostHeader.isBlank()) return null + return try { + URLBuilder("http://$hostHeader").build().host.ifEmpty { null } + } catch (_: Exception) { + null + } +} + +/** + * Configuration for the [DnsRebindingProtection] Ktor route-scoped plugin. + * + * @property allowedHosts List of hostnames allowed in the `Host` header. + * Comparison is port-agnostic and case-insensitive. + * Defaults to [LOCALHOST_ALLOWED_HOSTS]. + * An empty list will reject **all** requests. + * @property allowedOrigins Optional list of allowed `Origin` header values. + * If `null`, origin validation is disabled. + * If configured, requests **with** an `Origin` header not in the list are rejected, + * but requests **without** an `Origin` header are allowed (non-browser clients). + */ +public class DnsRebindingProtectionConfig { + public var allowedHosts: List = LOCALHOST_ALLOWED_HOSTS + public var allowedOrigins: List? = null +} + +/** + * Ktor route-scoped plugin that validates `Host` and `Origin` headers + * to protect against DNS rebinding attacks. + * + * Install on a route to intercept all requests **before** handlers: + * ```kotlin + * route("/mcp") { + * install(DnsRebindingProtection) { + * allowedHosts = listOf("myapp.com", "localhost") + * } + * // handlers... + * } + * ``` + */ +public val DnsRebindingProtection: RouteScopedPlugin = + createRouteScopedPlugin( + "MCP-DnsRebindingProtection", + ::DnsRebindingProtectionConfig, + ) { + val hosts: Set = pluginConfig.allowedHosts.mapTo(mutableSetOf()) { + extractHostname(it)?.lowercase() ?: it.lowercase() + } + val origins: Set? = pluginConfig.allowedOrigins?.mapTo(mutableSetOf()) { it.lowercase() } + + onCall { call -> + val hostHeader = call.request.header(HttpHeaders.Host) + val hostname = hostHeader?.let { extractHostname(it) }?.lowercase() + + if (hostname == null || hostname !in hosts) { + call.rejectDnsValidation("Invalid Host header: $hostHeader") + return@onCall + } + + if (origins != null) { + val origin = call.request.header(HttpHeaders.Origin)?.lowercase() + // Allow requests without Origin (non-browser clients cannot perform DNS rebinding) + if (origin != null && origin !in origins) { + call.rejectDnsValidation("Invalid Origin header: $origin") + return@onCall + } + } + } + } + +/** + * Responds with a 403 Forbidden JSON-RPC error without requiring ContentNegotiation. + */ +private suspend fun ApplicationCall.rejectDnsValidation(message: String) { + val error = JSONRPCError( + id = null, + error = RPCError( + code = RPCError.ErrorCode.CONNECTION_CLOSED, + message = message, + ), + ) + respondText( + McpJson.encodeToString(error), + ContentType.Application.Json, + HttpStatusCode.Forbidden, + ) +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 739f8d3c1..1b9d6a654 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -35,14 +35,24 @@ private val logger = KotlinLogging.logger {} * Use [Application.mcp] if you want SSE to be installed automatically. * * @param path the URL path to register the SSE endpoint. + * @param enableDnsRebindingProtection whether to install [DnsRebindingProtection] on this route. Defaults to `true`. + * @param allowedHosts hostnames allowed in the `Host` header. Defaults to [LOCALHOST_ALLOWED_HOSTS]. + * @param allowedOrigins origins allowed in the `Origin` header, or `null` to skip origin validation. * @param block factory block with access to the [ServerSSESession] * that creates and returns the [Server] to handle the connection. * @throws IllegalStateException if the [SSE] plugin is not installed. */ @KtorDsl -public fun Route.mcp(path: String, block: ServerSSESession.() -> Server) { +@Suppress("LongParameterList") +public fun Route.mcp( + path: String, + enableDnsRebindingProtection: Boolean = true, + allowedHosts: List? = null, + allowedOrigins: List? = null, + block: ServerSSESession.() -> Server, +) { route(path) { - mcp(block) + mcp(enableDnsRebindingProtection, allowedHosts, allowedOrigins, block) } } @@ -53,12 +63,20 @@ public fun Route.mcp(path: String, block: ServerSSESession.() -> Server) { * **Precondition:** the [SSE] plugin must be installed on the application before calling this function. * Use [Application.mcp] if you want SSE to be installed automatically. * + * @param enableDnsRebindingProtection whether to install [DnsRebindingProtection] on this route. Defaults to `true`. + * @param allowedHosts hostnames allowed in the `Host` header. Defaults to [LOCALHOST_ALLOWED_HOSTS]. + * @param allowedOrigins origins allowed in the `Origin` header, or `null` to skip origin validation. * @param block factory block with access to the [ServerSSESession] * that creates and returns the [Server] to handle the connection. * @throws IllegalStateException if the [SSE] plugin is not installed. */ @KtorDsl -public fun Route.mcp(block: ServerSSESession.() -> Server) { +public fun Route.mcp( + enableDnsRebindingProtection: Boolean = true, + allowedHosts: List? = null, + allowedOrigins: List? = null, + block: ServerSSESession.() -> Server, +) { try { plugin(SSE) } catch (e: MissingApplicationPluginException) { @@ -70,6 +88,13 @@ public fun Route.mcp(block: ServerSSESession.() -> Server) { ) } + if (enableDnsRebindingProtection) { + install(DnsRebindingProtection) { + this.allowedHosts = allowedHosts ?: LOCALHOST_ALLOWED_HOSTS + allowedOrigins?.let { this.allowedOrigins = it } + } + } + val transportManager = TransportManager() sse { @@ -86,20 +111,32 @@ public fun Route.mcp(block: ServerSSESession.() -> Server) { * over [Server-Sent Events (SSE) Transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse) * and sets up routing with the provided configuration block. * + * @param enableDnsRebindingProtection whether to install [DnsRebindingProtection] on this route. Defaults to `true`. + * @param allowedHosts hostnames allowed in the `Host` header. Defaults to [LOCALHOST_ALLOWED_HOSTS]. + * @param allowedOrigins origins allowed in the `Origin` header, or `null` to skip origin validation. * @param block factory block with access to the [ServerSSESession] * that creates and returns the [Server] to handle the connection. */ @KtorDsl -public fun Application.mcp(block: ServerSSESession.() -> Server) { +public fun Application.mcp( + enableDnsRebindingProtection: Boolean = true, + allowedHosts: List? = null, + allowedOrigins: List? = null, + block: ServerSSESession.() -> Server, +) { install(SSE) routing { - mcp(block) + mcp(enableDnsRebindingProtection, allowedHosts, allowedOrigins, block) } } +@Suppress("LongParameterList") private fun Application.mcpStreamableHttp( path: String = "/mcp", + enableDnsRebindingProtection: Boolean, + allowedHosts: List?, + allowedOrigins: List?, configuration: StreamableHttpServerTransport.Configuration, block: RoutingContext.() -> Server, ) { @@ -109,6 +146,13 @@ private fun Application.mcpStreamableHttp( routing { route(path) { + if (enableDnsRebindingProtection) { + install(DnsRebindingProtection) { + this.allowedHosts = allowedHosts ?: LOCALHOST_ALLOWED_HOSTS + allowedOrigins?.let { this.allowedOrigins = it } + } + } + sse { val transport = existingStreamableTransport(call, transportManager) ?: return@sse transport.handleRequest(this, call) @@ -140,10 +184,11 @@ private fun Application.mcpStreamableHttp( * Simple request/response pairs are returned as JSON (not SSE streams). * * @param path The base path for the MCP Streamable HTTP endpoint. Defaults to "/mcp". - * @param enableDnsRebindingProtection Enables DNS rebinding attack protection for the endpoint. Defaults to false. - * @param allowedHosts A list of hostnames allowed to access the endpoint. If `null`, no restrictions are applied. + * @param enableDnsRebindingProtection Enables DNS rebinding attack protection for the endpoint. Defaults to `true`. + * @param allowedHosts A list of hostnames allowed to access the endpoint. + * If `null` and DNS rebinding protection is enabled, defaults to [LOCALHOST_ALLOWED_HOSTS]. * @param allowedOrigins A list of origins allowed to perform cross-origin requests (CORS). - * If `null`, no restrictions are applied. + * If `null`, origin validation is disabled. * @param eventStore An optional [EventStore] instance to enable resumable event stream functionality. * Allows storing and replaying events. * @param block factory block with access to the [RoutingContext] (for reading request headers) @@ -153,7 +198,7 @@ private fun Application.mcpStreamableHttp( @Suppress("LongParameterList") public fun Application.mcpStreamableHttp( path: String = "/mcp", - enableDnsRebindingProtection: Boolean = false, + enableDnsRebindingProtection: Boolean = true, allowedHosts: List? = null, allowedOrigins: List? = null, eventStore: EventStore? = null, @@ -161,10 +206,10 @@ public fun Application.mcpStreamableHttp( ) { mcpStreamableHttp( path = path, + enableDnsRebindingProtection = enableDnsRebindingProtection, + allowedHosts = allowedHosts, + allowedOrigins = allowedOrigins, configuration = StreamableHttpServerTransport.Configuration( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, eventStore = eventStore, enableJsonResponse = true, ), @@ -172,8 +217,12 @@ public fun Application.mcpStreamableHttp( ) } +@Suppress("LongParameterList") private fun Application.mcpStatelessStreamableHttp( path: String = "/mcp", + enableDnsRebindingProtection: Boolean, + allowedHosts: List?, + allowedOrigins: List?, configuration: StreamableHttpServerTransport.Configuration, block: RoutingContext.() -> Server, ) { @@ -181,6 +230,13 @@ private fun Application.mcpStatelessStreamableHttp( routing { route(path) { + if (enableDnsRebindingProtection) { + install(DnsRebindingProtection) { + this.allowedHosts = allowedHosts ?: LOCALHOST_ALLOWED_HOSTS + allowedOrigins?.let { this.allowedOrigins = it } + } + } + post { mcpStatelessStreamableHttpEndpoint( configuration = configuration, @@ -213,9 +269,10 @@ private fun Application.mcpStatelessStreamableHttp( * Simple request/response pairs are returned as JSON (not SSE streams). * * @param path The URL path where the server listens for incoming JSON-RPC requests. Defaults to "/mcp". - * @param enableDnsRebindingProtection Determines whether DNS rebinding protection is enabled. Defaults to `false`. - * @param allowedHosts A list of allowed hostnames. If null, host filtering is disabled. - * @param allowedOrigins A list of allowed origins for CORS. If null, origin filtering is disabled. + * @param enableDnsRebindingProtection Determines whether DNS rebinding protection is enabled. Defaults to `true`. + * @param allowedHosts A list of allowed hostnames. If `null` and DNS rebinding protection is enabled, + * defaults to [LOCALHOST_ALLOWED_HOSTS]. + * @param allowedOrigins A list of allowed origins for CORS. If `null`, origin validation is disabled. * @param eventStore An optional [EventStore] implementation to provide resumability and event replay support. * @param block factory block with access to the [RoutingContext] (for reading request headers) * that creates and returns the [Server] to handle the connection. @@ -224,7 +281,7 @@ private fun Application.mcpStatelessStreamableHttp( @Suppress("LongParameterList") public fun Application.mcpStatelessStreamableHttp( path: String = "/mcp", - enableDnsRebindingProtection: Boolean = false, + enableDnsRebindingProtection: Boolean = true, allowedHosts: List? = null, allowedOrigins: List? = null, eventStore: EventStore? = null, @@ -232,10 +289,10 @@ public fun Application.mcpStatelessStreamableHttp( ) { mcpStatelessStreamableHttp( path = path, + enableDnsRebindingProtection = enableDnsRebindingProtection, + allowedHosts = allowedHosts, + allowedOrigins = allowedOrigins, configuration = StreamableHttpServerTransport.Configuration( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, eventStore = eventStore, enableJsonResponse = true, ), diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 821c525a1..882a8515a 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -144,8 +144,11 @@ public class StreamableHttpServerTransport(private val configuration: Configurat */ public class Configuration( public val enableJsonResponse: Boolean = false, + @Deprecated("Use install(DnsRebindingProtection) on your Ktor route instead") public val enableDnsRebindingProtection: Boolean = false, + @Deprecated("Use install(DnsRebindingProtection) on your Ktor route instead") public val allowedHosts: List? = null, + @Deprecated("Use install(DnsRebindingProtection) on your Ktor route instead") public val allowedOrigins: List? = null, public val eventStore: EventStore? = null, public val retryInterval: Duration? = null, @@ -627,15 +630,18 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } } - @Suppress("ReturnCount") + @Suppress("ReturnCount", "DEPRECATION") private fun validateHeaders(call: ApplicationCall): String? { if (!configuration.enableDnsRebindingProtection) return null configuration.allowedHosts?.let { hosts -> - val hostHeader = call.request.headers[HttpHeaders.Host]?.lowercase() - val allowedHostsLowercase = hosts.map { it.lowercase() } + val hostHeader = call.request.headers[HttpHeaders.Host] + val hostname = hostHeader?.let { extractHostname(it) }?.lowercase() + val allowedHostsLowercase = hosts.map { + extractHostname(it)?.lowercase() ?: it.lowercase() + } - if (hostHeader == null || hostHeader !in allowedHostsLowercase) { + if (hostname == null || hostname !in allowedHostsLowercase) { return "Invalid Host header: $hostHeader" } } @@ -644,7 +650,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat val originHeader = call.request.headers[HttpHeaders.Origin]?.lowercase() val allowedOriginsLowercase = origins.map { it.lowercase() } - if (originHeader == null || originHeader !in allowedOriginsLowercase) { + // Allow requests without Origin (non-browser clients cannot perform DNS rebinding) + if (originHeader != null && originHeader !in allowedOriginsLowercase) { return "Invalid Origin header: $originHeader" } } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt new file mode 100644 index 000000000..1ff5eb375 --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt @@ -0,0 +1,242 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.assertions.ktor.client.shouldHaveStatus +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.application.install +import io.ktor.server.response.respondText +import io.ktor.server.routing.post +import io.ktor.server.routing.route +import io.ktor.server.routing.routing +import io.ktor.server.sse.SSE +import io.ktor.server.testing.ApplicationTestBuilder +import io.ktor.server.testing.testApplication +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import kotlin.test.Test + +class DnsRebindingProtectionTest { + + private fun testWithPlugin( + config: DnsRebindingProtectionConfig.() -> Unit = {}, + test: suspend ApplicationTestBuilder.() -> Unit, + ): Unit = testApplication { + application { + install(SSE) + routing { + route("/mcp") { + install(DnsRebindingProtection, config) + post { call.respondText("ok") } + } + } + } + test() + } + + @Test + fun `plugin rejects request with missing Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "") + } + response.shouldHaveStatus(HttpStatusCode.Forbidden) + response.bodyAsText() shouldContain "Invalid Host header" + } + + @Test + fun `plugin allows localhost Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "localhost") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin allows 127_0_0_1 Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "127.0.0.1") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin allows IPv6 localhost Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "[::1]") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin strips port from Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "localhost:3000") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin strips port from IPv6 Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "[::1]:8080") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin rejects non-localhost Host header`() = testWithPlugin { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "evil.com") + } + response.shouldHaveStatus(HttpStatusCode.Forbidden) + response.bodyAsText() shouldContain "Invalid Host header" + } + + @Test + fun `plugin allows request without Origin header`() = testWithPlugin( + config = { allowedOrigins = listOf("http://localhost:3000") }, + ) { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "localhost") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin rejects disallowed Origin`() = testWithPlugin( + config = { allowedOrigins = listOf("http://localhost:3000") }, + ) { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "localhost") + header(HttpHeaders.Origin, "http://evil.com") + } + response.shouldHaveStatus(HttpStatusCode.Forbidden) + response.bodyAsText() shouldContain "Invalid Origin header" + } + + @Test + fun `plugin allows matching Origin`() = testWithPlugin( + config = { allowedOrigins = listOf("http://localhost:3000") }, + ) { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "localhost") + header(HttpHeaders.Origin, "http://localhost:3000") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin with custom allowedHosts accepts matching host`() = testWithPlugin( + config = { allowedHosts = listOf("myapp.com") }, + ) { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "myapp.com:443") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `plugin with custom allowedHosts rejects non-matching host`() = testWithPlugin( + config = { allowedHosts = listOf("myapp.com") }, + ) { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "localhost") + } + response.shouldHaveStatus(HttpStatusCode.Forbidden) + } + + @Test + fun `host validation is case insensitive`() = testWithPlugin( + config = { allowedHosts = listOf("MyApp.COM") }, + ) { + val response = client.post("/mcp") { + header(HttpHeaders.Host, "myapp.com") + } + response.shouldHaveStatus(HttpStatusCode.OK) + response.bodyAsText() shouldBe "ok" + } + + @Test + fun `Route mcp with DNS protection enabled rejects non-localhost`() = testApplication { + application { + install(SSE) + routing { + mcp { testServer() } + } + } + + val response = client.post("/") { + header(HttpHeaders.Host, "evil.com") + contentType(ContentType.Application.Json) + } + response.shouldHaveStatus(HttpStatusCode.Forbidden) + } + + @Test + fun `Route mcp with DNS protection disabled allows any host`() = testApplication { + application { + install(SSE) + routing { + mcp(enableDnsRebindingProtection = false) { testServer() } + } + } + + val response = client.post("/") { + header(HttpHeaders.Host, "evil.com") + contentType(ContentType.Application.Json) + } + // Not 403 — the request reaches the handler (may get 400 for missing sessionId, etc.) + response.shouldHaveStatus(HttpStatusCode.BadRequest) + } + + // -- extractHostname unit tests -- + + @Test + fun `extractHostname strips port from hostname`() { + extractHostname("localhost:3000") shouldBe "localhost" + } + + @Test + fun `extractHostname returns hostname without port unchanged`() { + extractHostname("localhost") shouldBe "localhost" + } + + @Test + fun `extractHostname handles IPv4 with port`() { + extractHostname("127.0.0.1:8080") shouldBe "127.0.0.1" + } + + @Test + fun `extractHostname handles IPv6 with port`() { + extractHostname("[::1]:3000") shouldBe "[::1]" + } + + @Test + fun `extractHostname handles IPv6 without port`() { + extractHostname("[::1]") shouldBe "[::1]" + } + + @Test + fun `extractHostname returns null for empty string`() { + extractHostname("") shouldBe null + } + + private fun testServer(): Server = Server( + serverInfo = Implementation(name = "test-server", version = "1.0.0"), + options = ServerOptions(capabilities = ServerCapabilities()), + ) +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt index 05200e59b..48e5ed623 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt @@ -26,7 +26,7 @@ class KtorApplicationExtensionsTest : AbstractKtorExtensionsTest() { @Test fun `Application mcp should installs SSE and coexist with other routes`() = testApplication { application { - mcp { testServer() } + mcp(enableDnsRebindingProtection = false) { testServer() } routing { get("/health") { call.respondText("healthy") } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt index 5ba4448f6..d6b95b4fe 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt @@ -61,7 +61,7 @@ class KtorRouteExtensionsTest : AbstractKtorExtensionsTest() { route("/api/mcp") { get("/test") { call.respondText("test-endpoint") } - mcp { testServer() } + mcp(enableDnsRebindingProtection = false) { testServer() } } } } @@ -91,7 +91,7 @@ class KtorRouteExtensionsTest : AbstractKtorExtensionsTest() { route("/v1") { route("/services") { route("/mcp") { - mcp { testServer() } + mcp(enableDnsRebindingProtection = false) { testServer() } } } } @@ -112,7 +112,7 @@ class KtorRouteExtensionsTest : AbstractKtorExtensionsTest() { routing { route("/api") { - mcp("/mcp-endpoint") { testServer() } + mcp("/mcp-endpoint", enableDnsRebindingProtection = false) { testServer() } } } } From a3ee7fd8485a79d85f7093551e055a64c196f45b Mon Sep 17 00:00:00 2001 From: devcrocod Date: Tue, 31 Mar 2026 01:37:16 +0200 Subject: [PATCH 2/3] refactor: make `LOCALHOST_ALLOWED_HOSTS` internal to limit API exposure --- kotlin-sdk-server/api/kotlin-sdk-server.api | 1 - .../io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index a4f5c9ab0..007fcb672 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -44,7 +44,6 @@ public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventS public final class io/modelcontextprotocol/kotlin/sdk/server/HostValidationKt { public static final fun getDnsRebindingProtection ()Lio/ktor/server/application/RouteScopedPlugin; - public static final fun getLOCALHOST_ALLOWED_HOSTS ()Ljava/util/List; } public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt index 23df79fb9..3700aae80 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt @@ -17,7 +17,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError * Default list of hostnames allowed for localhost DNS rebinding protection. * Matches the TypeScript SDK's `localhostAllowedHostnames()`. */ -public val LOCALHOST_ALLOWED_HOSTS: List = listOf("localhost", "127.0.0.1", "[::1]") +internal val LOCALHOST_ALLOWED_HOSTS: List = listOf("localhost", "127.0.0.1", "[::1]") /** * Extracts the hostname from a Host header value, stripping port and normalizing IPv6. From 8b7b725a935020885ab91b8f511d6361888d3688 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Tue, 31 Mar 2026 02:08:44 +0200 Subject: [PATCH 3/3] replace URLBuilder-based host parsing with strict Host header parser --- .../kotlin/sdk/server/HostValidation.kt | 38 ++++++++++++++----- .../server/StreamableHttpServerTransport.kt | 15 ++++++-- .../sdk/server/DnsRebindingProtectionTest.kt | 25 ++++++++++++ .../server/KtorApplicationExtensionsTest.kt | 2 +- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt index 3700aae80..958c6e686 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/HostValidation.kt @@ -3,7 +3,6 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.ktor.http.ContentType import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode -import io.ktor.http.URLBuilder import io.ktor.server.application.ApplicationCall import io.ktor.server.application.RouteScopedPlugin import io.ktor.server.application.createRouteScopedPlugin @@ -20,23 +19,44 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError internal val LOCALHOST_ALLOWED_HOSTS: List = listOf("localhost", "127.0.0.1", "[::1]") /** - * Extracts the hostname from a Host header value, stripping port and normalizing IPv6. + * Characters that are valid in a URL but must not appear in an HTTP `Host` header. + * Rejecting them prevents the parser from accepting malformed values + * (e.g. `evil.com@localhost`, `host/path`) that a generic URL parser would silently allow. + */ +private val FORBIDDEN_HOST_CHARS: CharArray = charArrayOf('/', '@', '?', '#') + +/** + * Extracts the hostname from a Host header value, stripping the port. + * + * Only accepts the strict `host [ ":" port ]` / `"[" ipv6 "]" [ ":" port ]` + * format defined by RFC 7230. Values containing URL-only characters + * (`/`, `@`, `?`, `#`) or whitespace are rejected. * * Examples: * - `"localhost:3000"` → `"localhost"` * - `"127.0.0.1:8080"` → `"127.0.0.1"` * - `"[::1]:3000"` → `"[::1]"` * - `"example.com"` → `"example.com"` + * - `"evil.com@localhost"` → `null` * - * @return the hostname, or `null` if parsing fails. + * @return the hostname, or `null` if the value is blank, malformed, or contains forbidden characters. */ -internal fun extractHostname(hostHeader: String): String? { - if (hostHeader.isBlank()) return null - return try { - URLBuilder("http://$hostHeader").build().host.ifEmpty { null } - } catch (_: Exception) { - null +internal fun extractHostname(hostHeader: String): String? = when { + hostHeader.isBlank() -> null + + hostHeader.any { it in FORBIDDEN_HOST_CHARS || it.isWhitespace() } -> null + + hostHeader.startsWith("[") -> { + val end = hostHeader.indexOf(']') + val tail = if (end > 0) hostHeader.substring(end + 1) else "" + if (end > 0 && (tail.isEmpty() || tail.startsWith(':'))) { + hostHeader.substring(0, end + 1) + } else { + null + } } + + else -> hostHeader.substringBefore(':').ifEmpty { null } } /** diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 882a8515a..75cb02ab5 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -144,11 +144,20 @@ public class StreamableHttpServerTransport(private val configuration: Configurat */ public class Configuration( public val enableJsonResponse: Boolean = false, - @Deprecated("Use install(DnsRebindingProtection) on your Ktor route instead") + @Deprecated( + message = "Use install(DnsRebindingProtection) on your Ktor route instead", + level = DeprecationLevel.WARNING, + ) public val enableDnsRebindingProtection: Boolean = false, - @Deprecated("Use install(DnsRebindingProtection) on your Ktor route instead") + @Deprecated( + message = "Use install(DnsRebindingProtection) on your Ktor route instead", + level = DeprecationLevel.WARNING, + ) public val allowedHosts: List? = null, - @Deprecated("Use install(DnsRebindingProtection) on your Ktor route instead") + @Deprecated( + message = "Use install(DnsRebindingProtection) on your Ktor route instead", + level = DeprecationLevel.WARNING, + ) public val allowedOrigins: List? = null, public val eventStore: EventStore? = null, public val retryInterval: Duration? = null, diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt index 1ff5eb375..95a96fea3 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/DnsRebindingProtectionTest.kt @@ -235,6 +235,31 @@ class DnsRebindingProtectionTest { extractHostname("") shouldBe null } + @Test + fun `extractHostname rejects userinfo in host header`() { + extractHostname("evil.com@localhost") shouldBe null + } + + @Test + fun `extractHostname rejects path in host header`() { + extractHostname("evil.com/path") shouldBe null + } + + @Test + fun `extractHostname rejects query in host header`() { + extractHostname("evil.com?q=1") shouldBe null + } + + @Test + fun `extractHostname rejects fragment in host header`() { + extractHostname("evil.com#frag") shouldBe null + } + + @Test + fun `extractHostname rejects malformed IPv6`() { + extractHostname("[::1") shouldBe null + } + private fun testServer(): Server = Server( serverInfo = Implementation(name = "test-server", version = "1.0.0"), options = ServerOptions(capabilities = ServerCapabilities()), diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt index 48e5ed623..988886791 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt @@ -24,7 +24,7 @@ class KtorApplicationExtensionsTest : AbstractKtorExtensionsTest() { * added to the same application. */ @Test - fun `Application mcp should installs SSE and coexist with other routes`() = testApplication { + fun `Application mcp should install SSE and coexist with other routes`() = testApplication { application { mcp(enableDnsRebindingProtection = false) { testServer() }