Skip to content

Commit

Permalink
feat: Implemtent traits T for Box<T> for indexing and query traits (#285
Browse files Browse the repository at this point in the history
)

When working with trait objects, some pipeline steps now allow for
Box<dyn Trait> as well.
  • Loading branch information
timonv authored Sep 10, 2024
1 parent dfa546b commit 3c9491b
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 8 deletions.
199 changes: 199 additions & 0 deletions swiftide-core/src/indexing_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ pub trait Transformer: Send + Sync {
}
}

#[async_trait]
impl Transformer for Box<dyn Transformer> {
async fn transform_node(&self, node: Node) -> Result<Node> {
self.as_ref().transform_node(node).await
}
fn concurrency(&self) -> Option<usize> {
self.as_ref().concurrency()
}
}

#[async_trait]
impl Transformer for &dyn Transformer {
async fn transform_node(&self, node: Node) -> Result<Node> {
(*self).transform_node(node).await
}
fn concurrency(&self) -> Option<usize> {
(*self).concurrency()
}
}

#[async_trait]
/// Use a closure as a transformer
impl<F> Transformer for F
Expand Down Expand Up @@ -66,10 +86,65 @@ where
}
}

#[async_trait]
impl BatchableTransformer for Box<dyn BatchableTransformer> {
async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
self.as_ref().batch_transform(nodes).await
}
fn concurrency(&self) -> Option<usize> {
self.as_ref().concurrency()
}
}

#[async_trait]
impl BatchableTransformer for &dyn BatchableTransformer {
async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
(*self).batch_transform(nodes).await
}
fn concurrency(&self) -> Option<usize> {
(*self).concurrency()
}
}

/// Starting point of a stream
#[cfg_attr(feature = "test-utils", automock, doc(hidden))]
pub trait Loader {
fn into_stream(self) -> IndexingStream;

/// Intended for use with Box<dyn Loader>
///
/// Only needed if you use trait objects (Box<dyn Loader>)
///
/// # Example
///
/// ```ignore
/// fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
/// self.into_stream()
/// }
/// ```
fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
unimplemented!("Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type")
}
}

impl Loader for Box<dyn Loader> {
fn into_stream(self) -> IndexingStream {
Loader::into_stream_boxed(self)
}

fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
Loader::into_stream(*self)
}
}

impl Loader for &dyn Loader {
fn into_stream(self) -> IndexingStream {
Loader::into_stream_boxed(Box::new(self))
}

fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
Loader::into_stream(*self)
}
}

#[cfg_attr(feature = "test-utils", automock, doc(hidden))]
Expand All @@ -84,6 +159,26 @@ pub trait ChunkerTransformer: Send + Sync + Debug {
}
}

#[async_trait]
impl ChunkerTransformer for Box<dyn ChunkerTransformer> {
async fn transform_node(&self, node: Node) -> IndexingStream {
self.as_ref().transform_node(node).await
}
fn concurrency(&self) -> Option<usize> {
self.as_ref().concurrency()
}
}

#[async_trait]
impl ChunkerTransformer for &dyn ChunkerTransformer {
async fn transform_node(&self, node: Node) -> IndexingStream {
(*self).transform_node(node).await
}
fn concurrency(&self) -> Option<usize> {
(*self).concurrency()
}
}

#[cfg_attr(feature = "test-utils", automock)]
#[async_trait]
/// Caches nodes, typically by their path and hash
Expand All @@ -95,6 +190,26 @@ pub trait NodeCache: Send + Sync + Debug {
async fn set(&self, node: &Node);
}

#[async_trait]
impl NodeCache for Box<dyn NodeCache> {
async fn get(&self, node: &Node) -> bool {
self.as_ref().get(node).await
}
async fn set(&self, node: &Node) {
self.as_ref().set(node).await;
}
}

#[async_trait]
impl NodeCache for &dyn NodeCache {
async fn get(&self, node: &Node) -> bool {
(*self).get(node).await
}
async fn set(&self, node: &Node) {
(*self).set(node).await;
}
}

