feat(db-postgres): add vector raw column type (#10422)
Example how you can add a vector column, enable the `vector` extension and query your embeddings in the included test - https://github.com/payloadcms/payload/compare/feat/more-types?expand=1#diff-7d876370487cb625eb42ff1ad7cffa78e8327367af3de2930837ed123f5e3ae6R1-R117
This commit is contained in:
@@ -35,6 +35,12 @@ export const columnToCodeConverter: ColumnToCodeConverter = ({
|
||||
}
|
||||
}
|
||||
|
||||
if (column.type === 'vector') {
|
||||
if (column.dimensions) {
|
||||
columnBuilderArgsArray.push(`dimensions: ${column.dimensions}`)
|
||||
}
|
||||
}
|
||||
|
||||
let columnBuilderArgs = ''
|
||||
|
||||
if (columnBuilderArgsArray.length) {
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
uniqueIndex,
|
||||
uuid,
|
||||
varchar,
|
||||
vector,
|
||||
} from 'drizzle-orm/pg-core'
|
||||
|
||||
import type { RawColumn, RawTable } from '../../types.js'
|
||||
@@ -81,6 +82,13 @@ export const buildDrizzleTable = ({
|
||||
break
|
||||
}
|
||||
|
||||
case 'vector': {
|
||||
const builder = vector(column.name, { dimensions: column.dimensions })
|
||||
columns[key] = builder
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
columns[key] = rawColumnBuilderMap[column.type](column.name)
|
||||
break
|
||||
|
||||
@@ -279,6 +279,11 @@ export type IntegerRawColumn = {
|
||||
type: 'integer'
|
||||
} & BaseRawColumn
|
||||
|
||||
export type VectorRawColumn = {
|
||||
dimensions?: number
|
||||
type: 'vector'
|
||||
} & BaseRawColumn
|
||||
|
||||
export type RawColumn =
|
||||
| ({
|
||||
type: 'boolean' | 'geometry' | 'jsonb' | 'numeric' | 'serial' | 'text' | 'varchar'
|
||||
@@ -287,6 +292,7 @@ export type RawColumn =
|
||||
| IntegerRawColumn
|
||||
| TimestampRawColumn
|
||||
| UUIDRawColumn
|
||||
| VectorRawColumn
|
||||
|
||||
export type IDType = 'integer' | 'numeric' | 'text' | 'uuid' | 'varchar'
|
||||
|
||||
|
||||
121
test/database/postgres-vector.int.spec.ts
Normal file
121
test/database/postgres-vector.int.spec.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
/* eslint-disable jest/require-top-level-describe */
|
||||
import { PostgresAdapter } from '@payloadcms/db-postgres/types'
|
||||
import { cosineDistance, desc, gt, sql } from 'drizzle-orm'
|
||||
import path from 'path'
|
||||
import { buildConfig, getPayload } from 'payload'
|
||||
import { fileURLToPath } from 'url'
|
||||
|
||||
const filename = fileURLToPath(import.meta.url)
|
||||
const dirname = path.dirname(filename)
|
||||
|
||||
// skip on ci as there db does not have the vector extension
|
||||
const describeToUse =
|
||||
process.env.PAYLOAD_DATABASE.startsWith('postgres') && process.env.CI !== 'true'
|
||||
? describe
|
||||
: describe.skip
|
||||
|
||||
describeToUse('postgres vector custom column', () => {
|
||||
it('should add a vector column and query it', async () => {
|
||||
const { databaseAdapter } = await import(path.resolve(dirname, '../databaseAdapter.js'))
|
||||
|
||||
const init = databaseAdapter.init
|
||||
|
||||
// set options
|
||||
databaseAdapter.init = ({ payload }) => {
|
||||
const adapter = init({ payload })
|
||||
|
||||
adapter.extensions = {
|
||||
vector: true,
|
||||
}
|
||||
adapter.beforeSchemaInit = [
|
||||
({ schema, adapter }) => {
|
||||
;(adapter as PostgresAdapter).rawTables.posts.columns.embedding = {
|
||||
type: 'vector',
|
||||
dimensions: 5,
|
||||
name: 'embedding',
|
||||
}
|
||||
return schema
|
||||
},
|
||||
]
|
||||
return adapter
|
||||
}
|
||||
|
||||
const config = await buildConfig({
|
||||
db: databaseAdapter,
|
||||
secret: 'secret',
|
||||
collections: [
|
||||
{
|
||||
slug: 'users',
|
||||
auth: true,
|
||||
fields: [],
|
||||
},
|
||||
{
|
||||
slug: 'posts',
|
||||
fields: [
|
||||
{
|
||||
type: 'json',
|
||||
name: 'embedding',
|
||||
},
|
||||
{
|
||||
name: 'title',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const payload = await getPayload({ config })
|
||||
|
||||
const catEmbedding = [1.5, -0.4, 7.2, 19.6, 20.2]
|
||||
|
||||
await payload.create({
|
||||
collection: 'posts',
|
||||
data: {
|
||||
embedding: [-5.2, 3.1, 0.2, 8.1, 3.5],
|
||||
title: 'apple',
|
||||
},
|
||||
})
|
||||
|
||||
await payload.create({
|
||||
collection: 'posts',
|
||||
data: {
|
||||
embedding: catEmbedding,
|
||||
title: 'cat',
|
||||
},
|
||||
})
|
||||
|
||||
await payload.create({
|
||||
collection: 'posts',
|
||||
data: {
|
||||
embedding: [-5.1, 2.9, 0.8, 7.9, 3.1],
|
||||
title: 'fruit',
|
||||
},
|
||||
})
|
||||
|
||||
await payload.create({
|
||||
collection: 'posts',
|
||||
data: {
|
||||
embedding: [1.7, -0.3, 6.9, 19.1, 21.1],
|
||||
title: 'dog',
|
||||
},
|
||||
})
|
||||
|
||||
const similarity = sql<number>`1 - (${cosineDistance(payload.db.tables.posts.embedding, catEmbedding)})`
|
||||
|
||||
const res = await payload.db.drizzle
|
||||
.select()
|
||||
.from(payload.db.tables.posts)
|
||||
.where(gt(similarity, 0.9))
|
||||
.orderBy(desc(similarity))
|
||||
|
||||
// Only cat and dog
|
||||
expect(res).toHaveLength(2)
|
||||
|
||||
// similarity sort
|
||||
expect(res[0].title).toBe('cat')
|
||||
expect(res[1].title).toBe('dog')
|
||||
|
||||
payload.logger.info(res)
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user