diff --git a/server/auth.go b/server/auth.go index e27beeb..98c0809 100644 --- a/server/auth.go +++ b/server/auth.go @@ -74,6 +74,21 @@ func (s *Server) handlePostAuthRefresh() gin.HandlerFunc { } } +func (s *Server) handlePostAuth2FA() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.Auth2FAReq + + if err := c.BindJSON(&req); err != nil { + return + } + + if err := s.b.UpgradeAuth(c.GetString("AuthUID"), req.TwoFactorCode); err != nil { + _ = c.AbortWithError(http.StatusUnauthorized, err) + return + } + } +} + func (s *Server) handleDeleteAuth() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.DeleteSession(c.GetString("UserID"), c.GetString("AuthUID")); err != nil { diff --git a/server/backend/account.go b/server/backend/account.go index e46f440..d22dbde 100644 --- a/server/backend/account.go +++ b/server/backend/account.go @@ -9,6 +9,11 @@ import ( "github.com/google/uuid" ) +// TODO: Add other options e.g. real 2FA time-based OTP. +type totp struct { + want *string +} + type account struct { userID string username string @@ -17,6 +22,7 @@ type account struct { userSettings proton.UserSettings contacts map[string]*proton.Contact + totp totp auth map[string]auth authLock sync.RWMutex diff --git a/server/backend/api_auth.go b/server/backend/api_auth.go index a8f11d7..10d0387 100644 --- a/server/backend/api_auth.go +++ b/server/backend/api_auth.go @@ -55,7 +55,15 @@ func (b *Backend) NewAuth(username string, ephemeral, proof []byte, session stri return proton.Auth{}, fmt.Errorf("invalid proof: %w", err) } - authUID, auth := uuid.NewString(), newAuth(b.authLife) + var scope Scope + + if acc.totp.want != nil { + scope = ScopeTOTP + } else { + scope = ScopeFull + } + + authUID, auth := uuid.NewString(), newAuth(scope) acc.authLock.Lock() defer acc.authLock.Unlock() @@ -83,7 +91,7 @@ func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) { return proton.Auth{}, fmt.Errorf("invalid auth ref") } - newAuth := newAuth(b.authLife) + newAuth := newAuth(auth.scope) acc.auth[authUID] = newAuth @@ -93,8 +101,43 @@ func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) { return proton.Auth{}, fmt.Errorf("invalid auth") } -func (b *Backend) VerifyAuth(authUID, authAcc string) (string, error) { +func (b *Backend) UpgradeAuth(authUID, totp string) error { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + acc.authLock.Lock() + defer acc.authLock.Unlock() + + auth, ok := acc.auth[authUID] + if !ok { + continue + } + + if auth.scope != ScopeTOTP { + return fmt.Errorf("invalid scope") + } else if acc.totp.want == nil { + return fmt.Errorf("2FA not enabled") + } else if *acc.totp.want != totp { + return fmt.Errorf("invalid 2FA code") + } + + auth.scope = ScopeFull + + acc.auth[authUID] = auth + + return nil + } + + return fmt.Errorf("no such auth") +} + +func (b *Backend) VerifyAuth(authUID, authAcc string, scope Scope) (string, error) { return withAccAuth(b, authUID, authAcc, func(acc *account) (string, error) { + if acc.auth[authUID].scope != scope { + return "", fmt.Errorf("invalid scope") + } + return acc.userID, nil }) } diff --git a/server/backend/backend.go b/server/backend/backend.go index 92b2532..0cd316d 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -65,6 +65,13 @@ func (b *Backend) SetAuthLife(authLife time.Duration) { b.authLife = authLife } +func (b *Backend) SetAuthTOTP(userID, totp string) error { + return b.withAcc(userID, func(acc *account) error { + acc.totp.want = &totp + return nil + }) +} + func (b *Backend) SetMaxUpdatesPerEvent(max int) { b.maxUpdatesPerEvent = max } @@ -556,7 +563,7 @@ func withAccAuth[T any](b *Backend, authUID, authAcc string, fn func(acc *accoun } if time.Since(val.creation) > b.authLife { - acc.auth[authUID] = auth{ref: val.ref, creation: val.creation} + acc.auth[authUID] = newAuthFromExpired(val) } else if val.acc == authAcc { return fn(acc) } diff --git a/server/backend/contact.go b/server/backend/contact.go index 158d7e7..9d7fad4 100644 --- a/server/backend/contact.go +++ b/server/backend/contact.go @@ -19,7 +19,7 @@ func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRin ContactMetadata: proton.ContactMetadata{ ID: contactID, Name: names[0].Value, - ContactEmails: []proton.ContactEmail{proton.ContactEmail{ + ContactEmails: []proton.ContactEmail{{ ID: "1", Name: names[0].Value, Email: emails[0].Value, diff --git a/server/backend/types.go b/server/backend/types.go index cc58fbb..ee06c32 100644 --- a/server/backend/types.go +++ b/server/backend/types.go @@ -35,22 +35,42 @@ func (v *ID) FromString(s string) error { return nil } +type Scope int + +// TODO: Add more scopes? +const ( + ScopeNone Scope = iota + ScopeTOTP + ScopeFull +) + type auth struct { acc string ref string + scope Scope + creation time.Time } -func newAuth(authLife time.Duration) auth { +func newAuth(scope Scope) auth { return auth{ - acc: uuid.NewString(), - ref: uuid.NewString(), - + acc: uuid.NewString(), + ref: uuid.NewString(), + scope: scope, creation: time.Now(), } } +func newAuthFromExpired(old auth) auth { + return auth{ + acc: "", + ref: old.ref, + scope: old.scope, + creation: old.creation, + } +} + func (auth *auth) toAuth(userID, authUID string, proof []byte) proton.Auth { return proton.Auth{ UserID: userID, diff --git a/server/router.go b/server/router.go index 89f5b60..e585d67 100644 --- a/server/router.go +++ b/server/router.go @@ -14,6 +14,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server/backend" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -38,7 +39,7 @@ func initRouter(s *Server) { } // These routes require auth. - if core := core.Group("", s.requireAuth()); core != nil { + if core := core.Group("", s.requireAuth(backend.ScopeFull)); core != nil { if users := core.Group("/users"); users != nil { users.GET("", s.handleGetUsers()) } @@ -77,7 +78,7 @@ func initRouter(s *Server) { } // All mail routes need authentication. - if mail := s.r.Group("/mail/v4", s.requireAuth()); mail != nil { + if mail := s.r.Group("/mail/v4", s.requireAuth(backend.ScopeFull)); mail != nil { if settings := mail.Group("/settings"); settings != nil { settings.GET("", s.handleGetMailSettings()) settings.PUT("/attachpublic", s.handlePutMailSettingsAttachPublicKey()) @@ -106,7 +107,7 @@ func initRouter(s *Server) { } // All contacts routes need authentication. - if contacts := s.r.Group("/contacts/v4", s.requireAuth()); contacts != nil { + if contacts := s.r.Group("/contacts/v4", s.requireAuth(backend.ScopeFull)); contacts != nil { contacts.GET("", s.handleGetContacts()) contacts.POST("", s.handlePostContacts()) contacts.GET("/:contactID", s.handleGetContact()) @@ -115,7 +116,7 @@ func initRouter(s *Server) { } // All data routes need authentication. - if data := s.r.Group("/data/v1", s.requireAuth()); data != nil { + if data := s.r.Group("/data/v1", s.requireAuth(backend.ScopeFull)); data != nil { if stats := data.Group("/stats"); stats != nil { stats.POST("", s.handlePostDataStats()) stats.POST("/multiple", s.handlePostDataStatsMultiple()) @@ -128,8 +129,13 @@ func initRouter(s *Server) { auth.POST("/info", s.handlePostAuthInfo()) auth.POST("/refresh", s.handlePostAuthRefresh()) + // These routes require auth with only TOTP scope. + if auth := auth.Group("", s.requireAuth(backend.ScopeTOTP)); auth != nil { + auth.POST("/2fa", s.handlePostAuth2FA()) + } + // These routes require auth. - if auth := auth.Group("", s.requireAuth()); auth != nil { + if auth := auth.Group("", s.requireAuth(backend.ScopeFull)); auth != nil { auth.DELETE("", s.handleDeleteAuth()) if sessions := auth.Group("/sessions"); sessions != nil { @@ -278,7 +284,7 @@ func (s *Server) handleOffline() gin.HandlerFunc { } } -func (s *Server) requireAuth() gin.HandlerFunc { +func (s *Server) requireAuth(scope backend.Scope) gin.HandlerFunc { return func(c *gin.Context) { authUID := c.Request.Header.Get("x-pm-uid") if authUID == "" { @@ -292,7 +298,7 @@ func (s *Server) requireAuth() gin.HandlerFunc { return } - userID, err := s.b.VerifyAuth(authUID, strings.Split(auth, " ")[1]) + userID, err := s.b.VerifyAuth(authUID, strings.Split(auth, " ")[1], scope) if err != nil { c.AbortWithStatus(http.StatusUnauthorized) return diff --git a/server/server.go b/server/server.go index 0c8d5cc..c183310 100644 --- a/server/server.go +++ b/server/server.go @@ -204,7 +204,6 @@ func (s *Server) AddMessageCreatedEvent(userID, messageID string) error { return s.b.AddMessageCreatedUpdate(userID, messageID) } -// SetMaxUpdatesPerEvent func (s *Server) SetMaxUpdatesPerEvent(max int) { s.b.SetMaxUpdatesPerEvent(max) } @@ -213,6 +212,10 @@ func (s *Server) SetAuthLife(authLife time.Duration) { s.b.SetAuthLife(authLife) } +func (s *Server) SetAuthTOTP(userID, totp string) error { + return s.b.SetAuthTOTP(userID, totp) +} + func (s *Server) SetMinAppVersion(minAppVersion *semver.Version) { s.minAppVersion = minAppVersion } diff --git a/server/server_test.go b/server/server_test.go index 29f053a..975eb6c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -48,6 +48,32 @@ func TestServer_LoginLogout(t *testing.T) { }) } +func TestServer_Login_2FA(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + userID, _, err := s.CreateUser("user", []byte("pass")) + require.NoError(t, err) + + // Set the expected 2FA code. + require.NoError(t, s.SetAuthTOTP(userID, "123123")) + + // Create a new client. + c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer c.Close() + + // Most requests should fail; we haven't provided the 2FA code. + must_fail(c.GetUser(ctx)) + + // Provide the 2FA code. + require.NoError(t, c.Auth2FA(ctx, proton.Auth2FAReq{ + TwoFactorCode: "123123", + })) + + // Now requests should succeed. + must(c.GetUser(ctx)) + }) +} + func TestServerMulti(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { _, _, err := s.CreateUser("user", []byte("pass")) @@ -2251,6 +2277,14 @@ func must[T any](t T, err error) T { return t } +func must_fail[T any](t T, err error) T { + if err == nil { + panic(err) + } + + return t +} + func elementsMatch[T comparable](want, got []T) bool { if len(want) != len(got) { return false