diff --git a/examples/index_md_into_pgvector.rs b/examples/index_md_into_pgvector.rs index 77ace355..a001eeac 100644 --- a/examples/index_md_into_pgvector.rs +++ b/examples/index_md_into_pgvector.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box> { tracing::info!("Test Dataset path: {:?}", test_dataset_path); - let (_pgv_db_container, pgv_db_url, _temp_dir) = swiftide_test_utils::start_postgres().await; + let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; tracing::info!("pgv_db_url :: {:#?}", pgv_db_url); diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs index 01e51235..c0706576 100644 --- a/swiftide-integrations/src/pgvector/persist.rs +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -13,8 +13,6 @@ use swiftide_core::{ impl Persist for PgVector { #[tracing::instrument(skip_all)] async fn setup(&self) -> Result<()> { - tracing::info!("Setting up table {} for PgVector", &self.table_name); - let mut tx = self.connection_pool.get_pool()?.begin().await?; // Create extension @@ -23,17 +21,14 @@ impl Persist for PgVector { // Create table let create_table_sql = self.generate_create_table_sql()?; - tracing::debug!("Executing CREATE TABLE SQL: {}", create_table_sql); sqlx::query(&create_table_sql).execute(&mut *tx).await?; // Create HNSW index let index_sql = self.create_index_sql()?; - tracing::debug!("Executing CREATE INDEX SQL: {}", index_sql); sqlx::query(&index_sql).execute(&mut *tx).await?; tx.commit().await?; - tracing::info!("Table {} setup completed", &self.table_name); Ok(()) } @@ -61,12 +56,10 @@ impl Persist for PgVector { mod tests { use crate::pgvector::PgVector; use swiftide_core::{indexing::EmbeddedField, Persist}; - use temp_dir::TempDir; use testcontainers::{ContainerAsync, GenericImage}; struct TestContext { pgv_storage: PgVector, - _temp_dir: TempDir, _pgv_db_container: ContainerAsync, } @@ -74,38 +67,23 @@ mod tests { /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage async fn setup() -> Result> { // Start PostgreSQL container and obtain the connection URL - let (pgv_db_container, pgv_db_url, temp_dir) = - swiftide_test_utils::start_postgres().await; - - tracing::info!("Postgres database URL: {:#?}", pgv_db_url); + let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; // Configure and build PgVector storage let pgv_storage = PgVector::builder() .try_connect_to_pool(pgv_db_url, Some(10)) - .await - .map_err(|err| { - tracing::error!("Failed to connect to Postgres server: {}", err); - err - })? + .await? .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata("filter") .table_name("swiftide_pgvector_test".to_string()) - .build() - .map_err(|err| { - tracing::error!("Failed to build PgVector: {}", err); - err - })?; + .build()?; // Set up PgVector storage (create the table if not exists) - pgv_storage.setup().await.map_err(|err| { - tracing::error!("PgVector setup failed: {}", err); - err - })?; + pgv_storage.setup().await?; Ok(Self { pgv_storage, - _temp_dir: temp_dir, _pgv_db_container: pgv_db_container, }) } diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs index afd045c0..3598d873 100644 --- a/swiftide-integrations/src/pgvector/pgv_table_types.rs +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -1,12 +1,13 @@ //! This module provides functionality to convert a `Node` into a `PostgreSQL` table schema. //! This conversion is crucial for storing data in `PostgreSQL`, enabling efficient vector similarity searches //! through the `pgvector` extension. The module also handles metadata augmentation and ensures compatibility -//! with `PostgreSQL's` required data format. +//! with `PostgreSQL`'s required data format. use crate::pgvector::PgVector; use anyhow::{anyhow, Context, Result}; use pgvector as ExtPgVector; use regex::Regex; +use sqlx::postgres::PgArguments; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use std::collections::BTreeMap; @@ -33,23 +34,12 @@ impl PgDBConnectionPool { for attempt in 1..=max_retries { match pool_options.clone().connect(database_url.as_ref()).await { Ok(pool) => { - tracing::info!( - "Successfully connected to PostgreSQL on attempt {}/{}", - attempt, - max_retries - ); return Ok(pool); } - Err(e) if attempt < max_retries => { - tracing::warn!( - "Connection failed, retrying attempt {}/{}: {}", - attempt, - max_retries, - e - ); + Err(_err) if attempt < max_retries => { sleep(Duration::from_secs(2)).await; } - Err(e) => return Err(e), + Err(err) => return Err(err), } } unreachable!() @@ -155,6 +145,15 @@ impl FieldConfig { } } +/// Structure to hold collected values for bulk upsert +#[derive(Default)] +struct BulkUpsertData { + ids: Vec, + chunks: Vec, + metadata_fields: BTreeMap>, + vector_fields: BTreeMap>, +} + impl PgVector { /// Generates a SQL statement to create a table for storing vector embeddings. /// @@ -186,7 +185,7 @@ impl PgVector { FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()), FieldConfig::Vector(_) => format!("{} VECTOR({})", field.field_name(), vector_size), }) - .chain(std::iter::once("record_id SERIAL PRIMARY KEY".to_string())) + .chain(std::iter::once("PRIMARY KEY (id)".to_string())) .collect(); let sql = format!( @@ -236,92 +235,170 @@ impl PgVector { /// /// # Returns /// - /// * `Result<()>` - Ok if all nodes are successfully stored, Err otherwise. + /// * `Result<()>` - `Ok` if all nodes are successfully stored, `Err` otherwise. /// /// # Errors /// /// This function will return an error if: /// - The database connection pool is not established. - /// - Any of the SQL queries fail to execute. + /// - Any of the SQL queries fail to execute due to schema mismatch, constraint violations, or connectivity issues. /// - Committing the transaction fails. pub async fn store_nodes(&self, nodes: &[Node]) -> Result<()> { let pool = self.connection_pool.get_pool()?; let mut tx = pool.begin().await?; - let sql = self.generate_bulk_insert_sql(nodes.len()); - - let mut query = sqlx::query(&sql); + let bulk_data = self.prepare_bulk_data(nodes)?; + let sql = self.generate_unnest_upsert_sql()?; - for node in nodes { - query = self.bind_node_to_bulk_query(query, node)?; - } + let query = self.bind_bulk_data_to_query(sqlx::query(&sql), &bulk_data)?; - query.execute(&mut *tx).await.map_err(|e| { - tracing::error!("Failed to store nodes: {:?}", e); - anyhow!("Failed to store nodes: {:?}", e) - })?; + query + .execute(&mut *tx) + .await + .map_err(|e| anyhow!("Failed to store nodes: {:?}", e))?; tx.commit() .await .map_err(|e| anyhow!("Failed to commit transaction: {:?}", e)) } - /// Generates an SQL upsert statement based on the current fields and table name. + /// Prepares data from nodes into vectors for bulk processing. + #[allow(clippy::implicit_clone)] + fn prepare_bulk_data(&self, nodes: &[Node]) -> Result { + let mut bulk_data = BulkUpsertData::default(); + + for node in nodes { + bulk_data.ids.push(node.id()); + bulk_data.chunks.push(node.chunk.clone()); + + for field in &self.fields { + match field { + FieldConfig::Metadata(config) => { + let value = node.metadata.get(&config.original_field).ok_or_else(|| { + anyhow!("Metadata field {} not found", config.original_field) + })?; + + let entry = bulk_data + .metadata_fields + .entry(config.field.clone()) + .or_default(); + + let mut metadata_map = BTreeMap::new(); + metadata_map.insert(config.original_field.clone(), value.clone()); + entry.push(serde_json::to_value(metadata_map)?); + } + FieldConfig::Vector(config) => { + let data = node + .vectors + .as_ref() + .and_then(|v| v.get(&config.embedded_field)) + .map(|v| v.to_vec()) + .unwrap_or_default(); + + bulk_data + .vector_fields + .entry(config.field.clone()) + .or_default() + .push(ExtPgVector::Vector::from(data)); + } + _ => continue, // ID and Chunk already handled + } + } + } + + Ok(bulk_data) + } + + /// Generates SQL for UNNEST-based bulk upsert. + /// + /// # Returns + /// + /// * `Result` - The generated SQL statement or an error if fields are empty. + /// + /// # Errors /// - /// This function constructs a SQL query that inserts new rows into a `PostgreSQL` database table - /// if they do not already exist (based on the "id" column), or updates them if they do. The generated - /// SQL is intended to be efficient and safe for concurrent use. - #[allow(clippy::redundant_closure_for_method_calls)] - /// Generates a bulk insert SQL statement for inserting multiple nodes. - fn generate_bulk_insert_sql(&self, node_count: usize) -> String { - let columns: Vec<&str> = self.fields.iter().map(|field| field.field_name()).collect(); - let placeholders: Vec = (1..=node_count) - .flat_map(|i| { - self.fields - .iter() - .enumerate() - .map(move |(j, _)| format!("${}", (i - 1) * self.fields.len() + j + 1)) + /// Returns an error if `self.fields` is empty, as no valid SQL can be generated. + fn generate_unnest_upsert_sql(&self) -> Result { + if self.fields.is_empty() { + return Err(anyhow!("Cannot generate upsert SQL with empty fields")); + } + + let mut columns = Vec::new(); + let mut unnest_params = Vec::new(); + let mut param_counter = 1; + + for field in &self.fields { + let name = field.field_name(); + columns.push(name.to_string()); + + unnest_params.push(format!( + "${param_counter}::{}", + match field { + FieldConfig::ID => "UUID[]", + FieldConfig::Chunk => "TEXT[]", + FieldConfig::Metadata(_) => "JSONB[]", + FieldConfig::Vector(_) => "VECTOR[]", + } + )); + + param_counter += 1; + } + + let update_columns = self + .fields + .iter() + .filter(|field| !matches!(field, FieldConfig::ID)) // Skip ID field in updates + .map(|field| { + let name = field.field_name(); + format!("{name} = EXCLUDED.{name}") }) - .collect(); + .collect::>() + .join(", "); - format!( - "INSERT INTO {} ({}) VALUES {}", + Ok(format!( + r#" + INSERT INTO {} ({}) + SELECT {} + FROM UNNEST({}) AS t({}) + ON CONFLICT (id) DO UPDATE SET {}"#, self.table_name, columns.join(", "), - placeholders - .chunks(self.fields.len()) - .map(|chunk| format!("({})", chunk.join(", "))) - .collect::>() - .join(", ") - ) + columns.join(", "), + unnest_params.join(", "), + columns.join(", "), + update_columns + )) } + /// Binds bulk data to the SQL query, ensuring data arrays are matched to corresponding fields. + /// + /// # Errors + /// + /// Returns an error if any metadata or vector field is missing from the bulk data. #[allow(clippy::implicit_clone)] - fn bind_node_to_bulk_query<'a>( + fn bind_bulk_data_to_query<'a>( &self, - mut query: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>, - node: &'a Node, - ) -> Result> { + mut query: sqlx::query::Query<'a, sqlx::Postgres, PgArguments>, + bulk_data: &'a BulkUpsertData, + ) -> Result> { for field in &self.fields { query = match field { - FieldConfig::ID => query.bind(node.id()), - FieldConfig::Chunk => query.bind(&node.chunk), + FieldConfig::ID => query.bind(&bulk_data.ids), + FieldConfig::Chunk => query.bind(&bulk_data.chunks), FieldConfig::Metadata(config) => { - let value = node.metadata.get(&config.original_field).ok_or_else(|| { - anyhow!("Metadata field {} not found", config.original_field) - })?; - let mut metadata_map = BTreeMap::new(); - metadata_map.insert(config.original_field.clone(), value.clone()); - query.bind(serde_json::to_value(metadata_map)?) + let values = bulk_data + .metadata_fields + .get(&config.field) + .ok_or_else(|| { + anyhow!("Metadata field {} not found in bulk data", config.field) + })?; + query.bind(values) } FieldConfig::Vector(config) => { - let data = node - .vectors - .as_ref() - .and_then(|v| v.get(&config.embedded_field)) - .map(|v| v.to_vec()) - .unwrap_or_default(); - query.bind(ExtPgVector::Vector::from(data)) + let vectors = bulk_data.vector_fields.get(&config.field).ok_or_else(|| { + anyhow!("Vector field {} not found in bulk data", config.field) + })?; + query.bind(vectors) } }; } diff --git a/swiftide-test-utils/src/test_utils.rs b/swiftide-test-utils/src/test_utils.rs index 54a43caa..b95d670c 100644 --- a/swiftide-test-utils/src/test_utils.rs +++ b/swiftide-test-utils/src/test_utils.rs @@ -11,7 +11,6 @@ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; use swiftide_integrations as integrations; -use temp_dir::TempDir; pub fn openai_client( mock_server_uri: &str, @@ -75,15 +74,11 @@ pub async fn start_redis() -> (ContainerAsync, String) { /// Setup Postgres container. /// Returns container server and `server_url`. -pub async fn start_postgres() -> (ContainerAsync, String, TempDir) { - // Create a temporary directory for Postgres data - let temp_dir = TempDir::new().expect("Failed to create temp folder"); - let temp_data_bind_path = temp_dir.path().to_str().unwrap(); - +pub async fn start_postgres() -> (ContainerAsync, String) { // Find a free port on the host for Postgres to use let host_port = portpicker::pick_unused_port().expect("No available free port on the host"); - let postgres = testcontainers::GenericImage::new("ankane/pgvector", "v0.5.1") + let postgres = testcontainers::GenericImage::new("pgvector/pgvector", "pg17") .with_wait_for(WaitFor::message_on_stdout( "database system is ready to accept connections", )) @@ -91,10 +86,7 @@ pub async fn start_postgres() -> (ContainerAsync, String, TempDir) .with_env_var("POSTGRES_USER", "myuser") .with_env_var("POSTGRES_PASSWORD", "mypassword") .with_env_var("POSTGRES_DB", "mydatabase") - .with_mount(Mount::bind_mount( - temp_data_bind_path, - "/var/lib/postgresql/data", - )) + .with_mount(Mount::tmpfs_mount("/var/lib/postgresql/data")) .start() .await .expect("Failed to start Postgres container"); @@ -105,7 +97,7 @@ pub async fn start_postgres() -> (ContainerAsync, String, TempDir) host_port ); - (postgres, postgres_url, temp_dir) + (postgres, postgres_url) } /// Mock embeddings creation endpoint.