Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI/JSSE: optimize out array creation in WolfSSLEngine RecvAppData() #244

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 250 additions & 63 deletions native/com_wolfssl_WolfSSLSession.c
Original file line number Diff line number Diff line change
@@ -1002,16 +1002,104 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
}
}

JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
/**
* Read len bytes from wolfSSL_read() back into provided output buffer.
*
* Internal function called by WolfSSLSession.read() calls.
*
* If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait for
* data with select()/poll() up to provided timeout.
*
* Returns number of bytes read on success, or negative on error.
*/
static int SSLReadNonblockingWithSelectPoll(WOLFSSL* ssl, byte* out,
int length, int timeout)
{
byte* data = NULL;
int size = 0, ret, err, sockfd;
int size, ret, err, sockfd;
int pollRx = 0;
int pollTx = 0;
wolfSSL_Mutex* jniSessLock = NULL;
SSLAppData* appData = NULL;

if (ssl == NULL || out == NULL) {
return BAD_FUNC_ARG;
}

/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
return WOLFSSL_FAILURE;
}

jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
return WOLFSSL_FAILURE;
}

do {
/* lock mutex around session I/O before read attempt */
if (wc_LockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}

size = wolfSSL_read(ssl, out, length);
err = wolfSSL_get_error(ssl, size);

/* unlock mutex around session I/O after read attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}

if (size < 0 &&
((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) {

sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}

if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}

#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(sockfd, timeout, pollRx);
#else
ret = socketPoll(sockfd, timeout, pollRx, pollTx);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_read() again */
continue;
} else {
/* Java will throw SocketTimeoutException or
* SocketException if ret equals
* WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED
* WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or
* WOLFJNI_IO_EVENT_FAIL */
size = ret;
break;
}
}

} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);

return size;
}

JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
{
int size = 0;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
(void)jcl;

@@ -1027,79 +1115,178 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
return SSL_FAILURE;
}

/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
size = SSLReadNonblockingWithSelectPoll(ssl, data + offset,
(int)length, (int)timeout);

if (size < 0) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
JNI_ABORT);
}
else {
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0);
}
}

jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
return size;
}

JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint length, jint timeout)
{
int size = 0;
int maxOutputSz;
int outSz = length;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;

jclass excClass;
jclass buffClass;
jmethodID positionMeth;
jmethodID limitMeth;
jmethodID hasArrayMeth;
jmethodID arrayMeth;
jmethodID setPositionMeth;

jint position;
jint limit;
jboolean hasArray;
jbyteArray bufArr;

(void)jcl;

if (jenv == NULL || ssl == NULL || buf == NULL) {
return BAD_FUNC_ARG;
}

if (length > 0) {
/* Get WolfSSLException class */
excClass = (*jenv)->FindClass(jenv, "com/wolfssl/WolfSSLException");
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
return -1;
}

do {
/* lock mutex around session I/O before read attempt */
if (wc_LockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}
/* Get ByteBuffer class */
buffClass = (*jenv)->GetObjectClass(jenv, buf);
if (buffClass == NULL) {
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer class in native read()");
return -1;
}

size = wolfSSL_read(ssl, data + offset, length);
err = wolfSSL_get_error(ssl, size);
/* Get ByteBuffer position */
positionMeth = (*jenv)->GetMethodID(jenv, buffClass, "position", "()I");
if (positionMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer position() method in native read()");
return -1;
}
position = (*jenv)->CallIntMethod(jenv, buf, positionMeth);

/* unlock mutex around session I/O after read attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
/* Get ByteBuffer limit */
limitMeth = (*jenv)->GetMethodID(jenv, buffClass, "limit", "()I");
if (limitMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer limit() method in native read()");
return -1;
}
limit = (*jenv)->CallIntMethod(jenv, buf, limitMeth);

if (size < 0 && ((err == SSL_ERROR_WANT_READ) || \
(err == SSL_ERROR_WANT_WRITE))) {
/* Get and call ByteBuffer.hasArray() before calling array() */
hasArrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "hasArray", "()Z");
if (hasArrayMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer hasArray() method in native read()");
return -1;
}

sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}
/* ByteBuffer.hasArray() does not throw any exceptions */
hasArray = (*jenv)->CallBooleanMethod(jenv, buf, hasArrayMeth);
if (!hasArray) {
(*jenv)->ThrowNew(jenv, excClass,
"ByteBuffer.hasArray() is false in native read()");
return BAD_FUNC_ARG;
}

if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}
/* Only read up to maximum space we have in this ByteBuffer */
maxOutputSz = (limit - position);
if (outSz > maxOutputSz) {
outSz = maxOutputSz;
}

