Skip to content

Commit

Permalink
Add keyed fold scan
Browse files Browse the repository at this point in the history
  • Loading branch information
imDema committed Jun 6, 2024
1 parent e1042b3 commit b74cbda
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/operator/iteration/iterate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ impl<Out: ExchangeData, State: ExchangeData + Sync> Operator for Iterate<Out, St
}
}

impl<Out: ExchangeData, State: ExchangeData + Sync> Display for Iterate<Out, State> {
impl<Out: ExchangeData, State: ExchangeData> Display for Iterate<Out, State> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Iterate<{}>", std::any::type_name::<Out>())
}
Expand Down
87 changes: 87 additions & 0 deletions src/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! The actual operator list can be found from the implemented methods of [`Stream`],
//! [`KeyedStream`], [`crate::WindowedStream`]

use std::collections::HashMap;
use std::fmt::Display;
use std::hash::Hash;
use std::ops::{AddAssign, Div};
Expand Down Expand Up @@ -2685,6 +2686,92 @@ where
self.0.split_block(End::new, NextStrategy::random())
}

pub fn fold_scan<O, S, L, F>(
self,
init: S,
fold: L,
map: F,
) -> KeyedStream<impl Operator<Out = (K, O)>>
where
Op::Out: ExchangeData,
I: Send,
K: ExchangeDataKey + Sync,
L: Fn(&K, &mut S, I) + Send + Clone + 'static,
F: Fn(&K, &S, I) -> O + Send + Clone + 'static,
S: ExchangeData + Sync,
O: ExchangeData,
{
#[derive(Serialize, Deserialize, Clone)]
enum TwoPass<I, O> {
First(I),
Second(I),
Output(O),
}

let (state, s) = self.map(|el| TwoPass::First(el.1)).unkey().iterate(
2,
HashMap::<K, S>::default(),
|s, state| {
s.to_keyed()
.map(move |(k, el)| match el {
TwoPass::First(el) => TwoPass::Second(el),
TwoPass::Second(el) => {
TwoPass::Output((map)(k, state.get().get(k).unwrap(), el))
}
TwoPass::Output(_) => unreachable!(),
})
.unkey()
},
move |local: &mut HashMap<K, S>, (k, el)| match el {
TwoPass::First(_) => {}
TwoPass::Second(el) => fold(
&k,
local.entry(k.clone()).or_insert_with(|| init.clone()),
el,
),
TwoPass::Output(_) => {}
},
move |global: &mut HashMap<K, S>, mut local| {
global.extend(local.drain());
},
|_| true,
);

state.for_each(std::mem::drop);
s.to_keyed().map(|(_, t)| match t {
TwoPass::First(_) | TwoPass::Second(_) => unreachable!(),
TwoPass::Output(o) => o,
})
}

pub fn reduce_scan<O, S, F1, F2, R>(
self,
first_map: F1,
reduce: R,
second_map: F2,
) -> KeyedStream<impl Operator<Out = (K, O)>>
where
Op::Out: ExchangeData,
F1: Fn(&K, I) -> S + Send + Clone + 'static,
F2: Fn(&K, &S, I) -> O + Send + Clone + 'static,
R: Fn(&K, S, S) -> S + Send + Clone + 'static,
K: Sync,
S: ExchangeData + Sync,
O: ExchangeData,
{
self.fold_scan(
None,
move |k, acc: &mut Option<S>, x| {
let map = (first_map)(k, x);
*acc = Some(match acc.take() {
Some(v) => (reduce)(k, v, map),
None => map,
});
},
move |k, state, x| (second_map)(k, state.as_ref().unwrap(), x),
)
}

/// Close the stream and send resulting items to a channel on a single host.
///
/// If the stream is distributed among multiple replicas, parallelism will
Expand Down
27 changes: 0 additions & 27 deletions src/operator/sink/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,30 +134,3 @@ where
.finalize_block();
}
}

// #[cfg(test)]
// mod qtests {
// use std::AvroSinkions::HashSet;

// use crate::config::RuntimeConfig;
// use crate::environment::StreamContext;
// use crate::operator::source;

// #[test]
// fn AvroSink_vec() {
// let env = StreamContext::new(RuntimeConfig::local(4).unwrap());
// let source = source::IteratorSource::new(0..10u8);
// let res = env.stream(source).AvroSink::<Vec<_>>();
// env.execute_blocking();
// assert_eq!(res.get().unwrap(), (0..10).AvroSink::<Vec<_>>());
// }

// #[test]
// fn AvroSink_set() {
// let env = StreamContext::new(RuntimeConfig::local(4).unwrap());
// let source = source::IteratorSource::new(0..10u8);
// let res = env.stream(source).AvroSink::<HashSet<_>>();
// env.execute_blocking();
// assert_eq!(res.get().unwrap(), (0..10).AvroSink::<HashSet<_>>());
// }
// }
52 changes: 52 additions & 0 deletions tests/fold_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,55 @@ fn reduce_scan() {
}
});
}

#[test]
fn keyed_fold_scan() {
TestHelper::local_remote_env(|ctx| {
let res = ctx
.stream_par_iter(0..100i32)
.group_by(|e| e % 5)
.fold_scan(
0,
|_k, acc: &mut i32, x| {
*acc += x;
},
|_k, acc, x| (x, *acc),
)
.unkey()
.map(|t| (t.0, t.1 .0, t.1 .1))
.collect_vec();

ctx.execute_blocking();
if let Some(mut res) = res.get() {
let mut expected = (0..100)
.map(|x| (x % 5, x, (0..100).filter(|m| m % 5 == x % 5).sum::<i32>()))
.collect::<Vec<_>>();
res.sort();
expected.sort();
assert_eq!(expected, res);
}
});
}

#[test]
fn keyed_reduce_scan() {
TestHelper::local_remote_env(|ctx| {
let res = ctx
.stream_par_iter(0..100i32)
.group_by(|e| e % 5)
.reduce_scan(|_k, x| x, |_k, a, b| a + b, |_k, acc, x| (x, *acc))
.unkey()
.map(|t| (t.0, t.1 .0, t.1 .1))
.collect_vec();

ctx.execute_blocking();
if let Some(mut res) = res.get() {
let mut expected = (0..100)
.map(|x| (x % 5, x, (0..100).filter(|m| m % 5 == x % 5).sum::<i32>()))
.collect::<Vec<_>>();
res.sort();
expected.sort();
assert_eq!(expected, res);
}
});
}

0 comments on commit b74cbda

Please sign in to comment.