#[cfg_attr(feature = "test-utils", automock)]
#[async_trait]
/// Embeds a list of strings and returns its embeddings.
Expand All @@ -103,6 +218,20 @@ pub trait EmbeddingModel: Send + Sync + Debug {
async fn embed(&self, input: Vec<String>) -> Result<Embeddings>;
}

#[async_trait]
impl EmbeddingModel for Box<dyn EmbeddingModel> {
async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
self.as_ref().embed(input).await
}
}

#[async_trait]
impl EmbeddingModel for &dyn EmbeddingModel {
async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
(*self).embed(input).await
}
}

#[cfg_attr(feature = "test-utils", automock)]
#[async_trait]
/// Embeds a list of strings and returns its embeddings.
Expand All @@ -111,6 +240,20 @@ pub trait SparseEmbeddingModel: Send + Sync + Debug {
async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings>;
}

#[async_trait]
impl SparseEmbeddingModel for Box<dyn SparseEmbeddingModel> {
async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings> {
self.as_ref().sparse_embed(input).await
}
}

#[async_trait]
impl SparseEmbeddingModel for &dyn SparseEmbeddingModel {
async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings> {
(*self).sparse_embed(input).await
}
}

#[cfg_attr(feature = "test-utils", automock)]
#[async_trait]
/// Given a string prompt, queries an LLM
Expand All @@ -119,6 +262,20 @@ pub trait SimplePrompt: Debug + Send + Sync {
async fn prompt(&self, prompt: Prompt) -> Result<String>;
}

#[async_trait]
impl SimplePrompt for Box<dyn SimplePrompt> {
async fn prompt(&self, prompt: Prompt) -> Result<String> {
self.as_ref().prompt(prompt).await
}
}

#[async_trait]
impl SimplePrompt for &dyn SimplePrompt {
async fn prompt(&self, prompt: Prompt) -> Result<String> {
(*self).prompt(prompt).await
}
}

#[cfg_attr(feature = "test-utils", automock)]
#[async_trait]
/// Persists nodes
Expand All @@ -131,6 +288,38 @@ pub trait Persist: Debug + Send + Sync {
}
}

#[async_trait]
impl Persist for Box<dyn Persist> {
async fn setup(&self) -> Result<()> {
self.as_ref().setup().await
}
async fn store(&self, node: Node) -> Result<Node> {
self.as_ref().store(node).await
}
async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
self.as_ref().batch_store(nodes).await
}
fn batch_size(&self) -> Option<usize> {
self.as_ref().batch_size()
}
}

#[async_trait]
impl Persist for &dyn Persist {
async fn setup(&self) -> Result<()> {
(*self).setup().await
}
async fn store(&self, node: Node) -> Result<Node> {
(*self).store(node).await
}
async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
(*self).batch_store(nodes).await
}
fn batch_size(&self) -> Option<usize> {
(*self).batch_size()
}
}

/// Allows for passing defaults from the pipeline to the transformer
/// Required for batch transformers as at least a marker, implementation is not required
pub trait WithIndexingDefaults {
Expand All @@ -144,7 +333,17 @@ pub trait WithBatchIndexingDefaults {
}

impl WithIndexingDefaults for dyn Transformer {}
impl WithIndexingDefaults for Box<dyn Transformer> {
fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
self.as_mut().with_indexing_defaults(indexing_defaults);
}
}
impl WithBatchIndexingDefaults for dyn BatchableTransformer {}
impl WithBatchIndexingDefaults for Box<dyn BatchableTransformer> {
fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
self.as_mut().with_indexing_defaults(indexing_defaults);
}
}

