Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIXELS-580] implement using LSH index to perform approximate nearest neighbour search on vector column #88

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f88dedb
implement a simple trino udf to calcualte the sum of all the elements…
TiannanSha Dec 4, 2023
5a6d75c
use array instead of features because apparently trino features don't…
TiannanSha Dec 6, 2023
df04bbf
implement three types of distances: euclidean, dotproduct and cosine …
TiannanSha Dec 7, 2023
2300020
make select exactNNS() udf work in trino
TiannanSha Dec 29, 2023
1767b44
fix a test
TiannanSha Dec 29, 2023
d52a31e
fix a comment
TiannanSha Dec 29, 2023
2201829
minor polish
TiannanSha Dec 30, 2023
420408a
make pixels vector type support trino array type
TiannanSha Jan 8, 2024
abbbee7
implement a exact NNS that acts as an aggregation function and should…
TiannanSha Jan 16, 2024
4ca554d
clean up
TiannanSha Jan 16, 2024
5bc07ba
pretty much finished LSH build; LSH search wip
TiannanSha Jan 30, 2024
01a3903
wip: lsh search
TiannanSha Feb 2, 2024
19fedc8
lshSearch work in progress
TiannanSha Feb 5, 2024
3e89598
implement LSH search, including updating mapping from col to buckets …
TiannanSha Feb 9, 2024
0fc8cd1
fix the ser and deser for LSH index
TiannanSha Feb 17, 2024
9150fab
auto decide s3dir for storing LSH buckets. clean up
TiannanSha Feb 18, 2024
dd9016d
implement code for experiments; before fixing remote page too large
TiannanSha Mar 4, 2024
b843296
accidentally got ignored entire folder of lsh build
TiannanSha Mar 5, 2024
fad3f5d
fix lsh_build
TiannanSha Mar 5, 2024
9eaf254
implement LSH load and adjust lsh search
TiannanSha Mar 7, 2024
09afb3f
fix lsh_search() so that it works with lsh_load()
TiannanSha Mar 8, 2024
f231da9
minor changes on lsh_load()
TiannanSha Mar 18, 2024
bbbb793
add bucket write threshold to lsh_build() to avoid sending big messages
TiannanSha Mar 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions connector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -186,5 +186,24 @@
<optional>true</optional>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-main</artifactId>
<version>405</version>
<scope>test</scope>
</dependency>
<!-- Maven Dependency -->
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-testing</artifactId>
<version>405</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use AWS Java SDK 1.x.
AWS Java SDK 2 is already included in the dependency.

<artifactId>aws-java-sdk-s3</artifactId>
<version>1.12.261</version>
</dependency>

