added action arg to the before Dao hook to allow skipping the default persist behavior

This commit is contained in:
Gani Georgiev
2023-07-29 19:52:36 +03:00
parent 6da94aef8d
commit cdeb9a94ed
5 changed files with 232 additions and 182 deletions

View File

@@ -5,6 +5,8 @@ package daos
import (
"errors"
"fmt"
"strings"
"time"
"github.com/pocketbase/dbx"
@@ -45,12 +47,12 @@ type Dao struct {
ModelQueryTimeout time.Duration
// write hooks
BeforeCreateFunc func(eventDao *Dao, m models.Model) error
AfterCreateFunc func(eventDao *Dao, m models.Model)
BeforeUpdateFunc func(eventDao *Dao, m models.Model) error
AfterUpdateFunc func(eventDao *Dao, m models.Model)
BeforeDeleteFunc func(eventDao *Dao, m models.Model) error
AfterDeleteFunc func(eventDao *Dao, m models.Model)
BeforeCreateFunc func(eventDao *Dao, m models.Model, action func() error) error
AfterCreateFunc func(eventDao *Dao, m models.Model) error
BeforeUpdateFunc func(eventDao *Dao, m models.Model, action func() error) error
AfterUpdateFunc func(eventDao *Dao, m models.Model) error
BeforeDeleteFunc func(eventDao *Dao, m models.Model, action func() error) error
AfterDeleteFunc func(eventDao *Dao, m models.Model) error
}
// DB returns the default dao db builder (*dbx.DB or *dbx.TX).
@@ -151,56 +153,75 @@ func (dao *Dao) RunInTransaction(fn func(txDao *Dao) error) error {
txDao := New(tx)
if dao.BeforeCreateFunc != nil {
txDao.BeforeCreateFunc = func(eventDao *Dao, m models.Model) error {
return dao.BeforeCreateFunc(eventDao, m)
txDao.BeforeCreateFunc = func(eventDao *Dao, m models.Model, action func() error) error {
return dao.BeforeCreateFunc(eventDao, m, action)
}
}
if dao.BeforeUpdateFunc != nil {
txDao.BeforeUpdateFunc = func(eventDao *Dao, m models.Model) error {
return dao.BeforeUpdateFunc(eventDao, m)
txDao.BeforeUpdateFunc = func(eventDao *Dao, m models.Model, action func() error) error {
return dao.BeforeUpdateFunc(eventDao, m, action)
}
}
if dao.BeforeDeleteFunc != nil {
txDao.BeforeDeleteFunc = func(eventDao *Dao, m models.Model) error {
return dao.BeforeDeleteFunc(eventDao, m)
txDao.BeforeDeleteFunc = func(eventDao *Dao, m models.Model, action func() error) error {
return dao.BeforeDeleteFunc(eventDao, m, action)
}
}
if dao.AfterCreateFunc != nil {
txDao.AfterCreateFunc = func(eventDao *Dao, m models.Model) {
txDao.AfterCreateFunc = func(eventDao *Dao, m models.Model) error {
afterCalls = append(afterCalls, afterCallGroup{"create", eventDao, m})
return nil
}
}
if dao.AfterUpdateFunc != nil {
txDao.AfterUpdateFunc = func(eventDao *Dao, m models.Model) {
txDao.AfterUpdateFunc = func(eventDao *Dao, m models.Model) error {
afterCalls = append(afterCalls, afterCallGroup{"update", eventDao, m})
return nil
}
}
if dao.AfterDeleteFunc != nil {
txDao.AfterDeleteFunc = func(eventDao *Dao, m models.Model) {
txDao.AfterDeleteFunc = func(eventDao *Dao, m models.Model) error {
afterCalls = append(afterCalls, afterCallGroup{"delete", eventDao, m})
return nil
}
}
return fn(txDao)
})
if txError == nil {
// execute after event calls on successful transaction
// (note: using the non-transaction dao to allow following queries in the after hooks)
for _, call := range afterCalls {
switch call.Action {
case "create":
dao.AfterCreateFunc(dao, call.Model)
case "update":
dao.AfterUpdateFunc(dao, call.Model)
case "delete":
dao.AfterDeleteFunc(dao, call.Model)
}
}
if txError != nil {
return txError
}
return txError
// execute after event calls on successful transaction
// (note: using the non-transaction dao to allow following queries in the after hooks)
var errs []error
for _, call := range afterCalls {
var err error
switch call.Action {
case "create":
err = dao.AfterCreateFunc(dao, call.Model)
case "update":
err = dao.AfterUpdateFunc(dao, call.Model)
case "delete":
err = dao.AfterDeleteFunc(dao, call.Model)
}
if err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
// @todo after go 1.20+ upgrade consider replacing with errors.Join()
var errsMsg strings.Builder
for _, err := range errs {
errsMsg.WriteString(err.Error())
errsMsg.WriteString("; ")
}
return fmt.Errorf("after transaction errors: %s", errsMsg.String())
}
return nil
}
return errors.New("failed to start transaction (unknown dao.NonconcurrentDB() instance)")
@@ -213,21 +234,23 @@ func (dao *Dao) Delete(m models.Model) error {
}
return dao.lockRetry(func(retryDao *Dao) error {
if retryDao.BeforeDeleteFunc != nil {
if err := retryDao.BeforeDeleteFunc(retryDao, m); err != nil {
action := func() error {
if err := retryDao.NonconcurrentDB().Model(m).Delete(); err != nil {
return err
}
if retryDao.AfterDeleteFunc != nil {
retryDao.AfterDeleteFunc(retryDao, m)
}
return nil
}
if err := retryDao.NonconcurrentDB().Model(m).Delete(); err != nil {
return err
if retryDao.BeforeDeleteFunc != nil {
return retryDao.BeforeDeleteFunc(retryDao, m, action)
}
if retryDao.AfterDeleteFunc != nil {
retryDao.AfterDeleteFunc(retryDao, m)
}
return nil
return action()
})
}
@@ -258,35 +281,35 @@ func (dao *Dao) update(m models.Model) error {
m.RefreshUpdated()
action := func() error {
if v, ok := any(m).(models.ColumnValueMapper); ok {
dataMap := v.ColumnValueMap()
_, err := dao.NonconcurrentDB().Update(
m.TableName(),
dataMap,
dbx.HashExp{"id": m.GetId()},
).Execute()
if err != nil {
return err
}
} else if err := dao.NonconcurrentDB().Model(m).Update(); err != nil {
return err
}
if dao.AfterUpdateFunc != nil {
return dao.AfterUpdateFunc(dao, m)
}
return nil
}
if dao.BeforeUpdateFunc != nil {
if err := dao.BeforeUpdateFunc(dao, m); err != nil {
return err
}
return dao.BeforeUpdateFunc(dao, m, action)
}
if v, ok := any(m).(models.ColumnValueMapper); ok {
dataMap := v.ColumnValueMap()
_, err := dao.NonconcurrentDB().Update(
m.TableName(),
dataMap,
dbx.HashExp{"id": m.GetId()},
).Execute()
if err != nil {
return err
}
} else {
if err := dao.NonconcurrentDB().Model(m).Update(); err != nil {
return err
}
}
if dao.AfterUpdateFunc != nil {
dao.AfterUpdateFunc(dao, m)
}
return nil
return action()
}
func (dao *Dao) create(m models.Model) error {
@@ -306,36 +329,36 @@ func (dao *Dao) create(m models.Model) error {
m.RefreshUpdated()
}
action := func() error {
if v, ok := any(m).(models.ColumnValueMapper); ok {
dataMap := v.ColumnValueMap()
if _, ok := dataMap["id"]; !ok {
dataMap["id"] = m.GetId()
}
_, err := dao.NonconcurrentDB().Insert(m.TableName(), dataMap).Execute()
if err != nil {
return err
}
} else if err := dao.NonconcurrentDB().Model(m).Insert(); err != nil {
return err
}
// clears the "new" model flag
m.MarkAsNotNew()
if dao.AfterCreateFunc != nil {
return dao.AfterCreateFunc(dao, m)
}
return nil
}
if dao.BeforeCreateFunc != nil {
if err := dao.BeforeCreateFunc(dao, m); err != nil {
return err
}
return dao.BeforeCreateFunc(dao, m, action)
}
if v, ok := any(m).(models.ColumnValueMapper); ok {
dataMap := v.ColumnValueMap()
if _, ok := dataMap["id"]; !ok {
dataMap["id"] = m.GetId()
}
_, err := dao.NonconcurrentDB().Insert(m.TableName(), dataMap).Execute()
if err != nil {
return err
}
} else {
if err := dao.NonconcurrentDB().Model(m).Insert(); err != nil {
return err
}
}
// clears the "new" model flag
m.MarkAsNotNew()
if dao.AfterCreateFunc != nil {
dao.AfterCreateFunc(dao, m)
}
return nil
return action()
}
func (dao *Dao) lockRetry(op func(retryDao *Dao) error) error {