Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gc/limit number of bytes from dialogue #2290

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dialogue-apache-hc5-client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ dependencies {
implementation 'io.dropwizard.metrics:metrics-core'
implementation 'org.apache.httpcomponents.core5:httpcore5'

testImplementation project(':dialogue-clients')
testImplementation project(':dialogue-test-common')
testImplementation project(':dialogue-serde')
testImplementation 'org.awaitility:awaitility'
testImplementation 'com.github.ben-manes.caffeine:caffeine'
testImplementation 'org.junit.jupiter:junit-jupiter'
testRuntimeOnly 'org.apache.logging.log4j:log4j-slf4j-impl'
testRuntimeOnly 'org.apache.logging.log4j:log4j-core'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,58 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.Uninterruptibles;
import com.palantir.conjure.java.api.config.service.PartialServiceConfiguration;
import com.palantir.conjure.java.api.config.service.ServicesConfigBlock;
import com.palantir.conjure.java.client.config.ClientConfiguration;
import com.palantir.conjure.java.client.config.ClientConfigurations;
import com.palantir.conjure.java.config.ssl.SslSocketFactories;
import com.palantir.conjure.java.dialogue.serde.ConjureBodySerDe;
import com.palantir.conjure.java.dialogue.serde.ConjurePlainSerDe;
import com.palantir.conjure.java.dialogue.serde.DefaultClients;
import com.palantir.conjure.java.dialogue.serde.DefaultConjureRuntime;
import com.palantir.conjure.java.dialogue.serde.Encodings;
import com.palantir.conjure.java.dialogue.serde.Encodings.LimitedSizeEncoding;
import com.palantir.conjure.java.dialogue.serde.ErrorDecoder;
import com.palantir.conjure.java.dialogue.serde.WeightedEncoding;
import com.palantir.dialogue.BodySerDe;
import com.palantir.dialogue.Channel;
import com.palantir.dialogue.Clients;
import com.palantir.dialogue.ConjureRuntime;
import com.palantir.dialogue.PlainSerDe;
import com.palantir.dialogue.Request;
import com.palantir.dialogue.Response;
import com.palantir.dialogue.TestConfigurations;
import com.palantir.dialogue.TestEndpoint;
import com.palantir.dialogue.clients.DialogueClients;
import com.palantir.dialogue.clients.DialogueClients.ReloadingFactory;
import com.palantir.logsafe.exceptions.SafeIoException;
import com.palantir.refreshable.Refreshable;
import io.undertow.Undertow;
import io.undertow.server.HttpHandler;
import io.undertow.server.handlers.BlockingHandler;
import io.undertow.util.Headers;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.xnio.IoUtils;

@ExtendWith(MockitoExtension.class)
public final class NoResponseTest {
private static final int MAX_BYTES = 1024;
private static final int PORT = 8080;

private static Channel create(ClientConfiguration config) {
return ApacheHttpClientChannels.create(config, "test");
Expand Down Expand Up @@ -85,6 +113,82 @@ public void testConnectionClosedAfterDelay() {
}
}

@Test
public void testLargePayload() {
String randomHostname = UUID.randomUUID().toString();

// Set up Undertow server with a large payload
String largePayload = "a".repeat(MAX_BYTES * 2); // This payload should exceed the maximum allowed size
Undertow server = createUndertowServerWithPayload(largePayload);
server.start();

try {
ReloadingFactory factory = DialogueClients.create(Refreshable.only(ServicesConfigBlock.builder()
.defaultSecurity(TestConfigurations.SSL_CONFIG)
.putServices(
"foo",
PartialServiceConfiguration.builder()
.addUris(getUri(server, randomHostname))
.build())
.build()))
.withUserAgent(TestConfigurations.AGENT)
.withDnsResolver(hostname -> {
if (randomHostname.equals(hostname)) {
try {
return ImmutableSet.of(
InetAddress.getByAddress(randomHostname, new byte[] {127, 0, 0, 1}));
} catch (UnknownHostException ignored) {
// fall-through
}
}
return ImmutableSet.of();
});

ReloadingFactory reloadingFactory = factory.withRuntime(new ConjureRuntime() {
@Override
public BodySerDe bodySerDe() {
return new ConjureBodySerDe(
cgouttham marked this conversation as resolved.
Show resolved Hide resolved
ImmutableList.of(WeightedEncoding.of(new LimitedSizeEncoding(/* maxBytes= */ 1024))),
ErrorDecoder.INSTANCE,
Encodings.emptyContainerDeserializer(),
DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC);
}

@Override
public PlainSerDe plainSerDe() {
return ConjurePlainSerDe.INSTANCE;
}

@Override
public Clients clients() {
return DefaultClients.INSTANCE;
}
});

Channel channel = reloadingFactory.getChannel("foo");

ListenableFuture<Response> response =
channel.execute(TestEndpoint.GET, Request.builder().build());

// Verify that the deserializer throws the expected exception due to the large payload
assertThatThrownBy(response::get)
.hasCauseInstanceOf(SafeIoException.class)
.hasMessageContaining("Deserialization exceeded the maximum allowed size");
} finally {
server.stop();
}
}

private static Undertow createUndertowServerWithPayload(String payload) {
return Undertow.builder()
.addHttpListener(PORT, "localhost")
.setHandler(new BlockingHandler(exchange -> {
exchange.getResponseHeaders().put(Headers.CONTENT_TYPE, "application/json");
exchange.getResponseSender().send(payload);
}))
.build();
}

@Test
public void testIdleConnectionClosed() throws Exception {
// Pooled connection should be reused, and retried if they've
Expand Down Expand Up @@ -122,6 +226,13 @@ private static ClientConfiguration defaultClientConfig(int port) {
TestConfigurations.AGENT);
}

private static String getUri(Undertow undertow, String hostname) {
Undertow.ListenerInfo listenerInfo = Iterables.getOnlyElement(undertow.getListenerInfo());
return String.format(
"%s://%s:%d",
listenerInfo.getProtcol(), hostname, ((InetSocketAddress) listenerInfo.getAddress()).getPort());
}

private static void assertSuccessfulRequest(Channel channel) throws Exception {
try (Response response = channel.execute(TestEndpoint.POST, request).get()) {
assertThat(response.code()).isEqualTo(200);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ final class ReloadingClientFactory implements DialogueClients.ReloadingFactory {
}

@Override
@SuppressWarnings("NullAway")
public Channel getNonReloadingChannel(String channelName, ClientConfiguration input) {
ClientConfiguration clientConf = hydrate(input);
ApacheHttpClientChannels.ClientBuilder clientBuilder = ApacheHttpClientChannels.clientBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import java.util.stream.Collectors;

/** Package private internal API. */
final class ConjureBodySerDe implements BodySerDe {
public final class ConjureBodySerDe implements BodySerDe {

private static final SafeLogger log = SafeLoggerFactory.get(ConjureBodySerDe.class);
private final List<Encoding> encodingsSortedByWeight;
Expand All @@ -65,7 +65,7 @@ final class ConjureBodySerDe implements BodySerDe {
* {@link Encoding#supportsContentType supports} the serialization format {@link HttpHeaders#ACCEPT accepted}
* by a given request, or the first serializer if no such serializer can be found.
*/
ConjureBodySerDe(
public ConjureBodySerDe(
List<WeightedEncoding> rawEncodings,
ErrorDecoder errorDecoder,
EmptyContainerDeserializer emptyContainerDeserializer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import java.util.UUID;

/** Package private internal API. */
enum ConjurePlainSerDe implements PlainSerDe {
public enum ConjurePlainSerDe implements PlainSerDe {
INSTANCE;

@DoNotLog
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import java.util.concurrent.ExecutionException;

/** Package private internal API. */
enum DefaultClients implements Clients {
public enum DefaultClients implements Clients {
INSTANCE;

private static final SafeLogger log = SafeLoggerFactory.get(DefaultClients.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.palantir.conjure.java.dialogue.serde;

import com.github.benmanes.caffeine.cache.CaffeineSpec;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.palantir.dialogue.BodySerDe;
Expand All @@ -31,8 +30,7 @@
* {@link DefaultConjureRuntime} provides functionality required by generated handlers.
*/
public final class DefaultConjureRuntime implements ConjureRuntime {
@VisibleForTesting
static final CaffeineSpec DEFAULT_SERDE_CACHE_SPEC =
public static final CaffeineSpec DEFAULT_SERDE_CACHE_SPEC =
CaffeineSpec.parse("maximumSize=1000,expireAfterAccess=1m,weakKeys,weakValues");

static final ImmutableList<WeightedEncoding> DEFAULT_ENCODINGS = ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.palantir.conjure.java.serialization.ObjectMappers;
import com.palantir.dialogue.TypeMarker;
import com.palantir.logsafe.Preconditions;
import com.palantir.logsafe.exceptions.SafeIoException;
import java.io.InputStream;
import java.util.function.Supplier;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -84,6 +85,60 @@ public boolean supportsContentType(String contentType) {
};
}

public static final class LimitedSizeEncoding implements Encoding {
private final Encoding jsonEncoding;
private final int maxBytes;

public LimitedSizeEncoding(int maxBytes) {
this.jsonEncoding = Encodings.json();
this.maxBytes = maxBytes;
}

@Override
public <T> Serializer<T> serializer(TypeMarker<T> type) {
return jsonEncoding.serializer(type);
}

@Override
public <T> Deserializer<T> deserializer(TypeMarker<T> type) {
Deserializer<T> delegate = jsonEncoding.deserializer(type);
return input -> {
int chunkSize = 1024; // set this to a suitable size
byte[] buffer = new byte[chunkSize];

int bytesRead;
int totalBytes = 0;
while ((bytesRead = input.readNBytes(buffer, 0, chunkSize)) > 0) {
totalBytes += bytesRead;
if (totalBytes > maxBytes) {
throw new SafeIoException("Deserialization exceeded the maximum allowed size");
}
}

// Reset the input stream to the beginning
if (input.markSupported()) {
input.reset();
} else {
throw new SafeIoException("Cannot reset the input stream");
}

// Now delegate to JSON deserializer
T value = delegate.deserialize(input);
return Preconditions.checkNotNull(value, "cannot deserialize a JSON null value");
};
}

@Override
public String getContentType() {
return jsonEncoding.getContentType();
}

@Override
public boolean supportsContentType(String contentType) {
return jsonEncoding.supportsContentType(contentType);
}
}

/** Returns a serializer for the Conjure CBOR wire format. */
public static Encoding cbor() {
return new AbstractJacksonEncoding(configure(ObjectMappers.newCborClientObjectMapper())) {
Expand Down Expand Up @@ -118,7 +173,7 @@ public boolean supportsContentType(String contentType) {
};
}

static EmptyContainerDeserializer emptyContainerDeserializer() {
public static EmptyContainerDeserializer emptyContainerDeserializer() {
return new JacksonEmptyContainerLoader(JSON_MAPPER.get());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* Note that the weight may not be applied to the Accept header, rather
* used to order values.
*/
final class WeightedEncoding {
public final class WeightedEncoding {

private final Encoding encoding;
private final double weight;
Expand All @@ -39,7 +39,7 @@ static WeightedEncoding of(Encoding encoding, double weight) {
return new WeightedEncoding(encoding, weight);
}

static WeightedEncoding of(Encoding encoding) {
public static WeightedEncoding of(Encoding encoding) {
return new WeightedEncoding(encoding, 1);
}

Expand Down
Loading