</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import java.util.concurrent.atomic.AtomicInteger;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -406,6 +407,7 @@ public Page getNextPage()
blocks[fieldId] = new LazyBlock(rowBatchSize, new PixelsBlockLoader(
vector, type, typeCategory, rowBatchSize));
}
// update the current column in cached LSH index
} catch (IOException e)
{
closeWithSuppression(e);
Expand Down Expand Up @@ -634,6 +636,26 @@ public Block load()
*/
block = new LongArrayBlock(batchSize, Optional.ofNullable(tscv.isNull), tscv.times);
break;
case VECTOR:
VectorColumnVector vcv = (VectorColumnVector) vector;
// builder that simply concatenate all double arrays
BlockBuilder allDoublesBuilder = DOUBLE.createBlockBuilder(null, batchSize * vcv.dimension);
int[] offsets = new int[batchSize+1];
// build a block into which we put a double array
for (int i = 0 ; i < batchSize; i++) {
offsets[i] = i * vcv.dimension;
for (int j = 0; j < vcv.dimension; j++) {
DOUBLE.writeDouble(allDoublesBuilder, vcv.vector[i][j]);
}
}
offsets[batchSize] = batchSize * vcv.dimension;
// after extensive research on how other connectors deal with array type, the following seems to
// be the way to go: basically we stuff all the values of all arrays into one big block, and provide
// an int[] as offsets to tell trino where each array begins and ends. Note that the final offset
// should be the position to tell trino the end of the final array
// Interestingly all the above is NOT documented in trino documentation or code at all.
block = ArrayBlock.fromElementBlock(batchSize, Optional.of(vcv.isNull), offsets, allDoublesBuilder.build());
break;
default:
BlockBuilder blockBuilder = type.createBlockBuilder(null, batchSize);
for (int i = 0; i < batchSize; ++i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.pixelsdb.pixels.common.turbo.Output;
import io.pixelsdb.pixels.common.turbo.WorkerType;
import io.pixelsdb.pixels.core.PixelsFooterCache;
import io.pixelsdb.pixels.core.TypeDescription;
import io.pixelsdb.pixels.core.utils.Pair;
import io.pixelsdb.pixels.executor.join.JoinAlgorithm;
import io.pixelsdb.pixels.executor.predicate.TableScanFilter;
Expand All @@ -44,6 +45,7 @@
import io.pixelsdb.pixels.planner.plan.physical.output.ScanOutput;
import io.pixelsdb.pixels.trino.exception.PixelsErrorCode;
import io.pixelsdb.pixels.trino.impl.PixelsTrinoConfig;
import io.pixelsdb.pixels.trino.vector.lshnns.CachedLSHIndex;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.*;

Expand Down Expand Up @@ -115,6 +117,12 @@ public ConnectorPageSource createPageSource(ConnectorTransactionHandle transacti

try
{
List<PixelsColumnHandle> columnHandles = tableHandle.getColumns();
if (columnHandles.size()==1
&& columnHandles.get(0).getTypeCategory() == TypeDescription.Category.VECTOR) {
// update the LSH index in every node
CachedLSHIndex.getInstance().setCurrColumn(columnHandles.get(0));
}
if (pixelsSplit.getTableType() == Table.TableType.AGGREGATED)
{
// perform aggregation push down.
Expand Down
20 changes: 20 additions & 0 deletions connector/src/main/java/io/pixelsdb/pixels/trino/PixelsPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@
package io.pixelsdb.pixels.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.pixelsdb.pixels.trino.block.TimeArrayBlockEncoding;
import io.pixelsdb.pixels.trino.block.VarcharArrayBlockEncoding;
import io.pixelsdb.pixels.trino.vector.VectorUDF;
import io.pixelsdb.pixels.trino.vector.exactnns.ExactNNSAggFunc;
import io.pixelsdb.pixels.trino.vector.lshnns.lshbuild.BuildLSHIndexAggFunc;
import io.pixelsdb.pixels.trino.vector.lshnns.lshbuild.LSHLoadUDF;
import io.pixelsdb.pixels.trino.vector.lshnns.search.LSHSearchUDF;
import io.trino.spi.Plugin;
import io.trino.spi.block.BlockEncoding;
import io.trino.spi.connector.ConnectorFactory;

import java.util.Set;

public class PixelsPlugin implements Plugin
{
@Override
Expand All @@ -39,4 +47,16 @@ public Iterable<ConnectorFactory> getConnectorFactories()
{
return ImmutableList.of(new PixelsConnectorFactory());
}

@Override
public Set<Class<?>> getFunctions()
{
return ImmutableSet.<Class<?>>builder()
.add(VectorUDF.class)
.add(ExactNNSAggFunc.class)
.add(BuildLSHIndexAggFunc.class)
.add(LSHSearchUDF.class)
.add(LSHLoadUDF.class)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import io.pixelsdb.pixels.trino.impl.PixelsMetadataProxy;
import io.pixelsdb.pixels.trino.impl.PixelsTrinoConfig;
import io.pixelsdb.pixels.trino.properties.PixelsSessionProperties;
import io.pixelsdb.pixels.trino.vector.lshnns.CachedLSHIndex;
import io.trino.spi.HostAddress;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.*;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package io.pixelsdb.pixels.trino.vector;

import io.pixelsdb.pixels.common.physical.Storage;
import io.pixelsdb.pixels.common.physical.StorageFactory;
import io.pixelsdb.pixels.core.PixelsWriter;
import io.pixelsdb.pixels.core.PixelsWriterImpl;
import io.pixelsdb.pixels.core.TypeDescription;
import io.pixelsdb.pixels.core.encoding.EncodingLevel;
import io.pixelsdb.pixels.core.exception.PixelsWriterException;
import io.pixelsdb.pixels.core.vector.VectorColumnVector;
import io.pixelsdb.pixels.core.vector.VectorizedRowBatch;

import java.io.IOException;

/**
* This class is responsible for writing test files containing vector columns to s3
*/
public class S3TestFileGenerator {
public static void main(String[] args) throws IOException
{
writeVectorColumnToS3(getTestVectors(4,2), "exactNNS-test-file3.pxl");
writeVectorColumnToS3(getTestVectors(4,2), "exactNNS-test-file4.pxl");
// todo maybe add a large scale test
}

public static void writeVectorColumnToS3(double[][] vectorsToWrite, String s3File)
{
int length = vectorsToWrite.length;
if (vectorsToWrite[0]==null) {
return;
}
int dimension = vectorsToWrite[0].length;
// Note you may need to restart intellij to let it pick up the updated environment variable value
// example path: s3://bucket-name/test-file.pxl
try
{
String pixelsFile = System.getenv("PIXELS_S3_TEST_BUCKET_PATH") + s3File;
Storage storage = StorageFactory.Instance().getStorage("s3");

String schemaStr = String.format("struct<v:vector(%s)>", dimension);

TypeDescription schema = TypeDescription.fromString(schemaStr);
VectorizedRowBatch rowBatch = schema.createRowBatch();
VectorColumnVector v = (VectorColumnVector) rowBatch.cols[0];

PixelsWriter pixelsWriter =
PixelsWriterImpl.newBuilder()
.setSchema(schema)
.setPixelStride(10000)
.setRowGroupSize(64 * 1024 * 1024)
.setStorage(storage)
.setPath(pixelsFile)
.setBlockSize(256 * 1024 * 1024)
.setReplication((short) 3)
.setBlockPadding(true)
.setEncodingLevel(EncodingLevel.EL2)
.setCompressionBlockSize(1)
.build();

for (int i = 0; i < length-1; i++)
{
int row = rowBatch.size++;
v.vector[row] = new double[dimension];
System.arraycopy(vectorsToWrite[row], 0, v.vector[row], 0, dimension);
v.isNull[row] = false;
if (rowBatch.size == rowBatch.getMaxSize())
{
pixelsWriter.addRowBatch(rowBatch);
rowBatch.reset();
}
}

if (rowBatch.size != 0)
{
pixelsWriter.addRowBatch(rowBatch);
System.out.println("A rowBatch of size " + rowBatch.size + " has been written to " + pixelsFile);
rowBatch.reset();
}

pixelsWriter.close();
} catch (IOException | PixelsWriterException e)
{
e.printStackTrace();
}
}

/**
* testVectors[i][j] = i + j*0.0001
* e.g. testVector[0][500] = 0.05
* @param length number of vectors
* @param dimension dimension of each vector
* @return
*/
private static double[][] getTestVectors(int length, int dimension)
{
double[][] testVecs = new double[length][dimension];
for (int i=0; i<length; i++) {
for (int j=0; j<dimension; j++) {
testVecs[i][j] = i + j*0.0001;
}
}
return testVecs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package io.pixelsdb.pixels.trino.vector;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.airlift.slice.Slice;
import io.pixelsdb.pixels.core.TypeDescription;
import io.pixelsdb.pixels.trino.PixelsColumnHandle;
import io.pixelsdb.pixels.trino.vector.VectorDistFuncs;
import io.trino.spi.block.Block;
import io.trino.spi.type.ArrayType;

import static io.trino.spi.type.DoubleType.DOUBLE;

public class VectorAggFuncUtil {

public static double[] blockToVec(Block block) {
double[] inputVector;
if (block == null) {
return null;
}
// todo use offset here
inputVector = new double[block.getPositionCount()];
for (int i = 0; i < block.getPositionCount(); i++) {
inputVector[i] = DOUBLE.getDouble(block, i);
}
return inputVector;
}

public static double[] sliceToVec(Slice slice) {
ObjectMapper objectMapper = new ObjectMapper();
try {
double[] vec = objectMapper.readValue(slice.toStringUtf8(), double[].class);
return vec;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}

public static VectorDistFuncs.DistFuncEnum sliceToDistFunc(Slice distFuncStr) {
return switch (distFuncStr.toStringUtf8()) {
case "euc" -> VectorDistFuncs.DistFuncEnum.EUCLIDEAN_DISTANCE;
case "cos" -> VectorDistFuncs.DistFuncEnum.COSINE_SIMILARITY;
case "dot" -> VectorDistFuncs.DistFuncEnum.DOT_PRODUCT;
default -> null;
};
}

public static PixelsColumnHandle sliceToColumn(Slice distFuncStr) {
String[] schemaTableCol = distFuncStr.toStringUtf8().split("\\.");
if (schemaTableCol.length != 3) {
throw new IllegalColumnException("column should be of form schema.table.column");
}
return PixelsColumnHandle.builder()
.setConnectorId("pixels")
.setSchemaName(schemaTableCol[0])
.setTableName(schemaTableCol[1])
.setColumnName(schemaTableCol[2])
.setColumnAlias(schemaTableCol[2])
.setColumnType(new ArrayType(DOUBLE))
.setTypeCategory(TypeDescription.Category.VECTOR)
.setLogicalOrdinal(0)
.setColumnComment("")
.build();
}

public static class IllegalColumnException extends IllegalArgumentException {
public IllegalColumnException(String msg) {
super(msg);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.pixelsdb.pixels.trino.vector;

import io.trino.spi.block.Block;

public interface VectorDistFunc {
Double getDist(double[] vec1, double[] vec2);
}
Loading