Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ object TestAppShared {
val values: Vector[TestClientTestCase] = Vector(
DiscoverNewPod,
DnsFailureRecover,
ResolveShortDomainName,
)

/**
Expand All @@ -165,5 +166,14 @@ object TestAppShared {
* - check that after the configured reload TTL, the client sees both servers
*/
case object DnsFailureRecover extends TestClientTestCase

/**
* [[TestClient]] test case verifying that `K8sDnsNameResolver` resolves short k8s
* service domain names using resolv.conf search domains.
*
* Why this needs a separate test case:
* [[com.evolution.jgrpc.tools.k8sdns.NameLookupState]]
*/
case object ResolveShortDomainName extends TestClientTestCase
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ private[it] final class TestClient {
testCaseDiscoverNewPod(fixture)
case TestClientTestCase.DnsFailureRecover =>
testCaseDnsFailureRecover(fixture)
case TestClientTestCase.ResolveShortDomainName =>
testCaseResolveShortDomainName(fixture)
}
}

Expand All @@ -86,22 +88,22 @@ private[it] final class TestClient {
private def testCaseDiscoverNewPod(fixture: Fixture): Unit = {
fixture.coreDns.ensureStarted(serviceIps = Set(fixture.srv1Ip))

withRoundRobinLbClient { client =>
callHost2TimesAssertServerIds(client, expectedServerIds = Set(1))
withRoundRobinLbClient() { client =>
callHostManyTimesAssertServerIds(client, expectedServerIds = Set(1))

fixture.coreDns.setServiceIps(Set(fixture.srv1Ip, fixture.srv2Ip))

sleepUntilClientGetsDnsUpdate()

callHost2TimesAssertServerIds(client, expectedServerIds = Set(1, 2))
callHostManyTimesAssertServerIds(client, expectedServerIds = Set(1, 2))
}
}

private def testCaseDnsFailureRecover(fixture: Fixture): Unit = {
fixture.coreDns.ensureStarted(serviceIps = Set(fixture.srv1Ip))

withRoundRobinLbClient { client =>
callHost2TimesAssertServerIds(client, expectedServerIds = Set(1))
withRoundRobinLbClient() { client =>
callHostManyTimesAssertServerIds(client, expectedServerIds = Set(1))

fixture.coreDns.ensureStopped()

Expand All @@ -111,7 +113,15 @@ private[it] final class TestClient {

sleepUntilClientGetsDnsUpdate()

callHost2TimesAssertServerIds(client, expectedServerIds = Set(1, 2))
callHostManyTimesAssertServerIds(client, expectedServerIds = Set(1, 2))
}
}

private def testCaseResolveShortDomainName(fixture: Fixture): Unit = {
fixture.coreDns.ensureStarted(serviceIps = Set(fixture.srv1Ip, fixture.srv2Ip))

withRoundRobinLbClient(targetHostname = svcHostnameShort) { client =>
callHostManyTimesAssertServerIds(client, expectedServerIds = Set(1, 2))
}
}

Expand All @@ -125,21 +135,44 @@ private[it] final class TestClient {
Thread.sleep(sleepIntervalSeconds.toLong * 1000)
}

private def callHost2TimesAssertServerIds(
private def callHostManyTimesAssertServerIds(
client: TestSvcBlockingStub,
expectedServerIds: Set[Int],
): Unit = {
val actualServerIds = 0.until(2).map { _ =>
require(expectedServerIds.subsetOf(allServerIds))

var observedServerIdsVec: Vector[Int] = Vector.fill(allServerIds.size) {
client.getId(GetIdRequest()).id
}.toSet
if (actualServerIds != expectedServerIds) {
sys.error(s"GRPC client observed server IDs $actualServerIds, expected $expectedServerIds")
}

// When the client is just establishing connections, sometimes calling it multiple times
// doesn't give the round-robin call picture (1, 2, 1, 2,...),
// but it appears as if the client
// routes all the calls to the first opened connection,
// while waiting for the rest to be fully ready.
// So if we haven't observed all the servers yet, let's wait a bit and call again.
// This helps to recover such cases.
if (observedServerIdsVec.toSet.size < allServerIds.size) {
Thread.sleep(1000L)
observedServerIdsVec = observedServerIdsVec ++ Vector.fill(allServerIds.size) {
client.getId(GetIdRequest()).id
}
}

val observedServerIds = observedServerIdsVec.toSet
if (observedServerIds != expectedServerIds) {
sys.error(s"GRPC client expected server IDs $expectedServerIds, " +
s"observed $observedServerIds (received in order: $observedServerIdsVec)")
}
}

private def withRoundRobinLbClient[T](body: TestSvcBlockingStub => T): T = {
private def withRoundRobinLbClient[T](
targetHostname: String = svcHostname,
)(
body: TestSvcBlockingStub => T,
): T = {
val channel = NettyChannelBuilder
.forTarget(s"k8s-dns://$svcHostname:${ TestAppShared.ServerPort }")
.forTarget(s"k8s-dns://$targetHostname:${ TestAppShared.ServerPort }")
.usePlaintext()
.defaultLoadBalancingPolicy("round_robin")
.build()
Expand All @@ -156,7 +189,10 @@ private[it] final class TestClient {
}

private object TestClient {
private val svcHostname: String = "svc.example.org"
private val allServerIds = Set(1, 2)
private val clusterHostnameSuffix = "svc.cluster.local"
private val svcHostnameShort = "acme-grpc.acme"
private val svcHostname = s"$svcHostnameShort.$clusterHostnameSuffix"
private val resolveConfPath = "/etc/resolv.conf"

private val coreDnsCoreFilePath = "/etc/coredns/CoreFile"
Expand All @@ -178,6 +214,8 @@ private object TestClient {
Paths.get(resolveConfPath),
Vector(
"nameserver 127.0.0.1",
s"search $clusterHostnameSuffix",
"options ndots:5",
).asJava,
StandardOpenOption.TRUNCATE_EXISTING,
)
Expand Down Expand Up @@ -224,7 +262,7 @@ private object TestClient {
private def writeCoreFile(): Unit = {
Files.writeString(
Paths.get(coreDnsCoreFilePath),
s"""$svcHostname {
s""". {
| hosts $coreDnsHostsFilePath {
| ttl $coreDnsHostsReloadIntervalSeconds
| reload ${ coreDnsHostsReloadIntervalSeconds }s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class K8sDnsNameResolverIt extends AnyFreeSpec with BeforeAndAfterAll {
"should recover after DNS query failure" in {
runTestCase(TestClientTestCase.DnsFailureRecover)
}

"should resolve short domain names with search domains" in {
runTestCase(TestClientTestCase.ResolveShortDomainName)
}
}

private def runTestCase(testCase: TestClientTestCase): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import static java.lang.Math.max;
import static java.lang.String.format;

import com.google.common.net.InetAddresses;
import io.grpc.*;
import io.grpc.SynchronizationContext.ScheduledHandle;
import java.net.InetAddress;
Expand All @@ -12,18 +11,12 @@
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import org.jspecify.annotations.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.Name;
import org.xbill.DNS.Record;
import org.xbill.DNS.Type;
import org.xbill.DNS.lookup.LookupResult;
import org.xbill.DNS.lookup.LookupSession;

/* package */ final class K8sDnsNameResolver extends NameResolver {

Expand All @@ -33,7 +26,8 @@
private final long refreshIntervalSeconds;
private final SynchronizationContext syncCtx;
private final ScheduledExecutorService scheduledExecutor;
private final LookupSession dnsLookupSession;

private NameLookupState nameLookupState;

@Nullable private Listener listener = null;

Expand All @@ -52,8 +46,7 @@ private record SuccessResult(List<InetAddress> addresses, Instant receiveTime) {
this.refreshIntervalSeconds = refreshIntervalSeconds;
this.syncCtx = syncCtx;
this.scheduledExecutor = scheduledExecutor;
this.dnsLookupSession =
LookupSession.defaultBuilder().searchPath(targetUri.host()).clearCaches().build();
this.nameLookupState = NameLookupState.initialize(targetUri.host());
}

@Override
Expand Down Expand Up @@ -159,28 +152,21 @@ private EquivalentAddressGroup mkAddressGroup(InetAddress addr) {
// callback is executed under syncCtx
private void resolveAllAsync(
BiConsumer<@Nullable List<InetAddress>, ? super @Nullable Throwable> cb) {
final var dnsLookupAsyncResult = this.dnsLookupSession.lookupAsync(Name.empty, Type.A);
dnsLookupAsyncResult
.thenApply(
(result) -> {
logger.debug("DNS lookup result: {}", result);
var records =
Optional.ofNullable(result).map(LookupResult::getRecords).orElse(List.of());
return records.stream()
.map(Record::rdataToString)
.distinct()
.sorted() // make sure that result comparison does not depend on order
.map(InetAddresses::forString)
.toList();
})
nameLookupState
.runNextLookup()
.whenComplete(
(addresses, err) ->
(nameLookupState, err) ->
this.syncCtx.execute(
() -> {
if (err != null) {
logger.error("DNS lookup failed", err);
cb.accept(null, err);
} else {
this.nameLookupState = nameLookupState;
logger.debug(
"DNS lookup successful {}", this.nameLookupState.getLastResult());
cb.accept(this.nameLookupState.getLastResult(), null);
}
cb.accept(addresses, err);
}));
}

Expand Down
Loading