impl<F> WithIndexingDefaults for F where F: Fn(Node) -> Result<Node> {}
impl<F> WithBatchIndexingDefaults for F where F: Fn(Vec<Node>) -> IndexingStream {}
Expand Down
58 changes: 50 additions & 8 deletions swiftide-core/src/query_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{

/// Can transform queries before retrieval
#[async_trait]
pub trait TransformQuery: Send + Sync + ToOwned {
pub trait TransformQuery: Send + Sync {
async fn transform_query(
&self,
query: Query<states::Pending>,
Expand All @@ -21,7 +21,7 @@ pub trait TransformQuery: Send + Sync + ToOwned {
#[async_trait]
impl<F> TransformQuery for F
where
F: Fn(Query<states::Pending>) -> Result<Query<states::Pending>> + Send + Sync + ToOwned,
F: Fn(Query<states::Pending>) -> Result<Query<states::Pending>> + Send + Sync,
{
async fn transform_query(
&self,
Expand All @@ -31,24 +31,45 @@ where
}
}

#[async_trait]
impl TransformQuery for Box<dyn TransformQuery> {
async fn transform_query(
&self,
query: Query<states::Pending>,
) -> Result<Query<states::Pending>> {
self.as_ref().transform_query(query).await
}
}

/// A search strategy for the query pipeline
pub trait SearchStrategy: Clone + Send + Sync + Default {}

/// Can retrieve documents given a SearchStrategy
#[async_trait]
pub trait Retrieve<S: SearchStrategy>: Send + Sync + ToOwned {
pub trait Retrieve<S: SearchStrategy>: Send + Sync {
async fn retrieve(
&self,
search_strategy: &S,
query: Query<states::Pending>,
) -> Result<Query<states::Retrieved>>;
}

#[async_trait]
impl<S: SearchStrategy> Retrieve<S> for Box<dyn Retrieve<S>> {
async fn retrieve(
&self,
search_strategy: &S,
query: Query<states::Pending>,
) -> Result<Query<states::Retrieved>> {
self.as_ref().retrieve(search_strategy, query).await
}
}

#[async_trait]
impl<S, F> Retrieve<S> for F
where
S: SearchStrategy,
F: Fn(&S, Query<states::Pending>) -> Result<Query<states::Retrieved>> + Send + Sync + ToOwned,
F: Fn(&S, Query<states::Pending>) -> Result<Query<states::Retrieved>> + Send + Sync,
{
async fn retrieve(
&self,
Expand All @@ -61,41 +82,62 @@ where

/// Can transform a response after retrieval
#[async_trait]
pub trait TransformResponse: Send + Sync + ToOwned {
pub trait TransformResponse: Send + Sync {
async fn transform_response(&self, query: Query<Retrieved>)
-> Result<Query<states::Retrieved>>;
}

#[async_trait]
impl<F> TransformResponse for F
where
F: Fn(Query<Retrieved>) -> Result<Query<Retrieved>> + Send + Sync + ToOwned,
F: Fn(Query<Retrieved>) -> Result<Query<Retrieved>> + Send + Sync,
{
async fn transform_response(&self, query: Query<Retrieved>) -> Result<Query<Retrieved>> {
(self)(query)
}
}

#[async_trait]
impl TransformResponse for Box<dyn TransformResponse> {
async fn transform_response(&self, query: Query<Retrieved>) -> Result<Query<Retrieved>> {
self.as_ref().transform_response(query).await
}
}

/// Can answer the original query
#[async_trait]
pub trait Answer: Send + Sync + ToOwned {
pub trait Answer: Send + Sync {
async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>>;
}

#[async_trait]
impl<F> Answer for F
where
F: Fn(Query<Retrieved>) -> Result<Query<states::Answered>> + Send + Sync + ToOwned,
F: Fn(Query<Retrieved>) -> Result<Query<states::Answered>> + Send + Sync,
{
async fn answer(&self, query: Query<Retrieved>) -> Result<Query<states::Answered>> {
(self)(query)
}
}

#[async_trait]
impl Answer for Box<dyn Answer> {
async fn answer(&self, query: Query<Retrieved>) -> Result<Query<states::Answered>> {
self.as_ref().answer(query).await
}
}

/// Evaluates a query
///
/// An evaluator needs to be able to respond to each step in the query pipeline
#[async_trait]
pub trait EvaluateQuery: Send + Sync {
async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()>;
}

#[async_trait]
impl EvaluateQuery for Box<dyn EvaluateQuery> {
async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()> {
self.as_ref().evaluate(evaluation).await
}
}
Loading

0 comments on commit 3c9491b

Please sign in to comment.