Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add support for PyOpenSSL in paho.mqtt.python #849

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ dependencies = []
proxy = [
"PySocks",
]
openssl = [
"pyOpenSSL"
]

[project.urls]
Homepage = "http://eclipse.org/paho"
Expand Down
255 changes: 203 additions & 52 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,53 @@
from .reasoncodes import ReasonCode, ReasonCodes
from .subscribeoptions import SubscribeOptions

try:
from OpenSSL import SSL
from OpenSSL.crypto import X509

def _subject_alt_name_string(cert: X509) -> list:
"""Extracts the subject alternative name (SAN) entries from the certificate."""
san = []
for i in range(cert.get_extension_count()):
ext = cert.get_extension(i)
if ext.get_short_name() == b'subjectAltName':
san_entries = ext.__str__().split(', ')
for entry in san_entries:
key, value = entry.split(':', 1)
san.append((key.strip(), value.strip()))
return san

def _openssl_match_hostname(cert: X509, hostname: str):
"""Verify that *cert* matches the *hostname* according to RFC 2818 and RFC 6125 rules.
CertificateError is raised on failure. On success, the function returns nothing.
"""
if not cert:
raise ValueError("Empty or no certificate. match_hostname needs a certificate.")

dnsnames = []
# Extract subject alternative name (SAN) entries
san = _subject_alt_name_string(cert)
for key, value in san:
if key == 'DNS':
if ssl._dnsname_match(value, hostname):
return
dnsnames.append(value)

if not dnsnames:
# TODO: check if no dns entry to use subject
raise ValueError("pyOpenssl match_hostname: using subject is not supported.")

if len(dnsnames) > 1:
raise ssl.CertificateError(f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}")
elif len(dnsnames) == 1:
raise ssl.CertificateError(f"Hostname {hostname} doesn't match {dnsnames[0]}")
else:
raise ssl.CertificateError("No appropriate commonName or subjectAltName fields were found")

HAS_OPENSSL = True
except ImportError:
HAS_OPENSSL = False

try:
from typing import Literal
except ImportError:
Expand Down Expand Up @@ -851,7 +898,7 @@ def __init__(
self._thread: threading.Thread | None = None
self._thread_terminate = False
self._ssl = False
self._ssl_context: ssl.SSLContext | None = None
self._ssl_context: ssl.SSLContext | SSL.Context | None = None
# Only used when SSL context does not have check_hostname attribute
self._tls_insecure = False
self._logger: logging.Logger | None = None
Expand Down Expand Up @@ -1093,32 +1140,96 @@ def logger(self, value: logging.Logger | None) -> None:
def _sock_recv(self, bufsize: int) -> bytes:
if self._sock is None:
raise ConnectionError("self._sock is None")
try:
return self._sock.recv(bufsize)
except ssl.SSLWantReadError as err:
raise BlockingIOError() from err
except ssl.SSLWantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except AttributeError as err:
self._easy_log(
MQTT_LOG_DEBUG, "socket was None: %s", err)
raise ConnectionError() from err

if HAS_OPENSSL:
from OpenSSL import SSL

if isinstance(self._ssl_context, SSL.Context):
try:
return self._sock.recv(bufsize)
except SSL.WantReadError as err:
raise BlockingIOError() from err
except SSL.WantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except SSL.WantX509LookupError as err:
raise ConnectionError() from err
except SSL.ZeroReturnError as err:
raise ConnectionError() from err
except SSL.SysCallError as err:
raise ConnectionError() from err
except AttributeError as err:
self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err)
raise ConnectionError() from err
else:
try:
return self._sock.recv(bufsize)
except ssl.SSLWantReadError as err:
raise BlockingIOError() from err
except ssl.SSLWantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except AttributeError as err:
self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err)
raise ConnectionError() from err
else:
try:
return self._sock.recv(bufsize)
except ssl.SSLWantReadError as err:
raise BlockingIOError() from err
except ssl.SSLWantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except AttributeError as err:
self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err)
raise ConnectionError() from err

def _sock_send(self, buf: bytes) -> int:
if self._sock is None:
raise ConnectionError("self._sock is None")

try:
return self._sock.send(buf)
except ssl.SSLWantReadError as err:
raise BlockingIOError() from err
except ssl.SSLWantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except BlockingIOError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
if HAS_OPENSSL:
from OpenSSL import SSL

if isinstance(self._ssl_context, SSL.Context):
try:
return self._sock.send(buf)
except SSL.WantReadError as err:
raise BlockingIOError() from err
except SSL.WantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except SSL.WantX509LookupError as err:
raise ConnectionError() from err
except SSL.ZeroReturnError as err:
raise ConnectionError() from err
except SSL.SysCallError as err:
raise ConnectionError() from err
except BlockingIOError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
else:
try:
return self._sock.send(buf)
except ssl.SSLWantReadError as err:
raise BlockingIOError() from err
except ssl.SSLWantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except BlockingIOError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
else:
try:
return self._sock.send(buf)
except ssl.SSLWantReadError as err:
raise BlockingIOError() from err
except ssl.SSLWantWriteError as err:
self._call_socket_register_write()
raise BlockingIOError() from err
except BlockingIOError as err:
self._call_socket_register_write()
raise BlockingIOError() from err

