Skip to content

Commit

Permalink
Add SNI parser
Browse files Browse the repository at this point in the history
  • Loading branch information
pvizeli committed Jan 28, 2019
1 parent 357a6b4 commit 1fcef02
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ venv.bak/

# mypy
.mypy_cache/

# Editors
.vscode/
17 changes: 17 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[isort]
multi_line_output = 4
indent = " "
not_skip = __init__.py
force_sort_within_sections = true
sections = FUTURE,STDLIB,INBETWEENS,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
default_section = THIRDPARTY
forced_separate = tests
combine_as_imports = true
use_parentheses = true

[yapf]
based_on_style = chromium
indent_width = 4

[flake8]
max-line-length = 80
Empty file added snitun/__init__.py
Empty file.
Empty file added snitun/client/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions snitun/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""SniTun Exceptions."""

class SniTunError(Exception):
"""Base Exception for SniTun exceptions."""


class ParseSNIError(SniTunError):
"""Invalid ClientHello data."""
Empty file added snitun/server/__init__.py
Empty file.
121 changes: 121 additions & 0 deletions snitun/server/sni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""TLS ClientHello parser."""
import logging

from ..exceptions import ParseSNIError

_LOGGER = logging.getLogger(__name__)

TLS_HEADER_LEN = 5
TLS_HANDSHAKE_CONTENT_TYPE = bytes(0x16)
TLS_HANDSHAKE_TYPE_CLIENT_HELLO = bytes(0x01)


def parse_tls_sni(data: bytes) -> str:
"""Parse TLS SNI extention."""

if len(data) < TLS_HEADER_LEN:
_LOGGER.debug("Invalid TLS header")
raise ParseSNIError()

# If TLS handshake
if data[0] != TLS_HANDSHAKE_CONTENT_TYPE:
_LOGGER.debug("Not TLS handshake received")
raise ParseSNIError()

# Check compatible ClientHello
if int(data[1]) < 3:
_LOGGER.debug("Received ClientHello without SNI support")
raise ParseSNIError()

# Calculate TLS record size
tls_size = int(data[3] << 8) + int(data[4]) + TLS_HEADER_LEN
if len(data) < tls_size:
_LOGGER.debug("Can't calculate the TLS record size")
raise ParseSNIError()

# Check if handshake is a ClientHello
pos = TLS_HEADER_LEN + 1
if data[pos] != TLS_HANDSHAKE_TYPE_CLIENT_HELLO:
_LOGGER.debug("Invalid ClientHello type")
raise ParseSNIError()

# Seek fixed length header part
pos += 38

# Seek SessionID
try:
pos += 1 + int(data[pos])
except IndexError:
_LOGGER.debug("Invalid SessionID")
raise ParseSNIError() from None

# Seek Cipher Suites
try:
pos += 2 + int(data[pos] << 8) + int(data[pos + 1])
except IndexError:
_LOGGER.debug("Invalid CipherSuites")
raise ParseSNIError() from None

# Seek Compression Methods
try:
pos += 1 + int(data[pos])
except IndexError:
_LOGGER.debug("Invalid CompressionMethods")
raise ParseSNIError() from None

# Check data buffer + extension size
if pos + 2 > len(data):
_LOGGER.debug("Mismatch Extension TLS header")
raise ParseSNIError()

# Process extension
return _parse_extension(data, pos)


def _parse_extension(data: bytes, pos: int) -> str:
"""Parse TLS ClientHello Extension."""

# Seek Extension start
try:
tls_extension_size = int(data[pos] << 8) + int(data[pos + 1])
pos += 2
except IndexError:
raise ParseSNIError() from None

# Check data buffer + extension size
if pos + tls_extension_size > len(data):
_LOGGER.debug("Mismatch Extension TLS header")
raise ParseSNIError()

# Loop over extension until we have our SNI
while pos + 4 <= len(data):

# SNI?
if data[pos] == 0x00 and data[pos + 1] == 0x00:
return _parse_host_name(data, pos + 4)

pos = 4 + int(data[pos + 2] << 8) + int(data[pos + 3])

_LOGGER.debug("Can't find any ServerName Extension")
raise ParseSNIError()


def _parse_host_name(data: bytes, pos: int) -> str:
"""Parse TLS ServerName Extension."""

# Skeep list size
pos += 2

while pos + 3 < len(data):
size = int(data[pos + 1] << 8) + data[pos + 2]

# Unknown server name type
if data[pos] != 0x00:
_LOGGER.debug("Unknown ServerName type")
pos += 3 + size
continue

return str(data[pos + 3:size])

_LOGGER.debug("Not found any valid ServerName")
raise ParseSNIError()

0 comments on commit 1fcef02

Please sign in to comment.