refactor(drizzle): replace query chaining with dynamic query building (#11923)

This replaces usage of our `chainMethods` helper to dynamically chain
queries with [drizzle dynamic query
building](https://orm.drizzle.team/docs/dynamic-query-building).

This is more type-safe, more readable and requires less code
This commit is contained in:
Alessio Gravili
2025-03-31 14:37:45 -06:00
committed by GitHub
parent 9a1c3cf4cc
commit 9c88af4b20
10 changed files with 93 additions and 143 deletions

View File

@@ -1,6 +1,5 @@
import type { ChainedMethods } from '@payloadcms/drizzle/types' import type { SQLiteSelect } from 'drizzle-orm/sqlite-core'
import { chainMethods } from '@payloadcms/drizzle'
import { count, sql } from 'drizzle-orm' import { count, sql } from 'drizzle-orm'
import type { CountDistinct, SQLiteAdapter } from './types.js' import type { CountDistinct, SQLiteAdapter } from './types.js'
@@ -20,30 +19,25 @@ export const countDistinct: CountDistinct = async function countDistinct(
return Number(countResult[0]?.count) return Number(countResult[0]?.count)
} }
const chainedMethods: ChainedMethods = [] let query: SQLiteSelect = db
joins.forEach(({ condition, table }) => {
chainedMethods.push({
args: [table, condition],
method: 'leftJoin',
})
})
// When we have any joins, we need to count each individual ID only once.
// COUNT(*) doesn't work for this well in this case, as it also counts joined tables.
// SELECT (COUNT DISTINCT id) has a very slow performance on large tables.
// Instead, COUNT (GROUP BY id) can be used which is still slower than COUNT(*) but acceptable.
const countResult = await chainMethods({
methods: chainedMethods,
query: db
.select({ .select({
count: sql`COUNT(1) OVER()`, count: sql`COUNT(1) OVER()`,
}) })
.from(this.tables[tableName]) .from(this.tables[tableName])
.where(where) .where(where)
.groupBy(this.tables[tableName].id) .groupBy(this.tables[tableName].id)
.limit(1), .limit(1)
.$dynamic()
joins.forEach(({ condition, table }) => {
query = query.leftJoin(table, condition)
}) })
// When we have any joins, we need to count each individual ID only once.
// COUNT(*) doesn't work for this well in this case, as it also counts joined tables.
// SELECT (COUNT DISTINCT id) has a very slow performance on large tables.
// Instead, COUNT (GROUP BY id) can be used which is still slower than COUNT(*) but acceptable.
const countResult = await query
return Number(countResult[0]?.count) return Number(countResult[0]?.count)
} }

View File

