Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/java-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
--file duo-client/pom.xml
- name: Test with Maven
run: >
mvn test
--batch-mode
mvn verify
--batch-mode
-file duo-client/pom.xml
- name: Lint with checkstyle
run: mvn checkstyle:check
25 changes: 25 additions & 0 deletions duo-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
<version>3.12.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-tls</artifactId>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -123,6 +135,19 @@
<parallel>methods</parallel>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.2.5</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.cyclonedx</groupId>
<artifactId>cyclonedx-maven-plugin</artifactId>
Expand Down
42 changes: 41 additions & 1 deletion duo-client/src/main/java/com/duosecurity/client/Http.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class Http {
private Headers.Builder headers;
private SortedMap<String, Object> params = new TreeMap<String, Object>();
protected int sigVersion = 5;
private long maxBackoffMs = MAX_BACKOFF_MS;
private Random random = new Random();
private OkHttpClient httpClient;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
Expand Down Expand Up @@ -314,10 +315,14 @@ private Response executeRequest(Request request) throws Exception {
long backoffMs = INITIAL_BACKOFF_MS;
while (true) {
Response response = httpClient.newCall(request).execute();
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > MAX_BACKOFF_MS) {
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > maxBackoffMs) {
return response;
}

// Close the 429 response to release the connection back to the pool before retrying
if (response.body() != null) {
response.close();
}
sleep(backoffMs + nextRandomInt(1000));
backoffMs *= BACKOFF_FACTOR;
}
Expand All @@ -327,6 +332,13 @@ protected void sleep(long ms) throws Exception {
Thread.sleep(ms);
}

protected void setMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;
}

