Skip to content

Commit

Permalink
[Security] add local auth for simple websocket servers
Browse files Browse the repository at this point in the history
  • Loading branch information
junhaoliao committed May 24, 2022
1 parent 681a751 commit 389f5c2
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 15 deletions.
2 changes: 2 additions & 0 deletions application/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
APP_PORT = 5000
LOCAL_AUTH_KEY = ''

os.environ['LOCAL_AUTH_KEY'] = LOCAL_AUTH_KEY

APP_HOST = '127.0.0.1'

profiles: Profile
Expand Down
10 changes: 9 additions & 1 deletion application/features/Audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from .Connection import Connection
from .. import app
from ..utils import find_free_port
from ..utils import find_free_port, get_headers_dict_from_str, local_auth

AUDIO_CONNECTIONS = {}

Expand Down Expand Up @@ -77,6 +77,12 @@ def __init__(self, *args, **kwargs):
self.module_id = None

def handleConnected(self):
headers = self.headerbuffer.decode('utf-8')
headers = get_headers_dict_from_str(headers)
if not local_auth(headers=headers, abort_func=self.close):
# local auth failure
return

audio_id = self.request.path[1:]
if audio_id not in AUDIO_CONNECTIONS:
print(f'AudioWebSocket: Requested audio_id={audio_id} does not exist.')
Expand Down Expand Up @@ -158,6 +164,8 @@ def handleClose(self):
# if we are in debug mode, run the server in the second round
if not app.debug or os.environ.get("WERKZEUG_RUN_MAIN") == "true":
AUDIO_PORT = find_free_port()
print("AUDIO_PORT =", AUDIO_PORT)

if os.environ.get('SSL_CERT_PATH') is None:
# no certificate provided, run in non-encrypted mode
# FIXME: consider using a self-signing certificate for local connections
Expand Down
10 changes: 9 additions & 1 deletion application/features/Term.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .Connection import Connection
from .. import app
from ..utils import find_free_port
from ..utils import find_free_port, local_auth, get_headers_dict_from_str

TERM_CONNECTIONS = {}

Expand Down Expand Up @@ -79,6 +79,12 @@ def handleMessage(self):
self.term.channel.send(self.data)

def handleConnected(self):
headers = self.headerbuffer.decode('utf-8')
headers = get_headers_dict_from_str(headers)
if not local_auth(headers=headers, abort_func=self.close):
# local auth failure
return

print(self.address, 'connected')
terminal_id = self.request.path[1:]
if terminal_id not in TERM_CONNECTIONS:
Expand Down Expand Up @@ -110,6 +116,8 @@ def handleClose(self):
# if we are in debug mode, run the server in the second round
if not app.debug or os.environ.get("WERKZEUG_RUN_MAIN") == "true":
TERMINAL_PORT = find_free_port()
print("TERMINAL_PORT =", TERMINAL_PORT)

if os.environ.get('SSL_CERT_PATH') is None:
# no certificate provided, run in non-encrypted mode
# FIXME: consider using a self-signing certificate for local connections
Expand Down
29 changes: 29 additions & 0 deletions application/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,32 @@ def validate_password(password):
return False, reason
else:
return True, None


def get_headers_dict_from_str(headers_str):
headers = {}

for line in headers_str.split("\r\n"):
if line.startswith("GET") or ':' not in line:
continue
header_name, header_value = line.split(': ', 1)
headers[header_name] = header_value

return headers


def local_auth(headers, abort_func):
auth_passed = True
local_auth_key = os.getenv('LOCAL_AUTH_KEY')
if local_auth_key != '':
try:
auth_type, auth_key = headers.get('Authorization').split()
if auth_type != 'Bearer' or auth_key != local_auth_key:
auth_passed = False
except Exception as e:
auth_passed = False

if not auth_passed:
abort_func(403, "You are not authorized to access this API.")

return auth_passed
2 changes: 1 addition & 1 deletion desktop_client/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ let mainWindow = null;
const setupLocalAuth = () => {
// Modify the user agent for all requests to the following urls.
const filter = {
urls: ['http://127.0.0.1/*'],
urls: ['http://127.0.0.1/*', 'ws://127.0.0.1/*'],
};

session.defaultSession.webRequest.onBeforeSendHeaders(filter,
Expand Down
15 changes: 3 additions & 12 deletions ictrl_be.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# IN THE SOFTWARE.

if __name__ == '__main__':
from application import api, app, APP_HOST, APP_PORT, LOCAL_AUTH_KEY
from application import api, app, APP_HOST, APP_PORT
from application.utils import local_auth

if not app.debug:
import os
Expand All @@ -29,17 +30,7 @@

@app.before_request
def before_request():
if LOCAL_AUTH_KEY != '':
try:
auth_type, auth_key = request.headers.get('Authorization').split()
if auth_type != 'Bearer' or auth_key != LOCAL_AUTH_KEY:
abort(403, "You are not authorized to access this API.")
except Exception as e:
abort(403, "You are not authorized to access this API.")

print("Auth failure: is anyone hacking?")
raise e

local_auth(headers=request.headers, abort_func=abort)

# Reference: https://stackoverflow.com/questions/44209978/serving-a-front-end-created-with-create-react-app-with
# -flask
Expand Down

0 comments on commit 389f5c2

Please sign in to comment.