feat: postgres transactions

This commit is contained in:
Dan Ribbens
2023-09-14 15:39:58 -04:00
parent a838274ae7
commit 8df2c888fe
15 changed files with 77 additions and 60 deletions

View File

@@ -9,11 +9,13 @@ export const create: Create = async function create({
data,
req,
}) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const collection = this.payload.collections[collectionSlug].config;
const result = await upsertRow({
adapter: this,
data,
db,
fields: collection.fields,
operation: 'create',
tableName: toSnakeCase(collectionSlug),

View File

@@ -11,11 +11,13 @@ export const createGlobal: CreateGlobal = async function createGlobal(
this: PostgresAdapter,
{ data, req = {} as PayloadRequest, slug },
) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const globalConfig = this.payload.globals.config.find((config) => config.slug === slug);
const result = await upsertRow({
adapter: this,
data,
db,
fields: globalConfig.fields,
operation: 'create',
tableName: toSnakeCase(slug),

View File

@@ -47,10 +47,11 @@ export const find: Find = async function find(
tableName,
where: whereArg,
});
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const orderedIDMap: Record<number | string, number> = {};
const selectQuery = this.db.selectDistinct(selectFields)
const selectQuery = db.selectDistinct(selectFields)
.from(table);
if (orderBy?.order && orderBy?.column) {
selectQuery.orderBy(orderBy.order(orderBy.column));
@@ -114,10 +115,10 @@ export const find: Find = async function find(
}
}
const findPromise = this.db.query[tableName].findMany(findManyArgs);
const findPromise = db.query[tableName].findMany(findManyArgs);
if (pagination !== false || selectDistinctResult?.length > limit) {
const selectCount = this.db.select({ count: sql<number>`count(*)` })
const selectCount = db.select({ count: sql<number>`count(*)` })
.from(table)
.where(where);
Object.entries(joins)

View File

@@ -10,8 +10,9 @@ import { transform } from './transform/read';
export const findGlobal: FindGlobal = async function findGlobal(
this: PostgresAdapter,
{ locale, slug, where },
{ locale, req, slug, where },
) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const globalConfig = this.payload.globals.config.find((config) => config.slug === slug);
const tableName = toSnakeCase(slug);
@@ -32,7 +33,7 @@ export const findGlobal: FindGlobal = async function findGlobal(
findManyArgs.where = query;
const doc = await this.db.query[tableName].findFirst(findManyArgs);
const doc = await db.query[tableName].findFirst(findManyArgs);
if (doc) {
const result = transform({

View File

@@ -13,6 +13,7 @@ export const findOne: FindOne = async function findOne({
req = {} as PayloadRequest,
where: incomingWhere,
}) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const collectionConfig: SanitizedCollectionConfig = this.payload.collections[collection].config;
const tableName = toSnakeCase(collection);
@@ -33,7 +34,7 @@ export const findOne: FindOne = async function findOne({
findManyArgs.where = where;
const doc = await this.db.query[tableName].findFirst(findManyArgs);
const doc = await db.query[tableName].findFirst(findManyArgs);
return transform({
config: this.payload.config,

View File

@@ -1,29 +1,31 @@
import type { Payload } from 'payload';
// import { findGlobalVersions } from './findGlobalVersions';
import { createDatabaseAdapter } from 'payload/database';
import type { Args, PostgresAdapter, PostgresAdapterResult } from './types';
import { connect } from './connect';
// import { findVersions } from './findVersions';
import { create } from './create';
import { createGlobal } from './createGlobal';
import { createMigration } from './createMigration';
import { createVersion } from './createVersion';
// import { updateVersion } from './updateVersion';
import { deleteMany } from './deleteMany';
import { deleteOne } from './deleteOne';
import { find } from './find';
// import { deleteVersions } from './deleteVersions';
import { findGlobal } from './findGlobal';
import { findOne } from './findOne';
import { init } from './init';
import { queryDrafts } from './queryDrafts';
import { beginTransaction } from './transactions/beginTransaction';
import { commitTransaction } from './transactions/commitTransaction';
import { rollbackTransaction } from './transactions/rollbackTransaction';
import { webpack } from './webpack';
// import { findGlobalVersions } from './findGlobalVersions';
// import { findVersions } from './findVersions';
import { create } from './create';
// import { updateVersion } from './updateVersion';
import { deleteMany } from './deleteMany';
import { deleteOne } from './deleteOne';
// import { deleteVersions } from './deleteVersions';
import { findGlobal } from './findGlobal';
import { findOne } from './findOne';
import { updateOne } from './update';
import { updateGlobal } from './updateGlobal';
import { webpack } from './webpack';
// import { destroy } from './destroy';

View File

@@ -17,8 +17,9 @@ export const queryDrafts: QueryDrafts = async function queryDrafts({
pagination,
req = {} as PayloadRequest,
sort: sortArg,
where,
where: whereArg,
}) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const collectionConfig: SanitizedCollectionConfig = this.payload.collections[collection].config;
const tableName = toSnakeCase(collection);
const versionsTableName = `_${tableName}_versions`;
@@ -33,17 +34,17 @@ export const queryDrafts: QueryDrafts = async function queryDrafts({
let hasNextPage;
let pagingCounter;
const query = await buildQuery({
const { where } = await buildQuery({
adapter: this,
fields: buildVersionCollectionFields(collectionConfig),
locale,
sort,
tableName: versionsTableName,
where
where: whereArg
});
if (pagination !== false) {
const countResult = await this.db.select({ count: sql<number>`count(*)` }).from(table).where(query);
const countResult = await db.select({ count: sql<number>`count(*)` }).from(table).where(where);
totalDocs = Number(countResult[0].count);
totalPages = Math.ceil(totalDocs / limit);
hasPrevPage = page > 1;
@@ -60,9 +61,9 @@ export const queryDrafts: QueryDrafts = async function queryDrafts({
findManyArgs.limit = limit === 0 ? undefined : limit;
findManyArgs.offset = (page - 1) * limit;
findManyArgs.where = query;
findManyArgs.where = where;
const rawDocs = await this.db.query[tableName].findMany(findManyArgs);
const rawDocs = await db.query[tableName].findMany(findManyArgs);
const docs = rawDocs.map((data) => {
return transform({

View File

@@ -29,7 +29,7 @@ export const beginTransaction: BeginTransaction = async function beginTransactio
this.sessions[id] = db;
await db.execute(sql`BEGIN;`);
await this.sessions[id].execute(sql`BEGIN;`);
} catch (err) {
this.payload.logger.error(
`Error: cannot begin transaction: ${err.message}`,

View File

@@ -12,30 +12,30 @@ export const updateOne: UpdateOne = async function updateOne({
id,
locale,
req,
where,
where: whereArg,
}) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const collection = this.payload.collections[collectionSlug].config;
const tableName = toSnakeCase(collection);
const whereToUse = whereArg || { id: { equals: id } };
let query: Result;
if (where) {
query = await buildQuery({
adapter: this,
collectionSlug,
locale,
where,
});
}
const { where } = await buildQuery({
adapter: this,
fields: collection.fields,
locale,
tableName,
where: whereToUse
});
const result = await upsertRow({
adapter: this,
data,
db,
fields: collection.fields,
id,
locale: req.locale,
operation: 'update',
tableName: toSnakeCase(collectionSlug),
where: query,
where,
});
return result;

View File

@@ -11,6 +11,7 @@ export const updateGlobal: UpdateGlobal = async function updateGlobal(
this: PostgresAdapter,
{ data, req = {} as PayloadRequest, slug },
) {
const db = req.transactionID ? this.sessions[req.transactionID] : this.db;
const globalConfig = this.payload.globals.config.find((config) => config.slug === slug);
const tableName = toSnakeCase(slug);
@@ -19,6 +20,7 @@ export const updateGlobal: UpdateGlobal = async function updateGlobal(
const result = await upsertRow({
adapter: this,
data,
db,
fields: globalConfig.fields,
id: existingGlobal.id,
operation: 'update',

View File

@@ -1,17 +1,20 @@
/* eslint-disable no-param-reassign */
import { eq } from 'drizzle-orm';
import { transform } from '../transform/read';
import { BlockRowToInsert } from '../transform/write/types';
import { insertArrays } from './insertArrays';
import { transformForWrite } from '../transform/write';
import { Args } from './types';
import { deleteExistingRowsByPath } from './deleteExistingRowsByPath';
import { deleteExistingArrayRows } from './deleteExistingArrayRows';
import type { BlockRowToInsert } from '../transform/write/types';
import type { Args } from './types';
import { buildFindManyArgs } from '../find/buildFindManyArgs';
import { transform } from '../transform/read';
import { transformForWrite } from '../transform/write';
import { deleteExistingArrayRows } from './deleteExistingArrayRows';
import { deleteExistingRowsByPath } from './deleteExistingRowsByPath';
import { insertArrays } from './insertArrays';
export const upsertRow = async ({
adapter,
data,
db,
fields,
id,
operation,
@@ -37,18 +40,18 @@ export const upsertRow = async ({
if (id) {
rowToInsert.row.id = id;
[insertedRow] = await adapter.db.insert(adapter.tables[tableName])
[insertedRow] = await db.insert(adapter.tables[tableName])
.values(rowToInsert.row)
.onConflictDoUpdate({ target, set: rowToInsert.row })
.onConflictDoUpdate({ set: rowToInsert.row, target })
.returning();
} else {
[insertedRow] = await adapter.db.insert(adapter.tables[tableName])
[insertedRow] = await db.insert(adapter.tables[tableName])
.values(rowToInsert.row)
.onConflictDoUpdate({ target, set: rowToInsert.row, where })
.onConflictDoUpdate({ set: rowToInsert.row, target, where })
.returning();
}
} else {
[insertedRow] = await adapter.db.insert(adapter.tables[tableName])
[insertedRow] = await db.insert(adapter.tables[tableName])
.values(rowToInsert.row).returning();
}
@@ -96,10 +99,10 @@ export const upsertRow = async ({
promises.push(async () => {
if (operation === 'update') {
await adapter.db.delete(localeTable).where(eq(localeTable._parentID, insertedRow.id));
await db.delete(localeTable).where(eq(localeTable._parentID, insertedRow.id));
}
await adapter.db.insert(localeTable).values(localesToInsert);
await db.insert(localeTable).values(localesToInsert);
});
}
@@ -114,15 +117,15 @@ export const upsertRow = async ({
await deleteExistingRowsByPath({
adapter,
localeColumnName: 'locale',
newRows: relationsToInsert,
parentColumnName: 'parent',
parentID: insertedRow.id,
pathColumnName: 'path',
newRows: relationsToInsert,
tableName: relationshipsTableName,
});
}
await adapter.db.insert(adapter.tables[relationshipsTableName])
await db.insert(adapter.tables[relationshipsTableName])
.values(relationsToInsert).returning();
});
}
@@ -139,14 +142,14 @@ export const upsertRow = async ({
if (operation === 'update') {
await deleteExistingRowsByPath({
adapter,
newRows: blockRows.map(({ row }) => row),
parentID: insertedRow.id,
pathColumnName: '_path',
newRows: blockRows.map(({ row }) => row),
tableName: `${tableName}_${blockName}`,
});
}
insertedBlockRows[blockName] = await adapter.db.insert(adapter.tables[`${tableName}_${blockName}`])
insertedBlockRows[blockName] = await db.insert(adapter.tables[`${tableName}_${blockName}`])
.values(blockRows.map(({ row }) => row)).returning();
insertedBlockRows[blockName].forEach((row, i) => {
@@ -171,7 +174,7 @@ export const upsertRow = async ({
}, []);
if (blockLocaleRowsToInsert.length > 0) {
await adapter.db.insert(adapter.tables[`${tableName}_${blockName}_locales`])
await db.insert(adapter.tables[`${tableName}_${blockName}_locales`])
.values(blockLocaleRowsToInsert).returning();
}
@@ -220,7 +223,7 @@ export const upsertRow = async ({
findManyArgs.where = eq(adapter.tables[tableName].id, insertedRow.id);
const doc = await adapter.db.query[tableName].findFirst(findManyArgs);
const doc = await db.query[tableName].findFirst(findManyArgs);
// //////////////////////////////////
// TRANSFORM DATA

View File

@@ -2,10 +2,12 @@ import type { SQL } from 'drizzle-orm';
import type { Field } from 'payload/types';
import type { GenericColumn, PostgresAdapter } from '../types';
import type { DrizzleDB } from '../types';
type BaseArgs = {
adapter: PostgresAdapter
data: Record<string, unknown>
db: DrizzleDB
fields: Field[]
path?: string
tableName: string