Skip to content

Commit

Permalink
Completed release v0.14.1 intake
Browse files Browse the repository at this point in the history
Addressed review feedback:
    - Updated PostgreSQL insertion to use upsert with unnest for bulk indexing of vector rows
    - Modified 'start_postgres()' to use the 'pgvector/pgvector:pg17' Docker image and 'Mount::tmpfs_mount()' for in-memory volume
    - Cleaned up extra logging and tracing for streamlined output

Signed-off-by: shamb0 <r.raajey@gmail.com>
  • Loading branch information
shamb0 committed Oct 30, 2024
1 parent b7aa295 commit bd0b265
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 108 deletions.
2 changes: 1 addition & 1 deletion examples/index_md_into_pgvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

tracing::info!("Test Dataset path: {:?}", test_dataset_path);

let (_pgv_db_container, pgv_db_url, _temp_dir) = swiftide_test_utils::start_postgres().await;
let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await;

tracing::info!("pgv_db_url :: {:#?}", pgv_db_url);

Expand Down
30 changes: 4 additions & 26 deletions swiftide-integrations/src/pgvector/persist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ use swiftide_core::{
impl Persist for PgVector {
#[tracing::instrument(skip_all)]
async fn setup(&self) -> Result<()> {
tracing::info!("Setting up table {} for PgVector", &self.table_name);

let mut tx = self.connection_pool.get_pool()?.begin().await?;

// Create extension
Expand All @@ -23,17 +21,14 @@ impl Persist for PgVector {

// Create table
let create_table_sql = self.generate_create_table_sql()?;
tracing::debug!("Executing CREATE TABLE SQL: {}", create_table_sql);
sqlx::query(&create_table_sql).execute(&mut *tx).await?;

// Create HNSW index
let index_sql = self.create_index_sql()?;
tracing::debug!("Executing CREATE INDEX SQL: {}", index_sql);
sqlx::query(&index_sql).execute(&mut *tx).await?;

tx.commit().await?;

tracing::info!("Table {} setup completed", &self.table_name);
Ok(())
}

Expand Down Expand Up @@ -61,51 +56,34 @@ impl Persist for PgVector {
mod tests {
use crate::pgvector::PgVector;
use swiftide_core::{indexing::EmbeddedField, Persist};
use temp_dir::TempDir;
use testcontainers::{ContainerAsync, GenericImage};

struct TestContext {
pgv_storage: PgVector,
_temp_dir: TempDir,
_pgv_db_container: ContainerAsync<GenericImage>,
}

impl TestContext {
/// Set up the test context, initializing `PostgreSQL` and `PgVector` storage
async fn setup() -> Result<Self, Box<dyn std::error::Error>> {
// Start PostgreSQL container and obtain the connection URL
let (pgv_db_container, pgv_db_url, temp_dir) =
swiftide_test_utils::start_postgres().await;

tracing::info!("Postgres database URL: {:#?}", pgv_db_url);
let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await;

// Configure and build PgVector storage
let pgv_storage = PgVector::builder()
.try_connect_to_pool(pgv_db_url, Some(10))
.await
.map_err(|err| {
tracing::error!("Failed to connect to Postgres server: {}", err);
err
})?
.await?
.vector_size(384)
.with_vector(EmbeddedField::Combined)
.with_metadata("filter")
.table_name("swiftide_pgvector_test".to_string())
.build()
.map_err(|err| {
tracing::error!("Failed to build PgVector: {}", err);
err
})?;
.build()?;

// Set up PgVector storage (create the table if not exists)
pgv_storage.setup().await.map_err(|err| {
tracing::error!("PgVector setup failed: {}", err);
err
})?;
pgv_storage.setup().await?;

Ok(Self {
pgv_storage,
_temp_dir: temp_dir,
_pgv_db_container: pgv_db_container,
})
}
Expand Down
215 changes: 146 additions & 69 deletions swiftide-integrations/src/pgvector/pgv_table_types.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
//! This module provides functionality to convert a `Node` into a `PostgreSQL` table schema.
//! This conversion is crucial for storing data in `PostgreSQL`, enabling efficient vector similarity searches
//! through the `pgvector` extension. The module also handles metadata augmentation and ensures compatibility
//! with `PostgreSQL's` required data format.
//! with `PostgreSQL`'s required data format.

use crate::pgvector::PgVector;
use anyhow::{anyhow, Context, Result};
use pgvector as ExtPgVector;
use regex::Regex;
use sqlx::postgres::PgArguments;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::collections::BTreeMap;
Expand All @@ -33,23 +34,12 @@ impl PgDBConnectionPool {
for attempt in 1..=max_retries {
match pool_options.clone().connect(database_url.as_ref()).await {
Ok(pool) => {
tracing::info!(
"Successfully connected to PostgreSQL on attempt {}/{}",
attempt,
max_retries
);
return Ok(pool);
}
Err(e) if attempt < max_retries => {
tracing::warn!(
"Connection failed, retrying attempt {}/{}: {}",
attempt,
max_retries,
e
);
Err(_err) if attempt < max_retries => {
sleep(Duration::from_secs(2)).await;
}
Err(e) => return Err(e),
Err(err) => return Err(err),
}
}
unreachable!()
Expand Down Expand Up @@ -155,6 +145,15 @@ impl FieldConfig {
}
}

/// Structure to hold collected values for bulk upsert
#[derive(Default)]
struct BulkUpsertData {
ids: Vec<sqlx::types::Uuid>,
chunks: Vec<String>,
metadata_fields: BTreeMap<String, Vec<serde_json::Value>>,
vector_fields: BTreeMap<String, Vec<ExtPgVector::Vector>>,
}

impl PgVector {
/// Generates a SQL statement to create a table for storing vector embeddings.
///
Expand Down Expand Up @@ -186,7 +185,7 @@ impl PgVector {
FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()),
FieldConfig::Vector(_) => format!("{} VECTOR({})", field.field_name(), vector_size),
})
.chain(std::iter::once("record_id SERIAL PRIMARY KEY".to_string()))
.chain(std::iter::once("PRIMARY KEY (id)".to_string()))
.collect();

let sql = format!(
Expand Down Expand Up @@ -236,92 +235,170 @@ impl PgVector {
///
/// # Returns
///
/// * `Result<()>` - Ok if all nodes are successfully stored, Err otherwise.
/// * `Result<()>` - `Ok` if all nodes are successfully stored, `Err` otherwise.
///
/// # Errors
///
/// This function will return an error if:
/// - The database connection pool is not established.
/// - Any of the SQL queries fail to execute.
/// - Any of the SQL queries fail to execute due to schema mismatch, constraint violations, or connectivity issues.
/// - Committing the transaction fails.
pub async fn store_nodes(&self, nodes: &[Node]) -> Result<()> {
let pool = self.connection_pool.get_pool()?;

let mut tx = pool.begin().await?;
let sql = self.generate_bulk_insert_sql(nodes.len());

let mut query = sqlx::query(&sql);
let bulk_data = self.prepare_bulk_data(nodes)?;
let sql = self.generate_unnest_upsert_sql()?;

for node in nodes {
query = self.bind_node_to_bulk_query(query, node)?;
}
let query = self.bind_bulk_data_to_query(sqlx::query(&sql), &bulk_data)?;

query.execute(&mut *tx).await.map_err(|e| {
tracing::error!("Failed to store nodes: {:?}", e);
anyhow!("Failed to store nodes: {:?}", e)
})?;
query
.execute(&mut *tx)
.await
.map_err(|e| anyhow!("Failed to store nodes: {:?}", e))?;

tx.commit()
.await
.map_err(|e| anyhow!("Failed to commit transaction: {:?}", e))
}

/// Generates an SQL upsert statement based on the current fields and table name.
/// Prepares data from nodes into vectors for bulk processing.
#[allow(clippy::implicit_clone)]
fn prepare_bulk_data(&self, nodes: &[Node]) -> Result<BulkUpsertData> {
let mut bulk_data = BulkUpsertData::default();

for node in nodes {
bulk_data.ids.push(node.id());
bulk_data.chunks.push(node.chunk.clone());

for field in &self.fields {
match field {
FieldConfig::Metadata(config) => {
let value = node.metadata.get(&config.original_field).ok_or_else(|| {
anyhow!("Metadata field {} not found", config.original_field)
})?;

let entry = bulk_data
.metadata_fields
.entry(config.field.clone())
.or_default();

let mut metadata_map = BTreeMap::new();
metadata_map.insert(config.original_field.clone(), value.clone());
entry.push(serde_json::to_value(metadata_map)?);
}
FieldConfig::Vector(config) => {
let data = node
.vectors
.as_ref()
.and_then(|v| v.get(&config.embedded_field))
.map(|v| v.to_vec())
.unwrap_or_default();

bulk_data
.vector_fields
.entry(config.field.clone())
.or_default()
.push(ExtPgVector::Vector::from(data));
}
_ => continue, // ID and Chunk already handled
}
}
}

Ok(bulk_data)
}

/// Generates SQL for UNNEST-based bulk upsert.
///
/// # Returns
///
/// * `Result<String>` - The generated SQL statement or an error if fields are empty.
///
/// # Errors
///
/// This function constructs a SQL query that inserts new rows into a `PostgreSQL` database table
/// if they do not already exist (based on the "id" column), or updates them if they do. The generated
/// SQL is intended to be efficient and safe for concurrent use.
#[allow(clippy::redundant_closure_for_method_calls)]
/// Generates a bulk insert SQL statement for inserting multiple nodes.
fn generate_bulk_insert_sql(&self, node_count: usize) -> String {
let columns: Vec<&str> = self.fields.iter().map(|field| field.field_name()).collect();
let placeholders: Vec<String> = (1..=node_count)
.flat_map(|i| {
self.fields
.iter()
.enumerate()
.map(move |(j, _)| format!("${}", (i - 1) * self.fields.len() + j + 1))
/// Returns an error if `self.fields` is empty, as no valid SQL can be generated.
fn generate_unnest_upsert_sql(&self) -> Result<String> {
if self.fields.is_empty() {
return Err(anyhow!("Cannot generate upsert SQL with empty fields"));
}

let mut columns = Vec::new();
let mut unnest_params = Vec::new();
let mut param_counter = 1;

for field in &self.fields {
let name = field.field_name();
columns.push(name.to_string());

unnest_params.push(format!(
"${param_counter}::{}",
match field {
FieldConfig::ID => "UUID[]",
FieldConfig::Chunk => "TEXT[]",
FieldConfig::Metadata(_) => "JSONB[]",
FieldConfig::Vector(_) => "VECTOR[]",
}
));

param_counter += 1;
}

let update_columns = self
.fields
.iter()
.filter(|field| !matches!(field, FieldConfig::ID)) // Skip ID field in updates
.map(|field| {
let name = field.field_name();
format!("{name} = EXCLUDED.{name}")
})
.collect();
.collect::<Vec<_>>()
.join(", ");

format!(
"INSERT INTO {} ({}) VALUES {}",
Ok(format!(
r#"
INSERT INTO {} ({})
SELECT {}
FROM UNNEST({}) AS t({})
ON CONFLICT (id) DO UPDATE SET {}"#,
self.table_name,
columns.join(", "),
placeholders
.chunks(self.fields.len())
.map(|chunk| format!("({})", chunk.join(", ")))
.collect::<Vec<_>>()
.join(", ")
)
columns.join(", "),
unnest_params.join(", "),
columns.join(", "),
update_columns
))
}

/// Binds bulk data to the SQL query, ensuring data arrays are matched to corresponding fields.
///
/// # Errors
///
/// Returns an error if any metadata or vector field is missing from the bulk data.
#[allow(clippy::implicit_clone)]
fn bind_node_to_bulk_query<'a>(
fn bind_bulk_data_to_query<'a>(
&self,
mut query: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
node: &'a Node,
) -> Result<sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>> {
mut query: sqlx::query::Query<'a, sqlx::Postgres, PgArguments>,
bulk_data: &'a BulkUpsertData,
) -> Result<sqlx::query::Query<'a, sqlx::Postgres, PgArguments>> {
for field in &self.fields {
query = match field {
FieldConfig::ID => query.bind(node.id()),
FieldConfig::Chunk => query.bind(&node.chunk),
FieldConfig::ID => query.bind(&bulk_data.ids),
FieldConfig::Chunk => query.bind(&bulk_data.chunks),
FieldConfig::Metadata(config) => {
let value = node.metadata.get(&config.original_field).ok_or_else(|| {
anyhow!("Metadata field {} not found", config.original_field)
})?;
let mut metadata_map = BTreeMap::new();
metadata_map.insert(config.original_field.clone(), value.clone());
query.bind(serde_json::to_value(metadata_map)?)
let values = bulk_data
.metadata_fields
.get(&config.field)
.ok_or_else(|| {
anyhow!("Metadata field {} not found in bulk data", config.field)
})?;
query.bind(values)
}
FieldConfig::Vector(config) => {
let data = node
.vectors
.as_ref()
.and_then(|v| v.get(&config.embedded_field))
.map(|v| v.to_vec())
.unwrap_or_default();
query.bind(ExtPgVector::Vector::from(data))
let vectors = bulk_data.vector_fields.get(&config.field).ok_or_else(|| {
anyhow!("Vector field {} not found in bulk data", config.field)
})?;
query.bind(vectors)
}
};
}
Expand Down
Loading

0 comments on commit bd0b265

Please sign in to comment.