Skip to content

Commit

Permalink
8348108: Race condition in AggregatePublisher.AggregateSubscription
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuch committed Jan 20, 2025
1 parent 3a4d5ff commit eb0811a
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2016, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -537,12 +537,20 @@ private static final class AggregateSubscription

@Override
public void request(long n) {
if (cancelled || publisher == null && bodies.isEmpty()) {
return;
synchronized (this) {
// We are finished when publisher is null and bodies
// is empty. This means that the data from the last
// publisher in the list has been consumed.
// If we are finished or cancelled, do nothing.
if (cancelled || (publisher == null && bodies.isEmpty())) {
return;
}
}
try {
demand.increase(n);
} catch (IllegalArgumentException x) {
// request() should not throw - the scheduler will
// invoke onError on the subscriber.
illegalRequest = x;
}
scheduler.runOrSchedule();
Expand All @@ -554,46 +562,68 @@ public void cancel() {
scheduler.runOrSchedule();
}

private boolean cancelSubscription() {
Flow.Subscription subscription = this.subscription;
private boolean cancelSubscription(Flow.Subscription subscription) {
if (subscription != null) {
this.subscription = null;
this.publisher = null;
synchronized (this) {
if (this.subscription == subscription) {
this.subscription = null;
this.publisher = null;
}
}
subscription.cancel();
}
// This nethod is called when cancel is true, so
// we should always stop the scheduler here
scheduler.stop();
return subscription != null;
}

public void run() {
try {
BodyPublisher publisher;
Flow.Subscription subscription = null;
while (error.get() == null
&& (!demand.isFulfilled()
|| (publisher == null && !bodies.isEmpty()))) {
|| (this.publisher == null && !bodies.isEmpty()))) {
boolean cancelled = this.cancelled;
BodyPublisher publisher = this.publisher;
Flow.Subscription subscription = this.subscription;
// make sure we see a consistent state.
synchronized (this) {
publisher = this.publisher;
subscription = this.subscription;
}
Throwable illegalRequest = this.illegalRequest;
if (cancelled) {
bodies.clear();
cancelSubscription();
cancelSubscription(subscription);
return;
}
if (publisher == null && !bodies.isEmpty()) {
this.publisher = publisher = bodies.poll();
// synchronize here to avoid race condition with
// request(long) which could otherwise observe a
// null publisher and an empty bodies list when
// polling the last publisher.
synchronized (this) {
this.publisher = publisher = bodies.poll();
}
publisher.subscribe(this);
subscription = this.subscription;
} else if (publisher == null) {
return;
}
if (illegalRequest != null) {
onError(illegalRequest);
return;
}
if (subscription == null) return;
if (!demand.isFulfilled()) {
long n = demand.decreaseAndGet(demand.get());
demanded.increase(n);
long n = 0;
// synchronize to avoid race condition with
// publisherDone()
synchronized (this) {
if ((subscription = this.subscription) == null) return;
if (!demand.isFulfilled()) {
n = demand.decreaseAndGet(demand.get());
demanded.increase(n);
}
}
if (n > 0 && !cancelled) {
subscription.request(n);
}
}
Expand All @@ -602,20 +632,35 @@ public void run() {
}
}

// It is important to synchronize when setting
// publisher to null to avoid race conditions
// with request(long)
private synchronized void publisherDone() {
publisher = null;
subscription = null;
}


@Override
public void onSubscribe(Flow.Subscription subscription) {
this.subscription = subscription;
// synchronize for asserting in a consistent state.
synchronized (this) {
// we shouldn't be able to observe a null publisher
// when onSubscribe is called, unless - possibly - if
// there was some error...
assert publisher != null || error.get() != null;
this.subscription = subscription;
}
scheduler.runOrSchedule();
}

@Override
public void onNext(ByteBuffer item) {
// make sure to cancel the subscription if we receive
// an item after the subscription was cancelled or
// make sure to cancel the downstream subscription if we receive
// an item after the aggregate subscription was cancelled or
// an error was reported.
if (cancelled || error.get() != null) {
cancelSubscription();
cancelSubscription(this.subscription);
return;
}
demanded.tryDecrement();
Expand All @@ -625,30 +670,36 @@ public void onNext(ByteBuffer item) {
@Override
public void onError(Throwable throwable) {
if (error.compareAndSet(null, throwable)) {
publisher = null;
subscription = null;
publisherDone();
subscriber.onError(throwable);
scheduler.stop();
}
}

@Override
public void onComplete() {
private synchronized boolean completeAndContinue() {
if (publisher != null && !bodies.isEmpty()) {
while (!demanded.isFulfilled()) {
demand.increase(demanded.decreaseAndGet(demanded.get()));
}
publisher = null;
subscription = null;
publisherDone();
return true; // continue
} else {
publisherDone();
return false; // stop
}
}

@Override
public void onComplete() {
if (completeAndContinue()) {
scheduler.runOrSchedule();
} else {
publisher = null;
subscription = null;
if (!cancelled) {
subscriber.onComplete();
}
scheduler.stop();
}
}
}

}
41 changes: 29 additions & 12 deletions test/jdk/java/net/httpclient/AggregateRequestBodyTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -33,8 +33,6 @@
* @summary Tests HttpRequest.BodyPublishers::concat
*/

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
Expand All @@ -57,6 +55,7 @@
import java.util.concurrent.Flow;
import java.util.concurrent.Flow.Subscriber;
import java.util.concurrent.Flow.Subscription;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -67,12 +66,8 @@
import java.util.stream.LongStream;
import java.util.stream.Stream;
import jdk.httpclient.test.lib.common.HttpServerAdapters;
import jdk.httpclient.test.lib.http2.Http2TestServer;
import javax.net.ssl.SSLContext;

import com.sun.net.httpserver.HttpServer;
import com.sun.net.httpserver.HttpsConfigurator;
import com.sun.net.httpserver.HttpsServer;
import jdk.test.lib.net.SimpleSSLContext;
import org.testng.Assert;
import org.testng.ITestContext;
Expand Down Expand Up @@ -423,6 +418,8 @@ static class RequestSubscriber implements Flow.Subscriber<ByteBuffer> {
ConcurrentLinkedDeque<ByteBuffer> items = new ConcurrentLinkedDeque<>();
CompletableFuture<List<ByteBuffer>> resultCF = new CompletableFuture<>();

Semaphore semaphore = new Semaphore(0);

@Override
public void onSubscribe(Subscription subscription) {
this.subscriptionCF.complete(subscription);
Expand All @@ -431,6 +428,11 @@ public void onSubscribe(Subscription subscription) {
@Override
public void onNext(ByteBuffer item) {
items.addLast(item);
int available = semaphore.availablePermits();
if (semaphore.availablePermits() > Integer.MAX_VALUE - 8) {
onError(new IllegalStateException("too many buffers in queue: " + available));
}
semaphore.release();
}

@Override
Expand All @@ -443,6 +445,18 @@ public void onComplete() {
resultCF.complete(items.stream().collect(Collectors.toUnmodifiableList()));
}

public ByteBuffer take() {
// it is not guaranteed that the buffer will be added to
// the queue in the same thread that calls request(1).
try {
semaphore.acquire();
} catch (InterruptedException x) {
Thread.currentThread().interrupt();
throw new CompletionException(x);
}
return items.pop();
}

CompletableFuture<List<ByteBuffer>> resultCF() { return resultCF; }
}

Expand Down Expand Up @@ -628,8 +642,9 @@ public void testPositiveRequests() {
publisher.subscribe(requestSubscriber1);
Subscription subscription1 = requestSubscriber1.subscriptionCF.join();
subscription1.request(16);
assertTrue(requestSubscriber1.resultCF().isDone());
// onNext() may not be called in the same thread than request()
List<ByteBuffer> list1 = requestSubscriber1.resultCF().join();
assertTrue(requestSubscriber1.resultCF().isDone());
String result1 = stringFromBytes(list1.stream());
assertEquals(result1, "Lorem ipsum dolor sit amet, consectetur adipiscing elit.");
System.out.println("Got expected sentence with one request: \"%s\"".formatted(result1));
Expand All @@ -646,8 +661,8 @@ public void testPositiveRequests() {
subscription2.request(4);
assertFalse(requestSubscriber2.resultCF().isDone());
subscription2.request(1);
assertTrue(requestSubscriber2.resultCF().isDone());
List<ByteBuffer> list2 = requestSubscriber2.resultCF().join();
assertTrue(requestSubscriber2.resultCF().isDone());
String result2 = stringFromBytes(list2.stream());
assertEquals(result2, "Lorem ipsum dolor sit amet, consectetur adipiscing elit.");
System.out.println("Got expected sentence with 4 requests: \"%s\"".formatted(result1));
Expand Down Expand Up @@ -689,7 +704,7 @@ public void testCancel() {
// receive half the data
for (int i = 0; i < n; i++) {
subscription.request(1);
ByteBuffer buffer = subscriber.items.pop();
ByteBuffer buffer = subscriber.take();
}

// cancel subscription
Expand Down Expand Up @@ -789,7 +804,8 @@ public void testCancelSubscription() {
@Test(dataProvider = "variants")
public void test(String uri, boolean sameClient) throws Exception {
checkSkip();
System.out.println("Request to " + uri);
System.out.printf("Request to %s (sameClient: %s)%n", uri, sameClient);
System.err.printf("Request to %s (sameClient: %s)%n", uri, sameClient);

HttpClient client = newHttpClient(sameClient);

Expand All @@ -802,7 +818,8 @@ public void test(String uri, boolean sameClient) throws Exception {
.POST(publisher)
.build();
for (int i = 0; i < ITERATION_COUNT; i++) {
System.out.println("Iteration: " + i);
System.out.println(uri + ": Iteration: " + i);
System.err.println(uri + ": Iteration: " + i);
HttpResponse<String> response = client.send(request, BodyHandlers.ofString());
int expectedResponse = RESPONSE_CODE;
if (response.statusCode() != expectedResponse)
Expand Down

0 comments on commit eb0811a

Please sign in to comment.