Skip to content

Commit

Permalink
Merge pull request #3792 from singlestore-labs/feature/vector-type
Browse files Browse the repository at this point in the history
[SingleStore] Add Vector column type
  • Loading branch information
AndriiSherman authored Jan 8, 2025
2 parents 44616e9 + 8a13fab commit eba87c6
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 1 deletion.
11 changes: 11 additions & 0 deletions drizzle-kit/src/introspect-singlestore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const singlestoreImportsList = new Set([
'tinyint',
'varbinary',
'varchar',
'vector',
'year',
'enum',
]);
Expand Down Expand Up @@ -789,6 +790,16 @@ const column = (
return out;
}

if (lowered.startsWith('vector')) {
const [dimensions, elementType] = lowered.substring('vector'.length + 1, lowered.length - 1).split(',');
let out = `${casing(name)}: vector(${
dbColumnName({ name, casing: rawCasing, withMode: true })
}{ dimensions: ${dimensions}, elementType: ${elementType} })`;

out += defaultValue ? `.default(${mapColumnDefault(defaultValue, isExpression)})` : '';
return out;
}

console.log('uknown', type);
return `// Warning: Can't parse ${type} from database\n\t// ${type}Type: ${type}("${name}")`;
};
Expand Down
2 changes: 1 addition & 1 deletion drizzle-kit/src/serializer/singlestoreSerializer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ export const generateSingleStoreSnapshot = (
if (typeof column.default === 'string') {
columnToSet.default = `'${column.default}'`;
} else {
if (sqlTypeLowered === 'json') {
if (sqlTypeLowered === 'json' || Array.isArray(column.default)) {
columnToSet.default = `'${JSON.stringify(column.default)}'`;
} else if (column.default instanceof Date) {
if (sqlTypeLowered === 'date') {
Expand Down
8 changes: 8 additions & 0 deletions drizzle-kit/tests/push/singlestore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
tinyint,
varbinary,
varchar,
vector,
year,
} from 'drizzle-orm/singlestore-core';
import getPort from 'get-port';
Expand Down Expand Up @@ -249,6 +250,13 @@ const singlestoreSuite: DialectSuite = {
columnNotNull: binary('column_not_null', { length: 1 }).notNull(),
columnDefault: binary('column_default', { length: 12 }),
}),

allVectors: singlestoreTable('all_vectors', {
vectorSimple: vector('vector_simple', { dimensions: 1 }),
vectorElementType: vector('vector_element_type', { dimensions: 1, elementType: 'I8' }),
vectorNotNull: vector('vector_not_null', { dimensions: 1 }).notNull(),
vectorDefault: vector('vector_default', { dimensions: 1 }).default([1]),
}),
};

const { statements } = await diffTestSchemasPushSingleStore(
Expand Down
2 changes: 2 additions & 0 deletions drizzle-orm/src/singlestore-core/columns/all.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { timestamp } from './timestamp.ts';
import { tinyint } from './tinyint.ts';
import { varbinary } from './varbinary.ts';
import { varchar } from './varchar.ts';
import { vector } from './vector.ts';
import { year } from './year.ts';

export function getSingleStoreColumnBuilders() {
Expand Down Expand Up @@ -51,6 +52,7 @@ export function getSingleStoreColumnBuilders() {
tinyint,
varbinary,
varchar,
vector,
year,
};
}
Expand Down
1 change: 1 addition & 0 deletions drizzle-orm/src/singlestore-core/columns/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ export * from './timestamp.ts';
export * from './tinyint.ts';
export * from './varbinary.ts';
export * from './varchar.ts';
export * from './vector.ts';
export * from './year.ts';
83 changes: 83 additions & 0 deletions drizzle-orm/src/singlestore-core/columns/vector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import type { ColumnBuilderBaseConfig, ColumnBuilderRuntimeConfig, MakeColumnConfig } from '~/column-builder.ts';
import type { ColumnBaseConfig } from '~/column.ts';
import { entityKind } from '~/entity.ts';
import type { AnySingleStoreTable } from '~/singlestore-core/table.ts';
import { SQL } from '~/sql/index.ts';
import { getColumnNameAndConfig } from '~/utils.ts';
import { SingleStoreColumn, SingleStoreColumnBuilder, SingleStoreGeneratedColumnConfig } from './common.ts';

export type SingleStoreVectorBuilderInitial<TName extends string> = SingleStoreVectorBuilder<{
name: TName;
dataType: 'array';
columnType: 'SingleStoreVector';
data: Array<number>;
driverParam: string;
enumValues: undefined;
}>;

export class SingleStoreVectorBuilder<T extends ColumnBuilderBaseConfig<'array', 'SingleStoreVector'>>
extends SingleStoreColumnBuilder<T, SingleStoreVectorConfig>
{
static override readonly [entityKind]: string = 'SingleStoreVectorBuilder';

constructor(name: T['name'], config: SingleStoreVectorConfig) {
super(name, 'array', 'SingleStoreVector');
this.config.dimensions = config.dimensions;
this.config.elementType = config.elementType;
}

/** @internal */
override build<TTableName extends string>(
table: AnySingleStoreTable<{ name: TTableName }>,
): SingleStoreVector<MakeColumnConfig<T, TTableName>> {
return new SingleStoreVector<MakeColumnConfig<T, TTableName>>(
table,
this.config as ColumnBuilderRuntimeConfig<any, any>,
);
}

/** @internal */
override generatedAlwaysAs(as: SQL<unknown> | (() => SQL) | T['data'], config?: SingleStoreGeneratedColumnConfig) {
throw new Error('not implemented');
}
}

export class SingleStoreVector<T extends ColumnBaseConfig<'array', 'SingleStoreVector'>>
extends SingleStoreColumn<T, SingleStoreVectorConfig>
{
static override readonly [entityKind]: string = 'SingleStoreVector';

dimensions: number = this.config.dimensions;
elementType: ElementType | undefined = this.config.elementType;

getSQLType(): string {
return `vector(${this.dimensions}, ${this.elementType || 'F32'})`;
}

override mapToDriverValue(value: Array<number>) {
return JSON.stringify(value);
}

override mapFromDriverValue(value: string): Array<number> {
return JSON.parse(value);
}
}

type ElementType = 'I8' | 'I16' | 'I32' | 'I64' | 'F32' | 'F64';

export interface SingleStoreVectorConfig {
dimensions: number;
elementType?: ElementType;
}

export function vector(
config: SingleStoreVectorConfig,
): SingleStoreVectorBuilderInitial<''>;
export function vector<TName extends string>(
name: TName,
config: SingleStoreVectorConfig,
): SingleStoreVectorBuilderInitial<TName>;
export function vector(a: string | SingleStoreVectorConfig, b?: SingleStoreVectorConfig) {
const { name, config } = getColumnNameAndConfig<SingleStoreVectorConfig>(a, b);
return new SingleStoreVectorBuilder(name, config);
}
9 changes: 9 additions & 0 deletions drizzle-orm/src/singlestore-core/expressions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ export function substring(
chunks.push(sql`)`);
return sql.join(chunks);
}

// Vectors
export function dotProduct(column: SingleStoreColumn | SQL.Aliased, value: Array<number>): SQL {
return sql`${column} <*> ${JSON.stringify(value)}`;
}

export function euclideanDistance(column: SingleStoreColumn | SQL.Aliased, value: Array<number>): SQL {
return sql`${column} <-> ${JSON.stringify(value)}`;
}
5 changes: 5 additions & 0 deletions drizzle-orm/type-tests/singlestore/tables.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
uniqueIndex,
varbinary,
varchar,
vector,
year,
} from '~/singlestore-core/index.ts';
import { singlestoreSchema } from '~/singlestore-core/schema.ts';
Expand Down Expand Up @@ -917,6 +918,8 @@ Expect<
varchar: varchar('varchar', { length: 1 }),
varchar2: varchar('varchar2', { length: 1, enum: ['a', 'b', 'c'] }),
varchardef: varchar('varchardef', { length: 1 }).default(''),
vector: vector('vector', { dimensions: 1 }),
vector2: vector('vector2', { dimensions: 1, elementType: 'I8' }),
year: year('year'),
yeardef: year('yeardef').default(0),
});
Expand Down Expand Up @@ -1015,6 +1018,8 @@ Expect<
varchar: varchar({ length: 1 }),
varchar2: varchar({ length: 1, enum: ['a', 'b', 'c'] }),
varchardef: varchar({ length: 1 }).default(''),
vector: vector({ dimensions: 1 }),
vector2: vector({ dimensions: 1, elementType: 'I8' }),
year: year(),
yeardef: year().default(0),
});
Expand Down
63 changes: 63 additions & 0 deletions integration-tests/tests/singlestore/singlestore-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ import {
uniqueIndex,
uniqueKeyName,
varchar,
vector,
year,
} from 'drizzle-orm/singlestore-core';
import { dotProduct, euclideanDistance } from 'drizzle-orm/singlestore-core/expressions';
import { migrate } from 'drizzle-orm/singlestore/migrator';
import getPort from 'get-port';
import { v4 as uuid } from 'uuid';
Expand Down Expand Up @@ -156,6 +158,12 @@ const aggregateTable = singlestoreTable('aggregate_table', {
nullOnly: int('null_only'),
});

