Skip to content

Commit

Permalink
Update CassIO tables
Browse files Browse the repository at this point in the history
  • Loading branch information
clun committed Jan 25, 2024
1 parent 870fb1f commit fc02782
Show file tree
Hide file tree
Showing 14 changed files with 337 additions and 191 deletions.
21 changes: 15 additions & 6 deletions astra-db-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
<properties>
<openai-java.version>0.18.2</openai-java.version>
<langchain4j.version>0.25.0</langchain4j.version>
<okhttp.version>4.10.0</okhttp.version>
<okhttp.version>4.12.0</okhttp.version>
<okio-jvm.version>3.7.0</okio-jvm.version>
</properties>

<dependencies>
Expand All @@ -23,6 +24,18 @@
<artifactId>slf4j-api</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>${okhttp.version}</version>
</dependency>

<dependency>
<groupId>com.squareup.okio</groupId>
<artifactId>okio-jvm</artifactId>
<version>${okio-jvm.version}</version>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
Expand Down Expand Up @@ -58,11 +71,7 @@
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>${okhttp.version}</version>
</dependency>

<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;
Expand All @@ -11,7 +10,7 @@
* Wrap query parameters as a Bean.
*/
@Data @Builder
public class SimilaritySearchQuery {
public class AnnQuery {

/**
* Maximum number of item returned
Expand All @@ -31,7 +30,7 @@ public class SimilaritySearchQuery {
/**
* Default distance is cosine
*/
private SimilarityMetric distance = SimilarityMetric.DOT_PRODUCT;
private CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;

/**
* If provided search on metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* record.
*/
@Data
public class SimilaritySearchResult<EMBEDDED> {
public class AnnResult<EMBEDDED> {

/**
* Embedded object
Expand All @@ -24,6 +24,6 @@ public class SimilaritySearchResult<EMBEDDED> {
/**
* Default constructor.
*/
public SimilaritySearchResult() {}
public AnnResult() {}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import com.dtsx.astra.sdk.db.AstraDBOpsClient;
import com.dtsx.astra.sdk.utils.AstraEnvironment;
import io.stargate.sdk.utils.Utils;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;

import java.io.File;
Expand Down Expand Up @@ -151,10 +150,10 @@ public static synchronized CqlSession init(String token, UUID dbId, String dbReg
* @return
* table to store vector
*/
public static MetadataVectorCassandraTable metadataVectorTable(String tableName, int vectorDimension) {
public static MetadataVectorTable metadataVectorTable(String tableName, int vectorDimension) {
if (tableName == null || tableName.isEmpty()) throw new IllegalArgumentException("Table name must be provided");
if (vectorDimension < 1) throw new IllegalArgumentException("Vector dimension must be greater than 0");
return new MetadataVectorCassandraTable(
return new MetadataVectorTable(
getCqlSession(),
cqlSession.getKeyspace().orElseThrow(() ->
new IllegalArgumentException("CqlSession does not select any keyspace")).asInternal(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
* Option for the similarity metric.
*/
@Getter
public enum SimilarityMetric {
public enum CassandraSimilarityMetric {

/** dot product. */
DOT_PRODUCT("DOT_PRODUCT","similarity_dot_product"),

/** cosine. */
COS("COSINE","similarity_cosine"),
COSINE("COSINE","similarity_cosine"),

/** euclidean. */
DOT("EUCLIDEAN","similarity_euclidean");
EUCLIDEAN("EUCLIDEAN","similarity_euclidean");

/**
* Option.
Expand All @@ -35,7 +35,7 @@ public enum SimilarityMetric {
* @param function
* function to be used in the query
*/
SimilarityMetric(String option, String function) {
CassandraSimilarityMetric(String option, String function) {
this.option = option;
this.function = function;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.dtsx.astra.sdk.cassio;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class ClusteredMetadataVectorRecord {

/** Partition id. */
String partitionId;

/** Row identifier. */
UUID rowId;

/** Text body. */
String body;

/**
* Store special attributes
*/
String attributes;

/**
* Metadata (for metadata filtering)
*/
Map<String, String> metadata = new HashMap<>();

/**
* Embeddings
*/
List<Float> vector;

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
* - column: value
*/
@Slf4j
public class ClusteredMetadataVectorCassandraTable
extends AbstractCassandraTable<ClusteredMetadataVectorCassandraTable.Record> {
public class ClusteredMetadataVectorTable extends AbstractCassandraTable<ClusteredMetadataVectorRecord> {

/**
* Dimension of the vector in use
Expand All @@ -24,7 +23,7 @@ public class ClusteredMetadataVectorCassandraTable
/**
* Similarity Metric, Vector is indexed with this metric.
*/
private final SimilarityMetric similarityMetric;
private final CassandraSimilarityMetric similarityMetric;

/**
* Prepared statements
Expand All @@ -44,28 +43,28 @@ public class ClusteredMetadataVectorCassandraTable
* @param vectorDimension vector dimension
* @param metric similarity metric
*/
public ClusteredMetadataVectorCassandraTable(
public ClusteredMetadataVectorTable(
@NonNull CqlSession session,
@NonNull String keyspaceName,
@NonNull String tableName,
@NonNull Integer vectorDimension,
@NonNull SimilarityMetric metric) {
@NonNull CassandraSimilarityMetric metric) {
super(session, keyspaceName, tableName);
this.vectorDimension = vectorDimension;
this.similarityMetric = metric;
}

/**
* Builder class for creating instances of {@link ClusteredMetadataVectorCassandraTable}.
* Builder class for creating instances of {@link ClusteredMetadataVectorTable}.
* This class follows the builder pattern to allow setting various parameters
* before creating an instance of {@link ClusteredMetadataVectorCassandraTable}.
* before creating an instance of {@link ClusteredMetadataVectorTable}.
*/
public static class Builder {
private CqlSession session;
private String keyspaceName;
private String tableName;
private Integer vectorDimension;
private SimilarityMetric metric = SimilarityMetric.COS;
private CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;

/**
* Sets the CqlSession.
Expand Down Expand Up @@ -117,7 +116,7 @@ public Builder withVectorDimension(Integer vectorDimension) {
* @param metric The SimilarityMetric to be used.
* @return The current Builder instance for chaining.
*/
public Builder withMetric(SimilarityMetric metric) {
public Builder withMetric(CassandraSimilarityMetric metric) {
this.metric = metric;
return this;
}
Expand All @@ -127,8 +126,8 @@ public Builder withMetric(SimilarityMetric metric) {
*
* @return A new instance of ClusteredMetadataVectorCassandraTable.
*/
public ClusteredMetadataVectorCassandraTable build() {
return new ClusteredMetadataVectorCassandraTable(session, keyspaceName, tableName, vectorDimension, metric);
public ClusteredMetadataVectorTable build() {
return new ClusteredMetadataVectorTable(session, keyspaceName, tableName, vectorDimension, metric);
}

/**
Expand All @@ -143,7 +142,7 @@ public Builder() {}
* @return
* builder for the class
*/
public static ClusteredMetadataVectorCassandraTable.Builder builder() {
public static ClusteredMetadataVectorTable.Builder builder() {
return new Builder();
}

Expand Down Expand Up @@ -203,13 +202,13 @@ public void create() {

/* {@inheritDoc} */
@Override
public void put(Record row) {
public void put(ClusteredMetadataVectorRecord row) {

}

/* {@inheritDoc} */
@Override
public Record mapRow(Row row) {
public ClusteredMetadataVectorRecord mapRow(Row row) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.dtsx.astra.sdk.cassio;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.UUID;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class ClusteredRecord {

/** Partition id. */
String partitionId;

/** Row identifier. */
UUID rowId;

/** Text body. */
String body;

}
Loading

0 comments on commit fc02782

Please sign in to comment.