synced with master

This commit is contained in:
Gani Georgiev
2025-02-10 09:38:15 +02:00
57 changed files with 1572 additions and 1147 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
@@ -353,7 +354,7 @@ func bindRealtimeEvents(app core.App) {
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
// note: use the outside scoped app instance for the access checks so that the API ruless
// note: use the outside scoped app instance for the access checks so that the API rules
// are performed out of the delete transaction ensuring that they would still work even if
// a cascade-deleted record's API rule relies on an already deleted parent record
err := realtimeBroadcastRecord(e.App, "delete", record, true, app)
@@ -375,14 +376,17 @@ func bindRealtimeEvents(app core.App) {
// delete: broadcast
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastDryCachedRecord(e.App, "delete", record)
// note: only ensure that it is a collection record
// and don't use realtimeResolveRecord because in case of a
// custom model it'll fail to resolve since the record is already deleted
collection := realtimeResolveRecordCollection(e.App, e.Model)
if collection != nil {
err := realtimeBroadcastDryCacheKey(e.App, getDryCacheKey("delete", e.Model))
if err != nil {
app.Logger().Debug(
"Failed to broadcast record delete",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
slog.Any("id", e.Model.PK()),
slog.String("collectionName", collection.Name),
slog.String("error", err.Error()),
)
}
@@ -398,7 +402,7 @@ func bindRealtimeEvents(app core.App) {
Func: func(e *core.ModelErrorEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeUnsetDryCachedRecord(e.App, "delete", record)
err := realtimeUnsetDryCacheKey(e.App, getDryCacheKey("delete", record))
if err != nil {
app.Logger().Debug(
"Failed to cleanup after broadcast record delete failure",
@@ -418,7 +422,14 @@ func bindRealtimeEvents(app core.App) {
// resolveRecord converts *if possible* the provided model interface to a Record.
// This is usually helpful if the provided model is a custom Record model struct.
func realtimeResolveRecord(app core.App, model core.Model, optCollectionType string) *core.Record {
record, _ := model.(*core.Record)
var record *core.Record
switch m := model.(type) {
case *core.Record:
record = m
case core.RecordProxy:
record = m.ProxyRecord()
}
if record != nil {
if optCollectionType == "" || record.Collection().Type == optCollectionType {
return record
@@ -447,14 +458,20 @@ func realtimeResolveRecord(app core.App, model core.Model, optCollectionType str
// realtimeResolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
// This is usually helpful if the provided model is a custom Record model struct.
func realtimeResolveRecordCollection(app core.App, model core.Model) (collection *core.Collection) {
if record, ok := model.(*core.Record); ok {
collection = record.Collection()
} else {
// check if it is custom Record model struct (ignore "private" tables)
collection, _ = app.FindCachedCollectionByNameOrId(model.TableName())
switch m := model.(type) {
case *core.Record:
return m.Collection()
case core.RecordProxy:
return m.ProxyRecord().Collection()
default:
// check if it is custom Record model struct
collection, err := app.FindCachedCollectionByNameOrId(model.TableName())
if err == nil {
return collection
}
}
return collection
return nil
}
// recordData represents the broadcasted record subscrition message data.
@@ -489,7 +506,7 @@ func realtimeBroadcastRecord(app core.App, action string, record *core.Record, d
(collection.Id + "?"): collection.ListRule,
}
dryCacheKey := action + "/" + record.Id
dryCacheKey := getDryCacheKey(action, record)
group := new(errgroup.Group)
@@ -634,15 +651,13 @@ func realtimeBroadcastRecord(app core.App, action string, record *core.Record, d
return group.Wait()
}
// realtimeBroadcastDryCachedRecord broadcasts all cached record related messages.
func realtimeBroadcastDryCachedRecord(app core.App, action string, record *core.Record) error {
// realtimeBroadcastDryCacheKey broadcasts the dry cached key related messages.
func realtimeBroadcastDryCacheKey(app core.App, key string) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
key := action + "/" + record.Id
group := new(errgroup.Group)
for _, chunk := range chunks {
@@ -671,15 +686,13 @@ func realtimeBroadcastDryCachedRecord(app core.App, action string, record *core.
return group.Wait()
}
// realtimeUnsetDryCachedRecord removes the dry cached record related messages.
func realtimeUnsetDryCachedRecord(app core.App, action string, record *core.Record) error {
// realtimeUnsetDryCacheKey removes the dry cached key related messages.
func realtimeUnsetDryCacheKey(app core.App, key string) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
key := action + "/" + record.Id
group := new(errgroup.Group)
for _, chunk := range chunks {
@@ -697,6 +710,15 @@ func realtimeUnsetDryCachedRecord(app core.App, action string, record *core.Reco
return group.Wait()
}
func getDryCacheKey(action string, model core.Model) string {
pkStr, ok := model.PK().(string)
if !ok {
pkStr = fmt.Sprintf("%v", model.PK())
}
return action + "/" + model.TableName() + "/" + pkStr
}
func isSameAuth(authA, authB *core.Record) bool {
if authA == nil {
return authB == nil

View File

@@ -2,10 +2,13 @@ package apis_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -14,6 +17,7 @@ import (
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRealtimeConnect(t *testing.T) {
@@ -632,3 +636,250 @@ func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
}
}
// -------------------------------------------------------------------
var _ core.Model = (*CustomModelResolve)(nil)
type CustomModelResolve struct {
core.BaseModel
tableName string
Created string `db:"created"`
}
func (m *CustomModelResolve) TableName() string {
return m.tableName
}
func TestRealtimeRecordResolve(t *testing.T) {
t.Parallel()
const testCollectionName = "realtime_test_collection"
testRecordId := core.GenerateDefaultRandomId()
client0 := subscriptions.NewDefaultClient()
client0.Subscribe(testCollectionName + "/*")
client0.Discard()
// ---
client1 := subscriptions.NewDefaultClient()
client1.Subscribe(testCollectionName + "/*")
// ---
client2 := subscriptions.NewDefaultClient()
client2.Subscribe(testCollectionName + "/" + testRecordId)
// ---
client3 := subscriptions.NewDefaultClient()
client3.Subscribe("demo1/*")
scenarios := []struct {
name string
op func(testApp core.App) error
expected map[string][]string // clientId -> [events]
}{
{
"core.Record",
func(testApp core.App) error {
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
if err != nil {
return err
}
r := core.NewRecord(c)
r.Id = testRecordId
// create
err = testApp.Save(r)
if err != nil {
return err
}
// update
err = testApp.Save(r)
if err != nil {
return err
}
// delete
err = testApp.Delete(r)
if err != nil {
return err
}
return nil
},
map[string][]string{
client1.Id(): {"create", "update", "delete"},
client2.Id(): {"create", "update", "delete"},
},
},
{
"core.RecordProxy",
func(testApp core.App) error {
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
if err != nil {
return err
}
r := core.NewRecord(c)
proxy := &struct {
core.BaseRecordProxy
}{}
proxy.SetProxyRecord(r)
proxy.Id = testRecordId
// create
err = testApp.Save(proxy)
if err != nil {
return err
}
// update
err = testApp.Save(proxy)
if err != nil {
return err
}
// delete
err = testApp.Delete(proxy)
if err != nil {
return err
}
return nil
},
map[string][]string{
client1.Id(): {"create", "update", "delete"},
client2.Id(): {"create", "update", "delete"},
},
},
{
"custom model struct",
func(testApp core.App) error {
m := &CustomModelResolve{tableName: testCollectionName}
m.Id = testRecordId
// create
err := testApp.Save(m)
if err != nil {
return err
}
// update
m.Created = "123"
err = testApp.Save(m)
if err != nil {
return err
}
// delete
err = testApp.Delete(m)
if err != nil {
return err
}
return nil
},
map[string][]string{
client1.Id(): {"create", "update", "delete"},
client2.Id(): {"create", "update", "delete"},
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// init realtime handlers
apis.NewRouter(testApp)
// create new test collection with public read access
testCollection := core.NewBaseCollection(testCollectionName)
testCollection.Fields.Add(&core.AutodateField{Name: "created", OnCreate: true, OnUpdate: true})
testCollection.ListRule = types.Pointer("")
testCollection.ViewRule = types.Pointer("")
err := testApp.Save(testCollection)
if err != nil {
t.Fatal(err)
}
testApp.SubscriptionsBroker().Register(client0)
testApp.SubscriptionsBroker().Register(client1)
testApp.SubscriptionsBroker().Register(client2)
testApp.SubscriptionsBroker().Register(client3)
var wg sync.WaitGroup
var notifications = map[string][]string{}
var mu sync.Mutex
notify := func(clientId string, eventData []byte) {
data := struct{ Action string }{}
_ = json.Unmarshal(eventData, &data)
mu.Lock()
defer mu.Unlock()
if notifications[clientId] == nil {
notifications[clientId] = []string{}
}
notifications[clientId] = append(notifications[clientId], data.Action)
}
wg.Add(1)
go func() {
defer wg.Done()
timeout := time.After(250 * time.Millisecond)
for {
select {
case e, ok := <-client0.Channel():
if ok {
notify(client0.Id(), e.Data)
}
case e, ok := <-client1.Channel():
if ok {
notify(client1.Id(), e.Data)
}
case e, ok := <-client2.Channel():
if ok {
notify(client2.Id(), e.Data)
}
case e, ok := <-client3.Channel():
if ok {
notify(client3.Id(), e.Data)
}
case <-timeout:
return
}
}
}()
err = s.op(testApp)
if err != nil {
t.Fatal(err)
}
wg.Wait()
if len(s.expected) != len(notifications) {
t.Fatalf("Expected %d notified clients, got %d:\n%v", len(s.expected), len(notifications), notifications)
}
for id, events := range s.expected {
if len(events) != len(notifications[id]) {
t.Fatalf("[%s] Expected %d events, got %d:\n%v\n%v", id, len(events), len(notifications[id]), s.expected, notifications)
}
for _, event := range events {
if !slices.Contains(notifications[id], event) {
t.Fatalf("[%s] Missing expected event %q in %v", id, event, notifications[id])
}
}
}
})
}
}