diff --git a/apis/record_auth_with_otp.go b/apis/record_auth_with_otp.go index 742c6bfc..92e5bdce 100644 --- a/apis/record_auth_with_otp.go +++ b/apis/record_auth_with_otp.go @@ -65,10 +65,19 @@ func recordAuthWithOTP(e *core.RequestEvent) error { // --- return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error { + otpId := e.OTP.Id + otpSentTo := e.OTP.SentTo() + + // eagerly delete the OTP to avoid unnecessery double delete model hook calls + // triggered by the password change below + err := e.App.Delete(e.OTP) + if err != nil { + e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id) + } + // update the user email verified state in case the OTP originate from an email address matching the current record one // // note: don't wait for success auth response (it could fail because of MFA) and because we already validated the OTP above - otpSentTo := e.OTP.SentTo() if !e.Record.Verified() && otpSentTo != "" && e.Record.Email() == otpSentTo { e.Record.SetVerified(true) @@ -82,18 +91,12 @@ func recordAuthWithOTP(e *core.RequestEvent) error { if err := e.App.Save(e.Record); err != nil { e.App.Logger().Error("Failed to update record verified state after successful OTP validation", "error", err, - "otpId", e.OTP.Id, + "otpId", otpId, "recordId", e.Record.Id, ) } } - // try to delete the used otp - err = e.App.Delete(e.OTP) - if err != nil { - e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id) - } - return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil) }) } diff --git a/apis/record_auth_with_otp_test.go b/apis/record_auth_with_otp_test.go index 52183e3c..a0d950be 100644 --- a/apis/record_auth_with_otp_test.go +++ b/apis/record_auth_with_otp_test.go @@ -406,10 +406,10 @@ func TestRecordAuthWithOTP(t *testing.T) { "OnModelCreate": 1, "OnModelCreateExecute": 1, "OnModelAfterCreateSuccess": 1, - // 2 record OTPs + 2 ExternalAuths delete - "OnModelDelete": 4, - "OnModelDeleteExecute": 4, - "OnModelAfterDeleteSuccess": 4, + // record OTP + 2 ExternalAuths delete + "OnModelDelete": 3, + "OnModelDeleteExecute": 3, + "OnModelAfterDeleteSuccess": 3, // user verified update "OnModelUpdate": 1, "OnModelUpdateExecute": 1, @@ -419,9 +419,9 @@ func TestRecordAuthWithOTP(t *testing.T) { "OnRecordCreate": 1, "OnRecordCreateExecute": 1, "OnRecordAfterCreateSuccess": 1, - "OnRecordDelete": 4, - "OnRecordDeleteExecute": 4, - "OnRecordAfterDeleteSuccess": 4, + "OnRecordDelete": 3, + "OnRecordDeleteExecute": 3, + "OnRecordAfterDeleteSuccess": 3, "OnRecordUpdate": 1, "OnRecordUpdateExecute": 1, "OnRecordAfterUpdateSuccess": 1, diff --git a/core/external_auth_model_test.go b/core/external_auth_model_test.go index 512f6771..46f7be7b 100644 --- a/core/external_auth_model_test.go +++ b/core/external_auth_model_test.go @@ -308,3 +308,104 @@ func TestExternalAuthValidateHook(t *testing.T) { }) } } + +func TestExternalAuthClearOnVerfiedUpgrade(t *testing.T) { + t.Parallel() + + app, _ := tests.NewTestApp() + defer app.Cleanup() + + t.Run("unverified->no changes", func(t *testing.T) { + user, err := app.FindAuthRecordByEmail("users", "test@example.com") + if err != nil { + t.Fatal(err) + } + + if user.Verified() { + t.Fatal("Expected user to be unverified") + } + + beforeAuths, err := app.FindAllExternalAuthsByRecord(user) + if err != nil || len(beforeAuths) == 0 { + t.Fatalf("Expected at least one external auth (%v)", err) + } + + oldTokenKey := user.TokenKey() + + if err = app.Save(user); err != nil { + t.Fatal(err) + } + + if oldTokenKey != user.TokenKey() { + t.Fatal("Expected tokenKey to remain unchanged") + } + + afterAuths, err := app.FindAllExternalAuthsByRecord(user) + if err != nil || len(afterAuths) != len(beforeAuths) { + t.Fatalf("Expected %d external auths, found %d (%v)", len(afterAuths), len(beforeAuths), err) + } + }) + + t.Run("unverified->verified", func(t *testing.T) { + user, err := app.FindAuthRecordByEmail("users", "test@example.com") + if err != nil { + t.Fatal(err) + } + + if user.Verified() { + t.Fatal("Expected user to be unverified") + } + + externalAuths, err := app.FindAllExternalAuthsByRecord(user) + if err != nil || len(externalAuths) == 0 { + t.Fatalf("Expected at least one external auth (%v)", err) + } + + oldTokenKey := user.TokenKey() + + user.SetVerified(true) + if err = app.Save(user); err != nil { + t.Fatal(err) + } + + if oldTokenKey == user.TokenKey() { + t.Fatal("Expected tokenKey to be renewed") + } + + externalAuths, err = app.FindAllExternalAuthsByRecord(user) + if err != nil || len(externalAuths) != 0 { + t.Fatalf("Expected all user external auths to be deleted, found %d (%v)", len(externalAuths), err) + } + }) + + t.Run("verified->no changes", func(t *testing.T) { + user, err := app.FindAuthRecordByEmail("users", "test3@example.com") + if err != nil { + t.Fatal(err) + } + + if !user.Verified() { + t.Fatal("Expected user to be verified") + } + + beforeAuths, err := app.FindAllExternalAuthsByRecord(user) + if err != nil || len(beforeAuths) == 0 { + t.Fatalf("Expected at least one external auth (%v)", err) + } + + oldTokenKey := user.TokenKey() + + if err = app.Save(user); err != nil { + t.Fatal(err) + } + + if oldTokenKey != user.TokenKey() { + t.Fatal("Expected tokenKey to remain unchanged") + } + + afterAuths, err := app.FindAllExternalAuthsByRecord(user) + if err != nil || len(afterAuths) != len(beforeAuths) { + t.Fatalf("Expected %d external auths, found %d (%v)", len(afterAuths), len(beforeAuths), err) + } + }) +} diff --git a/core/mfa_model_test.go b/core/mfa_model_test.go index 97669749..694b7bef 100644 --- a/core/mfa_model_test.go +++ b/core/mfa_model_test.go @@ -300,3 +300,64 @@ func TestMFAValidateHook(t *testing.T) { }) } } + +func TestMFAClearOnPasswordChange(t *testing.T) { + t.Parallel() + + app, _ := tests.NewTestApp() + defer app.Cleanup() + + user1, err := app.FindAuthRecordByEmail("users", "test@example.com") + if err != nil { + t.Fatal(err) + } + + user2, err := app.FindAuthRecordByEmail("users", "test2@example.com") + if err != nil { + t.Fatal(err) + } + + mfasToCreate := map[*core.Record]int{ + user1: 3, + user2: 2, + } + for user, total := range mfasToCreate { + for range total { + mfa := core.NewMFA(app) + mfa.SetCollectionRef(user.Collection().Id) + mfa.SetRecordRef(user.Id) + mfa.SetMethod(core.MFAMethodPassword) + if err := app.Save(mfa); err != nil { + t.Fatal(err) + } + } + } + + // update both users + err = app.Save(user1) + if err != nil { + t.Fatal(err) + } + + user2.SetRandomPassword() + err = app.Save(user2) + if err != nil { + t.Fatal(err) + } + + expectedMFAs := map[*core.Record]int{ + user1: 3, + user2: 0, + } + + for user, expected := range expectedMFAs { + mfas, err := app.FindAllMFAsByRecord(user) + if err != nil { + t.Fatal(err) + } + + if len(mfas) != expected { + t.Fatalf("Expected %d MFAs, got %d", expected, len(mfas)) + } + } +} diff --git a/core/otp_model.go b/core/otp_model.go index 385b5013..78e58e3c 100644 --- a/core/otp_model.go +++ b/core/otp_model.go @@ -135,7 +135,7 @@ func (app *BaseApp) registerOTPHooks() { return err } - if e.Record.Original().TokenKey() != e.Record.TokenKey() { + if !e.Record.Original().IsNew() && e.Record.Original().TokenKey() != e.Record.TokenKey() { err := e.App.DeleteAllOTPsByRecord(e.Record) if err != nil { return fmt.Errorf( diff --git a/core/otp_model_test.go b/core/otp_model_test.go index f0d5155a..a30aca13 100644 --- a/core/otp_model_test.go +++ b/core/otp_model_test.go @@ -300,3 +300,64 @@ func TestOTPValidateHook(t *testing.T) { }) } } + +func TestOTPClearOnTokenKeyChange(t *testing.T) { + t.Parallel() + + app, _ := tests.NewTestApp() + defer app.Cleanup() + + user1, err := app.FindAuthRecordByEmail("users", "test@example.com") + if err != nil { + t.Fatal(err) + } + + user2, err := app.FindAuthRecordByEmail("users", "test2@example.com") + if err != nil { + t.Fatal(err) + } + + otpsToCreate := map[*core.Record]int{ + user1: 3, + user2: 2, + } + for user, total := range otpsToCreate { + for range total { + otp := core.NewOTP(app) + otp.SetCollectionRef(user.Collection().Id) + otp.SetRecordRef(user.Id) + otp.SetPassword("123456") + if err := app.Save(otp); err != nil { + t.Fatal(err) + } + } + } + + // update both users + err = app.Save(user1) + if err != nil { + t.Fatal(err) + } + + user2.RefreshTokenKey() + err = app.Save(user2) + if err != nil { + t.Fatal(err) + } + + expectedOTPs := map[*core.Record]int{ + user1: 3, + user2: 0, + } + + for user, expected := range expectedOTPs { + otps, err := app.FindAllOTPsByRecord(user) + if err != nil { + t.Fatal(err) + } + + if len(otps) != expected { + t.Fatalf("Expected %d OTPs, got %d", expected, len(otps)) + } + } +} diff --git a/plugins/migratecmd/migratecmd_test.go b/plugins/migratecmd/migratecmd_test.go index 310147cc..e19846e1 100644 --- a/plugins/migratecmd/migratecmd_test.go +++ b/plugins/migratecmd/migratecmd_test.go @@ -137,7 +137,7 @@ migrate((app) => { "listRule": "@request.auth.id != '' && 1 > 0 || 'backtick` + "`" + `test' = 0", "manageRule": "1 != 2", "mfa": { - "duration": 1800, + "duration": 600, "enabled": false, "rule": "" }, @@ -319,7 +319,7 @@ func init() { "listRule": "@request.auth.id != '' && 1 > 0 || 'backtick` + "` + \"`\" + `" + `test' = 0", "manageRule": "1 != 2", "mfa": { - "duration": 1800, + "duration": 600, "enabled": false, "rule": "" }, @@ -590,7 +590,7 @@ migrate((app) => { "listRule": "@request.auth.id != '' && 1 > 0 || 'backtick` + "`" + `test' = 0", "manageRule": "1 != 2", "mfa": { - "duration": 1800, + "duration": 600, "enabled": false, "rule": "" }, @@ -775,7 +775,7 @@ func init() { "listRule": "@request.auth.id != '' && 1 > 0 || 'backtick` + "` + \"`\" + `" + `test' = 0", "manageRule": "1 != 2", "mfa": { - "duration": 1800, + "duration": 600, "enabled": false, "rule": "" },