Skip to content

Commit

Permalink
refactor(account): Login by code is now a stage
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Aug 20, 2024
1 parent 91b4771 commit f99d2c6
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 50 deletions.
1 change: 1 addition & 0 deletions allauth/account/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def generate_emailconfirmation_key(self, email):

def get_login_stages(self):
ret = []
ret.append("allauth.account.stages.LoginByCodeStage")
ret.append("allauth.account.stages.EmailVerificationStage")
if allauth_app_settings.MFA_ENABLED:
ret.append("allauth.mfa.stages.AuthenticateStage")
Expand Down
44 changes: 27 additions & 17 deletions allauth/account/internal/flows/login_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from allauth.account.models import Login


LOGIN_CODE_SESSION_KEY = "account_login_code"
LOGIN_CODE_STATE_KEY = "login_code"


def request_login_code(request: HttpRequest, email: str) -> None:
from allauth.account.utils import filter_users_by_email
from allauth.account.utils import filter_users_by_email, stash_login

adapter = get_adapter()
users = filter_users_by_email(email, is_active=True, prefer_verified=True)
Expand All @@ -28,6 +28,7 @@ def request_login_code(request: HttpRequest, email: str) -> None:
"failed_attempts": 0,
}
if not users:
user = None
send_unknown_account_mail(request, email)
else:
user = users[0]
Expand All @@ -40,27 +41,29 @@ def request_login_code(request: HttpRequest, email: str) -> None:
pending_login.update(
{"code": code, "user_id": user._meta.pk.value_to_string(user)}
)

request.session[LOGIN_CODE_SESSION_KEY] = pending_login
login = Login(user=user, email=email)
login.state[LOGIN_CODE_STATE_KEY] = pending_login
login.state["stages"] = {"current": "login_by_code"}
adapter.add_message(
request,
messages.SUCCESS,
"account/messages/login_code_sent.txt",
{"email": email},
)
stash_login(request, login)


