diff --git a/dialogue-apache-hc5-client/build.gradle b/dialogue-apache-hc5-client/build.gradle index 02687cc7d..858312f2b 100644 --- a/dialogue-apache-hc5-client/build.gradle +++ b/dialogue-apache-hc5-client/build.gradle @@ -20,9 +20,11 @@ dependencies { implementation 'io.dropwizard.metrics:metrics-core' implementation 'org.apache.httpcomponents.core5:httpcore5' + testImplementation project(':dialogue-clients') testImplementation project(':dialogue-test-common') testImplementation project(':dialogue-serde') testImplementation 'org.awaitility:awaitility' + testImplementation 'com.github.ben-manes.caffeine:caffeine' testImplementation 'org.junit.jupiter:junit-jupiter' testRuntimeOnly 'org.apache.logging.log4j:log4j-slf4j-impl' testRuntimeOnly 'org.apache.logging.log4j:log4j-core' diff --git a/dialogue-apache-hc5-client/src/test/java/com/palantir/dialogue/hc5/NoResponseTest.java b/dialogue-apache-hc5-client/src/test/java/com/palantir/dialogue/hc5/NoResponseTest.java index 364dcb4b3..b7de6e194 100644 --- a/dialogue-apache-hc5-client/src/test/java/com/palantir/dialogue/hc5/NoResponseTest.java +++ b/dialogue-apache-hc5-client/src/test/java/com/palantir/dialogue/hc5/NoResponseTest.java @@ -19,30 +19,58 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.Uninterruptibles; +import com.palantir.conjure.java.api.config.service.PartialServiceConfiguration; +import com.palantir.conjure.java.api.config.service.ServicesConfigBlock; import com.palantir.conjure.java.client.config.ClientConfiguration; import com.palantir.conjure.java.client.config.ClientConfigurations; import com.palantir.conjure.java.config.ssl.SslSocketFactories; +import com.palantir.conjure.java.dialogue.serde.ConjureBodySerDe; +import com.palantir.conjure.java.dialogue.serde.ConjurePlainSerDe; +import com.palantir.conjure.java.dialogue.serde.DefaultClients; +import com.palantir.conjure.java.dialogue.serde.DefaultConjureRuntime; +import com.palantir.conjure.java.dialogue.serde.Encodings; +import com.palantir.conjure.java.dialogue.serde.Encodings.LimitedSizeEncoding; +import com.palantir.conjure.java.dialogue.serde.ErrorDecoder; +import com.palantir.conjure.java.dialogue.serde.WeightedEncoding; +import com.palantir.dialogue.BodySerDe; import com.palantir.dialogue.Channel; +import com.palantir.dialogue.Clients; +import com.palantir.dialogue.ConjureRuntime; +import com.palantir.dialogue.PlainSerDe; import com.palantir.dialogue.Request; import com.palantir.dialogue.Response; import com.palantir.dialogue.TestConfigurations; import com.palantir.dialogue.TestEndpoint; +import com.palantir.dialogue.clients.DialogueClients; +import com.palantir.dialogue.clients.DialogueClients.ReloadingFactory; +import com.palantir.logsafe.exceptions.SafeIoException; +import com.palantir.refreshable.Refreshable; import io.undertow.Undertow; import io.undertow.server.HttpHandler; import io.undertow.server.handlers.BlockingHandler; +import io.undertow.util.Headers; import java.io.OutputStream; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketTimeoutException; +import java.net.UnknownHostException; import java.time.Duration; +import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; import org.xnio.IoUtils; +@ExtendWith(MockitoExtension.class) public final class NoResponseTest { + private static final int MAX_BYTES = 1024; + private static final int PORT = 8080; private static Channel create(ClientConfiguration config) { return ApacheHttpClientChannels.create(config, "test"); @@ -85,6 +113,82 @@ public void testConnectionClosedAfterDelay() { } } + @Test + public void testLargePayload() { + String randomHostname = UUID.randomUUID().toString(); + + // Set up Undertow server with a large payload + String largePayload = "a".repeat(MAX_BYTES * 2); // This payload should exceed the maximum allowed size + Undertow server = createUndertowServerWithPayload(largePayload); + server.start(); + + try { + ReloadingFactory factory = DialogueClients.create(Refreshable.only(ServicesConfigBlock.builder() + .defaultSecurity(TestConfigurations.SSL_CONFIG) + .putServices( + "foo", + PartialServiceConfiguration.builder() + .addUris(getUri(server, randomHostname)) + .build()) + .build())) + .withUserAgent(TestConfigurations.AGENT) + .withDnsResolver(hostname -> { + if (randomHostname.equals(hostname)) { + try { + return ImmutableSet.of( + InetAddress.getByAddress(randomHostname, new byte[] {127, 0, 0, 1})); + } catch (UnknownHostException ignored) { + // fall-through + } + } + return ImmutableSet.of(); + }); + + ReloadingFactory reloadingFactory = factory.withRuntime(new ConjureRuntime() { + @Override + public BodySerDe bodySerDe() { + return new ConjureBodySerDe( + ImmutableList.of(WeightedEncoding.of(new LimitedSizeEncoding(/* maxBytes= */ 1024))), + ErrorDecoder.INSTANCE, + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + } + + @Override + public PlainSerDe plainSerDe() { + return ConjurePlainSerDe.INSTANCE; + } + + @Override + public Clients clients() { + return DefaultClients.INSTANCE; + } + }); + + Channel channel = reloadingFactory.getChannel("foo"); + + ListenableFuture response = + channel.execute(TestEndpoint.GET, Request.builder().build()); + + // Verify that the deserializer throws the expected exception due to the large payload + assertThatThrownBy(response::get) + .hasCauseInstanceOf(SafeIoException.class) + .hasMessageContaining("Deserialization exceeded the maximum allowed size"); + } finally { + server.stop(); + } + } + + private static Undertow createUndertowServerWithPayload(String payload) { + return Undertow.builder() + .addHttpListener(PORT, "localhost") + .setHandler(new BlockingHandler(exchange -> { + exchange.getResponseHeaders().put(Headers.CONTENT_TYPE, "application/json"); + exchange.getResponseSender().send(payload); + })) + .build(); + } + @Test public void testIdleConnectionClosed() throws Exception { // Pooled connection should be reused, and retried if they've @@ -122,6 +226,13 @@ private static ClientConfiguration defaultClientConfig(int port) { TestConfigurations.AGENT); } + private static String getUri(Undertow undertow, String hostname) { + Undertow.ListenerInfo listenerInfo = Iterables.getOnlyElement(undertow.getListenerInfo()); + return String.format( + "%s://%s:%d", + listenerInfo.getProtcol(), hostname, ((InetSocketAddress) listenerInfo.getAddress()).getPort()); + } + private static void assertSuccessfulRequest(Channel channel) throws Exception { try (Response response = channel.execute(TestEndpoint.POST, request).get()) { assertThat(response.code()).isEqualTo(200); diff --git a/dialogue-clients/src/main/java/com/palantir/dialogue/clients/ReloadingClientFactory.java b/dialogue-clients/src/main/java/com/palantir/dialogue/clients/ReloadingClientFactory.java index be1497d10..20e39464a 100644 --- a/dialogue-clients/src/main/java/com/palantir/dialogue/clients/ReloadingClientFactory.java +++ b/dialogue-clients/src/main/java/com/palantir/dialogue/clients/ReloadingClientFactory.java @@ -79,6 +79,7 @@ final class ReloadingClientFactory implements DialogueClients.ReloadingFactory { } @Override + @SuppressWarnings("NullAway") public Channel getNonReloadingChannel(String channelName, ClientConfiguration input) { ClientConfiguration clientConf = hydrate(input); ApacheHttpClientChannels.ClientBuilder clientBuilder = ApacheHttpClientChannels.clientBuilder() diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java index b738b8529..dff9f0a95 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java @@ -49,7 +49,7 @@ import java.util.stream.Collectors; /** Package private internal API. */ -final class ConjureBodySerDe implements BodySerDe { +public final class ConjureBodySerDe implements BodySerDe { private static final SafeLogger log = SafeLoggerFactory.get(ConjureBodySerDe.class); private final List encodingsSortedByWeight; @@ -65,7 +65,7 @@ final class ConjureBodySerDe implements BodySerDe { * {@link Encoding#supportsContentType supports} the serialization format {@link HttpHeaders#ACCEPT accepted} * by a given request, or the first serializer if no such serializer can be found. */ - ConjureBodySerDe( + public ConjureBodySerDe( List rawEncodings, ErrorDecoder errorDecoder, EmptyContainerDeserializer emptyContainerDeserializer, diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjurePlainSerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjurePlainSerDe.java index 5a02ef22a..a72e80eba 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjurePlainSerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjurePlainSerDe.java @@ -25,7 +25,7 @@ import java.util.UUID; /** Package private internal API. */ -enum ConjurePlainSerDe implements PlainSerDe { +public enum ConjurePlainSerDe implements PlainSerDe { INSTANCE; @DoNotLog diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultClients.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultClients.java index fb7549029..705a16de6 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultClients.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultClients.java @@ -46,7 +46,7 @@ import java.util.concurrent.ExecutionException; /** Package private internal API. */ -enum DefaultClients implements Clients { +public enum DefaultClients implements Clients { INSTANCE; private static final SafeLogger log = SafeLoggerFactory.get(DefaultClients.class); diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java index 3e4766fda..eefcd51da 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java @@ -17,7 +17,6 @@ package com.palantir.conjure.java.dialogue.serde; import com.github.benmanes.caffeine.cache.CaffeineSpec; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.palantir.dialogue.BodySerDe; @@ -31,8 +30,7 @@ * {@link DefaultConjureRuntime} provides functionality required by generated handlers. */ public final class DefaultConjureRuntime implements ConjureRuntime { - @VisibleForTesting - static final CaffeineSpec DEFAULT_SERDE_CACHE_SPEC = + public static final CaffeineSpec DEFAULT_SERDE_CACHE_SPEC = CaffeineSpec.parse("maximumSize=1000,expireAfterAccess=1m,weakKeys,weakValues"); static final ImmutableList DEFAULT_ENCODINGS = ImmutableList.of( diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/Encodings.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/Encodings.java index 1e94c1a17..3c5b4de1b 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/Encodings.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/Encodings.java @@ -24,6 +24,7 @@ import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.TypeMarker; import com.palantir.logsafe.Preconditions; +import com.palantir.logsafe.exceptions.SafeIoException; import java.io.InputStream; import java.util.function.Supplier; import javax.annotation.Nullable; @@ -84,6 +85,60 @@ public boolean supportsContentType(String contentType) { }; } + public static final class LimitedSizeEncoding implements Encoding { + private final Encoding jsonEncoding; + private final int maxBytes; + + public LimitedSizeEncoding(int maxBytes) { + this.jsonEncoding = Encodings.json(); + this.maxBytes = maxBytes; + } + + @Override + public Serializer serializer(TypeMarker type) { + return jsonEncoding.serializer(type); + } + + @Override + public Deserializer deserializer(TypeMarker type) { + Deserializer delegate = jsonEncoding.deserializer(type); + return input -> { + int chunkSize = 1024; // set this to a suitable size + byte[] buffer = new byte[chunkSize]; + + int bytesRead; + int totalBytes = 0; + while ((bytesRead = input.readNBytes(buffer, 0, chunkSize)) > 0) { + totalBytes += bytesRead; + if (totalBytes > maxBytes) { + throw new SafeIoException("Deserialization exceeded the maximum allowed size"); + } + } + + // Reset the input stream to the beginning + if (input.markSupported()) { + input.reset(); + } else { + throw new SafeIoException("Cannot reset the input stream"); + } + + // Now delegate to JSON deserializer + T value = delegate.deserialize(input); + return Preconditions.checkNotNull(value, "cannot deserialize a JSON null value"); + }; + } + + @Override + public String getContentType() { + return jsonEncoding.getContentType(); + } + + @Override + public boolean supportsContentType(String contentType) { + return jsonEncoding.supportsContentType(contentType); + } + } + /** Returns a serializer for the Conjure CBOR wire format. */ public static Encoding cbor() { return new AbstractJacksonEncoding(configure(ObjectMappers.newCborClientObjectMapper())) { @@ -118,7 +173,7 @@ public boolean supportsContentType(String contentType) { }; } - static EmptyContainerDeserializer emptyContainerDeserializer() { + public static EmptyContainerDeserializer emptyContainerDeserializer() { return new JacksonEmptyContainerLoader(JSON_MAPPER.get()); } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/WeightedEncoding.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/WeightedEncoding.java index 010ee32b1..06a0cbc3e 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/WeightedEncoding.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/WeightedEncoding.java @@ -24,7 +24,7 @@ * Note that the weight may not be applied to the Accept header, rather * used to order values. */ -final class WeightedEncoding { +public final class WeightedEncoding { private final Encoding encoding; private final double weight; @@ -39,7 +39,7 @@ static WeightedEncoding of(Encoding encoding, double weight) { return new WeightedEncoding(encoding, weight); } - static WeightedEncoding of(Encoding encoding) { + public static WeightedEncoding of(Encoding encoding) { return new WeightedEncoding(encoding, 1); }