@@ -13,7 +13,7 @@ import { getTransaction } from './utilities/getTransaction.js'
export const deleteOne: DeleteOne = async function deleteOne( export const deleteOne: DeleteOne = async function deleteOne(
this: DrizzleAdapter, this: DrizzleAdapter,
{ collection: collectionSlug, req, select, where: whereArg, returning }, { collection: collectionSlug, req, returning, select, where: whereArg },
) { ) {
const db = await getTransaction(this, req) const db = await getTransaction(this, req)
const collection = this.payload.collections[collectionSlug].config const collection = this.payload.collections[collectionSlug].config
@@ -32,9 +32,9 @@ export const deleteOne: DeleteOne = async function deleteOne(
const selectDistinctResult = await selectDistinct({ const selectDistinctResult = await selectDistinct({
adapter: this, adapter: this,
chainedMethods: [{ args: [1], method: 'limit' }],
db, db,
joins, joins,
query: ({ query }) => query.limit(1),
selectFields, selectFields,
tableName, tableName,
where, where,

View File

@@ -1,3 +1,6 @@
/**
* @deprecated - will be removed in 4.0. Use query + $dynamic() instead: https://orm.drizzle.team/docs/dynamic-query-building
*/
export type ChainedMethods = { export type ChainedMethods = {
args: unknown[] args: unknown[]
method: string method: string
@@ -7,6 +10,8 @@ export type ChainedMethods = {
* Call and returning methods that would normally be chained together but cannot be because of control logic * Call and returning methods that would normally be chained together but cannot be because of control logic
* @param methods * @param methods
* @param query * @param query
*
* @deprecated - will be removed in 4.0. Use query + $dynamic() instead: https://orm.drizzle.team/docs/dynamic-query-building
*/ */
const chainMethods = <T>({ methods, query }: { methods: ChainedMethods; query: T }): T => { const chainMethods = <T>({ methods, query }: { methods: ChainedMethods; query: T }): T => {
return methods.reduce((query, { args, method }) => { return methods.reduce((query, { args, method }) => {

View File

@@ -3,7 +3,6 @@ import type { FindArgs, FlattenedField, TypeWithID } from 'payload'
import { inArray } from 'drizzle-orm' import { inArray } from 'drizzle-orm'
import type { DrizzleAdapter } from '../types.js' import type { DrizzleAdapter } from '../types.js'
import type { ChainedMethods } from './chainMethods.js'
import buildQuery from '../queries/buildQuery.js' import buildQuery from '../queries/buildQuery.js'
import { selectDistinct } from '../queries/selectDistinct.js' import { selectDistinct } from '../queries/selectDistinct.js'
@@ -62,15 +61,6 @@ export const findMany = async function find({
const orderedIDMap: Record<number | string, number> = {} const orderedIDMap: Record<number | string, number> = {}
let orderedIDs: (number | string)[] let orderedIDs: (number | string)[]
const selectDistinctMethods: ChainedMethods = []
if (orderBy) {
selectDistinctMethods.push({
args: [() => orderBy.map(({ column, order }) => order(column))],
method: 'orderBy',
})
}
const findManyArgs = buildFindManyArgs({ const findManyArgs = buildFindManyArgs({
adapter, adapter,
collectionSlug, collectionSlug,
@@ -84,15 +74,16 @@ export const findMany = async function find({
tableName, tableName,
versions, versions,
}) })
selectDistinctMethods.push({ args: [offset], method: 'offset' })
selectDistinctMethods.push({ args: [limit], method: 'limit' })
const selectDistinctResult = await selectDistinct({ const selectDistinctResult = await selectDistinct({
adapter, adapter,
chainedMethods: selectDistinctMethods,
db, db,
joins, joins,
query: ({ query }) => {
if (orderBy) {
query = query.orderBy(() => orderBy.map(({ column, order }) => order(column)))
}
return query.offset(offset).limit(limit)
},
selectFields, selectFields,
tableName, tableName,
where, where,

View File

@@ -1,5 +1,5 @@
import type { LibSQLDatabase } from 'drizzle-orm/libsql' import type { LibSQLDatabase } from 'drizzle-orm/libsql'
import type { SQLiteSelectBase } from 'drizzle-orm/sqlite-core' import type { SQLiteSelect, SQLiteSelectBase } from 'drizzle-orm/sqlite-core'
import { and, asc, count, desc, eq, or, sql } from 'drizzle-orm' import { and, asc, count, desc, eq, or, sql } from 'drizzle-orm'
import { import {
@@ -16,7 +16,7 @@ import {
import { fieldIsVirtual, fieldShouldBeLocalized } from 'payload/shared' import { fieldIsVirtual, fieldShouldBeLocalized } from 'payload/shared'
import toSnakeCase from 'to-snake-case' import toSnakeCase from 'to-snake-case'
import type { BuildQueryJoinAliases, ChainedMethods, DrizzleAdapter } from '../types.js' import type { BuildQueryJoinAliases, DrizzleAdapter } from '../types.js'
import type { Result } from './buildFindManyArgs.js' import type { Result } from './buildFindManyArgs.js'
import buildQuery from '../queries/buildQuery.js' import buildQuery from '../queries/buildQuery.js'
@@ -25,7 +25,6 @@ import { operatorMap } from '../queries/operatorMap.js'
import { getNameFromDrizzleTable } from '../utilities/getNameFromDrizzleTable.js' import { getNameFromDrizzleTable } from '../utilities/getNameFromDrizzleTable.js'
import { jsonAggBuildObject } from '../utilities/json.js' import { jsonAggBuildObject } from '../utilities/json.js'
import { rawConstraint } from '../utilities/rawConstraint.js' import { rawConstraint } from '../utilities/rawConstraint.js'
import { chainMethods } from './chainMethods.js'
const flattenAllWherePaths = (where: Where, paths: string[]) => { const flattenAllWherePaths = (where: Where, paths: string[]) => {
for (const k in where) { for (const k in where) {
@@ -612,34 +611,6 @@ export const traverseFields = ({
where: joinQueryWhere, where: joinQueryWhere,
}) })
const chainedMethods: ChainedMethods = []
joins.forEach(({ type, condition, table }) => {
chainedMethods.push({
args: [table, condition],
method: type ?? 'leftJoin',
})
})
if (page && limit !== 0) {
const offset = (page - 1) * limit - 1
if (offset > 0) {
chainedMethods.push({
args: [offset],
method: 'offset',
})
}
}
if (limit !== 0) {
chainedMethods.push({
args: [limit],
method: 'limit',
})
}
const db = adapter.drizzle as LibSQLDatabase
for (let key in selectFields) { for (let key in selectFields) {
const val = selectFields[key] const val = selectFields[key]
@@ -654,14 +625,29 @@ export const traverseFields = ({
selectFields.parent = newAliasTable.parent selectFields.parent = newAliasTable.parent
} }
const subQuery = chainMethods({ let query: SQLiteSelect = db
methods: chainedMethods,
query: db
.select(selectFields as any) .select(selectFields as any)
.from(newAliasTable) .from(newAliasTable)
.where(subQueryWhere) .where(subQueryWhere)
.orderBy(() => orderBy.map(({ column, order }) => order(column))), .orderBy(() => orderBy.map(({ column, order }) => order(column)))
}).as(subQueryAlias) .$dynamic()
joins.forEach(({ type, condition, table }) => {
query = query[type ?? 'leftJoin'](table, condition)
})
if (page && limit !== 0) {
const offset = (page - 1) * limit - 1
if (offset > 0) {
query = query.offset(offset)
}
}
if (limit !== 0) {
query = query.limit(limit)
}
const subQuery = query.as(subQueryAlias)
if (shouldCount) { if (shouldCount) {
currentArgs.extras[`${columnName}_count`] = sql`${db currentArgs.extras[`${columnName}_count`] = sql`${db

View File

@@ -1,10 +1,9 @@
import type { PgTableWithColumns } from 'drizzle-orm/pg-core'
import { count, sql } from 'drizzle-orm' import { count, sql } from 'drizzle-orm'
import type { ChainedMethods } from '../types.js'
import type { BasePostgresAdapter, CountDistinct } from './types.js' import type { BasePostgresAdapter, CountDistinct } from './types.js'
import { chainMethods } from '../find/chainMethods.js'
export const countDistinct: CountDistinct = async function countDistinct( export const countDistinct: CountDistinct = async function countDistinct(
this: BasePostgresAdapter, this: BasePostgresAdapter,
{ db, joins, tableName, where }, { db, joins, tableName, where },
@@ -20,30 +19,25 @@ export const countDistinct: CountDistinct = async function countDistinct(
return Number(countResult[0].count) return Number(countResult[0].count)
} }
const chainedMethods: ChainedMethods = [] let query = db
joins.forEach(({ condition, table }) => {
chainedMethods.push({
args: [table, condition],
method: 'leftJoin',
})
})
// When we have any joins, we need to count each individual ID only once.
// COUNT(*) doesn't work for this well in this case, as it also counts joined tables.
// SELECT (COUNT DISTINCT id) has a very slow performance on large tables.
// Instead, COUNT (GROUP BY id) can be used which is still slower than COUNT(*) but acceptable.
const countResult = await chainMethods({
methods: chainedMethods,
query: db
.select({ .select({
count: sql`COUNT(1) OVER()`, count: sql`COUNT(1) OVER()`,
}) })
.from(this.tables[tableName]) .from(this.tables[tableName])
.where(where) .where(where)
.groupBy(this.tables[tableName].id) .groupBy(this.tables[tableName].id)
.limit(1), .limit(1)
.$dynamic()
joins.forEach(({ condition, table }) => {
query = query.leftJoin(table as PgTableWithColumns<any>, condition)
}) })
// When we have any joins, we need to count each individual ID only once.
// COUNT(*) doesn't work for this well in this case, as it also counts joined tables.
// SELECT (COUNT DISTINCT id) has a very slow performance on large tables.
// Instead, COUNT (GROUP BY id) can be used which is still slower than COUNT(*) but acceptable.
const countResult = await query
return Number(countResult[0].count) return Number(countResult[0].count)
} }

View File

@@ -1,7 +1,7 @@
import type { QueryPromise, SQL } from 'drizzle-orm' import type { QueryPromise, SQL } from 'drizzle-orm'
import type { SQLiteColumn } from 'drizzle-orm/sqlite-core' import type { PgSelect } from 'drizzle-orm/pg-core'
import type { SQLiteColumn, SQLiteSelect } from 'drizzle-orm/sqlite-core'
import type { ChainedMethods } from '../find/chainMethods.js'
import type { import type {
DrizzleAdapter, DrizzleAdapter,
DrizzleTransaction, DrizzleTransaction,
@@ -12,13 +12,11 @@ import type {
} from '../types.js' } from '../types.js'
import type { BuildQueryJoinAliases } from './buildQuery.js' import type { BuildQueryJoinAliases } from './buildQuery.js'
import { chainMethods } from '../find/chainMethods.js'
type Args = { type Args = {
adapter: DrizzleAdapter adapter: DrizzleAdapter
chainedMethods?: ChainedMethods
db: DrizzleAdapter['drizzle'] | DrizzleTransaction db: DrizzleAdapter['drizzle'] | DrizzleTransaction
joins: BuildQueryJoinAliases joins: BuildQueryJoinAliases
query?: (args: { query: SQLiteSelect }) => SQLiteSelect
selectFields: Record<string, GenericColumn> selectFields: Record<string, GenericColumn>
tableName: string tableName: string
where: SQL where: SQL
@@ -29,42 +27,40 @@ type Args = {
*/ */
export const selectDistinct = ({ export const selectDistinct = ({
adapter, adapter,
chainedMethods = [],
db, db,
joins, joins,
query: queryModifier = ({ query }) => query,
selectFields, selectFields,
tableName, tableName,
where, where,
}: Args): QueryPromise<{ id: number | string }[] & Record<string, GenericColumn>> => { }: Args): QueryPromise<{ id: number | string }[] & Record<string, GenericColumn>> => {
if (Object.keys(joins).length > 0) { if (Object.keys(joins).length > 0) {
if (where) { let query: SQLiteSelect
chainedMethods.push({ args: [where], method: 'where' })
}
joins.forEach(({ condition, table }) => {
chainedMethods.push({
args: [table, condition],
method: 'leftJoin',
})
})
let query
const table = adapter.tables[tableName] const table = adapter.tables[tableName]
if (adapter.name === 'postgres') { if (adapter.name === 'postgres') {
query = (db as TransactionPg) query = (db as TransactionPg)
.selectDistinct(selectFields as Record<string, GenericPgColumn>) .selectDistinct(selectFields as Record<string, GenericPgColumn>)
.from(table) .from(table)
.$dynamic() as unknown as SQLiteSelect
} }
if (adapter.name === 'sqlite') { if (adapter.name === 'sqlite') {
query = (db as TransactionSQLite) query = (db as TransactionSQLite)
.selectDistinct(selectFields as Record<string, SQLiteColumn>) .selectDistinct(selectFields as Record<string, SQLiteColumn>)
.from(table) .from(table)
.$dynamic()
} }
return chainMethods({ if (where) {
methods: chainedMethods, query = query.where(where)
query, }
joins.forEach(({ condition, table }) => {
query = query.leftJoin(table, condition)
}) })
return queryModifier({
query,
}) as unknown as QueryPromise<{ id: number | string }[] & Record<string, GenericColumn>>
} }
} }

View File

@@ -37,11 +37,8 @@ import type { DrizzleSnapshotJSON } from 'drizzle-kit/api'
import type { SQLiteRaw } from 'drizzle-orm/sqlite-core/query-builders/raw' import type { SQLiteRaw } from 'drizzle-orm/sqlite-core/query-builders/raw'
import type { QueryResult } from 'pg' import type { QueryResult } from 'pg'
import type { ChainedMethods } from './find/chainMethods.js'
import type { Operators } from './queries/operatorMap.js' import type { Operators } from './queries/operatorMap.js'
export { ChainedMethods }
export type PostgresDB = NodePgDatabase<Record<string, unknown>> export type PostgresDB = NodePgDatabase<Record<string, unknown>>
export type SQLiteDB = LibSQLDatabase< export type SQLiteDB = LibSQLDatabase<
@@ -377,3 +374,8 @@ export type RelationMap = Map<
type: 'many' | 'one' type: 'many' | 'one'
} }
> >
/**
* @deprecated - will be removed in 4.0. Use query + $dynamic() instead: https://orm.drizzle.team/docs/dynamic-query-building
*/
export type { ChainedMethods } from './find/chainMethods.js'

View File

@@ -3,9 +3,8 @@ import type { UpdateMany } from 'payload'
import toSnakeCase from 'to-snake-case' import toSnakeCase from 'to-snake-case'
import type { ChainedMethods, DrizzleAdapter } from './types.js' import type { DrizzleAdapter } from './types.js'
import { chainMethods } from './find/chainMethods.js'
import buildQuery from './queries/buildQuery.js' import buildQuery from './queries/buildQuery.js'
import { selectDistinct } from './queries/selectDistinct.js' import { selectDistinct } from './queries/selectDistinct.js'
import { upsertRow } from './upsertRow/index.js' import { upsertRow } from './upsertRow/index.js'
@@ -45,16 +44,10 @@ export const updateMany: UpdateMany = async function updateMany(
const selectDistinctResult = await selectDistinct({ const selectDistinctResult = await selectDistinct({
adapter: this, adapter: this,
chainedMethods: orderBy
? [
{
args: [() => orderBy.map(({ column, order }) => order(column))],
method: 'orderBy',
},
]
: [],
db, db,
joins, joins,
query: ({ query }) =>
orderBy ? query.orderBy(() => orderBy.map(({ column, order }) => order(column))) : query,
selectFields, selectFields,
tableName, tableName,
where, where,
@@ -69,28 +62,17 @@ export const updateMany: UpdateMany = async function updateMany(
const table = this.tables[tableName] const table = this.tables[tableName]
const query = _db.select({ id: table.id }).from(table).where(where) let query = _db.select({ id: table.id }).from(table).where(where).$dynamic()
const chainedMethods: ChainedMethods = []
if (typeof limit === 'number' && limit > 0) { if (typeof limit === 'number' && limit > 0) {
chainedMethods.push({ query = query.limit(limit)
args: [limit],
method: 'limit',
})
} }
if (orderBy) { if (orderBy) {
chainedMethods.push({ query = query.orderBy(() => orderBy.map(({ column, order }) => order(column)))
args: [() => orderBy.map(({ column, order }) => order(column))],
method: 'orderBy',
})
} }
const docsToUpdate = await chainMethods({ const docsToUpdate = await query
methods: chainedMethods,
query,
})
idsToUpdate = docsToUpdate?.map((doc) => doc.id) idsToUpdate = docsToUpdate?.map((doc) => doc.id)
} }

View File

@@ -41,9 +41,9 @@ export const updateOne: UpdateOne = async function updateOne(
// selectDistinct will only return if there are joins // selectDistinct will only return if there are joins
const selectDistinctResult = await selectDistinct({ const selectDistinctResult = await selectDistinct({
adapter: this, adapter: this,
chainedMethods: [{ args: [1], method: 'limit' }],
db, db,
joins, joins,
query: ({ query }) => query.limit(1),
selectFields, selectFields,
tableName, tableName,
where, where,