Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Filter operator #3

Merged
merged 4 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added eggstrain/cargo
Empty file.
77 changes: 77 additions & 0 deletions eggstrain/src/execution/operators/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use super::{Operator, UnaryOperator};
use arrow::compute::filter_record_batch;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::common::cast::as_boolean_array;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::PhysicalExpr;
use datafusion_common::Result;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;

pub struct Filter {
pub predicate: Arc<dyn PhysicalExpr>,
pub children: Vec<Arc<dyn ExecutionPlan>>,
}

impl Filter {
pub fn new(predicate: Arc<dyn PhysicalExpr>, children: Vec<Arc<dyn ExecutionPlan>>) -> Self {
Self {
predicate,
children,
}
}

/// https://docs.rs/datafusion-physical-plan/36.0.0/src/datafusion_physical_plan/filter.rs.html#307
pub fn batch_filter(&self, batch: RecordBatch) -> Result<RecordBatch> {
self.predicate
.evaluate(&batch)
.and_then(|v| v.into_array(batch.num_rows()))
.and_then(|array| {
Ok(as_boolean_array(&array)?)
// apply filter array to record batch
.and_then(|filter_array| Ok(filter_record_batch(&batch, filter_array)?))
})
}
}

impl Operator for Filter {
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.children.clone()
}
}

#[async_trait]
impl UnaryOperator for Filter {
type In = RecordBatch;
type Out = RecordBatch;

fn into_unary(self) -> Arc<dyn UnaryOperator<In = Self::In, Out = Self::Out>> {
Arc::new(self)
}

async fn execute(
&self,
mut rx: broadcast::Receiver<Self::In>,
tx: broadcast::Sender<Self::Out>,
) {
loop {
match rx.recv().await {
Ok(batch) => {
let filtered_batch = self
.batch_filter(batch)
.expect("Filter::batch_filter() fails");

if filtered_batch.num_rows() > 0 {
tx.send(filtered_batch).expect("tx.send() fails");
}
}
Err(e) => match e {
RecvError::Closed => break,
RecvError::Lagged(_) => todo!(),
},
}
}
}
}
1 change: 1 addition & 0 deletions eggstrain/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use datafusion::physical_plan::ExecutionPlan;
use std::sync::Arc;
use tokio::sync::broadcast::{Receiver, Sender};

pub mod filter;
pub mod project;

pub trait Operator {
Expand Down
77 changes: 65 additions & 12 deletions eggstrain/src/execution/query_dag.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::operators::filter::Filter;
use super::operators::project::Project;
use super::operators::UnaryOperator;
use arrow::record_batch::RecordBatch;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::{projection::ProjectionExec, ExecutionPlan, Partitioning};
use datafusion_common::{DataFusionError, Result};
use futures::stream::StreamExt;
Expand All @@ -15,17 +17,21 @@ const BATCH_SIZE: usize = 1024;
#[derive(Clone)]
pub(crate) enum EggstrainOperator {
Project(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
// Filter(Arc<dyn UnaryNode>),
Filter(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
// Sort(Arc<dyn UnaryNode>),

// Aggregate(Arc<dyn UnaryNode>),

// TableScan(Arc<dyn UnaryNode>),

// HashJoin(Arc<dyn BinaryNode>),
}

impl EggstrainOperator {
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
match self {
Self::Project(x) => x.children(),
Self::Filter(x) => x.children(),
}
}
}
Expand All @@ -49,8 +55,16 @@ fn extract_df_node(plan: Arc<dyn ExecutionPlan>) -> Result<EggstrainOperator> {
let node = Project::new(child_schema, projection_plan).into_unary();

Ok(EggstrainOperator::Project(node))
// } else if id == TypeId::of::<FilterExec>() {
// todo!();
} else if id == TypeId::of::<FilterExec>() {
let filter_plan = root
.downcast_ref::<FilterExec>()
.expect("Unable to downcast_ref to FilterExec");

let node =
Filter::new(filter_plan.predicate().clone(), filter_plan.children()).into_unary();

Ok(EggstrainOperator::Filter(node))

// } else if id == TypeId::of::<HashJoinExec>() {
// todo!();
// } else if id == TypeId::of::<SortExec>() {
Expand Down Expand Up @@ -92,6 +106,7 @@ fn df_execute_node(plan: Arc<dyn ExecutionPlan>, tx: broadcast::Sender<RecordBat
}

pub fn build_query_dag(plan: Arc<dyn ExecutionPlan>) -> Result<broadcast::Receiver<RecordBatch>> {
// A tuple containing a plan node and a sender into that node
let mut queue = VecDeque::new();

// Final output is going to be sent to root_rx
Expand All @@ -103,23 +118,61 @@ pub fn build_query_dag(plan: Arc<dyn ExecutionPlan>) -> Result<broadcast::Receiv
queue.push_back((root, root_tx));

while let Some((node, tx)) = queue.pop_front() {
for child in node.children() {
let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);
let node = node.clone();

match node.children().len() {
0 => {
todo!();
}
1 => {
let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);
let child_plan = node.children()[0].clone();

if let Ok(child_node) = extract_df_node(child.clone()) {
match child_node.clone() {
EggstrainOperator::Project(project) => {
match node.clone() {
EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
let tx = tx.clone();
tokio::spawn(async move {
project.execute(child_rx, tx).await;
eggnode.execute(child_rx, tx).await;
});
}
};
queue.push_back((child_node, child_tx));
} else {
df_execute_node(child.clone(), tx.clone());

match extract_df_node(child_plan.clone()) {
Ok(val) => {
queue.push_back((val, child_tx));
}
Err(_) => {
df_execute_node(child_plan, child_tx);
}
}
}
2 => {
todo!();
}
_ => {
return Err(DataFusionError::NotImplemented(
"More than 2 children not implemented".to_string(),
));
}
}

// for child in node.children() {
// let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);

// if let Ok(child_node) = extract_df_node(child.clone()) {
// match child_node.clone() {
// EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
// let tx = tx.clone();
// tokio::spawn(async move {
// eggnode.execute(child_rx, tx).await;
// });
// }
// };
// queue.push_back((child_node, child_tx));
// } else {
// df_execute_node(child.clone(), tx.clone());
// }
// }
}

Ok(root_rx)
Expand Down
4 changes: 4 additions & 0 deletions eggstrain/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ async fn main() -> Result<()> {

let physical_plan = tpch.clone().create_physical_plan().await?;

println!("{:#?}", physical_plan.clone());

// let physical_plan = physical_plan.children()[0].clone();

let results = run(physical_plan).await;

results.into_iter().for_each(|batch| {
Expand Down
7 changes: 7 additions & 0 deletions queries/basic_filter.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
orders.o_totalprice
FROM
orders
WHERE
orders.o_totalprice < 850.00
;
Loading