#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(sockfd, (int)timeout, pollRx);
#else
ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_read() again */
continue;
} else {
/* Java will throw SocketTimeoutException or
* SocketException if ret equals
* WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED
* WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or
* WOLFJNI_IO_EVENT_FAIL */
size = ret;
break;
}
/* Get reference to underlying byte[] from ByteBuffer */
arrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "array", "()[B");
if (arrayMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer array() method in native read()");
return -1;
}
bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, arrayMeth);

} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
/* Get array elements */
data = (byte*)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
(*jenv)->ThrowNew(jenv, excClass,
"Exception when calling ByteBuffer.array() in native read()");
return -1;
}


if (data != NULL) {
size = SSLReadNonblockingWithSelectPoll(ssl, data + position,
maxOutputSz, (int)timeout);

/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0);
/* Relase array elements */
if (size < 0) {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr, (jbyte*)data,
JNI_ABORT);
}
else {
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
(jbyte*)data, 0);

/* Update ByteBuffer position() based on bytes written */
setPositionMeth = (*jenv)->GetMethodID(jenv, buffClass,
"position", "(I)Ljava/nio/Buffer;");
if (setPositionMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to set ByteBuffer position() from "
"native read()");
size = -1;
}
else {
(*jenv)->CallVoidMethod(jenv, buf, setPositionMeth,
position + size);
}
}
}
}

return size;
10 changes: 9 additions & 1 deletion native/com_wolfssl_WolfSSLSession.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 83 additions & 0 deletions src/java/com/wolfssl/WolfSSLSession.java
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.lang.StringBuilder;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import com.wolfssl.WolfSSLException;
@@ -253,6 +254,8 @@ private native int write(long ssl, byte[] data, int offset, int length,
int timeout);
private native int read(long ssl, byte[] data, int offset, int sz,
int timeout);
private native int read(long ssl, ByteBuffer data, int sz, int timeout)
throws WolfSSLException;
private native int accept(long ssl, int timeout);
private native void freeSSL(long ssl);
private native int shutdownSSL(long ssl, int timeout);
@@ -1112,6 +1115,86 @@ public int read(byte[] data, int offset, int sz, int timeout)
return ret;
}

