diff --git a/packages/drizzle/src/postgres/columnToCodeConverter.ts b/packages/drizzle/src/postgres/columnToCodeConverter.ts index 571dfa714..2519c9ae5 100644 --- a/packages/drizzle/src/postgres/columnToCodeConverter.ts +++ b/packages/drizzle/src/postgres/columnToCodeConverter.ts @@ -35,6 +35,12 @@ export const columnToCodeConverter: ColumnToCodeConverter = ({ } } + if (column.type === 'vector') { + if (column.dimensions) { + columnBuilderArgsArray.push(`dimensions: ${column.dimensions}`) + } + } + let columnBuilderArgs = '' if (columnBuilderArgsArray.length) { diff --git a/packages/drizzle/src/postgres/schema/buildDrizzleTable.ts b/packages/drizzle/src/postgres/schema/buildDrizzleTable.ts index 8c6b9252b..62599ddf3 100644 --- a/packages/drizzle/src/postgres/schema/buildDrizzleTable.ts +++ b/packages/drizzle/src/postgres/schema/buildDrizzleTable.ts @@ -13,6 +13,7 @@ import { uniqueIndex, uuid, varchar, + vector, } from 'drizzle-orm/pg-core' import type { RawColumn, RawTable } from '../../types.js' @@ -81,6 +82,13 @@ export const buildDrizzleTable = ({ break } + case 'vector': { + const builder = vector(column.name, { dimensions: column.dimensions }) + columns[key] = builder + + break + } + default: columns[key] = rawColumnBuilderMap[column.type](column.name) break diff --git a/packages/drizzle/src/types.ts b/packages/drizzle/src/types.ts index 92a082991..da6f148f1 100644 --- a/packages/drizzle/src/types.ts +++ b/packages/drizzle/src/types.ts @@ -279,6 +279,11 @@ export type IntegerRawColumn = { type: 'integer' } & BaseRawColumn +export type VectorRawColumn = { + dimensions?: number + type: 'vector' +} & BaseRawColumn + export type RawColumn = | ({ type: 'boolean' | 'geometry' | 'jsonb' | 'numeric' | 'serial' | 'text' | 'varchar' @@ -287,6 +292,7 @@ export type RawColumn = | IntegerRawColumn | TimestampRawColumn | UUIDRawColumn + | VectorRawColumn export type IDType = 'integer' | 'numeric' | 'text' | 'uuid' | 'varchar' diff --git a/test/database/postgres-vector.int.spec.ts b/test/database/postgres-vector.int.spec.ts new file mode 100644 index 000000000..06cccdc9a --- /dev/null +++ b/test/database/postgres-vector.int.spec.ts @@ -0,0 +1,121 @@ +/* eslint-disable jest/require-top-level-describe */ +import { PostgresAdapter } from '@payloadcms/db-postgres/types' +import { cosineDistance, desc, gt, sql } from 'drizzle-orm' +import path from 'path' +import { buildConfig, getPayload } from 'payload' +import { fileURLToPath } from 'url' + +const filename = fileURLToPath(import.meta.url) +const dirname = path.dirname(filename) + +// skip on ci as there db does not have the vector extension +const describeToUse = + process.env.PAYLOAD_DATABASE.startsWith('postgres') && process.env.CI !== 'true' + ? describe + : describe.skip + +describeToUse('postgres vector custom column', () => { + it('should add a vector column and query it', async () => { + const { databaseAdapter } = await import(path.resolve(dirname, '../databaseAdapter.js')) + + const init = databaseAdapter.init + + // set options + databaseAdapter.init = ({ payload }) => { + const adapter = init({ payload }) + + adapter.extensions = { + vector: true, + } + adapter.beforeSchemaInit = [ + ({ schema, adapter }) => { + ;(adapter as PostgresAdapter).rawTables.posts.columns.embedding = { + type: 'vector', + dimensions: 5, + name: 'embedding', + } + return schema + }, + ] + return adapter + } + + const config = await buildConfig({ + db: databaseAdapter, + secret: 'secret', + collections: [ + { + slug: 'users', + auth: true, + fields: [], + }, + { + slug: 'posts', + fields: [ + { + type: 'json', + name: 'embedding', + }, + { + name: 'title', + type: 'text', + }, + ], + }, + ], + }) + + const payload = await getPayload({ config }) + + const catEmbedding = [1.5, -0.4, 7.2, 19.6, 20.2] + + await payload.create({ + collection: 'posts', + data: { + embedding: [-5.2, 3.1, 0.2, 8.1, 3.5], + title: 'apple', + }, + }) + + await payload.create({ + collection: 'posts', + data: { + embedding: catEmbedding, + title: 'cat', + }, + }) + + await payload.create({ + collection: 'posts', + data: { + embedding: [-5.1, 2.9, 0.8, 7.9, 3.1], + title: 'fruit', + }, + }) + + await payload.create({ + collection: 'posts', + data: { + embedding: [1.7, -0.3, 6.9, 19.1, 21.1], + title: 'dog', + }, + }) + + const similarity = sql`1 - (${cosineDistance(payload.db.tables.posts.embedding, catEmbedding)})` + + const res = await payload.db.drizzle + .select() + .from(payload.db.tables.posts) + .where(gt(similarity, 0.9)) + .orderBy(desc(similarity)) + + // Only cat and dog + expect(res).toHaveLength(2) + + // similarity sort + expect(res[0].title).toBe('cat') + expect(res[1].title).toBe('dog') + + payload.logger.info(res) + }) +})