public void signRequest(String ikey, String skey)
throws UnsupportedEncodingException {
signRequest(ikey, skey, sigVersion);
Expand Down Expand Up @@ -529,6 +541,7 @@ protected abstract static class ClientBuilder<T extends Http> {
private final String uri;

private int timeout = DEFAULT_TIMEOUT_SECS;
private long maxBackoffMs = MAX_BACKOFF_MS;
private String[] caCerts = null;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
private Map<String, String> headers = new HashMap<String, String>();
Expand Down Expand Up @@ -558,6 +571,32 @@ public ClientBuilder<T> useTimeout(int timeout) {
return this;
}

/**
* Set the maximum base backoff time in milliseconds for rate limit (429) retries.
* When a request receives a 429 response, the client retries with exponential
* backoff until the base backoff exceeds this threshold. Note that actual sleep
* time includes up to 1000ms of random jitter on top of the base backoff.
* Setting to 0 disables retries (as does any value below the initial
* backoff of 1000ms). Default is 32000ms (32 seconds).
*
* <p>Note: When using method chaining from outside this package (e.g. with
* {@code AuthBuilder} or {@code AdminBuilder}), assign the builder to a variable
* and call methods separately, then call {@code build()}. This is a known
* limitation of all {@code ClientBuilder} methods.
*
* @param maxBackoffMs the maximum base backoff in milliseconds (must be >= 0)
* @return the Builder
* @throws IllegalArgumentException if maxBackoffMs is negative
*/
public ClientBuilder<T> useMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;

return this;
}

/**
* Provide custom CA certificates for certificate pinning.
*
Expand Down Expand Up @@ -604,6 +643,7 @@ public ClientBuilder<T> addHeader(String name, String value) {
*/
public T build() {
T duoClient = createClient(method, host, uri, timeout);
duoClient.setMaxBackoffMs(maxBackoffMs);
if (caCerts != null) {
duoClient.useCustomCertificates(caCerts);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package com.duosecurity.client;

import okhttp3.OkHttpClient;
import okhttp3.Response;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.tls.HandshakeCertificates;
import okhttp3.tls.HeldCertificate;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

import java.lang.reflect.Field;

import static org.junit.Assert.assertEquals;

public class HttpRateLimitRetryIntegrationIT {

private MockWebServer server;
private HandshakeCertificates clientCerts;

@Before
public void setUp() throws Exception {
HeldCertificate serverCert = new HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build();

HandshakeCertificates serverCerts = new HandshakeCertificates.Builder()
.heldCertificate(serverCert)
.build();

clientCerts = new HandshakeCertificates.Builder()
.addTrustedCertificate(serverCert.certificate())
.build();

server = new MockWebServer();
server.useHttps(serverCerts.sslSocketFactory(), false);
server.start();
}

@After
public void tearDown() throws Exception {
server.shutdown();
}

/**
* Builds an Http spy pointing at the MockWebServer, with sleep() stubbed out to avoid real
* delays and the OkHttpClient replaced with one that trusts the test certificate.
*
* <p>The builder must be constructed with host "localhost" (no port) so that CertificatePinner
* accepts the pattern. This method then sets the real host (with port) and replaces the
* OkHttpClient via reflection before the spy is used.
*/
private Http buildSpyHttp(Http.ClientBuilder<Http> builder) throws Exception {
Http spy = Mockito.spy(builder.build());
Mockito.doNothing().when(spy).sleep(Mockito.any(Long.class));

// Point the host at the MockWebServer port (CertificatePinner rejects host:port patterns,
// so the builder uses "localhost" and we fix it here after construction).
Field hostField = Http.class.getDeclaredField("host");
hostField.setAccessible(true);
hostField.set(spy, "localhost:" + server.getPort());

// Replace the OkHttpClient with one configured to trust the test certificate
OkHttpClient testClient = new OkHttpClient.Builder()
.sslSocketFactory(clientCerts.sslSocketFactory(), clientCerts.trustManager())
.build();

Field httpClientField = Http.class.getDeclaredField("httpClient");
httpClientField.setAccessible(true);
httpClientField.set(spy, testClient);

return spy;
}

private Http.HttpBuilder defaultBuilder() {
// Use "localhost" without a port — CertificatePinner rejects host:port patterns.
// buildSpyHttp sets the real host (with port) via reflection after construction.
return new Http.HttpBuilder("GET", "localhost", "/foo/bar");
}

@Test
public void testSingleRateLimitRetry() throws Exception {
server.enqueue(new MockResponse().setResponseCode(429));
server.enqueue(new MockResponse().setResponseCode(200));

Http http = buildSpyHttp(defaultBuilder());
Response response = http.executeHttpRequest();

assertEquals(200, response.code());
assertEquals(2, server.getRequestCount());
Mockito.verify(http, Mockito.times(1)).sleep(Mockito.any(Long.class));
}

@Test
public void testRateLimitExhaustsDefaultMaxBackoff() throws Exception {
// Enqueue more responses than will ever be consumed
for (int i = 0; i < 10; i++) {
server.enqueue(new MockResponse().setResponseCode(429));
}

Http http = buildSpyHttp(defaultBuilder());
Response response = http.executeHttpRequest();

assertEquals(429, response.code());
// Default max backoff (32s): sleeps at 1s, 2s, 4s, 8s, 16s, 32s = 6 sleeps, 7 total requests
assertEquals(7, server.getRequestCount());
Mockito.verify(http, Mockito.times(6)).sleep(Mockito.any(Long.class));
}

@Test
public void testCustomMaxBackoffLimitsRetries() throws Exception {
for (int i = 0; i < 10; i++) {
server.enqueue(new MockResponse().setResponseCode(429));
}

Http http = buildSpyHttp(defaultBuilder().useMaxBackoffMs(4000));
Response response = http.executeHttpRequest();

assertEquals(429, response.code());
// maxBackoff=4000: sleeps at 1s, 2s, 4s = 3 sleeps, 4 total requests (next would be 8s > 4s)
assertEquals(4, server.getRequestCount());
Mockito.verify(http, Mockito.times(3)).sleep(Mockito.any(Long.class));
}

@Test
public void testMaxBackoffZeroDisablesRetry() throws Exception {
server.enqueue(new MockResponse().setResponseCode(429));

Http http = buildSpyHttp(defaultBuilder().useMaxBackoffMs(0));
Response response = http.executeHttpRequest();

assertEquals(429, response.code());
assertEquals(1, server.getRequestCount());
Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ public class HttpRateLimitRetryTest {

private final int RANDOM_INT = 234;

@Before
public void before() throws Exception {
http = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
http = Mockito.spy(http);
private void setupHttp(Http client) throws Exception {
http = Mockito.spy(client);

Field httpClientField = Http.class.getDeclaredField("httpClient");
httpClientField.setAccessible(true);
Expand All @@ -39,6 +37,12 @@ public void before() throws Exception {
Mockito.doNothing().when(http).sleep(Mockito.any(Long.class));
}

@Before
public void before() throws Exception {
Http client = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
setupHttp(client);
}

@Test
public void testSingleRateLimitRetry() throws Exception {
final List<Response> responses = new ArrayList<Response>();
Expand Down Expand Up @@ -128,4 +132,87 @@ public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
assertEquals(16000L + RANDOM_INT, (long) sleepTimes.get(4));
assertEquals(32000L + RANDOM_INT, (long) sleepTimes.get(5));
}

@Test
public void testMaxBackoffZeroDisablesRetry() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(0)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

Response actualRes = http.executeHttpRequest();
assertEquals(1, responses.size());
assertEquals(429, actualRes.code());

// Verify no sleep was called
Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class));
}

@Test
public void testMaxBackoffCustomLimit() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(4000)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

// With maxBackoff=4000, retries at 1000, 2000, 4000, then 8000 > 4000 exits
// That's 4 total requests (1 initial + 3 retries)
Response actualRes = http.executeHttpRequest();
assertEquals(4, responses.size());
assertEquals(429, actualRes.code());

ArgumentCaptor<Long> sleepCapture = ArgumentCaptor.forClass(Long.class);
Mockito.verify(http, Mockito.times(3)).sleep(sleepCapture.capture());
List<Long> sleepTimes = sleepCapture.getAllValues();
assertEquals(1000L + RANDOM_INT, (long) sleepTimes.get(0));
assertEquals(2000L + RANDOM_INT, (long) sleepTimes.get(1));
assertEquals(4000L + RANDOM_INT, (long) sleepTimes.get(2));
}

@Test(expected = IllegalArgumentException.class)
public void testMaxBackoffNegativeThrows() {
new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(-1)
.build();
}
}