From 8df2c888fea8c10a062e5de4333a7d9ca04758bf Mon Sep 17 00:00:00 2001 From: Dan Ribbens Date: Thu, 14 Sep 2023 15:39:58 -0400 Subject: [PATCH] feat: postgres transactions --- .../src/queries/sanitizeQueryValue.ts | 6 +-- packages/db-postgres/src/create/index.ts | 2 + packages/db-postgres/src/createGlobal.ts | 2 + packages/db-postgres/src/find.ts | 7 +-- packages/db-postgres/src/findGlobal.ts | 5 ++- packages/db-postgres/src/findOne.ts | 3 +- packages/db-postgres/src/index.ts | 22 +++++----- packages/db-postgres/src/queryDrafts.ts | 13 +++--- .../src/transactions/beginTransaction.ts | 2 +- packages/db-postgres/src/update/index.ts | 26 +++++------ packages/db-postgres/src/updateGlobal.ts | 2 + packages/db-postgres/src/upsertRow/index.ts | 43 ++++++++++--------- packages/db-postgres/src/upsertRow/types.ts | 2 + ...ateAdapter.ts => createDatabaseAdapter.ts} | 0 packages/payload/src/exports/database.ts | 2 +- 15 files changed, 77 insertions(+), 60 deletions(-) rename packages/payload/src/database/{createAdapter.ts => createDatabaseAdapter.ts} (100%) diff --git a/packages/db-mongodb/src/queries/sanitizeQueryValue.ts b/packages/db-mongodb/src/queries/sanitizeQueryValue.ts index 1210c4b922..fb799790bc 100644 --- a/packages/db-mongodb/src/queries/sanitizeQueryValue.ts +++ b/packages/db-mongodb/src/queries/sanitizeQueryValue.ts @@ -1,8 +1,8 @@ -import type { Field, TabAsField } from 'payload/types' +import type { Field, TabAsField } from 'payload/types'; -import mongoose from 'mongoose' +import mongoose from 'mongoose'; +import { createArrayFromCommaDelineated } from 'payload/utilities'; -import { createArrayFromCommaDelineated } from '../utilities/createArrayFromCommaDelineated' type SanitizeQueryValueArgs = { field: Field | TabAsField diff --git a/packages/db-postgres/src/create/index.ts b/packages/db-postgres/src/create/index.ts index 2b3622043b..8f75e0fab9 100644 --- a/packages/db-postgres/src/create/index.ts +++ b/packages/db-postgres/src/create/index.ts @@ -9,11 +9,13 @@ export const create: Create = async function create({ data, req, }) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const collection = this.payload.collections[collectionSlug].config; const result = await upsertRow({ adapter: this, data, + db, fields: collection.fields, operation: 'create', tableName: toSnakeCase(collectionSlug), diff --git a/packages/db-postgres/src/createGlobal.ts b/packages/db-postgres/src/createGlobal.ts index e2499ebecd..3d5b1a1824 100644 --- a/packages/db-postgres/src/createGlobal.ts +++ b/packages/db-postgres/src/createGlobal.ts @@ -11,11 +11,13 @@ export const createGlobal: CreateGlobal = async function createGlobal( this: PostgresAdapter, { data, req = {} as PayloadRequest, slug }, ) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const globalConfig = this.payload.globals.config.find((config) => config.slug === slug); const result = await upsertRow({ adapter: this, data, + db, fields: globalConfig.fields, operation: 'create', tableName: toSnakeCase(slug), diff --git a/packages/db-postgres/src/find.ts b/packages/db-postgres/src/find.ts index f2d69fadf4..35c021f499 100644 --- a/packages/db-postgres/src/find.ts +++ b/packages/db-postgres/src/find.ts @@ -47,10 +47,11 @@ export const find: Find = async function find( tableName, where: whereArg, }); + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const orderedIDMap: Record = {}; - const selectQuery = this.db.selectDistinct(selectFields) + const selectQuery = db.selectDistinct(selectFields) .from(table); if (orderBy?.order && orderBy?.column) { selectQuery.orderBy(orderBy.order(orderBy.column)); @@ -114,10 +115,10 @@ export const find: Find = async function find( } } - const findPromise = this.db.query[tableName].findMany(findManyArgs); + const findPromise = db.query[tableName].findMany(findManyArgs); if (pagination !== false || selectDistinctResult?.length > limit) { - const selectCount = this.db.select({ count: sql`count(*)` }) + const selectCount = db.select({ count: sql`count(*)` }) .from(table) .where(where); Object.entries(joins) diff --git a/packages/db-postgres/src/findGlobal.ts b/packages/db-postgres/src/findGlobal.ts index 9a120e0c95..4807cef1ce 100644 --- a/packages/db-postgres/src/findGlobal.ts +++ b/packages/db-postgres/src/findGlobal.ts @@ -10,8 +10,9 @@ import { transform } from './transform/read'; export const findGlobal: FindGlobal = async function findGlobal( this: PostgresAdapter, - { locale, slug, where }, + { locale, req, slug, where }, ) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const globalConfig = this.payload.globals.config.find((config) => config.slug === slug); const tableName = toSnakeCase(slug); @@ -32,7 +33,7 @@ export const findGlobal: FindGlobal = async function findGlobal( findManyArgs.where = query; - const doc = await this.db.query[tableName].findFirst(findManyArgs); + const doc = await db.query[tableName].findFirst(findManyArgs); if (doc) { const result = transform({ diff --git a/packages/db-postgres/src/findOne.ts b/packages/db-postgres/src/findOne.ts index 29fc96bd6e..b92ff12c2e 100644 --- a/packages/db-postgres/src/findOne.ts +++ b/packages/db-postgres/src/findOne.ts @@ -13,6 +13,7 @@ export const findOne: FindOne = async function findOne({ req = {} as PayloadRequest, where: incomingWhere, }) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const collectionConfig: SanitizedCollectionConfig = this.payload.collections[collection].config; const tableName = toSnakeCase(collection); @@ -33,7 +34,7 @@ export const findOne: FindOne = async function findOne({ findManyArgs.where = where; - const doc = await this.db.query[tableName].findFirst(findManyArgs); + const doc = await db.query[tableName].findFirst(findManyArgs); return transform({ config: this.payload.config, diff --git a/packages/db-postgres/src/index.ts b/packages/db-postgres/src/index.ts index 4e4a176d9e..1e9524d2b5 100644 --- a/packages/db-postgres/src/index.ts +++ b/packages/db-postgres/src/index.ts @@ -1,29 +1,31 @@ import type { Payload } from 'payload'; +// import { findGlobalVersions } from './findGlobalVersions'; +import { createDatabaseAdapter } from 'payload/database'; + import type { Args, PostgresAdapter, PostgresAdapterResult } from './types'; import { connect } from './connect'; +// import { findVersions } from './findVersions'; +import { create } from './create'; import { createGlobal } from './createGlobal'; import { createMigration } from './createMigration'; import { createVersion } from './createVersion'; +// import { updateVersion } from './updateVersion'; +import { deleteMany } from './deleteMany'; +import { deleteOne } from './deleteOne'; import { find } from './find'; +// import { deleteVersions } from './deleteVersions'; +import { findGlobal } from './findGlobal'; +import { findOne } from './findOne'; import { init } from './init'; import { queryDrafts } from './queryDrafts'; import { beginTransaction } from './transactions/beginTransaction'; import { commitTransaction } from './transactions/commitTransaction'; import { rollbackTransaction } from './transactions/rollbackTransaction'; -import { webpack } from './webpack'; -// import { findGlobalVersions } from './findGlobalVersions'; -// import { findVersions } from './findVersions'; -import { create } from './create'; -// import { updateVersion } from './updateVersion'; -import { deleteMany } from './deleteMany'; -import { deleteOne } from './deleteOne'; -// import { deleteVersions } from './deleteVersions'; -import { findGlobal } from './findGlobal'; -import { findOne } from './findOne'; import { updateOne } from './update'; import { updateGlobal } from './updateGlobal'; +import { webpack } from './webpack'; // import { destroy } from './destroy'; diff --git a/packages/db-postgres/src/queryDrafts.ts b/packages/db-postgres/src/queryDrafts.ts index 613955a3fa..b78025ce8c 100644 --- a/packages/db-postgres/src/queryDrafts.ts +++ b/packages/db-postgres/src/queryDrafts.ts @@ -17,8 +17,9 @@ export const queryDrafts: QueryDrafts = async function queryDrafts({ pagination, req = {} as PayloadRequest, sort: sortArg, - where, + where: whereArg, }) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const collectionConfig: SanitizedCollectionConfig = this.payload.collections[collection].config; const tableName = toSnakeCase(collection); const versionsTableName = `_${tableName}_versions`; @@ -33,17 +34,17 @@ export const queryDrafts: QueryDrafts = async function queryDrafts({ let hasNextPage; let pagingCounter; - const query = await buildQuery({ + const { where } = await buildQuery({ adapter: this, fields: buildVersionCollectionFields(collectionConfig), locale, sort, tableName: versionsTableName, - where + where: whereArg }); if (pagination !== false) { - const countResult = await this.db.select({ count: sql`count(*)` }).from(table).where(query); + const countResult = await db.select({ count: sql`count(*)` }).from(table).where(where); totalDocs = Number(countResult[0].count); totalPages = Math.ceil(totalDocs / limit); hasPrevPage = page > 1; @@ -60,9 +61,9 @@ export const queryDrafts: QueryDrafts = async function queryDrafts({ findManyArgs.limit = limit === 0 ? undefined : limit; findManyArgs.offset = (page - 1) * limit; - findManyArgs.where = query; + findManyArgs.where = where; - const rawDocs = await this.db.query[tableName].findMany(findManyArgs); + const rawDocs = await db.query[tableName].findMany(findManyArgs); const docs = rawDocs.map((data) => { return transform({ diff --git a/packages/db-postgres/src/transactions/beginTransaction.ts b/packages/db-postgres/src/transactions/beginTransaction.ts index 1808d0a980..f9cd2c3dc3 100644 --- a/packages/db-postgres/src/transactions/beginTransaction.ts +++ b/packages/db-postgres/src/transactions/beginTransaction.ts @@ -29,7 +29,7 @@ export const beginTransaction: BeginTransaction = async function beginTransactio this.sessions[id] = db; - await db.execute(sql`BEGIN;`); + await this.sessions[id].execute(sql`BEGIN;`); } catch (err) { this.payload.logger.error( `Error: cannot begin transaction: ${err.message}`, diff --git a/packages/db-postgres/src/update/index.ts b/packages/db-postgres/src/update/index.ts index 87e0eaa09c..9e9485c2ac 100644 --- a/packages/db-postgres/src/update/index.ts +++ b/packages/db-postgres/src/update/index.ts @@ -12,30 +12,30 @@ export const updateOne: UpdateOne = async function updateOne({ id, locale, req, - where, + where: whereArg, }) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const collection = this.payload.collections[collectionSlug].config; + const tableName = toSnakeCase(collection); + const whereToUse = whereArg || { id: { equals: id } }; - let query: Result; - - if (where) { - query = await buildQuery({ - adapter: this, - collectionSlug, - locale, - where, - }); - } + const { where } = await buildQuery({ + adapter: this, + fields: collection.fields, + locale, + tableName, + where: whereToUse + }); const result = await upsertRow({ adapter: this, data, + db, fields: collection.fields, id, - locale: req.locale, operation: 'update', tableName: toSnakeCase(collectionSlug), - where: query, + where, }); return result; diff --git a/packages/db-postgres/src/updateGlobal.ts b/packages/db-postgres/src/updateGlobal.ts index 00df932b32..f932036adc 100644 --- a/packages/db-postgres/src/updateGlobal.ts +++ b/packages/db-postgres/src/updateGlobal.ts @@ -11,6 +11,7 @@ export const updateGlobal: UpdateGlobal = async function updateGlobal( this: PostgresAdapter, { data, req = {} as PayloadRequest, slug }, ) { + const db = req.transactionID ? this.sessions[req.transactionID] : this.db; const globalConfig = this.payload.globals.config.find((config) => config.slug === slug); const tableName = toSnakeCase(slug); @@ -19,6 +20,7 @@ export const updateGlobal: UpdateGlobal = async function updateGlobal( const result = await upsertRow({ adapter: this, data, + db, fields: globalConfig.fields, id: existingGlobal.id, operation: 'update', diff --git a/packages/db-postgres/src/upsertRow/index.ts b/packages/db-postgres/src/upsertRow/index.ts index 74e9c1e77d..90cabbf7ca 100644 --- a/packages/db-postgres/src/upsertRow/index.ts +++ b/packages/db-postgres/src/upsertRow/index.ts @@ -1,17 +1,20 @@ /* eslint-disable no-param-reassign */ import { eq } from 'drizzle-orm'; -import { transform } from '../transform/read'; -import { BlockRowToInsert } from '../transform/write/types'; -import { insertArrays } from './insertArrays'; -import { transformForWrite } from '../transform/write'; -import { Args } from './types'; -import { deleteExistingRowsByPath } from './deleteExistingRowsByPath'; -import { deleteExistingArrayRows } from './deleteExistingArrayRows'; + +import type { BlockRowToInsert } from '../transform/write/types'; +import type { Args } from './types'; + import { buildFindManyArgs } from '../find/buildFindManyArgs'; +import { transform } from '../transform/read'; +import { transformForWrite } from '../transform/write'; +import { deleteExistingArrayRows } from './deleteExistingArrayRows'; +import { deleteExistingRowsByPath } from './deleteExistingRowsByPath'; +import { insertArrays } from './insertArrays'; export const upsertRow = async ({ adapter, data, + db, fields, id, operation, @@ -37,18 +40,18 @@ export const upsertRow = async ({ if (id) { rowToInsert.row.id = id; - [insertedRow] = await adapter.db.insert(adapter.tables[tableName]) + [insertedRow] = await db.insert(adapter.tables[tableName]) .values(rowToInsert.row) - .onConflictDoUpdate({ target, set: rowToInsert.row }) + .onConflictDoUpdate({ set: rowToInsert.row, target }) .returning(); } else { - [insertedRow] = await adapter.db.insert(adapter.tables[tableName]) + [insertedRow] = await db.insert(adapter.tables[tableName]) .values(rowToInsert.row) - .onConflictDoUpdate({ target, set: rowToInsert.row, where }) + .onConflictDoUpdate({ set: rowToInsert.row, target, where }) .returning(); } } else { - [insertedRow] = await adapter.db.insert(adapter.tables[tableName]) + [insertedRow] = await db.insert(adapter.tables[tableName]) .values(rowToInsert.row).returning(); } @@ -96,10 +99,10 @@ export const upsertRow = async ({ promises.push(async () => { if (operation === 'update') { - await adapter.db.delete(localeTable).where(eq(localeTable._parentID, insertedRow.id)); + await db.delete(localeTable).where(eq(localeTable._parentID, insertedRow.id)); } - await adapter.db.insert(localeTable).values(localesToInsert); + await db.insert(localeTable).values(localesToInsert); }); } @@ -114,15 +117,15 @@ export const upsertRow = async ({ await deleteExistingRowsByPath({ adapter, localeColumnName: 'locale', + newRows: relationsToInsert, parentColumnName: 'parent', parentID: insertedRow.id, pathColumnName: 'path', - newRows: relationsToInsert, tableName: relationshipsTableName, }); } - await adapter.db.insert(adapter.tables[relationshipsTableName]) + await db.insert(adapter.tables[relationshipsTableName]) .values(relationsToInsert).returning(); }); } @@ -139,14 +142,14 @@ export const upsertRow = async ({ if (operation === 'update') { await deleteExistingRowsByPath({ adapter, + newRows: blockRows.map(({ row }) => row), parentID: insertedRow.id, pathColumnName: '_path', - newRows: blockRows.map(({ row }) => row), tableName: `${tableName}_${blockName}`, }); } - insertedBlockRows[blockName] = await adapter.db.insert(adapter.tables[`${tableName}_${blockName}`]) + insertedBlockRows[blockName] = await db.insert(adapter.tables[`${tableName}_${blockName}`]) .values(blockRows.map(({ row }) => row)).returning(); insertedBlockRows[blockName].forEach((row, i) => { @@ -171,7 +174,7 @@ export const upsertRow = async ({ }, []); if (blockLocaleRowsToInsert.length > 0) { - await adapter.db.insert(adapter.tables[`${tableName}_${blockName}_locales`]) + await db.insert(adapter.tables[`${tableName}_${blockName}_locales`]) .values(blockLocaleRowsToInsert).returning(); } @@ -220,7 +223,7 @@ export const upsertRow = async ({ findManyArgs.where = eq(adapter.tables[tableName].id, insertedRow.id); - const doc = await adapter.db.query[tableName].findFirst(findManyArgs); + const doc = await db.query[tableName].findFirst(findManyArgs); // ////////////////////////////////// // TRANSFORM DATA diff --git a/packages/db-postgres/src/upsertRow/types.ts b/packages/db-postgres/src/upsertRow/types.ts index c8cf4548ca..86ea8fabed 100644 --- a/packages/db-postgres/src/upsertRow/types.ts +++ b/packages/db-postgres/src/upsertRow/types.ts @@ -2,10 +2,12 @@ import type { SQL } from 'drizzle-orm'; import type { Field } from 'payload/types'; import type { GenericColumn, PostgresAdapter } from '../types'; +import type { DrizzleDB } from '../types'; type BaseArgs = { adapter: PostgresAdapter data: Record + db: DrizzleDB fields: Field[] path?: string tableName: string diff --git a/packages/payload/src/database/createAdapter.ts b/packages/payload/src/database/createDatabaseAdapter.ts similarity index 100% rename from packages/payload/src/database/createAdapter.ts rename to packages/payload/src/database/createDatabaseAdapter.ts diff --git a/packages/payload/src/exports/database.ts b/packages/payload/src/exports/database.ts index 189d69b320..4ee0e8a0af 100644 --- a/packages/payload/src/exports/database.ts +++ b/packages/payload/src/exports/database.ts @@ -48,7 +48,7 @@ export * from '../database/queryValidation/types' export { combineQueries } from '../database/combineQueries' -export { createDatabaseAdapter } from '../database/createAdapter' +export { createDatabaseAdapter } from '../database/createDatabaseAdapter' export { default as flattenWhereToOperators } from '../database/flattenWhereToOperators'