Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Laub committed Nov 8, 2024
1 parent 35d9525 commit bed8cd1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
import com.palantir.dialogue.Response;
import com.palantir.dialogue.core.CautiousIncreaseAggressiveDecreaseConcurrencyLimiter.Behavior;
import com.palantir.dialogue.core.DialogueChannel.StateHolder;
import com.palantir.dialogue.core.DialogueChannel.StateHolderKey;
import com.palantir.dialogue.futures.DialogueFutures;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.exceptions.SafeIllegalArgumentException;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
import com.palantir.tritium.metrics.registry.TaggedMetricRegistry;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.LongStream;
import org.immutables.value.Value;

Expand All @@ -43,26 +42,16 @@
final class ConcurrencyLimitedChannel implements LimitedChannel {
private static final SafeLogger log = SafeLoggerFactory.get(ConcurrencyLimitedChannel.class);

@Value.Immutable
interface ConcurrencyLimitedChannelState extends StateHolder {
CautiousIncreaseAggressiveDecreaseConcurrencyLimiter hostLimiter();

CautiousIncreaseAggressiveDecreaseConcurrencyLimiter endpointLimiter();
}
private static final StateHolderKey<ConcurrencyLimitedChannelState> STATE_HOLDER_KEY =
new StateHolderKey<>(ConcurrencyLimitedChannelState.class, ConcurrencyLimitedChannel::createState);

private final NeverThrowChannel delegate;
private final CautiousIncreaseAggressiveDecreaseConcurrencyLimiter limiter;
private final String channelNameForLogging;

static LimitedChannel createForHost(
Config cf,
Channel channel,
int uriIndex,
TargetUri targetUri,
BiFunction<TargetUri, Supplier<StateHolder>, StateHolder> stateHolderFactory) {
static LimitedChannel createForHost(Config cf, Channel channel, int uriIndex, StateHolder stateHolder) {
TaggedMetricRegistry metrics = cf.clientConf().taggedMetricRegistry();
ConcurrencyLimitedChannelState state = (ConcurrencyLimitedChannelState)
stateHolderFactory.apply(targetUri, ConcurrencyLimitedChannel::createState);
ConcurrencyLimitedChannelState state = stateHolder.getState(STATE_HOLDER_KEY);
ConcurrencyLimitedChannelInstrumentation instrumentation = new HostConcurrencyLimitedChannelInstrumentation(
cf.channelName(), uriIndex, state.hostLimiter(), metrics);
return new ConcurrencyLimitedChannel(channel, state.hostLimiter(), instrumentation);
Expand All @@ -73,14 +62,8 @@ static LimitedChannel createForHost(
* Metrics are not reported by this component per-endpoint, only by the per-endpoint queue.
*/
static LimitedChannel createForEndpoint(
Channel channel,
String channelName,
int uriIndex,
TargetUri targetUri,
Endpoint endpoint,
BiFunction<TargetUri, Supplier<StateHolder>, StateHolder> stateHolderFactory) {
ConcurrencyLimitedChannelState state = (ConcurrencyLimitedChannelState)
stateHolderFactory.apply(targetUri, ConcurrencyLimitedChannel::createState);
Channel channel, String channelName, int uriIndex, Endpoint endpoint, StateHolder stateHolder) {
ConcurrencyLimitedChannelState state = stateHolder.getState(STATE_HOLDER_KEY);
return new ConcurrencyLimitedChannel(
channel,
state.endpointLimiter(),
Expand All @@ -103,10 +86,6 @@ static ConcurrencyLimitedChannelState createState() {
.build();
}

static CautiousIncreaseAggressiveDecreaseConcurrencyLimiter createLimiter(Behavior behavior) {
return new CautiousIncreaseAggressiveDecreaseConcurrencyLimiter(behavior);
}

@Override
public Optional<ListenableFuture<Response>> maybeExecute(
Endpoint endpoint, Request request, LimitEnforcement limitEnforcement) {
Expand Down Expand Up @@ -210,4 +189,11 @@ public String channelNameForLogging() {
return channelNameForLogging;
}
}

@Value.Immutable
interface ConcurrencyLimitedChannelState {
CautiousIncreaseAggressiveDecreaseConcurrencyLimiter hostLimiter();

CautiousIncreaseAggressiveDecreaseConcurrencyLimiter endpointLimiter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
import com.palantir.dialogue.EndpointChannelFactory;
import com.palantir.dialogue.Request;
import com.palantir.dialogue.Response;
import com.palantir.logsafe.Preconditions;
import com.palantir.logsafe.Safe;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.UnsafeArg;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
import com.palantir.refreshable.Refreshable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -82,10 +84,38 @@ public String toString() {
+ cf.channelName() + ", delegate=" + delegate + '}';
}

interface StateHolder {}
static final class StateHolderKey<T> {
private final Class<T> valueClass;
private final Supplier<T> factory;

interface TargetUriStateSupplier extends Function<List<TargetUri>, LimitedChannel> {
StateHolder getState(TargetUri targetUri, Supplier<StateHolder> stateHolderFactory);
T cast(final Object value) {
return valueClass.cast(value);
}

Supplier<T> getFactory() {
return factory;
}

StateHolderKey(final Class<T> valueClass, Supplier<T> factory) {
this.valueClass = valueClass;
this.factory = factory;
}
}

static final class StateHolder {
@SuppressWarnings("DangerousIdentityKey")
private final Map<StateHolderKey<?>, Object> state = new HashMap<>();

<T> T getState(StateHolderKey<T> key) {
if (state.containsKey(key)) {
return key.cast(state.get(key));
} else {
T value = key.getFactory().get();
Preconditions.checkNotNull(value, "state factory cannot produce a null value");
state.put(key, value);
return value;
}
}
}

public static final class Builder {
Expand Down Expand Up @@ -185,34 +215,33 @@ public DialogueChannel build() {
// Reloading currently forgets channel state (pinned target, channel scores, concurrency limits, etc...)
// In a future change we should attempt to retain this state for channels that are retained between
// updates.
LimitedChannel nodeSelectionChannel = new SupplierChannel(cf.uris().map(new TargetUriStateSupplier() {
private final Map<TargetUri, StateHolder> state = new HashMap<>();

@Override
public StateHolder getState(TargetUri targetUri, Supplier<StateHolder> stateHolderFactory) {
return state.getOrDefault(targetUri, stateHolderFactory.get());
}

@Override
public LimitedChannel apply(List<TargetUri> targetUris) {
// remove state for uris we no longer care about
Set<TargetUri> toRemove = state.keySet().stream()
.filter(uri -> !targetUris.contains(uri))
.collect(Collectors.toSet());
toRemove.forEach(state::remove);

reloadMeter.mark();
log.info(
"Reloaded channel '{}' targets. (uris: {}, numUris: {}, targets: {}, numTargets: {})",
SafeArg.of("channel", cf.channelName()),
UnsafeArg.of("uris", cf.clientConf().uris()),
SafeArg.of("numUris", cf.clientConf().uris().size()),
UnsafeArg.of("targets", targetUris),
SafeArg.of("numTargets", targetUris.size()));
ImmutableList<LimitedChannel> targetChannels = createHostChannels(cf, targetUris, this);
return NodeSelectionStrategyChannel.create(cf, targetChannels);
}
}));
LimitedChannel nodeSelectionChannel =
new SupplierChannel(cf.uris().map(new Function<List<TargetUri>, LimitedChannel>() {
private final Map<TargetUri, StateHolder> state = new HashMap<>();

@Override
public LimitedChannel apply(List<TargetUri> targetUris) {
// remove state for uris we no longer care about, and create new StateHolders
// for uris we don't know about yet
Set<TargetUri> toRemove = state.keySet().stream()
.filter(uri -> !targetUris.contains(uri))
.collect(Collectors.toSet());
toRemove.forEach(state::remove);
targetUris.forEach(uri -> state.putIfAbsent(uri, new StateHolder()));

reloadMeter.mark();
log.info(
"Reloaded channel '{}' targets. (uris: {}, numUris: {}, targets: {}, numTargets: {})",
SafeArg.of("channel", cf.channelName()),
UnsafeArg.of("uris", cf.clientConf().uris()),
SafeArg.of("numUris", cf.clientConf().uris().size()),
UnsafeArg.of("targets", targetUris),
SafeArg.of("numTargets", targetUris.size()));
ImmutableList<LimitedChannel> targetChannels =
createHostChannels(cf, targetUris, Collections.unmodifiableMap(state));
return NodeSelectionStrategyChannel.create(cf, targetChannels);
}
}));

LimitedChannel stickyValidationChannel = new StickyValidationChannel(nodeSelectionChannel);

Expand All @@ -233,7 +262,7 @@ public LimitedChannel apply(List<TargetUri> targetUris) {
}

private static ImmutableList<LimitedChannel> createHostChannels(
Config cf, List<TargetUri> targetUris, TargetUriStateSupplier stateSupplier) {
Config cf, List<TargetUri> targetUris, Map<TargetUri, StateHolder> state) {
ImmutableList.Builder<LimitedChannel> perUriChannels = ImmutableList.builder();
for (int uriIndex = 0; uriIndex < targetUris.size(); uriIndex++) {
final int uriIndexForInstrumentation =
Expand All @@ -250,6 +279,9 @@ private static ImmutableList<LimitedChannel> createHostChannels(
channel =
new TraceEnrichingChannel(channel, DialogueTracing.tracingTags(cf, uriIndexForInstrumentation));

StateHolder stateHolder = state.get(targetUri);
Preconditions.checkNotNull(stateHolder, "no StateHolder exists for this TargetUri");

LimitedChannel limitedChannel;
if (cf.isConcurrencyLimitingEnabled()) {
Channel unlimited = channel;
Expand All @@ -258,16 +290,11 @@ private static ImmutableList<LimitedChannel> createHostChannels(
return unlimited;
}
LimitedChannel limited = ConcurrencyLimitedChannel.createForEndpoint(
unlimited,
cf.channelName(),
uriIndexForInstrumentation,
targetUri,
endpoint,
stateSupplier::getState);
unlimited, cf.channelName(), uriIndexForInstrumentation, endpoint, stateHolder);
return QueuedChannel.create(cf, endpoint, limited);
});
limitedChannel = ConcurrencyLimitedChannel.createForHost(
cf, channel, uriIndexForInstrumentation, targetUri, stateSupplier::getState);
cf, channel, uriIndexForInstrumentation, stateHolder);
} else {
limitedChannel = new ChannelToLimitedChannelAdapter(channel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public void testUnavailable_endpoint() {
public void testWithDefaultLimiter() {
channel = new ConcurrencyLimitedChannel(
delegate,
ConcurrencyLimitedChannel.createLimiter(Behavior.HOST_LEVEL),
new CautiousIncreaseAggressiveDecreaseConcurrencyLimiter(Behavior.HOST_LEVEL),
NopConcurrencyLimitedChannelInstrumentation.INSTANCE);

assertThat(channel.maybeExecute(endpoint, request, LimitEnforcement.DEFAULT_ENABLED))
Expand Down

0 comments on commit bed8cd1

Please sign in to comment.