def get_pending_login(
request: HttpRequest, peek: bool = False
login: Login, peek: bool = False
) -> Tuple[Optional[AbstractBaseUser], Optional[Dict[str, Any]]]:
if peek:
data = request.session.get(LOGIN_CODE_SESSION_KEY)
data = login.state.get(LOGIN_CODE_STATE_KEY)
else:
data = request.session.pop(LOGIN_CODE_SESSION_KEY, None)
data = login.state.pop(LOGIN_CODE_STATE_KEY, None)
if not data:
return None, None
if time.time() - data["at"] >= app_settings.LOGIN_BY_CODE_TIMEOUT:
request.session.pop(LOGIN_CODE_SESSION_KEY, None)
login.state.pop(LOGIN_CODE_STATE_KEY, None)
return None, None
user_id_str = data.get("user_id")
user = None
Expand All @@ -70,30 +73,37 @@ def get_pending_login(
return user, data


def record_invalid_attempt(request: HttpRequest, pending_login: Dict[str, Any]) -> bool:
def record_invalid_attempt(request, login: Login) -> bool:
from allauth.account.utils import stash_login, unstash_login

pending_login = login.state[LOGIN_CODE_STATE_KEY]
n = pending_login["failed_attempts"]
n += 1
pending_login["failed_attempts"] = n
if n >= app_settings.LOGIN_BY_CODE_MAX_ATTEMPTS:
request.session.pop(LOGIN_CODE_SESSION_KEY, None)
unstash_login(request)
return False
else:
request.session[LOGIN_CODE_SESSION_KEY] = pending_login
login.state[LOGIN_CODE_STATE_KEY] = pending_login
stash_login(request, login)
return True


def perform_login_by_code(
request: HttpRequest,
user: AbstractBaseUser,
stage,
redirect_url: Optional[str],
pending_login: Dict[str, Any],
):
request.session.pop(LOGIN_CODE_SESSION_KEY, None)
record_authentication(request, method="code", email=pending_login["email"])
state = stage.login.state.pop(LOGIN_CODE_STATE_KEY)
email = state["email"]
record_authentication(request, method="code", email=email)
# Just requesting a login code does is not considered to be a real login,
# yet, is needed in order to make the stage machinery work. Now that we've
# completed the code, let's start a real login.
login = Login(
user=user,
user=stage.login.user,
redirect_url=redirect_url,
email=pending_login["email"],
email=email,
)
return perform_login(request, login)

Expand Down
19 changes: 11 additions & 8 deletions allauth/account/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractBaseUser
from django.core import signing
from django.db import models
from django.db.models import Q
Expand Down Expand Up @@ -230,6 +231,9 @@ class Login:
case email verification is optional and we are only logging in).
"""

# Optional, because we might be prentending logins to prevent user
# enumeration.
user: Optional[AbstractBaseUser]
email_verification: app_settings.EmailVerificationMethod
signal_kwargs: Optional[Dict]
signup: bool
Expand Down Expand Up @@ -271,7 +275,7 @@ def serialize(self):
signal_kwargs["sociallogin"] = sociallogin.serialize()

data = {
"user_pk": user_pk_to_url_str(self.user),
"user_pk": user_pk_to_url_str(self.user) if self.user else None,
"email_verification": self.email_verification,
"signup": self.signup,
"redirect_url": self.redirect_url,
Expand All @@ -286,13 +290,12 @@ def serialize(self):
def deserialize(cls, data):
from allauth.account.utils import url_str_to_user_pk

user = (
get_user_model()
.objects.filter(pk=url_str_to_user_pk(data["user_pk"]))
.first()
)
if user is None:
raise ValueError()
user = None
user_pk = data["user_pk"]
if user_pk is not None:
user = (
get_user_model().objects.filter(pk=url_str_to_user_pk(user_pk)).first()
)
try:
# :-( Knowledge of the `socialaccount` is entering the `account` app.
signal_kwargs = data["signal_kwargs"]
Expand Down
14 changes: 14 additions & 0 deletions allauth/account/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,17 @@ def handle(self):
self.request, login.user
)
return response, cont


class LoginByCodeStage(LoginStage):
key = "login_by_code"

def handle(self):
from allauth.account.internal.flows import login_by_code

user, data = login_by_code.get_pending_login(self.login, peek=True)
if data is None:
# No pending login, just continue.
return None, True
response = HttpResponseRedirect(reverse("account_confirm_login_code"))
return response, True
12 changes: 6 additions & 6 deletions allauth/account/tests/test_login_by_code.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from unittest.mock import ANY

from django.contrib.auth import SESSION_KEY
from django.urls import reverse

import pytest

from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
from allauth.account.internal.flows.login_by_code import LOGIN_CODE_SESSION_KEY
from allauth.account.internal.flows.login import LOGIN_SESSION_KEY
from allauth.account.internal.flows.login_by_code import LOGIN_CODE_STATE_KEY


@pytest.fixture
Expand All @@ -23,7 +23,7 @@ def f(client, email):
resp["location"] == reverse("account_confirm_login_code") + "?next=%2Ffoo"
)
assert len(mailoutbox) == 1
code = client.session[LOGIN_CODE_SESSION_KEY]["code"]
code = client.session[LOGIN_SESSION_KEY]["state"][LOGIN_CODE_STATE_KEY]["code"]
assert len(code) == 6
assert code in mailoutbox[0].body
return code
Expand All @@ -39,7 +39,7 @@ def test_login_by_code(client, user, request_login_by_code):
data={"code": code_with_ws, "next": "/foo"},
)
assert resp.status_code == 302
assert client.session[SESSION_KEY] == str(user.pk)
assert LOGIN_SESSION_KEY not in client.session
assert resp["location"] == "/foo"
assert client.session[AUTHENTICATION_METHODS_SESSION_KEY][-1] == {
"method": "code",
Expand All @@ -58,10 +58,10 @@ def test_login_by_code_max_attempts(client, user, request_login_by_code, setting
if i >= 1:
assert resp.status_code == 302
assert resp["location"] == reverse("account_request_login_code")
assert LOGIN_CODE_SESSION_KEY not in client.session
assert LOGIN_SESSION_KEY not in client.session
else:
assert resp.status_code == 200
assert LOGIN_CODE_SESSION_KEY in client.session
assert LOGIN_SESSION_KEY in client.session
assert resp.context["form"].errors == {"code": ["Incorrect code."]}


Expand Down
15 changes: 11 additions & 4 deletions allauth/account/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
EmailConfirmation,
get_emailconfirmation_model,
)
from allauth.account.stages import EmailVerificationStage, LoginStageController
from allauth.account.stages import (
EmailVerificationStage,
LoginByCodeStage,
LoginStageController,
)
from allauth.account.utils import (
complete_signup,
perform_login,
Expand Down Expand Up @@ -957,8 +961,11 @@ class ConfirmLoginCodeView(RedirectAuthenticatedUserMixin, NextRedirectMixin, Fo

@method_decorator(never_cache)
def dispatch(self, request, *args, **kwargs):
self.stage = LoginStageController.enter(request, LoginByCodeStage.key)
if not self.stage:
return HttpResponseRedirect(reverse("account_request_login_code"))
self.user, self.pending_login = flows.login_by_code.get_pending_login(
request, peek=True
self.stage.login, peek=True
)
if not self.pending_login:
return HttpResponseRedirect(reverse("account_request_login_code"))
Expand All @@ -975,12 +982,12 @@ def get_form_kwargs(self):
def form_valid(self, form):
redirect_url = self.get_next_url()
return flows.login_by_code.perform_login_by_code(
self.request, self.user, redirect_url, self.pending_login
self.request, self.stage, redirect_url
)

def form_invalid(self, form):
attempts_left = flows.login_by_code.record_invalid_attempt(
self.request, self.pending_login
self.request, self.stage.login
)
if attempts_left:
return super().form_invalid(form)
Expand Down
31 changes: 31 additions & 0 deletions allauth/headless/account/tests/test_login_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,34 @@ def test_login_by_code_rate_limit(
"param": "email",
},
]


def test_login_by_code_max_attemps(headless_reverse, user, client, settings):
settings.ACCOUNT_LOGIN_BY_CODE_MAX_ATTEMPTS = 2
resp = client.post(
headless_reverse("headless:account:request_login_code"),
data={"email": user.email},
content_type="application/json",
)
assert resp.status_code == 401
for i in range(3):
resp = client.post(
headless_reverse("headless:account:confirm_login_code"),
data={"code": "wrong"},
content_type="application/json",
)
session_resp = client.get(
headless_reverse("headless:account:current_session"),
data={"code": "wrong"},
content_type="application/json",
)
assert session_resp.status_code == 401
pending_flows = [
f for f in session_resp.json()["data"]["flows"] if f.get("is_pending")
]
if i >= 1:
assert resp.status_code == 409 if i >= 2 else 400
assert len(pending_flows) == 0
else:
assert resp.status_code == 400
assert len(pending_flows) == 1
15 changes: 11 additions & 4 deletions allauth/headless/account/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ForbiddenResponse,
)
from allauth.headless.base.views import APIView, AuthenticatedAPIView
from allauth.headless.internal import authkit
from allauth.headless.internal.restkit.response import ErrorResponse


Expand All @@ -50,15 +51,17 @@ class ConfirmLoginCodeView(APIView):
input_class = ConfirmLoginCodeInput

def dispatch(self, request, *args, **kwargs):
auth_status = authkit.AuthenticationStatus(request)
self.stage = auth_status.get_pending_stage()
if not self.stage:
return ConflictResponse(request)
self.user, self.pending_login = flows.login_by_code.get_pending_login(
request, peek=True
self.stage.login, peek=True
)
return super().dispatch(request, *args, **kwargs)

def post(self, request, *args, **kwargs):
flows.login_by_code.perform_login_by_code(
self.request, self.user, None, self.pending_login
)
flows.login_by_code.perform_login_by_code(self.request, self.stage, None)
return AuthenticationResponse(request)

def get_input_kwargs(self):
Expand All @@ -68,6 +71,10 @@ def get_input_kwargs(self):
)
return kwargs

def handle_invalid_input(self, input):
flows.login_by_code.record_invalid_attempt(self.request, self.stage.login)
return super().handle_invalid_input(input)


@method_decorator(rate_limit(action="login"), name="handle")
class LoginView(APIView):
Expand Down
15 changes: 9 additions & 6 deletions allauth/headless/base/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ def _get_flows(self, request, user):
if not allauth_settings.SOCIALACCOUNT_ONLY:
ret.append({"id": Flow.LOGIN})
if account_settings.LOGIN_BY_CODE_ENABLED:
code_flow = {"id": Flow.LOGIN_BY_CODE}
_, data = flows.login_by_code.get_pending_login(request, peek=True)
if data:
code_flow["is_pending"] = True
ret.append(code_flow)
ret.append({"id": Flow.LOGIN_BY_CODE})
if (
get_account_adapter().is_open_for_signup(request)
and not allauth_settings.SOCIALACCOUNT_ONLY
Expand All @@ -72,9 +68,16 @@ def _get_flows(self, request, user):
pending_flow = {"id": stage_key, "is_pending": True}
if stage and stage_key == Flow.MFA_AUTHENTICATE:
self._enrich_mfa_flow(stage, pending_flow)
ret.append(pending_flow)
self._upsert_pending_flow(ret, pending_flow)
return ret

def _upsert_pending_flow(self, flows, pending_flow):
flow = next((flow for flow in flows if flow["id"] == pending_flow["id"]), None)
if flow:
flow.update(pending_flow)
else:
flows.append(pending_flow)

def _enrich_mfa_flow(self, stage, flow: dict) -> None:
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
from allauth.mfa.models import Authenticator
Expand Down
4 changes: 2 additions & 2 deletions allauth/headless/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum

from allauth.account.stages import EmailVerificationStage
from allauth.account.stages import EmailVerificationStage, LoginByCodeStage


class Client(str, Enum):
Expand All @@ -11,7 +11,7 @@ class Client(str, Enum):
class Flow(str, Enum):
VERIFY_EMAIL = EmailVerificationStage.key
LOGIN = "login"
LOGIN_BY_CODE = "login_by_code"
LOGIN_BY_CODE = LoginByCodeStage.key
SIGNUP = "signup"
PROVIDER_REDIRECT = "provider_redirect"
PROVIDER_SIGNUP = "provider_signup"
Expand Down
Loading

0 comments on commit f99d2c6

Please sign in to comment.