diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e0b0467..b601642d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,73 +1,143 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -## [0.14.1](https://github.com/bosun-ai/swiftide/compare/v0.14.0...v0.14.1) - 2024-10-27 - -### Bug fixes - -- [5bbcd55](https://github.com/bosun-ai/swiftide/commit/5bbcd55de65d73d7908e91c96f120928edb6b388) Revert 0.14 release as mistralrs is unpublished ([#417](https://github.com/bosun-ai/swiftide/pull/417)) - -````text -Revert the 0.14 release as `mistralrs` is unpublished and unfortunately - cannot be released. -```` - -### Miscellaneous - -- [07c2661](https://github.com/bosun-ai/swiftide/commit/07c2661b7a7cdf75cdba12fab0ca91866793f727) Re-release 0.14 without mistralrs ([#419](https://github.com/bosun-ai/swiftide/pull/419)) - -````text -- **Revert "fix: Revert 0.14 release as mistralrs is unpublished - ([#417](https://github.com/bosun-ai/swiftide/pull/417))"** - - **Fix changelog** -```` - - -**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.0...0.14.1 - - - -## [0.14.0](https://github.com/bosun-ai/swiftide/compare/v0.13.4...v0.14.0) - 2024-10-27 - -### Bug fixes - -- [551a9cb](https://github.com/bosun-ai/swiftide/commit/551a9cb769293e42e15bae5dca3ab677be0ee8ea) *(indexing)* [**breaking**] Node ID no longer memoized ([#414](https://github.com/bosun-ai/swiftide/pull/414)) - -````text -As @shamb0 pointed out in [#392](https://github.com/bosun-ai/swiftide/pull/392), there is a potential issue where Node - ids are get cached before chunking or other transformations, breaking - upserts and potentially resulting in data loss. -```` - +# Changelog + +All notable changes to this project will be documented in this file. + +## [0.14.1](https://github.com/bosun-ai/swiftide/compare/v0.14.0...v0.14.1) - 2024-10-27 + +### Bug fixes + +- [5bbcd55](https://github.com/bosun-ai/swiftide/commit/5bbcd55de65d73d7908e91c96f120928edb6b388) Revert 0.14 release as mistralrs is unpublished ([#417](https://github.com/bosun-ai/swiftide/pull/417)) + +````text +Revert the 0.14 release as `mistralrs` is unpublished and unfortunately + cannot be released. +```` + +### Miscellaneous + +- [07c2661](https://github.com/bosun-ai/swiftide/commit/07c2661b7a7cdf75cdba12fab0ca91866793f727) Re-release 0.14 without mistralrs ([#419](https://github.com/bosun-ai/swiftide/pull/419)) + +````text +- **Revert "fix: Revert 0.14 release as mistralrs is unpublished + ([#417](https://github.com/bosun-ai/swiftide/pull/417))"** + - **Fix changelog** +```` + + +**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.0...0.14.1 + + + +## [0.14.0](https://github.com/bosun-ai/swiftide/compare/v0.13.4...v0.14.0) - 2024-10-27 + +### Bug fixes + +- [551a9cb](https://github.com/bosun-ai/swiftide/commit/551a9cb769293e42e15bae5dca3ab677be0ee8ea) *(indexing)* [**breaking**] Node ID no longer memoized ([#414](https://github.com/bosun-ai/swiftide/pull/414)) + +````text +As @shamb0 pointed out in [#392](https://github.com/bosun-ai/swiftide/pull/392), there is a potential issue where Node + ids are get cached before chunking or other transformations, breaking + upserts and potentially resulting in data loss. +```` + +**BREAKING CHANGE**: This PR reworks Nodes with a builder API and a private +id. Hence, manually creating nodes no longer works. In the future, all +the fields are likely to follow the same pattern, so that we can +decouple the inner fields from the Node's implementation. + +- [c091ffa](https://github.com/bosun-ai/swiftide/commit/c091ffa6be792b0bd7bb03d604e26e40b2adfda8) *(indexing)* Use atomics for key generation in memory storage ([#415](https://github.com/bosun-ai/swiftide/pull/415)) + +### Miscellaneous + +- [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies + + +**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.4...0.14.0 + + + +## [0.13.4](https://github.com/bosun-ai/swiftide/compare/v0.13.3...v0.13.4) - 2024-10-21 + +### Bug fixes + +- [47455fb](https://github.com/bosun-ai/swiftide/commit/47455fb04197a4b51142e2fb4c980e42ac54d11e) *(indexing)* Visibility of ChunkMarkdown builder should be public + +- [2b3b401](https://github.com/bosun-ai/swiftide/commit/2b3b401dcddb2cb32214850b9b4dbb0481943d38) *(indexing)* Improve splitters consistency and provide defaults ([#403](https://github.com/bosun-ai/swiftide/pull/403)) + + +**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.3...0.13.4 + + +======= +# Changelog + +All notable changes to this project will be documented in this file. + +## [0.14.1](https://github.com/bosun-ai/swiftide/compare/v0.14.0...v0.14.1) - 2024-10-27 + +### Bug fixes + +- [5bbcd55](https://github.com/bosun-ai/swiftide/commit/5bbcd55de65d73d7908e91c96f120928edb6b388) Revert 0.14 release as mistralrs is unpublished ([#417](https://github.com/bosun-ai/swiftide/pull/417)) + +````text +Revert the 0.14 release as `mistralrs` is unpublished and unfortunately + cannot be released. +```` + +### Miscellaneous + +- [07c2661](https://github.com/bosun-ai/swiftide/commit/07c2661b7a7cdf75cdba12fab0ca91866793f727) Re-release 0.14 without mistralrs ([#419](https://github.com/bosun-ai/swiftide/pull/419)) + +````text +- **Revert "fix: Revert 0.14 release as mistralrs is unpublished + ([#417](https://github.com/bosun-ai/swiftide/pull/417))"** + - **Fix changelog** +```` + + +**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.0...0.14.1 + + + +## [0.14.0](https://github.com/bosun-ai/swiftide/compare/v0.13.4...v0.14.0) - 2024-10-27 + +### Bug fixes + +- [551a9cb](https://github.com/bosun-ai/swiftide/commit/551a9cb769293e42e15bae5dca3ab677be0ee8ea) *(indexing)* [**breaking**] Node ID no longer memoized ([#414](https://github.com/bosun-ai/swiftide/pull/414)) + +````text +As @shamb0 pointed out in [#392](https://github.com/bosun-ai/swiftide/pull/392), there is a potential issue where Node + ids are get cached before chunking or other transformations, breaking + upserts and potentially resulting in data loss. +```` + **BREAKING CHANGE**: This PR reworks Nodes with a builder API and a private id. Hence, manually creating nodes no longer works. In the future, all the fields are likely to follow the same pattern, so that we can -decouple the inner fields from the Node's implementation. - -- [c091ffa](https://github.com/bosun-ai/swiftide/commit/c091ffa6be792b0bd7bb03d604e26e40b2adfda8) *(indexing)* Use atomics for key generation in memory storage ([#415](https://github.com/bosun-ai/swiftide/pull/415)) - -### Miscellaneous - -- [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies - - -**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.4...0.14.0 - - - -## [0.13.4](https://github.com/bosun-ai/swiftide/compare/v0.13.3...v0.13.4) - 2024-10-21 - -### Bug fixes - -- [47455fb](https://github.com/bosun-ai/swiftide/commit/47455fb04197a4b51142e2fb4c980e42ac54d11e) *(indexing)* Visibility of ChunkMarkdown builder should be public - -- [2b3b401](https://github.com/bosun-ai/swiftide/commit/2b3b401dcddb2cb32214850b9b4dbb0481943d38) *(indexing)* Improve splitters consistency and provide defaults ([#403](https://github.com/bosun-ai/swiftide/pull/403)) - - -**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.3...0.13.4 - - +decouple the inner fields from the Node's implementation. + +- [c091ffa](https://github.com/bosun-ai/swiftide/commit/c091ffa6be792b0bd7bb03d604e26e40b2adfda8) *(indexing)* Use atomics for key generation in memory storage ([#415](https://github.com/bosun-ai/swiftide/pull/415)) + +### Miscellaneous + +- [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies + + +**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.4...0.14.0 + + + +## [0.13.4](https://github.com/bosun-ai/swiftide/compare/v0.13.3...v0.13.4) - 2024-10-21 + +### Bug fixes + +- [47455fb](https://github.com/bosun-ai/swiftide/commit/47455fb04197a4b51142e2fb4c980e42ac54d11e) *(indexing)* Visibility of ChunkMarkdown builder should be public + +- [2b3b401](https://github.com/bosun-ai/swiftide/commit/2b3b401dcddb2cb32214850b9b4dbb0481943d38) *(indexing)* Improve splitters consistency and provide defaults ([#403](https://github.com/bosun-ai/swiftide/pull/403)) + + +**Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.3...0.13.4 + # Changelog All notable changes to this project will be documented in this file. diff --git a/Cargo.lock b/Cargo.lock index 21cff987..b0c6a92c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1259,6 +1259,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "benchmarks" version = "0.14.1" @@ -1292,6 +1298,9 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] [[package]] name = "bitpacking" @@ -1756,6 +1765,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -1838,6 +1853,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32c" version = "0.6.8" @@ -2438,6 +2468,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -2508,6 +2549,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -2559,6 +2601,12 @@ dependencies = [ "litrs", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "downcast" version = "0.11.0" @@ -2615,6 +2663,9 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -2919,6 +2970,8 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ + "futures-core", + "futures-sink", "spin", ] @@ -3334,6 +3387,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot 0.12.3", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -3581,6 +3645,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.4.1" @@ -3673,6 +3746,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -4753,6 +4835,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "lebe" @@ -4893,6 +4978,17 @@ dependencies = [ "redox_syscall 0.5.7", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -5350,6 +5446,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -5769,6 +5882,15 @@ dependencies = [ "stfu8", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -5830,6 +5952,15 @@ dependencies = [ "indexmap 2.6.0", ] +[[package]] +name = "pgvector" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e8871b6d7ca78348c6cd29b911b94851f3429f0cd403130ca17f26c1fb91a6" +dependencies = [ + "sqlx", +] + [[package]] name = "pharos" version = "0.5.3" @@ -5989,6 +6120,27 @@ dependencies = [ "futures-io", ] +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.31" @@ -6054,6 +6206,15 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "portpicker" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9" +dependencies = [ + "rand", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -6776,6 +6937,26 @@ dependencies = [ "byteorder", ] +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -7194,6 +7375,17 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha1_smol" version = "1.0.1" @@ -7244,6 +7436,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -7315,6 +7517,9 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +dependencies = [ + "serde", +] [[package]] name = "snafu" @@ -7419,6 +7624,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -7431,6 +7646,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "sqlformat" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" +dependencies = [ + "nom", + "unicode_categories", +] + [[package]] name = "sqlparser" version = "0.49.0" @@ -7452,6 +7677,208 @@ dependencies = [ "syn 2.0.82", ] +[[package]] +name = "sqlx" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93334716a037193fac19df402f8571269c84a00852f6a7066b5d2616dcd64d3e" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" +dependencies = [ + "atoi", + "byteorder", + "bytes", + "chrono", + "crc", + "crossbeam-queue", + "either", + "event-listener 5.3.1", + "futures-channel", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.14.5", + "hashlink", + "hex", + "indexmap 2.6.0", + "log", + "memchr", + "once_cell", + "paste", + "percent-encoding", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlformat", + "thiserror", + "tokio", + "tokio-stream", + "tracing", + "url", + "uuid", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cac0692bcc9de3b073e8d747391827297e075c7710ff6276d9f7a1f3d58c6657" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.82", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" +dependencies = [ + "dotenvy", + "either", + "heck 0.5.0", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 2.0.82", + "tempfile", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" +dependencies = [ + "atoi", + "chrono", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "serde_urlencoded", + "sqlx-core", + "tracing", + "url", + "uuid", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -7519,6 +7946,17 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3c3ee6129eec20fed59acf2e9cfb3ffd20d0bbe39fe334c22af0edc56dfe752" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" @@ -7637,9 +8075,12 @@ dependencies = [ "qdrant-client", "serde_json", "spider", + "sqlx", "swiftide", + "swiftide-test-utils", "temp-dir", "tokio", + "tracing", "tracing-subscriber", ] @@ -7698,6 +8139,7 @@ dependencies = [ "mockall", "ollama-rs", "parquet", + "pgvector", "qdrant-client", "redb", "redis", @@ -7707,6 +8149,7 @@ dependencies = [ "serde", "serde_json", "spider", + "sqlx", "strum", "strum_macros", "swiftide-core", @@ -7770,12 +8213,14 @@ dependencies = [ "async-openai", "insta", "mockall", + "portpicker", "qdrant-client", "serde", "serde_json", "swiftide-core", "swiftide-integrations", "temp-dir", + "tempfile", "test-case", "test-log", "testcontainers", @@ -8825,6 +9270,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -9002,6 +9453,12 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.95" @@ -9107,6 +9564,16 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall 0.5.7", + "wasite", +] + [[package]] name = "widestring" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5419f60c..f24ecf3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,8 @@ arrow-array = { version = "52.0", default-features = false } arrow = { version = "52.2" } parquet = { version = "52.2", default-features = false, features = ["async"] } redb = { version = "2.1" } +sqlx = { version = "0.8.2", features = ["postgres", "uuid"] } +pgvector = { version = "0.4.0", features = ["sqlx"] } # Testing test-log = "0.2.16" @@ -61,6 +63,8 @@ temp-dir = "0.1.13" wiremock = "0.6.0" test-case = "3.3.1" insta = { version = "1.39.0", features = ["yaml"] } +tempfile = "3.10.1" +portpicker = "0.1.1" [workspace.lints.rust] unsafe_code = "forbid" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index b1046027..5e5b643d 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -21,13 +21,17 @@ swiftide = { path = "../swiftide/", features = [ "ollama", "fluvio", "lancedb", + "pgvector", ] } tracing-subscriber = "0.3" +tracing = { workspace = true } serde_json = { workspace = true } spider = { workspace = true } qdrant-client = { workspace = true } fluvio = { workspace = true } temp-dir = { workspace = true } +sqlx = { workspace = true } +swiftide-test-utils = { path = "../swiftide-test-utils" } [[example]] doc-scrape-examples = true @@ -91,3 +95,7 @@ path = "fluvio.rs" [[example]] name = "lancedb" path = "lancedb.rs" + +[[example]] +name = "index-md-pgvector" +path = "index_md_into_pgvector.rs" diff --git a/examples/index_md_into_pgvector.rs b/examples/index_md_into_pgvector.rs new file mode 100644 index 00000000..9b298def --- /dev/null +++ b/examples/index_md_into_pgvector.rs @@ -0,0 +1,126 @@ +/** +* This example demonstrates how to index markdown into PGVector +*/ +use std::path::PathBuf; +use swiftide::{ + indexing::{ + self, + loaders::FileLoader, + transformers::{ + metadata_qa_text::NAME as METADATA_QA_TEXT_NAME, ChunkMarkdown, Embed, MetadataQAText, + }, + EmbeddedField, + }, + integrations::{self, fastembed::FastEmbed, pgvector::PgVector}, + query::{self, answers, query_transformers, response_transformers}, + traits::SimplePrompt, +}; + +async fn ask_query( + llm_client: impl SimplePrompt + Clone + 'static, + embed: FastEmbed, + vector_store: PgVector, + question: String, +) -> Result> { + // By default the search strategy is SimilaritySingleEmbedding + // which takes the latest query, embeds it, and does a similarity search + // + // Pgvector will return an error if multiple embeddings are set + // + // The pipeline generates subquestions to increase semantic coverage, embeds these in a single + // embedding, retrieves the default top_k documents, summarizes them and uses that as context + // for the final answer. + let pipeline = query::Pipeline::default() + .then_transform_query(query_transformers::GenerateSubquestions::from_client( + llm_client.clone(), + )) + .then_transform_query(query_transformers::Embed::from_client(embed)) + .then_retrieve(vector_store.clone()) + .then_transform_response(response_transformers::Summary::from_client( + llm_client.clone(), + )) + .then_answer(answers::Simple::from_client(llm_client.clone())); + + let result = pipeline.query(question).await?; + Ok(result.answer().into()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + tracing::info!("Starting PgVector indexing test"); + + // Get the manifest directory path + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + + // Create a PathBuf to test dataset from the manifest directory + let test_dataset_path = PathBuf::from(manifest_dir).join("../README.md"); + + tracing::info!("Test Dataset path: {:?}", test_dataset_path); + + let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + + tracing::info!("pgv_db_url :: {:#?}", pgv_db_url); + + let llm_client = integrations::ollama::Ollama::default() + .with_default_prompt_model("llama3.2:latest") + .to_owned(); + + let fastembed = + integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed"); + + // Configure Pgvector with a default vector size, a single embedding + // and in addition to embedding the text metadata, also store it in a field + let pgv_storage = PgVector::builder() + .try_connect_to_pool(pgv_db_url, Some(10)) + .await + .expect("Failed to connect to postgres server") + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .with_metadata(METADATA_QA_TEXT_NAME) + .table_name("swiftide_pgvector_test".to_string()) + .build() + .unwrap(); + + // Drop the existing test table before running the test + tracing::info!("Dropping existing test table & index if it exists"); + let drop_table_sql = "DROP TABLE IF EXISTS swiftide_pgvector_test"; + let drop_index_sql = "DROP INDEX IF EXISTS swiftide_pgvector_test_embedding_idx"; + + if let Ok(pool) = pgv_storage.get_pool() { + sqlx::query(drop_table_sql).execute(&pool).await?; + sqlx::query(drop_index_sql).execute(&pool).await?; + } else { + return Err("Failed to get database connection pool".into()); + } + + tracing::info!("Starting indexing pipeline"); + indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"])) + .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) + .then(MetadataQAText::new(llm_client.clone())) + .then_in_batch(Embed::new(fastembed.clone()).with_batch_size(100)) + .then_store_with(pgv_storage.clone()) + .run() + .await?; + + for (i, question) in [ + "What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.", + "How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.", + ] + .iter() + .enumerate() + { + let result = ask_query( + llm_client.clone(), + fastembed.clone(), + pgv_storage.clone(), + question.to_string(), + ).await?; + tracing::info!("*** Answer Q{} ***", i + 1); + tracing::info!("{}", result); + tracing::info!("===X==="); + } + + tracing::info!("PgVector Indexing & retrieval test completed successfully"); + Ok(()) +} diff --git a/swiftide-integrations/Cargo.toml b/swiftide-integrations/Cargo.toml index bcd5f7b7..4564b955 100644 --- a/swiftide-integrations/Cargo.toml +++ b/swiftide-integrations/Cargo.toml @@ -34,6 +34,13 @@ async-openai = { workspace = true, optional = true } qdrant-client = { workspace = true, optional = true, default-features = false, features = [ "serde", ] } +sqlx = { workspace = true, optional = true, features = [ + "postgres", + "runtime-tokio", + "chrono", + "uuid" +] } +pgvector = { workspace = true, optional = true, features = ["sqlx"] } redis = { version = "0.27", features = [ "aio", "tokio-comp", @@ -102,6 +109,8 @@ default = ["rustls"] rustls = ["reqwest/rustls-tls-native-roots"] # Qdrant for storage qdrant = ["dep:qdrant-client", "swiftide-core/qdrant"] +# PgVector for storage +pgvector = ["dep:sqlx", "dep:pgvector"] # Redis for caching and storage redis = ["dep:redis"] # Tree-sitter for code operations and chunking diff --git a/swiftide-integrations/src/lib.rs b/swiftide-integrations/src/lib.rs index d1e38a08..74f1c1d6 100644 --- a/swiftide-integrations/src/lib.rs +++ b/swiftide-integrations/src/lib.rs @@ -16,6 +16,8 @@ pub mod ollama; pub mod openai; #[cfg(feature = "parquet")] pub mod parquet; +#[cfg(feature = "pgvector")] +pub mod pgvector; #[cfg(feature = "qdrant")] pub mod qdrant; #[cfg(feature = "redb")] diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs new file mode 100644 index 00000000..e4ca677f --- /dev/null +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -0,0 +1,593 @@ +//! This module integrates with the pgvector database, providing functionalities to create and manage vector collections, +//! store data, and optimize indexing for efficient searches. +//! +//! pgvector is utilized in both the `indexing::Pipeline` and `query::Pipeline` modules. +mod persist; +mod pgv_table_types; +mod retrieve; +use anyhow::Result; +use derive_builder::Builder; +use sqlx::PgPool; +use std::fmt; + +use pgv_table_types::{FieldConfig, MetadataConfig, PgDBConnectionPool, VectorConfig}; + +const DEFAULT_BATCH_SIZE: usize = 50; + +/// Represents a Pgvector client with configuration options. +/// +/// This struct is used to interact with the Pgvector vector database, providing methods to manage vector collections, +/// store data, and ensure efficient searches. The client can be cloned with low cost as it shares connections. +#[derive(Builder, Clone)] +#[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] +pub struct PgVector { + /// Database connection pool. + #[builder(default = "PgDBConnectionPool::default()")] + connection_pool: PgDBConnectionPool, + + /// Table name to store vectors in. + #[builder(default = "String::from(\"swiftide_pgv_store\")")] + table_name: String, + + /// Default sizes of vectors. Vectors can also be of different + /// sizes by specifying the size in the vector configuration. + vector_size: Option, + + /// Batch size for storing nodes. + #[builder(default = "Some(DEFAULT_BATCH_SIZE)")] + batch_size: Option, + + /// Field configuration for the Pgvector table, determining the eventual table schema. + /// + /// Supports multiple field types; see [`FieldConfig`] for details. + #[builder(default)] + fields: Vec, +} + +impl fmt::Debug for PgVector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgVector") + .field("table_name", &self.table_name) + .field("vector_size", &self.vector_size) + .field("batch_size", &self.batch_size) + .finish() + } +} + +impl PgVector { + /// Creates a new instance of `PgVectorBuilder` with default settings. + /// + /// # Returns + /// + /// A new `PgVectorBuilder`. + pub fn builder() -> PgVectorBuilder { + PgVectorBuilder::default() + } + + /// Retrieves a connection pool for `PostgreSQL`. + /// + /// This function returns the connection pool used for interacting with the `PostgreSQL` database. + /// It fetches the pool from the `PgDBConnectionPool` struct. + /// + /// # Returns + /// + /// A `Result` that, on success, contains the `PgPool` representing the database connection pool. + /// On failure, an error is returned. + /// + /// # Errors + /// + /// This function will return an error if it fails to retrieve the connection pool, which could occur + /// if the underlying connection to `PostgreSQL` has not been properly established. + pub fn get_pool(&self) -> Result { + self.connection_pool.get_pool() + } +} + +impl PgVectorBuilder { + /// Tries to asynchronously connect to a `Postgres` server and initialize a connection pool. + /// + /// This function attempts to establish a connection to the specified `Postgres` server and + /// sets up a connection pool with an optional maximum number of connections. + /// + /// # Arguments + /// + /// * `url` - A string reference representing the URL of the `Postgres` server to connect to. + /// * `connection_max` - An optional value specifying the maximum number of connections in the pool. + /// + /// # Returns + /// + /// A `Result` that contains an updated `PgVector` instance with the new connection pool on success. + /// On failure, an error is returned. + /// + /// # Errors + /// + /// This function returns an error if the connection to the database fails or if retries are exhausted. + /// Possible reasons include invalid database URLs, unreachable servers, or exceeded retry limits. + pub async fn try_connect_to_pool( + mut self, + url: impl AsRef, + connection_max: Option, + ) -> Result { + let pool = self.connection_pool.clone().unwrap_or_default(); + + self.connection_pool = Some(pool.try_connect_to_url(url, connection_max).await?); + + Ok(self) + } + + /// Adds a vector configuration to the builder. + /// + /// # Arguments + /// + /// * `config` - The vector configuration to add, which can be converted into a `VectorConfig`. + /// + /// # Returns + /// + /// A mutable reference to the builder with the new vector configuration added. + pub fn with_vector(&mut self, config: impl Into) -> &mut Self { + // Use `get_or_insert_with` to initialize `fields` if it's `None` + self.fields + .get_or_insert_with(Self::default_fields) + .push(FieldConfig::Vector(config.into())); + + self + } + + /// Sets the metadata configuration for the vector similarity search. + /// + /// This method allows you to specify metadata configurations for vector similarity search using `MetadataConfig`. + /// The provided configuration will be added as a new field in the builder. + /// + /// # Arguments + /// + /// * `config` - The metadata configuration to use. + /// + /// # Returns + /// + /// * Returns a mutable reference to `self` for method chaining. + pub fn with_metadata(&mut self, config: impl Into) -> &mut Self { + // Use `get_or_insert_with` to initialize `fields` if it's `None` + self.fields + .get_or_insert_with(Self::default_fields) + .push(FieldConfig::Metadata(config.into())); + + self + } + + fn default_fields() -> Vec { + vec![FieldConfig::ID, FieldConfig::Chunk] + } +} + +#[cfg(test)] +mod tests { + use crate::pgvector::PgVector; + use futures_util::TryStreamExt; + use swiftide_core::{indexing, indexing::EmbeddedField, Persist}; + use swiftide_core::{ + indexing::EmbedMode, + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + Retrieve, + }; + use test_case::test_case; + use testcontainers::{ContainerAsync, GenericImage}; + + struct TestContext { + pgv_storage: PgVector, + _pgv_db_container: ContainerAsync, + } + + impl TestContext { + /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage + /// with configurable metadata fields + async fn setup_with_cfg( + metadata_fields: Option>, + embedded_field: indexing::EmbeddedField, + ) -> Result> { + // Start `PostgreSQL` container and obtain the connection URL + let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + tracing::info!("Postgres database URL: {:#?}", pgv_db_url); + + // Initialize the connection pool outside of the builder chain + let mut connection_pool = PgVector::builder() + .try_connect_to_pool(pgv_db_url, Some(10)) + .await + .map_err(|err| { + tracing::error!("Failed to connect to Postgres server: {}", err); + err + })?; + + // Configure PgVector storage + let mut builder = connection_pool + .vector_size(384) + .with_vector(embedded_field) + .table_name("swiftide_pgvector_test".to_string()); + + // Add all metadata fields + if let Some(metadata_fields_inner) = metadata_fields { + for field in metadata_fields_inner { + builder = builder.with_metadata(field); + } + }; + + let pgv_storage = builder.build().map_err(|err| { + tracing::error!("Failed to build PgVector: {}", err); + err + })?; + + // Set up PgVector storage (create the table if not exists) + pgv_storage.setup().await.map_err(|err| { + tracing::error!("PgVector setup failed: {}", err); + err + })?; + + Ok(Self { + pgv_storage, + _pgv_db_container: pgv_db_container, + }) + } + } + + #[test_log::test(tokio::test)] + async fn test_metadata_filter_with_vector_search() { + let test_context = TestContext::setup_with_cfg( + vec!["category", "priority"].into(), + EmbeddedField::Combined, + ) + .await + .expect("Test setup failed"); + + // Create nodes with different metadata and vectors + let nodes = vec![ + indexing::Node::new("content1") + .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) + .with_metadata(vec![("category", "A"), ("priority", "1")]), + indexing::Node::new("content2") + .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])]) + .with_metadata(vec![("category", "A"), ("priority", "2")]), + indexing::Node::new("content3") + .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])]) + .with_metadata(vec![("category", "B"), ("priority", "1")]), + ] + .into_iter() + .map(|node| node.to_owned()) + .collect(); + + // Store all nodes + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + // Test combined metadata and vector search + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + // Search with category filter + let search_strategy = + SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string()); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"content1".to_string())); + assert!(result.documents().contains(&"content2".to_string())); + + // Additional test with priority filter + let search_strategy = + SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string()); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"content1".to_string())); + assert!(result.documents().contains(&"content3".to_string())); + } + + #[test_log::test(tokio::test)] + async fn test_vector_similarity_search_accuracy() { + let test_context = TestContext::setup_with_cfg(None, EmbeddedField::Combined) + .await + .expect("Test setup failed"); + + // Create nodes with known vector relationships + let base_vector = vec![1.0; 384]; + let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::>(); + let dissimilar_vector = vec![-1.0; 384]; + + let nodes = vec![ + indexing::Node::new("base_content") + .with_vectors([(EmbeddedField::Combined, base_vector)]), + indexing::Node::new("similar_content") + .with_vectors([(EmbeddedField::Combined, similar_vector)]), + indexing::Node::new("dissimilar_content") + .with_vectors([(EmbeddedField::Combined, dissimilar_vector)]), + ] + .into_iter() + .map(|node| node.to_owned()) + .collect(); + + // Store all nodes + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + // Search with base vector + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(2); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .unwrap(); + + // Verify that similar vectors are retrieved first + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"base_content".to_string())); + assert!(result.documents().contains(&"similar_content".to_string())); + } + + #[derive(Clone)] + struct PgVectorTestData<'a> { + pub embed_mode: indexing::EmbedMode, + pub chunk: &'a str, + pub metadata: Option, + pub vectors: Vec<(indexing::EmbeddedField, Vec)>, + pub expected_in_results: bool, + } + + impl<'a> PgVectorTestData<'a> { + fn to_node(&self) -> indexing::Node { + // Create the initial builder + let mut base_builder = indexing::Node::builder(); + + // Set the required fields + let mut builder = base_builder.chunk(self.chunk).embed_mode(self.embed_mode); + + // Add metadata if it exists + if let Some(metadata) = &self.metadata { + builder = builder.metadata(metadata.clone()); + } + + // Build the node and add vectors + let mut node = builder.build().unwrap(); + node.vectors = Some(self.vectors.clone().into_iter().collect()); + node + } + } + + fn create_test_vector(field: EmbeddedField, base_value: f32) -> (EmbeddedField, Vec) { + (field, vec![base_value; 384]) + } + + #[test_case( + // SingleWithMetadata - No Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_no_meta_1", + metadata: None, + vectors: vec![create_test_vector(EmbeddedField::Combined, 1.0)], + expected_in_results: true, + }, + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_no_meta_2", + metadata: None, + vectors: vec![create_test_vector(EmbeddedField::Combined, 1.1)], + expected_in_results: true, + } + ] + ; "SingleWithMetadata mode without metadata")] + #[test_case( + // SingleWithMetadata - With Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_with_meta_1", + metadata: Some(vec![ + ("category", "A"), + ("priority", "high") + ].into()), + vectors: vec![create_test_vector(EmbeddedField::Combined, 1.2)], + expected_in_results: true, + }, + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_with_meta_2", + metadata: Some(vec![ + ("category", "B"), + ("priority", "low") + ].into()), + vectors: vec![create_test_vector(EmbeddedField::Combined, 1.3)], + expected_in_results: true, + } + ] + ; "SingleWithMetadata mode with metadata")] + #[test_case( + // Both - No Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_no_meta_1", + metadata: None, + vectors: vec![ + create_test_vector(EmbeddedField::Combined, 3.0), + create_test_vector(EmbeddedField::Chunk, 3.1) + ], + expected_in_results: true, + }, + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_no_meta_2", + metadata: None, + vectors: vec![ + create_test_vector(EmbeddedField::Combined, 3.2), + create_test_vector(EmbeddedField::Chunk, 3.3) + ], + expected_in_results: true, + } + ] + ; "Both mode without metadata")] + #[test_case( + // Both - With Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_with_meta_1", + metadata: Some(vec![ + ("category", "P"), + ("priority", "urgent"), + ("tag", "test1") + ].into()), + vectors: vec![ + create_test_vector(EmbeddedField::Combined, 3.4), + create_test_vector(EmbeddedField::Chunk, 3.5), + create_test_vector(EmbeddedField::Metadata("category".into()), 3.6), + create_test_vector(EmbeddedField::Metadata("priority".into()), 3.7), + create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) + ], + expected_in_results: true, + }, + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_with_meta_2", + metadata: Some(vec![ + ("category", "Q"), + ("priority", "low"), + ("tag", "test2") + ].into()), + vectors: vec![ + create_test_vector(EmbeddedField::Combined, 3.9), + create_test_vector(EmbeddedField::Chunk, 4.0), + create_test_vector(EmbeddedField::Metadata("category".into()), 4.1), + create_test_vector(EmbeddedField::Metadata("priority".into()), 4.2), + create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) + ], + expected_in_results: true, + } + ] + ; "Both mode with metadata")] + #[test_log::test(tokio::test)] + async fn test_persist_and_retrieve_nodes(test_cases: Vec>) { + // Extract all possible metadata fields from test cases + let metadata_fields: Vec<&str> = test_cases + .iter() + .filter_map(|case| case.metadata.as_ref()) + .flat_map(|metadata| metadata.iter().map(|(key, _)| key.as_str())) + .collect::>() + .into_iter() + .collect(); + + // Initialize test context with all required metadata fields + let test_context = + TestContext::setup_with_cfg(Some(metadata_fields), EmbeddedField::Combined) + .await + .expect("Test setup failed"); + + // Convert test cases to nodes and store them + let nodes: Vec = test_cases.iter().map(PgVectorTestData::to_node).collect(); + + // Test batch storage + let stored_nodes = test_context + .pgv_storage + .batch_store(nodes.clone()) + .await + .try_collect::>() + .await + .expect("Failed to store nodes"); + + assert_eq!( + stored_nodes.len(), + nodes.len(), + "All nodes should be stored" + ); + + // Verify storage and retrieval for each test case + for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) { + // 1. Verify basic node properties + assert_eq!( + stored_node.chunk, test_case.chunk, + "Stored chunk should match" + ); + assert_eq!( + stored_node.embed_mode, test_case.embed_mode, + "Embed mode should match" + ); + + // 2. Verify vectors were stored correctly + let stored_vectors = stored_node + .vectors + .as_ref() + .expect("Vectors should be present"); + assert_eq!( + stored_vectors.len(), + test_case.vectors.len(), + "Vector count should match" + ); + + // 3. Test vector similarity search + for (field, vector) in &test_case.vectors { + let mut query = Query::::new("test_query"); + query.embedding = Some(vector.clone()); + + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(nodes.len() as u64); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .expect("Retrieval should succeed"); + + if test_case.expected_in_results { + assert!( + result.documents().contains(&test_case.chunk.to_string()), + "Document should be found in results for field {field}", + ); + } + } + + // 4. Test metadata filtering if present + if let Some(metadata) = &test_case.metadata { + for (key, value) in metadata { + let filter_query = format!("{key} = \"{value}\""); + let search_strategy = SimilaritySingleEmbedding::from_filter(filter_query); + + let mut query = Query::::new("test_query"); + query.embedding = Some(test_case.vectors[0].1.clone()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .expect("Filtered retrieval should succeed"); + + if test_case.expected_in_results { + assert!( + result.documents().contains(&test_case.chunk.to_string()), + "Document should be found when filtering by metadata {key}={value}" + ); + } + } + } + } + } +} diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs new file mode 100644 index 00000000..c0706576 --- /dev/null +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -0,0 +1,102 @@ +//! This module implements the `Persist` trait for the `PgVector` struct. +//! It provides methods for setting up storage, saving individual nodes, and batch-storing multiple nodes. +//! This integration enables the Swiftide project to use `PgVector` as a storage backend. +use crate::pgvector::PgVector; +use anyhow::Result; +use async_trait::async_trait; +use swiftide_core::{ + indexing::{IndexingStream, Node}, + Persist, +}; + +#[async_trait] +impl Persist for PgVector { + #[tracing::instrument(skip_all)] + async fn setup(&self) -> Result<()> { + let mut tx = self.connection_pool.get_pool()?.begin().await?; + + // Create extension + let sql = "CREATE EXTENSION IF NOT EXISTS vector"; + sqlx::query(sql).execute(&mut *tx).await?; + + // Create table + let create_table_sql = self.generate_create_table_sql()?; + sqlx::query(&create_table_sql).execute(&mut *tx).await?; + + // Create HNSW index + let index_sql = self.create_index_sql()?; + sqlx::query(&index_sql).execute(&mut *tx).await?; + + tx.commit().await?; + + Ok(()) + } + + #[tracing::instrument(skip_all)] + async fn store(&self, node: Node) -> Result { + let mut nodes = vec![node; 1]; + self.store_nodes(&nodes).await?; + + let node = nodes.swap_remove(0); + + Ok(node) + } + + #[tracing::instrument(skip_all)] + async fn batch_store(&self, nodes: Vec) -> IndexingStream { + self.store_nodes(&nodes).await.map(|()| nodes).into() + } + + fn batch_size(&self) -> Option { + self.batch_size + } +} + +#[cfg(test)] +mod tests { + use crate::pgvector::PgVector; + use swiftide_core::{indexing::EmbeddedField, Persist}; + use testcontainers::{ContainerAsync, GenericImage}; + + struct TestContext { + pgv_storage: PgVector, + _pgv_db_container: ContainerAsync, + } + + impl TestContext { + /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage + async fn setup() -> Result> { + // Start PostgreSQL container and obtain the connection URL + let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + + // Configure and build PgVector storage + let pgv_storage = PgVector::builder() + .try_connect_to_pool(pgv_db_url, Some(10)) + .await? + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .with_metadata("filter") + .table_name("swiftide_pgvector_test".to_string()) + .build()?; + + // Set up PgVector storage (create the table if not exists) + pgv_storage.setup().await?; + + Ok(Self { + pgv_storage, + _pgv_db_container: pgv_db_container, + }) + } + } + + #[test_log::test(tokio::test)] + async fn test_persist_setup_no_error_when_table_exists() { + let test_context = TestContext::setup().await.expect("Test setup failed"); + + test_context + .pgv_storage + .setup() + .await + .expect("PgVector setup should not fail when the table already exists"); + } +} diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs new file mode 100644 index 00000000..3598d873 --- /dev/null +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -0,0 +1,494 @@ +//! This module provides functionality to convert a `Node` into a `PostgreSQL` table schema. +//! This conversion is crucial for storing data in `PostgreSQL`, enabling efficient vector similarity searches +//! through the `pgvector` extension. The module also handles metadata augmentation and ensures compatibility +//! with `PostgreSQL`'s required data format. + +use crate::pgvector::PgVector; +use anyhow::{anyhow, Context, Result}; +use pgvector as ExtPgVector; +use regex::Regex; +use sqlx::postgres::PgArguments; +use sqlx::postgres::PgPoolOptions; +use sqlx::PgPool; +use std::collections::BTreeMap; +use std::sync::Arc; +use swiftide_core::indexing::{EmbeddedField, Node}; +use tokio::time::{sleep, Duration}; + +#[derive(Clone)] +pub struct PgDBConnectionPool(Arc>); + +impl Default for PgDBConnectionPool { + fn default() -> Self { + Self(Arc::new(None)) + } +} + +impl PgDBConnectionPool { + /// Attempts to connect to the database with retries. + async fn connect_with_retry( + database_url: impl AsRef, + max_retries: u32, + pool_options: &PgPoolOptions, + ) -> Result { + for attempt in 1..=max_retries { + match pool_options.clone().connect(database_url.as_ref()).await { + Ok(pool) => { + return Ok(pool); + } + Err(_err) if attempt < max_retries => { + sleep(Duration::from_secs(2)).await; + } + Err(err) => return Err(err), + } + } + unreachable!() + } + + /// Connects to the database using the provided URL and sets the connection pool. + pub async fn try_connect_to_url( + mut self, + database_url: impl AsRef, + connection_max: Option, + ) -> Result { + let pool_options = PgPoolOptions::new().max_connections(connection_max.unwrap_or(10)); + + let pool = Self::connect_with_retry(database_url, 10, &pool_options) + .await + .context("Failed to connect to the database")?; + + self.0 = Arc::new(Some(pool)); + + Ok(self) + } + + /// Retrieves the connection pool, returning an error if the pool is not initialized. + pub fn get_pool(&self) -> Result { + self.0 + .as_ref() + .clone() + .ok_or_else(|| anyhow!("Database connection pool is not initialized")) + } + + /// Returns the connection status of the pool. + pub fn connection_status(&self) -> &'static str { + match self.0.as_ref() { + Some(pool) if !pool.is_closed() => "Open", + Some(_) => "Closed", + None => "Not initialized", + } + } +} + +#[derive(Clone, Debug)] +pub struct VectorConfig { + embedded_field: EmbeddedField, + field: String, +} + +impl VectorConfig { + pub fn new(embedded_field: &EmbeddedField) -> Self { + Self { + embedded_field: embedded_field.clone(), + field: format!( + "vector_{}", + PgVector::normalize_field_name(&embedded_field.to_string()), + ), + } + } +} + +impl From for VectorConfig { + fn from(val: EmbeddedField) -> Self { + Self::new(&val) + } +} + +#[derive(Clone, Debug)] +pub struct MetadataConfig { + field: String, + original_field: String, +} + +impl MetadataConfig { + pub fn new>(original_field: T) -> Self { + let original = original_field.into(); + Self { + field: format!("meta_{}", PgVector::normalize_field_name(&original)), + original_field: original, + } + } +} + +impl> From for MetadataConfig { + fn from(val: T) -> Self { + Self::new(val.as_ref()) + } +} + +#[derive(Clone, Debug)] +pub enum FieldConfig { + Vector(VectorConfig), + Metadata(MetadataConfig), + Chunk, + ID, +} + +impl FieldConfig { + pub fn field_name(&self) -> &str { + match self { + FieldConfig::Vector(config) => &config.field, + FieldConfig::Metadata(config) => &config.field, + FieldConfig::Chunk => "chunk", + FieldConfig::ID => "id", + } + } +} + +/// Structure to hold collected values for bulk upsert +#[derive(Default)] +struct BulkUpsertData { + ids: Vec, + chunks: Vec, + metadata_fields: BTreeMap>, + vector_fields: BTreeMap>, +} + +impl PgVector { + /// Generates a SQL statement to create a table for storing vector embeddings. + /// + /// The table will include columns for an ID, chunk data, metadata, and a vector embedding. + /// + /// # Returns + /// + /// * The generated SQL statement. + /// + /// # Errors + /// + /// * Returns an error if the table name is invalid or if `vector_size` is not configured. + pub fn generate_create_table_sql(&self) -> Result { + // Validate table_name and field_name (e.g., check against allowed patterns) + if !Self::is_valid_identifier(&self.table_name) { + return Err(anyhow::anyhow!("Invalid table name")); + } + + let vector_size = self + .vector_size + .ok_or_else(|| anyhow!("vector_size must be configured"))?; + + let columns: Vec = self + .fields + .iter() + .map(|field| match field { + FieldConfig::ID => "id UUID NOT NULL".to_string(), + FieldConfig::Chunk => format!("{} TEXT NOT NULL", field.field_name()), + FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()), + FieldConfig::Vector(_) => format!("{} VECTOR({})", field.field_name(), vector_size), + }) + .chain(std::iter::once("PRIMARY KEY (id)".to_string())) + .collect(); + + let sql = format!( + "CREATE TABLE IF NOT EXISTS {} (\n {}\n)", + self.table_name, + columns.join(",\n ") + ); + + Ok(sql) + } + + /// Generates the SQL statement to create an HNSW index on the vector column. + /// + /// # Errors + /// + /// Returns an error if: + /// - No vector field is found in the table configuration. + /// - The table name or field name is invalid. + pub fn create_index_sql(&self) -> Result { + let index_name = format!("{}_embedding_idx", self.table_name); + let vector_field = self + .fields + .iter() + .find(|f| matches!(f, FieldConfig::Vector(_))) + .ok_or_else(|| anyhow::anyhow!("No vector field found in configuration"))? + .field_name(); + + // Validate table_name and field_name (e.g., check against allowed patterns) + if !Self::is_valid_identifier(&self.table_name) + || !Self::is_valid_identifier(&index_name) + || !Self::is_valid_identifier(vector_field) + { + return Err(anyhow::anyhow!("Invalid table or field name")); + } + + Ok(format!( + "CREATE INDEX IF NOT EXISTS {} ON {} USING hnsw ({} vector_cosine_ops)", + index_name, &self.table_name, vector_field + )) + } + + /// Stores a list of nodes in the database using an upsert operation. + /// + /// # Arguments + /// + /// * `nodes` - A slice of `Node` objects to be stored. + /// + /// # Returns + /// + /// * `Result<()>` - `Ok` if all nodes are successfully stored, `Err` otherwise. + /// + /// # Errors + /// + /// This function will return an error if: + /// - The database connection pool is not established. + /// - Any of the SQL queries fail to execute due to schema mismatch, constraint violations, or connectivity issues. + /// - Committing the transaction fails. + pub async fn store_nodes(&self, nodes: &[Node]) -> Result<()> { + let pool = self.connection_pool.get_pool()?; + + let mut tx = pool.begin().await?; + let bulk_data = self.prepare_bulk_data(nodes)?; + let sql = self.generate_unnest_upsert_sql()?; + + let query = self.bind_bulk_data_to_query(sqlx::query(&sql), &bulk_data)?; + + query + .execute(&mut *tx) + .await + .map_err(|e| anyhow!("Failed to store nodes: {:?}", e))?; + + tx.commit() + .await + .map_err(|e| anyhow!("Failed to commit transaction: {:?}", e)) + } + + /// Prepares data from nodes into vectors for bulk processing. + #[allow(clippy::implicit_clone)] + fn prepare_bulk_data(&self, nodes: &[Node]) -> Result { + let mut bulk_data = BulkUpsertData::default(); + + for node in nodes { + bulk_data.ids.push(node.id()); + bulk_data.chunks.push(node.chunk.clone()); + + for field in &self.fields { + match field { + FieldConfig::Metadata(config) => { + let value = node.metadata.get(&config.original_field).ok_or_else(|| { + anyhow!("Metadata field {} not found", config.original_field) + })?; + + let entry = bulk_data + .metadata_fields + .entry(config.field.clone()) + .or_default(); + + let mut metadata_map = BTreeMap::new(); + metadata_map.insert(config.original_field.clone(), value.clone()); + entry.push(serde_json::to_value(metadata_map)?); + } + FieldConfig::Vector(config) => { + let data = node + .vectors + .as_ref() + .and_then(|v| v.get(&config.embedded_field)) + .map(|v| v.to_vec()) + .unwrap_or_default(); + + bulk_data + .vector_fields + .entry(config.field.clone()) + .or_default() + .push(ExtPgVector::Vector::from(data)); + } + _ => continue, // ID and Chunk already handled + } + } + } + + Ok(bulk_data) + } + + /// Generates SQL for UNNEST-based bulk upsert. + /// + /// # Returns + /// + /// * `Result` - The generated SQL statement or an error if fields are empty. + /// + /// # Errors + /// + /// Returns an error if `self.fields` is empty, as no valid SQL can be generated. + fn generate_unnest_upsert_sql(&self) -> Result { + if self.fields.is_empty() { + return Err(anyhow!("Cannot generate upsert SQL with empty fields")); + } + + let mut columns = Vec::new(); + let mut unnest_params = Vec::new(); + let mut param_counter = 1; + + for field in &self.fields { + let name = field.field_name(); + columns.push(name.to_string()); + + unnest_params.push(format!( + "${param_counter}::{}", + match field { + FieldConfig::ID => "UUID[]", + FieldConfig::Chunk => "TEXT[]", + FieldConfig::Metadata(_) => "JSONB[]", + FieldConfig::Vector(_) => "VECTOR[]", + } + )); + + param_counter += 1; + } + + let update_columns = self + .fields + .iter() + .filter(|field| !matches!(field, FieldConfig::ID)) // Skip ID field in updates + .map(|field| { + let name = field.field_name(); + format!("{name} = EXCLUDED.{name}") + }) + .collect::>() + .join(", "); + + Ok(format!( + r#" + INSERT INTO {} ({}) + SELECT {} + FROM UNNEST({}) AS t({}) + ON CONFLICT (id) DO UPDATE SET {}"#, + self.table_name, + columns.join(", "), + columns.join(", "), + unnest_params.join(", "), + columns.join(", "), + update_columns + )) + } + + /// Binds bulk data to the SQL query, ensuring data arrays are matched to corresponding fields. + /// + /// # Errors + /// + /// Returns an error if any metadata or vector field is missing from the bulk data. + #[allow(clippy::implicit_clone)] + fn bind_bulk_data_to_query<'a>( + &self, + mut query: sqlx::query::Query<'a, sqlx::Postgres, PgArguments>, + bulk_data: &'a BulkUpsertData, + ) -> Result> { + for field in &self.fields { + query = match field { + FieldConfig::ID => query.bind(&bulk_data.ids), + FieldConfig::Chunk => query.bind(&bulk_data.chunks), + FieldConfig::Metadata(config) => { + let values = bulk_data + .metadata_fields + .get(&config.field) + .ok_or_else(|| { + anyhow!("Metadata field {} not found in bulk data", config.field) + })?; + query.bind(values) + } + FieldConfig::Vector(config) => { + let vectors = bulk_data.vector_fields.get(&config.field).ok_or_else(|| { + anyhow!("Vector field {} not found in bulk data", config.field) + })?; + query.bind(vectors) + } + }; + } + + Ok(query) + } + + /// Retrieves the name of the vector column configured in the schema. + /// + /// # Returns + /// * `Ok(String)` - The name of the vector column if exactly one is configured. + /// # Errors + /// * `Error::NoEmbedding` - If no vector field is configured in the schema. + /// * `Error::MultipleEmbeddings` - If multiple vector fields are configured in the schema. + pub fn get_vector_column_name(&self) -> Result { + let vector_fields: Vec<_> = self + .fields + .iter() + .filter(|field| matches!(field, FieldConfig::Vector(_))) + .collect(); + + match vector_fields.as_slice() { + [field] => Ok(field.field_name().to_string()), + [] => Err(anyhow!("No vector field configured in schema")), + _ => Err(anyhow!("Multiple vector fields configured in schema")), + } + } +} + +impl PgVector { + pub(crate) fn normalize_field_name(field: &str) -> String { + field + .to_lowercase() + .replace(|c: char| !c.is_alphanumeric(), "_") + } + + pub(crate) fn is_valid_identifier(identifier: &str) -> bool { + // PostgreSQL identifier rules: + // 1. Must start with a letter (a-z) or underscore + // 2. Subsequent characters can be letters, underscores, digits (0-9), or dollar signs + // 3. Maximum length is 63 bytes + // 4. Cannot be a reserved keyword + + // Check length + if identifier.is_empty() || identifier.len() > 63 { + return false; + } + + // Use a regular expression to check the pattern + let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_$]*$").unwrap(); + if !identifier_regex.is_match(identifier) { + return false; + } + + // Check if it's not a reserved keyword + !Self::is_reserved_keyword(identifier) + } + + pub(crate) fn is_reserved_keyword(word: &str) -> bool { + // This list is not exhaustive. You may want to expand it based on + // the PostgreSQL version you're using. + const RESERVED_KEYWORDS: &[&str] = &[ + "SELECT", "FROM", "WHERE", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "TABLE", + "INDEX", "ALTER", "ADD", "COLUMN", "AND", "OR", "NOT", "NULL", "TRUE", + "FALSE", + // Add more keywords as needed + ]; + + RESERVED_KEYWORDS.contains(&word.to_uppercase().as_str()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_identifiers() { + assert!(PgVector::is_valid_identifier("valid_name")); + assert!(PgVector::is_valid_identifier("_valid_name")); + assert!(PgVector::is_valid_identifier("valid_name_123")); + assert!(PgVector::is_valid_identifier("validName")); + } + + #[test] + fn test_invalid_identifiers() { + assert!(!PgVector::is_valid_identifier("")); // Empty string + assert!(!PgVector::is_valid_identifier(&"a".repeat(64))); // Too long + assert!(!PgVector::is_valid_identifier("123_invalid")); // Starts with a number + assert!(!PgVector::is_valid_identifier("invalid-name")); // Contains hyphen + assert!(!PgVector::is_valid_identifier("select")); // Reserved keyword + } +} diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs new file mode 100644 index 00000000..7987650d --- /dev/null +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -0,0 +1,223 @@ +use crate::pgvector::{PgVector, PgVectorBuilder}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use pgvector::Vector; +use sqlx::{prelude::FromRow, types::Uuid}; +use swiftide_core::{ + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + Retrieve, +}; + +#[allow(dead_code)] +#[derive(Debug, Clone, FromRow)] +struct VectorSearchResult { + id: Uuid, + chunk: String, +} + +#[allow(clippy::redundant_closure_for_method_calls)] +#[async_trait] +impl Retrieve> for PgVector { + #[tracing::instrument] + async fn retrieve( + &self, + search_strategy: &SimilaritySingleEmbedding, + query_state: Query, + ) -> Result> { + let embedding = query_state + .embedding + .as_ref() + .ok_or_else(|| anyhow!("No embedding for query"))?; + let embedding = Vector::from(embedding.clone()); + + // let pool = self.connection_pool.get_pool().await?; + let pool = self.connection_pool.get_pool()?; + + let default_columns: Vec<_> = PgVectorBuilder::default_fields() + .iter() + .map(|f| f.field_name().to_string()) + .collect(); + let vector_column_name = self.get_vector_column_name()?; + + // Start building the SQL query + let mut sql = format!( + "SELECT {} FROM {}", + default_columns.join(", "), + self.table_name + ); + + if let Some(filter) = search_strategy.filter() { + let filter_parts: Vec<&str> = filter.split('=').collect(); + if filter_parts.len() == 2 { + let key = filter_parts[0].trim(); + let value = filter_parts[1].trim().trim_matches('"'); + tracing::debug!( + "Filter being applied: key = {:#?}, value = {:#?}", + key, + value + ); + + let sql_filter = format!( + " WHERE meta_{}->>'{}' = '{}'", + PgVector::normalize_field_name(key), + key, + value + ); + sql.push_str(&sql_filter); + } else { + return Err(anyhow!("Invalid filter format")); + } + } + + // Add the ORDER BY clause for vector similarity search + sql.push_str(&format!( + " ORDER BY {} <=> $1 LIMIT $2", + &vector_column_name + )); + + tracing::debug!("Running retrieve with SQL: {}", sql); + + let top_k = i32::try_from(search_strategy.top_k()) + .map_err(|_| anyhow!("Failed to convert top_k to i32"))?; + + let data: Vec = sqlx::query_as(&sql) + .bind(embedding) + .bind(top_k) + .fetch_all(&pool) + .await?; + + let docs = data.into_iter().map(|r| r.chunk).collect(); + + Ok(query_state.retrieved_documents(docs)) + } +} + +#[async_trait] +impl Retrieve for PgVector { + async fn retrieve( + &self, + search_strategy: &SimilaritySingleEmbedding, + query: Query, + ) -> Result> { + Retrieve::>::retrieve( + self, + &search_strategy.into_concrete_filter::(), + query, + ) + .await + } +} + +#[cfg(test)] +mod tests { + use crate::pgvector::PgVector; + use futures_util::TryStreamExt; + use swiftide_core::{indexing, indexing::EmbeddedField, Persist}; + use swiftide_core::{ + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + Retrieve, + }; + use testcontainers::{ContainerAsync, GenericImage}; + + struct TestContext { + pgv_storage: PgVector, + _pgv_db_container: ContainerAsync, + } + + impl TestContext { + /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage + async fn setup() -> Result> { + // Start PostgreSQL container and obtain the connection URL + let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + + tracing::info!("Postgres database URL: {:#?}", pgv_db_url); + + // Configure and build PgVector storage + let pgv_storage = PgVector::builder() + .try_connect_to_pool(pgv_db_url, Some(10)) + .await + .map_err(|err| { + tracing::error!("Failed to connect to Postgres server: {}", err); + err + })? + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .with_metadata("filter") + .table_name("swiftide_pgvector_test".to_string()) + .build() + .map_err(|err| { + tracing::error!("Failed to build PgVector: {}", err); + err + })?; + + // Set up PgVector storage (create the table if not exists) + pgv_storage.setup().await.map_err(|err| { + tracing::error!("PgVector setup failed: {}", err); + err + })?; + + Ok(Self { + pgv_storage, + _pgv_db_container: pgv_db_container, + }) + } + } + + #[test_log::test(tokio::test)] + async fn test_retrieve_multiple_docs_and_filter() { + let test_context = TestContext::setup().await.expect("Test setup failed"); + + let nodes = vec![ + indexing::Node::new("test_query1").with_metadata(("filter", "true")), + indexing::Node::new("test_query2").with_metadata(("filter", "true")), + indexing::Node::new("test_query3").with_metadata(("filter", "false")), + ] + .into_iter() + .map(|node| { + node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); + node.to_owned() + }) + .collect(); + + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let search_strategy = SimilaritySingleEmbedding::<()>::default(); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 3); + + let search_strategy = + SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + + let search_strategy = + SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + assert_eq!(result.documents().len(), 0); + } +} diff --git a/swiftide-test-utils/Cargo.toml b/swiftide-test-utils/Cargo.toml index 5df13155..1bce78dd 100644 --- a/swiftide-test-utils/Cargo.toml +++ b/swiftide-test-utils/Cargo.toml @@ -13,7 +13,7 @@ homepage.workspace = true [dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } -swiftide-integrations = { path = "../swiftide-integrations", all-features = true } +swiftide-integrations = { path = "../swiftide-integrations", features = ["openai"] } async-openai = { workspace = true } qdrant-client = { workspace = true, default-features = false, features = [ @@ -31,6 +31,8 @@ insta = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } +tempfile = { workspace = true } +portpicker = { workspace = true } [features] default = ["test-utils"] diff --git a/swiftide-test-utils/src/test_utils.rs b/swiftide-test-utils/src/test_utils.rs index 86ba416a..b95d670c 100644 --- a/swiftide-test-utils/src/test_utils.rs +++ b/swiftide-test-utils/src/test_utils.rs @@ -3,7 +3,9 @@ use serde_json::json; use testcontainers::{ - core::wait::HttpWaitStrategy, runners::AsyncRunner as _, ContainerAsync, GenericImage, + core::{wait::HttpWaitStrategy, IntoContainerPort, Mount, WaitFor}, + runners::AsyncRunner, + ContainerAsync, GenericImage, ImageExt, }; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -70,6 +72,34 @@ pub async fn start_redis() -> (ContainerAsync, String) { (redis, redis_url) } +/// Setup Postgres container. +/// Returns container server and `server_url`. +pub async fn start_postgres() -> (ContainerAsync, String) { + // Find a free port on the host for Postgres to use + let host_port = portpicker::pick_unused_port().expect("No available free port on the host"); + + let postgres = testcontainers::GenericImage::new("pgvector/pgvector", "pg17") + .with_wait_for(WaitFor::message_on_stdout( + "database system is ready to accept connections", + )) + .with_mapped_port(host_port, 5432.tcp()) + .with_env_var("POSTGRES_USER", "myuser") + .with_env_var("POSTGRES_PASSWORD", "mypassword") + .with_env_var("POSTGRES_DB", "mydatabase") + .with_mount(Mount::tmpfs_mount("/var/lib/postgresql/data")) + .start() + .await + .expect("Failed to start Postgres container"); + + // Construct the connection URL using the dynamically assigned port + let postgres_url = format!( + "postgresql://myuser:mypassword@127.0.0.1:{}/mydatabase", + host_port + ); + + (postgres, postgres_url) +} + /// Mock embeddings creation endpoint. /// `embeddings_count` controls number of returned embedding vectors. pub async fn mock_embeddings(mock_server: &MockServer, embeddings_count: u8) { diff --git a/swiftide/Cargo.toml b/swiftide/Cargo.toml index e2e9204d..ed7664b3 100644 --- a/swiftide/Cargo.toml +++ b/swiftide/Cargo.toml @@ -35,6 +35,7 @@ all = [ "aws-bedrock", "groq", "ollama", + "pgvector", ] #! ### Integrations @@ -42,6 +43,9 @@ all = [ ## Enables Qdrant for storage and retrieval qdrant = ["swiftide-integrations/qdrant", "swiftide-core/qdrant"] +## Enables PgVector for storage and retrieval +pgvector = ["swiftide-integrations/pgvector"] + ## Enables Redis as an indexing cache and storage redis = ["swiftide-integrations/redis"]