Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract ASGI scope creation into function #162

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 89 additions & 76 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, InitVar
from contextlib import ExitStack

from mangum.types import ASGIApp
from mangum.types import ASGIApp, Scope
from mangum.protocols.lifespan import LifespanCycle
from mangum.protocols.http import HTTPCycle
from mangum.exceptions import ConfigurationError
Expand Down Expand Up @@ -91,89 +91,102 @@ def __call__(self, event: dict, context: "LambdaContext") -> dict:
)
stack.enter_context(lifespan_cycle)

request_context = event["requestContext"]

if event.get("multiValueHeaders"):
headers = {
k.lower(): ", ".join(v) if isinstance(v, list) else ""
for k, v in event.get("multiValueHeaders", {}).items()
}
elif event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
else:
headers = {}

# API Gateway v2
if event.get("version") == "2.0":
source_ip = request_context["http"]["sourceIp"]
path = request_context["http"]["path"]
http_method = request_context["http"]["method"]
query_string = event.get("rawQueryString", "").encode()

if event.get("cookies"):
headers["cookie"] = "; ".join(event.get("cookies", []))

# API Gateway v1 / ELB
else:
if "elb" in request_context:
# NOTE: trust only the most right side value
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
else:
source_ip = request_context.get("identity", {}).get("sourceIp")

path = event["path"]
http_method = event["httpMethod"]

if event.get("multiValueQueryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("multiValueQueryStringParameters", {}), doseq=True
).encode()
elif event.get("queryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("queryStringParameters", {})
).encode()
else:
query_string = b""

server_name = headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
else:
server_name, server_port = server_name.split(":") # pragma: no cover
server = (server_name, int(server_port))
client = (source_ip, 0)

if not path: # pragma: no cover
path = "/"
elif self.api_gateway_base_path:
if path.startswith(self.api_gateway_base_path):
path = path[len(self.api_gateway_base_path) :]

scope = {
"type": "http",
"http_version": "1.1",
"method": http_method,
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": urllib.parse.unquote(path),
"raw_path": None,
"root_path": "",
"scheme": headers.get("x-forwarded-proto", "https"),
"query_string": query_string,
"server": server,
"client": client,
"asgi": {"version": "3.0"},
"aws.event": event,
"aws.context": context,
}

is_binary = event.get("isBase64Encoded", False)
initial_body = event.get("body") or b""
if is_binary:
initial_body = base64.b64decode(initial_body)
elif not isinstance(initial_body, bytes):
initial_body = initial_body.encode()

scope = self._create_scope(event, context)
http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types)
response = http_cycle(self.app, initial_body)

return response

def _create_scope(self, event: dict, context: "LambdaContext") -> Scope:
"""Creates a scope object according to ASGI specification from a Lambda Event.

https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope

The event comes from various sources: AWS ALB, AWS API Gateway of different
versions and configurations(multivalue header, etc).
Thus, some heuristics is applied to guess an event type.

"""
request_context = event["requestContext"]

if event.get("multiValueHeaders"):
headers = {
k.lower(): ", ".join(v) if isinstance(v, list) else ""
for k, v in event.get("multiValueHeaders", {}).items()
}
elif event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
else:
headers = {}

# API Gateway v2
if event.get("version") == "2.0":
source_ip = request_context["http"]["sourceIp"]
path = request_context["http"]["path"]
http_method = request_context["http"]["method"]
query_string = event.get("rawQueryString", "").encode()

if event.get("cookies"):
headers["cookie"] = "; ".join(event.get("cookies", []))

# API Gateway v1 / ELB
else:
if "elb" in request_context:
# NOTE: trust only the most right side value
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
else:
source_ip = request_context.get("identity", {}).get("sourceIp")

path = event["path"]
http_method = event["httpMethod"]

if event.get("multiValueQueryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("multiValueQueryStringParameters", {}), doseq=True
).encode()
elif event.get("queryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("queryStringParameters", {})
).encode()
else:
query_string = b""

server_name = headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
else:
server_name, server_port = server_name.split(":") # pragma: no cover
server = (server_name, int(server_port))
client = (source_ip, 0)

if not path: # pragma: no cover
path = "/"
elif self.api_gateway_base_path:
if path.startswith(self.api_gateway_base_path):
path = path[len(self.api_gateway_base_path) :]

scope = {
"type": "http",
"http_version": "1.1",
"method": http_method,
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": urllib.parse.unquote(path),
"raw_path": None,
"root_path": "",
"scheme": headers.get("x-forwarded-proto", "https"),
"query_string": query_string,
"server": server,
"client": client,
"asgi": {"version": "3.0"},
"aws.event": event,
"aws.context": context,
}

return scope