synced with master
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -353,7 +354,7 @@ func bindRealtimeEvents(app core.App) {
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
// note: use the outside scoped app instance for the access checks so that the API ruless
|
||||
// note: use the outside scoped app instance for the access checks so that the API rules
|
||||
// are performed out of the delete transaction ensuring that they would still work even if
|
||||
// a cascade-deleted record's API rule relies on an already deleted parent record
|
||||
err := realtimeBroadcastRecord(e.App, "delete", record, true, app)
|
||||
@@ -375,14 +376,17 @@ func bindRealtimeEvents(app core.App) {
|
||||
// delete: broadcast
|
||||
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeBroadcastDryCachedRecord(e.App, "delete", record)
|
||||
// note: only ensure that it is a collection record
|
||||
// and don't use realtimeResolveRecord because in case of a
|
||||
// custom model it'll fail to resolve since the record is already deleted
|
||||
collection := realtimeResolveRecordCollection(e.App, e.Model)
|
||||
if collection != nil {
|
||||
err := realtimeBroadcastDryCacheKey(e.App, getDryCacheKey("delete", e.Model))
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record delete",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.Any("id", e.Model.PK()),
|
||||
slog.String("collectionName", collection.Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
@@ -398,7 +402,7 @@ func bindRealtimeEvents(app core.App) {
|
||||
Func: func(e *core.ModelErrorEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeUnsetDryCachedRecord(e.App, "delete", record)
|
||||
err := realtimeUnsetDryCacheKey(e.App, getDryCacheKey("delete", record))
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to cleanup after broadcast record delete failure",
|
||||
@@ -418,7 +422,14 @@ func bindRealtimeEvents(app core.App) {
|
||||
// resolveRecord converts *if possible* the provided model interface to a Record.
|
||||
// This is usually helpful if the provided model is a custom Record model struct.
|
||||
func realtimeResolveRecord(app core.App, model core.Model, optCollectionType string) *core.Record {
|
||||
record, _ := model.(*core.Record)
|
||||
var record *core.Record
|
||||
switch m := model.(type) {
|
||||
case *core.Record:
|
||||
record = m
|
||||
case core.RecordProxy:
|
||||
record = m.ProxyRecord()
|
||||
}
|
||||
|
||||
if record != nil {
|
||||
if optCollectionType == "" || record.Collection().Type == optCollectionType {
|
||||
return record
|
||||
@@ -447,14 +458,20 @@ func realtimeResolveRecord(app core.App, model core.Model, optCollectionType str
|
||||
// realtimeResolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
|
||||
// This is usually helpful if the provided model is a custom Record model struct.
|
||||
func realtimeResolveRecordCollection(app core.App, model core.Model) (collection *core.Collection) {
|
||||
if record, ok := model.(*core.Record); ok {
|
||||
collection = record.Collection()
|
||||
} else {
|
||||
// check if it is custom Record model struct (ignore "private" tables)
|
||||
collection, _ = app.FindCachedCollectionByNameOrId(model.TableName())
|
||||
switch m := model.(type) {
|
||||
case *core.Record:
|
||||
return m.Collection()
|
||||
case core.RecordProxy:
|
||||
return m.ProxyRecord().Collection()
|
||||
default:
|
||||
// check if it is custom Record model struct
|
||||
collection, err := app.FindCachedCollectionByNameOrId(model.TableName())
|
||||
if err == nil {
|
||||
return collection
|
||||
}
|
||||
}
|
||||
|
||||
return collection
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordData represents the broadcasted record subscrition message data.
|
||||
@@ -489,7 +506,7 @@ func realtimeBroadcastRecord(app core.App, action string, record *core.Record, d
|
||||
(collection.Id + "?"): collection.ListRule,
|
||||
}
|
||||
|
||||
dryCacheKey := action + "/" + record.Id
|
||||
dryCacheKey := getDryCacheKey(action, record)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
@@ -634,15 +651,13 @@ func realtimeBroadcastRecord(app core.App, action string, record *core.Record, d
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeBroadcastDryCachedRecord broadcasts all cached record related messages.
|
||||
func realtimeBroadcastDryCachedRecord(app core.App, action string, record *core.Record) error {
|
||||
// realtimeBroadcastDryCacheKey broadcasts the dry cached key related messages.
|
||||
func realtimeBroadcastDryCacheKey(app core.App, key string) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
key := action + "/" + record.Id
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
@@ -671,15 +686,13 @@ func realtimeBroadcastDryCachedRecord(app core.App, action string, record *core.
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeUnsetDryCachedRecord removes the dry cached record related messages.
|
||||
func realtimeUnsetDryCachedRecord(app core.App, action string, record *core.Record) error {
|
||||
// realtimeUnsetDryCacheKey removes the dry cached key related messages.
|
||||
func realtimeUnsetDryCacheKey(app core.App, key string) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
key := action + "/" + record.Id
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
@@ -697,6 +710,15 @@ func realtimeUnsetDryCachedRecord(app core.App, action string, record *core.Reco
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
func getDryCacheKey(action string, model core.Model) string {
|
||||
pkStr, ok := model.PK().(string)
|
||||
if !ok {
|
||||
pkStr = fmt.Sprintf("%v", model.PK())
|
||||
}
|
||||
|
||||
return action + "/" + model.TableName() + "/" + pkStr
|
||||
}
|
||||
|
||||
func isSameAuth(authA, authB *core.Record) bool {
|
||||
if authA == nil {
|
||||
return authB == nil
|
||||
|
||||
@@ -2,10 +2,13 @@ package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +17,7 @@ import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRealtimeConnect(t *testing.T) {
|
||||
@@ -632,3 +636,250 @@ func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ core.Model = (*CustomModelResolve)(nil)
|
||||
|
||||
type CustomModelResolve struct {
|
||||
core.BaseModel
|
||||
tableName string
|
||||
|
||||
Created string `db:"created"`
|
||||
}
|
||||
|
||||
func (m *CustomModelResolve) TableName() string {
|
||||
return m.tableName
|
||||
}
|
||||
|
||||
func TestRealtimeRecordResolve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testCollectionName = "realtime_test_collection"
|
||||
|
||||
testRecordId := core.GenerateDefaultRandomId()
|
||||
|
||||
client0 := subscriptions.NewDefaultClient()
|
||||
client0.Subscribe(testCollectionName + "/*")
|
||||
client0.Discard()
|
||||
// ---
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Subscribe(testCollectionName + "/*")
|
||||
// ---
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Subscribe(testCollectionName + "/" + testRecordId)
|
||||
// ---
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Subscribe("demo1/*")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
op func(testApp core.App) error
|
||||
expected map[string][]string // clientId -> [events]
|
||||
}{
|
||||
{
|
||||
"core.Record",
|
||||
func(testApp core.App) error {
|
||||
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := core.NewRecord(c)
|
||||
r.Id = testRecordId
|
||||
|
||||
// create
|
||||
err = testApp.Save(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
err = testApp.Save(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"core.RecordProxy",
|
||||
func(testApp core.App) error {
|
||||
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := core.NewRecord(c)
|
||||
|
||||
proxy := &struct {
|
||||
core.BaseRecordProxy
|
||||
}{}
|
||||
proxy.SetProxyRecord(r)
|
||||
proxy.Id = testRecordId
|
||||
|
||||
// create
|
||||
err = testApp.Save(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
err = testApp.Save(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom model struct",
|
||||
func(testApp core.App) error {
|
||||
m := &CustomModelResolve{tableName: testCollectionName}
|
||||
m.Id = testRecordId
|
||||
|
||||
// create
|
||||
err := testApp.Save(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
m.Created = "123"
|
||||
err = testApp.Save(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
// create new test collection with public read access
|
||||
testCollection := core.NewBaseCollection(testCollectionName)
|
||||
testCollection.Fields.Add(&core.AutodateField{Name: "created", OnCreate: true, OnUpdate: true})
|
||||
testCollection.ListRule = types.Pointer("")
|
||||
testCollection.ViewRule = types.Pointer("")
|
||||
err := testApp.Save(testCollection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testApp.SubscriptionsBroker().Register(client0)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var notifications = map[string][]string{}
|
||||
|
||||
var mu sync.Mutex
|
||||
notify := func(clientId string, eventData []byte) {
|
||||
data := struct{ Action string }{}
|
||||
_ = json.Unmarshal(eventData, &data)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if notifications[clientId] == nil {
|
||||
notifications[clientId] = []string{}
|
||||
}
|
||||
notifications[clientId] = append(notifications[clientId], data.Action)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
timeout := time.After(250 * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case e, ok := <-client0.Channel():
|
||||
if ok {
|
||||
notify(client0.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client1.Channel():
|
||||
if ok {
|
||||
notify(client1.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client2.Channel():
|
||||
if ok {
|
||||
notify(client2.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client3.Channel():
|
||||
if ok {
|
||||
notify(client3.Id(), e.Data)
|
||||
}
|
||||
case <-timeout:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = s.op(testApp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(s.expected) != len(notifications) {
|
||||
t.Fatalf("Expected %d notified clients, got %d:\n%v", len(s.expected), len(notifications), notifications)
|
||||
}
|
||||
|
||||
for id, events := range s.expected {
|
||||
if len(events) != len(notifications[id]) {
|
||||
t.Fatalf("[%s] Expected %d events, got %d:\n%v\n%v", id, len(events), len(notifications[id]), s.expected, notifications)
|
||||
}
|
||||
for _, event := range events {
|
||||
if !slices.Contains(notifications[id], event) {
|
||||
t.Fatalf("[%s] Missing expected event %q in %v", id, event, notifications[id])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user