merge v0.23.0-rc changes
This commit is contained in:
@@ -6,231 +6,742 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRequestInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/?test=123", strings.NewReader(`{"test":456}`))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
req.Header.Set("X-Token-Test", "123")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
dummyRecord := &models.Record{}
|
||||
dummyRecord.Id = "id1"
|
||||
c.Set(apis.ContextAuthRecordKey, dummyRecord)
|
||||
|
||||
dummyAdmin := &models.Admin{}
|
||||
dummyAdmin.Id = "id2"
|
||||
c.Set(apis.ContextAdminKey, dummyAdmin)
|
||||
|
||||
result := apis.RequestInfo(c)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected *models.RequestInfo instance, got nil")
|
||||
}
|
||||
|
||||
if result.Method != http.MethodPost {
|
||||
t.Fatalf("Expected Method %v, got %v", http.MethodPost, result.Method)
|
||||
}
|
||||
|
||||
rawHeaders, _ := json.Marshal(result.Headers)
|
||||
expectedHeaders := `{"content_type":"application/json","x_token_test":"123"}`
|
||||
if v := string(rawHeaders); v != expectedHeaders {
|
||||
t.Fatalf("Expected Query %v, got %v", expectedHeaders, v)
|
||||
}
|
||||
|
||||
rawQuery, _ := json.Marshal(result.Query)
|
||||
expectedQuery := `{"test":"123"}`
|
||||
if v := string(rawQuery); v != expectedQuery {
|
||||
t.Fatalf("Expected Query %v, got %v", expectedQuery, v)
|
||||
}
|
||||
|
||||
rawData, _ := json.Marshal(result.Data)
|
||||
expectedData := `{"test":456}`
|
||||
if v := string(rawData); v != expectedData {
|
||||
t.Fatalf("Expected Data %v, got %v", expectedData, v)
|
||||
}
|
||||
|
||||
if result.AuthRecord == nil || result.AuthRecord.Id != dummyRecord.Id {
|
||||
t.Fatalf("Expected AuthRecord %v, got %v", dummyRecord, result.AuthRecord)
|
||||
}
|
||||
|
||||
if result.Admin == nil || result.Admin.Id != dummyAdmin.Id {
|
||||
t.Fatalf("Expected Admin %v, got %v", dummyAdmin, result.Admin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponse(t *testing.T) {
|
||||
func TestEnrichRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// mock test data
|
||||
// ---
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
dummyAdmin := &models.Admin{}
|
||||
dummyAdmin.Id = "id1"
|
||||
|
||||
nonAuthRecord, err := app.Dao().FindRecordById("demo1", "al1h9ijdeojtsjy")
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authRecord, err := app.Dao().FindRecordById("users", "4q1xlclmfloku33")
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
unverifiedAuthRecord, err := app.Dao().FindRecordById("clients", "o1y0dd0spd786md")
|
||||
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// temp update the view rule to ensure that request context is set to "expand"
|
||||
demo4, err := app.FindCollectionByNameOrId("demo4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
|
||||
if err := app.Save(demo4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
record *models.Record
|
||||
meta any
|
||||
expectError bool
|
||||
expectedContent []string
|
||||
notExpectedContent []string
|
||||
expectedEvents map[string]int
|
||||
name string
|
||||
auth *core.Record
|
||||
records []*core.Record
|
||||
queryExpand string
|
||||
defaultExpands []string
|
||||
expected []string
|
||||
notExpected []string
|
||||
}{
|
||||
// email visibility checks
|
||||
{
|
||||
name: "non auth record",
|
||||
record: nonAuthRecord,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid auth record but with unverified email in onlyVerified collection",
|
||||
record: unverifiedAuthRecord,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid auth record - without meta",
|
||||
record: authRecord,
|
||||
expectError: false,
|
||||
expectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"id":"`,
|
||||
`"expand":{"rel":{`,
|
||||
name: "[emailVisibility] guest",
|
||||
auth: nil,
|
||||
records: usersRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
},
|
||||
notExpectedContent: []string{
|
||||
`"meta":`,
|
||||
},
|
||||
expectedEvents: map[string]int{
|
||||
"OnRecordAuthRequest": 1,
|
||||
notExpected: []string{
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid auth record - with meta",
|
||||
record: authRecord,
|
||||
meta: map[string]any{"meta_test": 123},
|
||||
expectError: false,
|
||||
expectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"id":"`,
|
||||
`"expand":{"rel":{`,
|
||||
`"meta":{"meta_test":123`,
|
||||
name: "[emailVisibility] owner",
|
||||
auth: user,
|
||||
records: usersRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
`"test@example.com"`, // owner
|
||||
},
|
||||
expectedEvents: map[string]int{
|
||||
"OnRecordAuthRequest": 1,
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] manager",
|
||||
auth: user,
|
||||
records: nologinRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] superuser",
|
||||
auth: superuser,
|
||||
records: nologinRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
|
||||
auth: user,
|
||||
records: demo1Records,
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"expand":{}`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
|
||||
auth: superuser,
|
||||
records: demo1Records,
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test@example.com"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
|
||||
// expand checks
|
||||
{
|
||||
name: "[expand] guest (query)",
|
||||
auth: nil,
|
||||
records: usersRecords,
|
||||
queryExpand: "rel",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] guest (default expands)",
|
||||
auth: nil,
|
||||
records: usersRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] @request.context=expand check",
|
||||
auth: nil,
|
||||
records: demo5Records,
|
||||
queryExpand: "rel_one",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{}`,
|
||||
`"expand":{"`,
|
||||
`"rel_many":[{`,
|
||||
`"rel_one":{`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
app.ResetEventCalls()
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand=rel", nil)
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Set(apis.ContextAdminKey, dummyAdmin)
|
||||
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
|
||||
e.Record.WithCustomData(true)
|
||||
e.Record.Set("customField", "123")
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
responseErr := apis.RecordAuthResponse(app, c, s.record, s.meta)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
hasErr := responseErr != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.name, s.expectError, hasErr, responseErr)
|
||||
}
|
||||
requestEvent := new(core.RequestEvent)
|
||||
requestEvent.App = app
|
||||
requestEvent.Request = req
|
||||
requestEvent.Response = rec
|
||||
requestEvent.Auth = s.auth
|
||||
|
||||
if len(app.EventCalls) != len(s.expectedEvents) {
|
||||
t.Fatalf("[%s] Expected events \n%v, \ngot \n%v", s.name, s.expectedEvents, app.EventCalls)
|
||||
}
|
||||
for k, v := range s.expectedEvents {
|
||||
if app.EventCalls[k] != v {
|
||||
t.Fatalf("[%s] Expected event %s to be called %d times, got %d", s.name, k, v, app.EventCalls[k])
|
||||
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
response := rec.Body.String()
|
||||
|
||||
for _, v := range s.expectedContent {
|
||||
if !strings.Contains(response, v) {
|
||||
t.Fatalf("[%s] Missing %v in response \n%v", s.name, v, response)
|
||||
raw, err := json.Marshal(s.records)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
for _, v := range s.notExpectedContent {
|
||||
if strings.Contains(response, v) {
|
||||
t.Fatalf("[%s] Unexpected %v in response \n%v", s.name, v, response)
|
||||
for _, str := range s.expected {
|
||||
if !strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, str := range s.notExpected {
|
||||
if strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand=rel_many", nil)
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
dummyAdmin := &models.Admin{}
|
||||
dummyAdmin.Id = "test_id"
|
||||
c.Set(apis.ContextAdminKey, dummyAdmin)
|
||||
|
||||
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
records, err := app.Dao().FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
apis.EnrichRecords(c, app.Dao(), records, "rel_one")
|
||||
scenarios := []struct {
|
||||
name string
|
||||
rule *string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"admin only rule",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty rule",
|
||||
types.Pointer(""),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"false rule",
|
||||
types.Pointer("1=2"),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"true rule",
|
||||
types.Pointer("1=1"),
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
expand := record.Expand()
|
||||
if len(expand) == 0 {
|
||||
t.Fatalf("Expected non-empty expand, got nil for record %v", record)
|
||||
}
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
user.Collection().AuthRule = s.rule
|
||||
|
||||
if len(record.GetStringSlice("rel_one")) != 0 {
|
||||
if _, ok := expand["rel_one"]; !ok {
|
||||
t.Fatalf("Expected rel_one to be expanded for record %v, got \n%v", record, expand)
|
||||
err := apis.RecordAuthResponse(event, user, "", nil)
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(record.GetStringSlice("rel_many")) != 0 {
|
||||
if _, ok := expand["rel_many"]; !ok {
|
||||
t.Fatalf("Expected rel_many to be expanded for record %v, got \n%v", record, expand)
|
||||
// in all cases login alert shouldn't be send because of the empty auth method
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
apiErr, ok := err.(*router.ApiError)
|
||||
|
||||
if !ok || apiErr == nil {
|
||||
t.Fatalf("Expected ApiError, got %v", apiErr)
|
||||
}
|
||||
|
||||
if apiErr.Status != http.StatusForbidden {
|
||||
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
|
||||
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
devices []string // mock existing device fingerprints
|
||||
expectDevices []string
|
||||
enabled bool
|
||||
expectEmail bool
|
||||
}{
|
||||
{
|
||||
name: "first login",
|
||||
devices: nil,
|
||||
expectDevices: []string{testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "existing device",
|
||||
devices: []string{"1", testFingerprint},
|
||||
expectDevices: []string{"1", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "new device (< 5)",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "new device (>= 5)",
|
||||
devices: []string{"1", "2", "3", "4", "5"},
|
||||
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "with disabled auth alert collection flag",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2"},
|
||||
enabled: false,
|
||||
expectEmail: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
user.Collection().AuthRule = types.Pointer("")
|
||||
user.Collection().AuthAlert.Enabled = s.enabled
|
||||
|
||||
// ensure that there are no other auth origins
|
||||
err = app.DeleteAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert the mock devices
|
||||
for _, fingerprint := range s.devices {
|
||||
d := core.NewAuthOrigin(app)
|
||||
d.SetCollectionRef(user.Collection().Id)
|
||||
d.SetRecordRef(user.Id)
|
||||
d.SetFingerprint(fingerprint)
|
||||
if err = app.Save(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve auth response: %v", err)
|
||||
}
|
||||
|
||||
var expectTotalSend int
|
||||
if s.expectEmail {
|
||||
expectTotalSend = 1
|
||||
}
|
||||
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
|
||||
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
|
||||
}
|
||||
|
||||
devices, err := app.FindAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve auth origins: %v", err)
|
||||
}
|
||||
|
||||
if len(devices) != len(s.expectDevices) {
|
||||
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
|
||||
}
|
||||
|
||||
for _, fingerprint := range s.expectDevices {
|
||||
var exists bool
|
||||
fingerprints := make([]string, 0, len(devices))
|
||||
for _, d := range devices {
|
||||
if d.Fingerprint() == fingerprint {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
fingerprints = append(fingerprints, d.Fingerprint())
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseMFACheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = rec
|
||||
|
||||
resetMFAs := func(authRecord *core.Record) {
|
||||
// ensure that mfa is enabled
|
||||
user.Collection().MFA.Enabled = true
|
||||
user.Collection().MFA.Duration = 5
|
||||
user.Collection().MFA.Rule = ""
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
for _, mfa := range mfas {
|
||||
if err := app.Delete(mfa); err != nil {
|
||||
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// reset response
|
||||
rec = httptest.NewRecorder()
|
||||
event.Response = rec
|
||||
}
|
||||
|
||||
totalMFAs := func(authRecord *core.Record) int {
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
return len(mfas)
|
||||
}
|
||||
|
||||
t.Run("no collection MFA enabled", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no explicit auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=2"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=1"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa first-time", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected 0 mfa records, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
|
||||
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if totalMFAs(user) != 0 {
|
||||
t.Fatal("Expected the expired mfa record to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa for different auth record", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user2.Collection().Id)
|
||||
mfa.SetRecordRef(user2.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no user mfas, got %d", total)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user2); total != 1 {
|
||||
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user