const vectorSearchTable = singlestoreTable('vector_search', {
id: serial('id').notNull(),
text: text('text').notNull(),
embedding: vector('embedding', { dimensions: 10 }),
});

// To test another schema and multischema
const mySchema = singlestoreSchema(`mySchema`);

Expand Down Expand Up @@ -366,6 +374,31 @@ export function tests(driver?: string) {
]);
}

async function setupVectorSearchTest(db: TestSingleStoreDB) {
await db.execute(sql`drop table if exists \`vector_search\``);
await db.execute(
sql`
create table \`vector_search\` (
\`id\` integer primary key auto_increment not null,
\`text\` text not null,
\`embedding\` vector(10) not null
)
`,
);
await db.insert(vectorSearchTable).values([
{
id: 1,
text: 'I like dogs',
embedding: [0.6119, 0.1395, 0.2921, 0.3664, 0.4561, 0.7852, 0.1997, 0.5142, 0.5924, 0.0465],
},
{
id: 2,
text: 'I like cats',
embedding: [0.6075, 0.1705, 0.0651, 0.9489, 0.9656, 0.8084, 0.3046, 0.0977, 0.6842, 0.4402],
},
]);
}

test('table config: unsigned ints', async () => {
const unsignedInts = singlestoreTable('cities1', {
bigint: bigint('bigint', { mode: 'number', unsigned: true }),
Expand Down Expand Up @@ -2907,6 +2940,36 @@ export function tests(driver?: string) {
expect(result2[0]?.value).toBe(null);
});

test('simple vector search', async (ctx) => {
const { db } = ctx.singlestore;
const table = vectorSearchTable;
const embedding = [0.42, 0.93, 0.88, 0.57, 0.32, 0.64, 0.76, 0.52, 0.19, 0.81]; // ChatGPT's 10 dimension embedding for "dogs are cool" not sure how accurate but it works
await setupVectorSearchTest(db);

const withRankEuclidean = db.select({
id: table.id,
text: table.text,
rank: sql`row_number() over (order by ${euclideanDistance(table.embedding, embedding)})`.as('rank'),
}).from(table).as('with_rank');
const withRankDotProduct = db.select({
id: table.id,
text: table.text,
rank: sql`row_number() over (order by ${dotProduct(table.embedding, embedding)})`.as('rank'),
}).from(table).as('with_rank');
const result1 = await db.select({ id: withRankEuclidean.id, text: withRankEuclidean.text }).from(
withRankEuclidean,
).where(eq(withRankEuclidean.rank, 1));
const result2 = await db.select({ id: withRankDotProduct.id, text: withRankDotProduct.text }).from(
withRankDotProduct,
).where(eq(withRankDotProduct.rank, 1));

expect(result1.length).toEqual(1);
expect(result1[0]).toEqual({ id: 1, text: 'I like dogs' });

expect(result2.length).toEqual(1);
expect(result2[0]).toEqual({ id: 1, text: 'I like dogs' });
});

test('test $onUpdateFn and $onUpdate works as $default', async (ctx) => {
const { db } = ctx.singlestore;

Expand Down

0 comments on commit eba87c6

Please sign in to comment.