From acf81c1ab2aff7cc9826a31eb47d710468344dd4 Mon Sep 17 00:00:00 2001 From: Daniel Fuchs Date: Fri, 7 Feb 2025 13:43:08 +0000 Subject: [PATCH] 8349662: SSLTube SSLSubscriptionWrapper has potential races when switching subscriptions --- .../jdk/internal/net/http/common/SSLTube.java | 81 +++++++++++----- .../java/net/httpclient/CookieHeaderTest.java | 8 +- .../java/net/httpclient/DigestEchoClient.java | 97 ++++++++++++++----- 3 files changed, 133 insertions(+), 53 deletions(-) diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/SSLTube.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/SSLTube.java index c35a6f62a1b18..5c8870e38d8ad 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/SSLTube.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/SSLTube.java @@ -119,7 +119,7 @@ void connect(Flow.Subscriber> downReader, // Connect the read sink first. That's the left-hand side // downstream subscriber from the HttpConnection (or more // accurately, the SSLSubscriberWrapper that will wrap it - // when SSLTube::connectFlows is called. + // when SSLTube::connectFlows is called). reader.subscribe(downReader); // Connect the right hand side tube (the socket tube). @@ -191,7 +191,7 @@ public boolean isFinished() { private volatile Flow.Subscription readSubscription; // The DelegateWrapper wraps a subscribed {@code Flow.Subscriber} and - // tracks the subscriber's state. In particular it makes sure that + // tracks the subscriber's state. In particular, it makes sure that // onComplete/onError are not called before onSubscribed. static final class DelegateWrapper implements FlowTube.TubeSubscriber { private final FlowTube.TubeSubscriber delegate; @@ -302,7 +302,7 @@ public String toString() { // Used to read data from the SSLTube. final class SSLSubscriberWrapper implements FlowTube.TubeSubscriber { - private AtomicReference pendingDelegate = + private final AtomicReference pendingDelegate = new AtomicReference<>(); private volatile DelegateWrapper subscribed; private volatile boolean onCompleteReceived; @@ -353,15 +353,15 @@ void setDelegate(Flow.Subscriber> delegate) { return; } // sslDelegate field should have been initialized by the - // the time we reach here, as there can be no subscriber + // time we reach here, as there can be no subscriber // until SSLTube is fully constructed. if (handleNow || !sslDelegate.resumeReader()) { processPendingSubscriber(); } } - // Can be called outside of the flow if an error has already been - // raise. Otherwise, must be called within the SSLFlowDelegate + // Can be called outside the flow if an error has already been + // raised. Otherwise, must be called within the SSLFlowDelegate // downstream reader flow. // If there is a subscription, and if there is a pending delegate, // calls dropSubscription() on the previous delegate (if any), @@ -619,32 +619,57 @@ final class SSLSubscriptionWrapper implements Flow.Subscription { private volatile boolean cancelled; void setSubscription(Flow.Subscription sub) { - long demand = writeDemand.get(); // FIXME: isn't it a racy way of passing the demand? - delegate = sub; - if (debug.on()) - debug.log("setSubscription: demand=%d, cancelled:%s", demand, cancelled); + long demand; + // Avoid race condition and requesting demand twice if + // request() runs concurrently with setSubscription() + boolean cancelled; + synchronized (this) { + demand = writeDemand.get(); + delegate = sub; + cancelled = this.cancelled; + } + if (debug.on()) { + debug.log("setSubscription: demand=%d, cancelled:%s, new subscription %s", + demand, cancelled, sub); + } if (cancelled) - delegate.cancel(); + sub.cancel(); else if (demand > 0) sub.request(demand); } @Override public void request(long n) { - writeDemand.increase(n); - if (debug.on()) debug.log("request: n=%d", n); - Flow.Subscription sub = delegate; - if (sub != null && n > 0) { - sub.request(n); + final long demand = n; + // Avoid race condition and requesting demand twice if + // request() runs concurrently with setSubscription() + Flow.Subscription sub; + long demanded; + synchronized (this) { + sub = delegate; + demanded = writeDemand.get(); + writeDemand.increase(n); + } + if (debug.on()) { + debug.log("request: n=%s to %s (%s already demanded)", + demand, sub, demanded); + } + if (sub != null && demand > 0) { + if (debug.on()) debug.log("requesting %s from %s", demand, sub); + sub.request(demand); } } @Override public void cancel() { - cancelled = true; - if (delegate != null) - delegate.cancel(); + Flow.Subscription sub; + synchronized (this) { + cancelled = true; + sub = delegate; + } + if (debug.on()) debug.log("cancel: cancelling subscription: " + sub); + if (sub != null) sub.cancel(); } } @@ -652,10 +677,16 @@ public void cancel() { @Override public void onSubscribe(Flow.Subscription subscription) { Objects.requireNonNull(subscription); - Flow.Subscription x = writeSubscription.delegate; - if (x != null) - x.cancel(); + Flow.Subscription old; + synchronized (this) { + old = writeSubscription.delegate; + } + if (old != null && old != subscription) { + if (debug.on()) debug.log("onSubscribe: cancelling old subscription: " + old); + old.cancel(); + } + if (debug.on()) debug.log("onSubscribe: new subscription: " + subscription); writeSubscription.setSubscription(subscription); } @@ -664,8 +695,10 @@ public void onNext(List item) { Objects.requireNonNull(item); boolean decremented = writeDemand.tryDecrement(); assert decremented : "Unexpected writeDemand: "; - if (debug.on()) - debug.log("sending %d buffers to SSL flow delegate", item.size()); + if (debug.on()) { + debug.log("sending %s buffers to SSL flow delegate (%s bytes)", + item.size(), Utils.remaining(item)); + } sslDelegate.upstreamWriter().onNext(item); } diff --git a/test/jdk/java/net/httpclient/CookieHeaderTest.java b/test/jdk/java/net/httpclient/CookieHeaderTest.java index d5eca06c0f0ed..b39a23371abf8 100644 --- a/test/jdk/java/net/httpclient/CookieHeaderTest.java +++ b/test/jdk/java/net/httpclient/CookieHeaderTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -33,9 +33,6 @@ * CookieHeaderTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -51,7 +48,6 @@ import java.io.PrintWriter; import java.io.Writer; import java.net.CookieHandler; -import java.net.CookieManager; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -65,7 +61,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -76,7 +71,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; diff --git a/test/jdk/java/net/httpclient/DigestEchoClient.java b/test/jdk/java/net/httpclient/DigestEchoClient.java index 7038d82d91fdc..7b4f7fd34620c 100644 --- a/test/jdk/java/net/httpclient/DigestEchoClient.java +++ b/test/jdk/java/net/httpclient/DigestEchoClient.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -58,6 +58,7 @@ import sun.net.www.HeaderParser; import static java.lang.System.out; +import static java.lang.System.err; import static java.lang.String.format; /** @@ -304,6 +305,9 @@ public static void main(String[] args) throws Exception { } catch(Throwable t) { out.println(DigestEchoServer.now() + ": Unexpected exception: " + t); + t.printStackTrace(System.out); + err.println(DigestEchoServer.now() + + ": Unexpected exception: " + t); t.printStackTrace(); failed = t; throw t; @@ -393,15 +397,21 @@ void testBasic(Version clientVersion, Version serverVersion, boolean async, HttpResponse r; CompletableFuture> cf1; String auth = null; + Throwable failed = null; + URI reqURI = null; try { - for (int i=0; i lines = List.of(Arrays.copyOfRange(data, 0, i+1)); + List lines = List.of(Arrays.copyOfRange(data, 0, i + 1)); assert lines.size() == i + 1; String body = lines.stream().collect(Collectors.joining("\r\n")); BodyPublisher reqBody = BodyPublishers.ofString(body); - HttpRequest.Builder builder = HttpRequest.newBuilder(uri).version(clientVersion) + URI baseReq = URI.create(uri + "?iteration=" + i + ",async=" + async + + ",addHeaders=" + addHeaders + ",preemptive=" + preemptive + + ",expectContinue=" + expectContinue + ",version=" + clientVersion); + reqURI = URI.create(baseReq + ",basicCount=" + basicCount.get()); + HttpRequest.Builder builder = HttpRequest.newBuilder(reqURI).version(clientVersion) .POST(reqBody).expectContinue(expectContinue); boolean isTunnel = isProxy(authType) && useSSL; if (addHeaders) { @@ -433,8 +443,10 @@ void testBasic(Version clientVersion, Version serverVersion, boolean async, HttpResponse> resp; try { if (async) { + out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request); resp = client.sendAsync(request, BodyHandlers.ofLines()).join(); } else { + out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } } catch (Throwable t) { @@ -443,17 +455,10 @@ void testBasic(Version clientVersion, Version serverVersion, boolean async, long n = basicCount.getAndIncrement(); basics.set((basics.get() * n + (stop - start)) / (n + 1)); } - // unwrap CompletionException - if (t instanceof CompletionException) { - assert t.getCause() != null; - t = t.getCause(); - } - out.println(DigestEchoServer.now() - + ": Unexpected exception: " + t); - throw new RuntimeException("Unexpected exception: " + t, t); + throw t; } - if (addHeaders && !preemptive && (i==0 || isSchemeDisabled())) { + if (addHeaders && !preemptive && (i == 0 || isSchemeDisabled())) { assert resp.statusCode() == 401 || resp.statusCode() == 407; Stream respBody = resp.body(); if (respBody != null) { @@ -462,11 +467,15 @@ void testBasic(Version clientVersion, Version serverVersion, boolean async, } System.out.println(String.format("%s received: adding header %s: %s", resp.statusCode(), authorizationKey(authType), auth)); - request = HttpRequest.newBuilder(uri).version(clientVersion) + reqURI = URI.create(baseReq + ",withAuthorization=" + + authType + ",basicCount=" + basicCount.get()); + request = HttpRequest.newBuilder(reqURI).version(clientVersion) .POST(reqBody).header(authorizationKey(authType), auth).build(); if (async) { + out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request); resp = client.sendAsync(request, BodyHandlers.ofLines()).join(); } else { + out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } } @@ -500,6 +509,15 @@ void testBasic(Version clientVersion, Version serverVersion, boolean async, throw new RuntimeException("Unexpected response: " + respLines); } } + } catch (Throwable t) { + if (reqURI == null) { + failed = t; + throw t; + } + String decoration = "%s Unexpected exception %s for %s".formatted(DigestEchoServer.now(), t, reqURI); + RuntimeException decorated = new RuntimeException(decoration, t); + failed = decorated; + throw decorated; } finally { client = null; System.gc(); @@ -508,7 +526,10 @@ void testBasic(Version clientVersion, Version serverVersion, boolean async, if (queue.remove(100) == ref) break; } var error = TRACKER.checkShutdown(900); - if (error != null) throw error; + if (error != null) { + if (failed != null) error.addSuppressed(failed); + throw error; + } } System.out.println("OK"); } @@ -546,16 +567,22 @@ void testDigest(Version clientVersion, Version serverVersion, byte[] cnonce = new byte[16]; String cnonceStr = null; DigestEchoServer.DigestResponse challenge = null; - + ReferenceQueue queue = new ReferenceQueue<>(); + WeakReference ref = new WeakReference<>(client, queue); + URI reqURI = null; + Throwable failed = null; try { - for (int i=0; i lines = List.of(Arrays.copyOfRange(data, 0, i+1)); + List lines = List.of(Arrays.copyOfRange(data, 0, i + 1)); assert lines.size() == i + 1; String body = lines.stream().collect(Collectors.joining("\r\n")); HttpRequest.BodyPublisher reqBody = HttpRequest.BodyPublishers.ofString(body); + URI baseReq = URI.create(uri + "?iteration=" + i + ",async=" + async + + ",expectContinue=" + expectContinue + ",version=" + clientVersion); + reqURI = URI.create(baseReq + ",digestCount=" + digestCount.get()); HttpRequest.Builder reqBuilder = HttpRequest - .newBuilder(uri).version(clientVersion).POST(reqBody) + .newBuilder(reqURI).version(clientVersion).POST(reqBody) .expectContinue(expectContinue); boolean isTunnel = isProxy(authType) && useSSL; @@ -578,8 +605,10 @@ void testDigest(Version clientVersion, Version serverVersion, HttpRequest request = reqBuilder.build(); HttpResponse> resp; if (async) { + out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request); resp = client.sendAsync(request, BodyHandlers.ofLines()).join(); } else { + out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } System.out.println(resp); @@ -609,16 +638,18 @@ void testDigest(Version clientVersion, Version serverVersion, challenge = DigestEchoServer.DigestResponse .create(authenticate.substring("Digest ".length())); String auth = digestResponse(uri, digestMethod, challenge, cnonceStr); + reqURI = URI.create(baseReq + ",withAuth=" + authType + ",digestCount=" + digestCount.get()); try { - request = HttpRequest.newBuilder(uri).version(clientVersion) - .POST(reqBody).header(authorizationKey(authType), auth).build(); + request = HttpRequest.newBuilder(reqURI).version(clientVersion) + .POST(reqBody).header(authorizationKey(authType), auth).build(); } catch (IllegalArgumentException x) { throw x; } - if (async) { + out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request); resp = client.sendAsync(request, BodyHandlers.ofLines()).join(); } else { + out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } System.out.println(resp); @@ -649,7 +680,29 @@ void testDigest(Version clientVersion, Version serverVersion, throw new RuntimeException("Unexpected response: " + respLines); } } + } catch (Throwable t) { + if (reqURI == null) { + failed = t; + throw t; + } + String decoration = "%s Unexpected exception %s for %s".formatted(DigestEchoServer.now(), t, reqURI); + RuntimeException decorated = new RuntimeException(decoration, t); + failed = decorated; + throw decorated; } finally { + client = null; + System.gc(); + while (!ref.refersTo(null)) { + System.gc(); + if (queue.remove(100) == ref) break; + } + var error = TRACKER.checkShutdown(900); + if (error != null) { + if (failed != null) { + error.addSuppressed(failed); + } + throw error; + } } System.out.println("OK"); }