Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flags] Use option instead of long cast statements #919

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 39 additions & 91 deletions proxy/common/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Optional, List, Any, cast

from .plugins import Plugins
from .types import IpAddress
from .utils import bytes_, is_py2, is_threadless, set_open_file_limit
from .constants import COMMA, DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_NUM_ACCEPTORS, DEFAULT_NUM_WORKERS
from .constants import DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE
Expand Down Expand Up @@ -109,12 +108,27 @@ def initialize(
print(__version__)
sys.exit(0)

# https://github.com/python/mypy/issues/5865
def option(t: object, key: str, default: Optional[Any] = None) -> Any:
return cast(
t, # type: ignore
opts.get(
key,
default or getattr(args, key),
),
)

# Command line arguments MUST always take preference
# over kwargs passed to the program constructor.
# for f in args.__dict__.keys():
# print(f)
# print(option(Any, f))

# proxy.py currently cannot serve over HTTPS and also perform TLS interception
# at the same time. Check if user is trying to enable both feature
# at the same time.
#
# TODO: Use parser.add_mutually_exclusive_group()
# and remove this logic from here.
# TODO: Use parser.add_mutually_exclusive_group() and remove this logic from here.
if (args.cert_file and args.key_file) and \
(args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file):
print(
Expand Down Expand Up @@ -157,27 +171,9 @@ def initialize(

# --enable flags must be parsed before loading plugins
# otherwise we will miss the plugins passed via constructor
args.enable_web_server = cast(
bool,
opts.get(
'enable_web_server',
args.enable_web_server,
),
)
args.enable_static_server = cast(
bool,
opts.get(
'enable_static_server',
args.enable_static_server,
),
)
args.enable_events = cast(
bool,
opts.get(
'enable_events',
args.enable_events,
),
)
args.enable_web_server = option(bool, 'enable_web_server')
args.enable_static_server = option(bool, 'enable_static_server')
args.enable_events = option(bool, 'enable_events')

# Load default plugins along with user provided --plugins
default_plugins = [
Expand All @@ -191,10 +187,6 @@ def initialize(
default_plugins + auth_plugins + requested_plugins,
)

# https://github.com/python/mypy/issues/5865
#
# def option(t: object, key: str, default: Any) -> Any:
# return cast(t, opts.get(key, default))
args.work_klass = work_klass
args.plugins = plugins
args.auth_code = cast(
Expand All @@ -204,20 +196,8 @@ def initialize(
auth_code,
),
)
args.server_recvbuf_size = cast(
int,
opts.get(
'server_recvbuf_size',
args.server_recvbuf_size,
),
)
args.client_recvbuf_size = cast(
int,
opts.get(
'client_recvbuf_size',
args.client_recvbuf_size,
),
)
args.server_recvbuf_size = option(int, 'server_recvbuf_size')
args.client_recvbuf_size = option(int, 'client_recvbuf_size')
args.pac_file = cast(
Optional[str], opts.get(
'pac_file', bytes_(
Expand All @@ -241,44 +221,18 @@ def initialize(
],
),
)
args.disable_headers = disabled_headers if disabled_headers is not None else DEFAULT_DISABLE_HEADERS
args.certfile = cast(
Optional[str], opts.get(
'cert_file', args.cert_file,
),
)
args.keyfile = cast(Optional[str], opts.get('key_file', args.key_file))
args.ca_key_file = cast(
Optional[str], opts.get(
'ca_key_file', args.ca_key_file,
),
)
args.ca_cert_file = cast(
Optional[str], opts.get(
'ca_cert_file', args.ca_cert_file,
),
)
args.ca_signing_key_file = cast(
Optional[str],
opts.get(
'ca_signing_key_file',
args.ca_signing_key_file,
),
)
args.ca_file = cast(
Optional[str],
opts.get(
'ca_file',
args.ca_file,
),
)
args.hostname = cast(
IpAddress,
opts.get('hostname', ipaddress.ip_address(args.hostname)),
)
args.unix_socket_path = opts.get(
'unix_socket_path', args.unix_socket_path,
)
args.disable_headers = disabled_headers \
if disabled_headers is not None \
else DEFAULT_DISABLE_HEADERS
args.certfile = option(Optional[str], 'cert_file')
args.keyfile = option(Optional[str], 'key_file')
args.ca_key_file = option(Optional[str], 'ca_key_file')
args.ca_cert_file = option(Optional[str], 'ca_cert_file')
args.ca_signing_key_file = option(Optional[str], 'ca_signing_key_file')
args.ca_file = option(Optional[str], 'ca_file')
args.hostname = option(str, 'hostname')
args.hostname = ipaddress.ip_address(args.hostname)
args.unix_socket_path = option(str, 'unix_socket_path')
# AF_UNIX is not available on Windows
# See https://bugs.python.org/issue33408
if not IS_WINDOWS:
Expand All @@ -294,13 +248,13 @@ def initialize(
#
# assert args.unix_socket_path is None
args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
args.port = cast(int, opts.get('port', args.port))
args.backlog = cast(int, opts.get('backlog', args.backlog))
num_workers = opts.get('num_workers', args.num_workers)
args.port = option(int, 'port')
args.backlog = option(int, 'backlog')
num_workers = option(int, 'num_workers')
args.num_workers = cast(
int, num_workers if num_workers > 0 else multiprocessing.cpu_count(),
)
num_acceptors = opts.get('num_acceptors', args.num_acceptors)
num_acceptors = option(int, 'num_acceptors')
# See https://github.com/abhinavsingh/proxy.py/pull/714 description
# to understand rationale behind the following logic.
#
Expand All @@ -314,13 +268,7 @@ def initialize(
int, num_acceptors if num_acceptors > 0 else multiprocessing.cpu_count(),
)

args.static_server_dir = cast(
str,
opts.get(
'static_server_dir',
args.static_server_dir,
),
)
args.static_server_dir = option(str, 'static_server_dir')
args.min_compression_limit = cast(
bool,
opts.get(
Expand Down