def _sock_close(self) -> None:
"""Close the connection to the server."""
Expand Down Expand Up @@ -1181,26 +1292,37 @@ def ws_set_options(

def tls_set_context(
self,
context: ssl.SSLContext | None = None,
context: ssl.SSLContext | SSL.Context | None = None,
) -> None:
"""Configure network encryption and authentication context. Enables SSL/TLS support.

:param context: an ssl.SSLContext object. By default this is given by
``ssl.create_default_context()``, if available.
:param context: an ssl.SSLContext or OpenSSL.SSL.Context object. By default, this is given by
``ssl.create_default_context()`` if available.

Must be called before `connect()`, `connect_async()` or `connect_srv()`."""
Must be called before `connect()`, `connect_async()` or `connect_srv()`.
"""
if self._ssl_context is not None:
raise ValueError('SSL/TLS has already been configured.')

if context is None:
context = ssl.create_default_context()
if HAS_OPENSSL:
raise ValueError("OpenSSL custom context is not provided.")
else:
context = ssl.create_default_context()

self._ssl = True
self._ssl_context = context

# Ensure _tls_insecure is consistent with check_hostname attribute
if hasattr(context, 'check_hostname'):
# Ensure _tls_insecure is consistent with check_hostname attribute for ssl.SSLContext
if isinstance(context, ssl.SSLContext) and hasattr(context, 'check_hostname'):
self._tls_insecure = not context.check_hostname
elif HAS_OPENSSL and isinstance(context, SSL.Context):
# PyOpenSSL Context does not have check_hostname attribute
# Set _tls_insecure based on custom logic if necessary
self._tls_insecure = False # Assuming default to False for PyOpenSSL
else:
# If OpenSSL is not available and context is an SSL.Context, raise an error
raise ValueError("OpenSSL is not available, cannot use SSL.Context.")

def tls_set(
self,
Expand Down Expand Up @@ -4639,42 +4761,71 @@ def _create_socket_connection(self) -> _socket.socket:
else:
return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source)

def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket:
def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket:
if self._ssl_context is None:
raise ValueError(
"Impossible condition. _ssl_context should never be None if _ssl is True"
)

verify_host = not self._tls_insecure
try:
# Try with server_hostname, even it's not supported in certain scenarios
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
server_hostname=self._host,
do_handshake_on_connect=False,
)
if isinstance(self._ssl_context, ssl.SSLContext):
# Use the built-in ssl.SSLContext
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
server_hostname=self._host,
do_handshake_on_connect=False,
)
elif HAS_OPENSSL:
from OpenSSL import SSL

if isinstance(self._ssl_context, SSL.Context):
# Use PyOpenSSL's SSL.Context
conn = SSL.Connection(self._ssl_context, tcp_sock)
conn.set_connect_state()
if self._host:
conn.set_tlsext_host_name(self._host.encode('utf-8'))
ssl_sock = conn
else:
raise ValueError("Unsupported SSL context type")
else:
raise ValueError("Unsupported SSL context type")
except ssl.CertificateError:
# CertificateError is derived from ValueError
raise
except ValueError:
# Python version requires SNI in order to handle server_hostname, but SNI is not available
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
do_handshake_on_connect=False,
)
else:
# If SSL context has already checked hostname, then don't need to do it again
if getattr(self._ssl_context, 'check_hostname', False): # type: ignore
verify_host = False
if isinstance(self._ssl_context, ssl.SSLContext):
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
do_handshake_on_connect=False,
)
else:
raise

ssl_sock.settimeout(self._keepalive)
ssl_sock.do_handshake()

if verify_host:
# TODO: this type error is a true error:
# error: Module has no attribute "match_hostname" [attr-defined]
# Python 3.12 no longer have this method.
ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore
# Function to handle retries for non-blocking SSL handshake
def do_handshake_with_retries(ssl_sock, retries=35, delay=0.1):
for attempt in range(retries):
try:
ssl_sock.do_handshake()
return
except SSL.WantReadError:
if attempt == retries - 1:
raise RuntimeError("Handshake failed after maximum retries")
time.sleep(delay)

if HAS_OPENSSL and isinstance(ssl_sock, SSL.Connection):
do_handshake_with_retries(ssl_sock)
if verify_host:
if getattr(self._ssl_context, 'check_hostname', False):
verify_host = False
_openssl_match_hostname(ssl_sock.get_peer_certificate(), self._host)
else:
ssl_sock.do_handshake()
if verify_host:
if getattr(self._ssl_context, 'check_hostname', False):
verify_host = False
ssl.match_hostname(ssl_sock.getpeercert(), self._host)

return ssl_sock

Expand Down