/**
* Reads bytes from the SSL session and returns the read bytes into
* the provided ByteBuffer, using socket timeout value in milliseconds.
*
* The bytes read are removed from the internal receive buffer.
* <p>
* If necessary, <code>read()</code> will negotiate an SSL/TLS session
* if the handshake has not already been performed yet by <code>connect()
* </code> or <code>accept()</code>.
* <p>
* The SSL/TLS protocol uses SSL records which have a maximum size of
* 16kB. As such, wolfSSL needs to read an entire SSL record internally
* before it is able to process and decrypt the record. Because of this,
* a call to <code>read()</code> will only be able to return the
* maximum buffer size which has been decrypted at the time of calling.
* There may be additional not-yet-decrypted data waiting in the internal
* wolfSSL receive buffer which will be retrieved and decrypted with the
* next call to <code>read()</code>.
*
* @param data ByteBuffer where the data read from the SSL connection
* will be placed. position() will be updated after this
* method writes data to the ByteBuffer.
* @param sz number of bytes to read into <b><code>data</code></b>,
* may be adjusted to the maximum space in data if that is
* smaller than this size.
* @param timeout read timeout, milliseconds.
* @return the number of bytes read upon success. <code>SSL_FAILURE
* </code> will be returned upon failure which may be caused
* by either a clean (close notify alert) shutdown or just
* that the peer closed the connection. <code>
* SSL_FATAL_ERROR</code> upon failure when either an error
* occurred or, when using non-blocking sockets, the
* <b>SSL_ERROR_WANT_READ</b> or <b>SSL_ERROR_WANT_WRITE</b>
* error was received and the application needs to call
* <code>read()</code> again. Use <code>getError</code> to
* get a specific error code.
* <code>BAD_FUNC_ARC</code> when bad arguments are used.
* @throws IllegalStateException WolfSSLContext has been freed
* @throws SocketTimeoutException if socket timeout occurs
* @throws SocketException Native socket select/poll() failed
*/
public int read(ByteBuffer data, int sz, int timeout)
throws IllegalStateException, SocketTimeoutException, SocketException {

int ret;
long localPtr;

confirmObjectIsActive();

/* Fix for Infer scan, since not synchronizing on sslLock for
* access to this.sslPtr, see note below */
synchronized (sslLock) {
localPtr = this.sslPtr;
}

WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, "entered read(ByteBuffer, " +
"sz: " + sz + ", timeout: " + timeout + ")");

/* not synchronizing on sslLock here since JNI read() locks
* session mutex around native wolfSSL_read() call. If sslLock
* is locked here, since we call select() inside native JNI we
* could timeout waiting for corresponding write() operation to
* occur if needed */
try {
ret = read(localPtr, data, sz, timeout);
} catch (WolfSSLException e) {
/* JNI code may throw WolfSSLException on JNI specific errors */
throw new SocketException(e.getMessage());
}

WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, "read() ret: " + ret +
", err: " + getError(ret));

throwExceptionFromIOReturnValue(ret, "wolfSSL_read()");

return ret;
}

/**
* Waits for an SSL client to initiate the SSL/TLS handshake.
* This method is called on the server side. When it is called, the
57 changes: 35 additions & 22 deletions src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java
Original file line number Diff line number Diff line change
@@ -818,15 +818,23 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
int ret = 0;
int idx = 0; /* index into out[] array */
int err = 0;
byte[] tmp;
byte[] tmp = null;

/* create read buffer of max output size */
/* Calculate maximum output size across ByteBuffer arrays */
maxOutSz = getTotalOutputSize(out, ofst, length);
tmp = new byte[maxOutSz];

synchronized (ioLock) {
try {
ret = this.ssl.read(tmp, maxOutSz);
/* If we only have one ByteBuffer, skip allocating
* separate intermediate byte[] and write directly to underlying
* ByteBuffer array */
if (out.length == 1) {
ret = this.ssl.read(out[0], maxOutSz, 0);
}
else {
tmp = new byte[maxOutSz];
ret = this.ssl.read(tmp, maxOutSz);
}
} catch (SocketTimeoutException | SocketException e) {
throw new SSLException(e);
}
@@ -883,27 +891,32 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
}
}
else {
/* write processed data into output buffers */
for (i = 0; i < ret;) {
if (idx + ofst >= length) {
/* no more output buffers left */
break;
}
if (out.length == 1) {
totalRead = ret;
}
else {
/* write processed data into output buffers */
for (i = 0; i < ret;) {
if (idx + ofst >= length) {
/* no more output buffers left */
break;
}

bufSpace = out[idx + ofst].remaining();
if (bufSpace == 0) {
/* no more space in current out buffer, advance */
idx++;
continue;
}
bufSpace = out[idx + ofst].remaining();
if (bufSpace == 0) {
/* no more space in current out buffer, advance */
idx++;
continue;
}

sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
out[idx + ofst].put(tmp, i, sz);
i += sz;
totalRead += sz;
sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
out[idx + ofst].put(tmp, i, sz);
i += sz;
totalRead += sz;

if ((ret - i) > 0) {
idx++; /* go to next output buffer */
if ((ret - i) > 0) {
idx++; /* go to next output buffer */
}
}
}
}