diff --git a/Cargo.lock b/Cargo.lock index f59a3a4fd..472ae2715 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -336,7 +336,7 @@ dependencies = [ [[package]] name = "arrow-json" version = "53.2.0" -source = "git+https://github.com/ArroyoSystems/arrow-rs?branch=53.2.0/json#24c93dff8203e766ea30cdd0461d06c10608f53c" +source = "git+https://github.com/ArroyoSystems/arrow-rs?branch=53.2.0%2Fjson#24c93dff8203e766ea30cdd0461d06c10608f53c" dependencies = [ "arrow-array", "arrow-buffer", @@ -645,6 +645,7 @@ dependencies = [ "k8s-openapi", "kube", "lazy_static", + "log", "petgraph", "postgres", "postgres-types", @@ -689,6 +690,7 @@ dependencies = [ "datafusion-proto", "dyn-clone", "hex", + "itertools 0.13.0", "petgraph", "proc-macro2", "prost 0.13.3", @@ -2827,7 +2829,7 @@ checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" [[package]] name = "datafusion" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -2883,7 +2885,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow-schema", "async-trait", @@ -2897,7 +2899,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -2921,7 +2923,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "log", "tokio", @@ -2930,7 +2932,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "chrono", @@ -2950,7 +2952,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -2973,7 +2975,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "datafusion-common", @@ -2984,7 +2986,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "arrow-buffer", @@ -3010,7 +3012,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -3030,7 +3032,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -3054,7 +3056,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "arrow-array", @@ -3076,7 +3078,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "datafusion-common", "datafusion-expr", @@ -3090,7 +3092,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -3099,7 +3101,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "async-trait", @@ -3118,7 +3120,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -3145,7 +3147,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -3159,7 +3161,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "arrow-schema", @@ -3174,7 +3176,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "ahash", "arrow", @@ -3208,7 +3210,7 @@ dependencies = [ [[package]] name = "datafusion-proto" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "chrono", @@ -3223,7 +3225,7 @@ dependencies = [ [[package]] name = "datafusion-proto-common" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "chrono", @@ -3235,7 +3237,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "43.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0/arroyo#319f59b143208e025d7bf952748492d847068a39" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=43.0.0%2Farroyo#319f59b143208e025d7bf952748492d847068a39" dependencies = [ "arrow", "arrow-array", @@ -6282,7 +6284,7 @@ dependencies = [ [[package]] name = "object_store" version = "0.11.1" -source = "git+http://github.com/ArroyoSystems/arrow-rs?branch=object_store_0.11.1/arroyo#4cfe48061503161e43cd3cd7960e74ce789bd3b9" +source = "git+http://github.com/ArroyoSystems/arrow-rs?branch=object_store_0.11.1%2Farroyo#4cfe48061503161e43cd3cd7960e74ce789bd3b9" dependencies = [ "async-trait", "base64 0.22.1", @@ -6533,7 +6535,7 @@ dependencies = [ [[package]] name = "parquet" version = "53.2.0" -source = "git+https://github.com/ArroyoSystems/arrow-rs?branch=53.2.0/parquet_bytes#424920f863e1b8286c3ce8261cce16f0360428c5" +source = "git+https://github.com/ArroyoSystems/arrow-rs?branch=53.2.0%2Fparquet_bytes#424920f863e1b8286c3ce8261cce16f0360428c5" dependencies = [ "ahash", "arrow-array", diff --git a/crates/arroyo-api/src/pipelines.rs b/crates/arroyo-api/src/pipelines.rs index 2f62fcd46..7e829c5db 100644 --- a/crates/arroyo-api/src/pipelines.rs +++ b/crates/arroyo-api/src/pipelines.rs @@ -6,10 +6,11 @@ use axum::{debug_handler, Json}; use axum_extra::extract::WithRejection; use http::StatusCode; +use petgraph::visit::NodeRef; use petgraph::{Direction, EdgeDirection}; use std::collections::HashMap; - -use petgraph::visit::NodeRef; +use std::num::ParseIntError; +use std::str::FromStr; use std::time::{Duration, SystemTime}; use crate::{compiler_service, connection_profiles, jobs, types}; @@ -23,7 +24,9 @@ use arroyo_rpc::api_types::{JobCollection, PaginationQueryParams, PipelineCollec use arroyo_rpc::grpc::api::{ArrowProgram, ConnectorOp}; use arroyo_connectors::kafka::{KafkaConfig, KafkaTable, SchemaRegistry}; -use arroyo_datastream::logical::{LogicalNode, LogicalProgram, OperatorName}; +use arroyo_datastream::logical::{ + ChainedLogicalOperator, LogicalNode, LogicalProgram, OperatorChain, OperatorName, +}; use arroyo_df::{ArroyoSchemaProvider, CompiledSql, SqlConfig}; use arroyo_formats::ser::ArrowSerializer; use arroyo_rpc::formats::Format; @@ -268,17 +271,19 @@ async fn register_schemas(compiled_sql: &mut CompiledSql) -> anyhow::Result<()> let schema = edge.weight().schema.schema.clone(); let node = compiled_sql.program.graph.node_weight_mut(idx).unwrap(); - if node.operator_name == OperatorName::ConnectorSink { - let mut op = ConnectorOp::decode(&node.operator_config[..]).map_err(|_| { - anyhow!( - "failed to decode configuration for connector node {:?}", - node - ) - })?; - - try_register_confluent_schema(&mut op, &schema).await?; - - node.operator_config = op.encode_to_vec(); + for (node, _) in node.operator_chain.iter_mut() { + if node.operator_name == OperatorName::ConnectorSink { + let mut op = ConnectorOp::decode(&node.operator_config[..]).map_err(|_| { + anyhow!( + "failed to decode configuration for connector node {:?}", + node + ) + })?; + + try_register_confluent_schema(&mut op, &schema).await?; + + node.operator_config = op.encode_to_vec(); + } } } @@ -324,19 +329,31 @@ pub(crate) async fn create_pipeline_int<'a>( let g = &mut compiled.program.graph; for idx in g.node_indices() { let should_replace = { - let node = g.node_weight(idx).unwrap(); - node.operator_name == OperatorName::ConnectorSink - && node.operator_config != default_sink().encode_to_vec() + let node = &g.node_weight(idx).unwrap().operator_chain; + node.is_sink() + && node.iter().next().unwrap().0.operator_config + != default_sink().encode_to_vec() }; if should_replace { if enable_sinks { let new_idx = g.add_node(LogicalNode { - operator_id: format!("{}_1", g.node_weight(idx).unwrap().operator_id), + node_id: g.node_weights().map(|n| n.node_id).max().unwrap() + 1, description: "Preview sink".to_string(), - operator_name: OperatorName::ConnectorSink, - operator_config: default_sink().encode_to_vec(), + operator_chain: OperatorChain::new(ChainedLogicalOperator { + operator_id: format!( + "{}_1", + g.node_weight(idx) + .unwrap() + .operator_chain + .first() + .operator_id + ), + operator_name: OperatorName::ConnectorSink, + operator_config: default_sink().encode_to_vec(), + }), parallelism: 1, }); + let edges: Vec<_> = g .edges_directed(idx, Direction::Incoming) .map(|e| (e.source(), e.weight().clone())) @@ -345,8 +362,14 @@ pub(crate) async fn create_pipeline_int<'a>( g.add_edge(source, new_idx, weight); } } else { - g.node_weight_mut(idx).unwrap().operator_config = - default_sink().encode_to_vec(); + g.node_weight_mut(idx) + .unwrap() + .operator_chain + .iter_mut() + .next() + .unwrap() + .0 + .operator_config = default_sink().encode_to_vec(); } } } @@ -452,8 +475,9 @@ impl TryInto for DbPipeline { .as_object() .unwrap() .into_iter() - .map(|(k, v)| (k.clone(), v.as_u64().unwrap() as usize)) - .collect(), + .map(|(k, v)| Ok((u32::from_str(k)?, v.as_u64().unwrap() as usize))) + .collect::, ParseIntError>>() + .map_err(|e| bad_request(format!("invalid node_id: {}", e)))?, ); let stop = match self.stop { @@ -682,10 +706,10 @@ pub async fn patch_pipeline( .ok_or_else(|| not_found("Job"))?; let program = ArrowProgram::decode(&res.program[..]).map_err(log_and_map)?; - let map: HashMap = program + let map: HashMap<_, _> = program .nodes .into_iter() - .map(|node| (node.node_id, parallelism as u32)) + .map(|node| (node.node_id.to_string(), parallelism as u32)) .collect(); Some(serde_json::to_value(map).map_err(log_and_map)?) diff --git a/crates/arroyo-connectors/src/blackhole/mod.rs b/crates/arroyo-connectors/src/blackhole/mod.rs index 33de9e1d4..02c180886 100644 --- a/crates/arroyo-connectors/src/blackhole/mod.rs +++ b/crates/arroyo-connectors/src/blackhole/mod.rs @@ -1,7 +1,7 @@ use crate::blackhole::operator::BlackholeSinkFunc; use anyhow::anyhow; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; @@ -120,8 +120,8 @@ impl Connector for BlackholeConnector { _: Self::ProfileT, _: Self::TableT, _: OperatorConfig, - ) -> anyhow::Result { - Ok(OperatorNode::from_operator(Box::new( + ) -> anyhow::Result { + Ok(ConstructedOperator::from_operator(Box::new( BlackholeSinkFunc::new(), ))) } diff --git a/crates/arroyo-connectors/src/blackhole/operator.rs b/crates/arroyo-connectors/src/blackhole/operator.rs index e5ca2eec3..6f0f6a50c 100644 --- a/crates/arroyo-connectors/src/blackhole/operator.rs +++ b/crates/arroyo-connectors/src/blackhole/operator.rs @@ -1,5 +1,5 @@ use arrow::array::RecordBatch; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use async_trait::async_trait; @@ -18,7 +18,12 @@ impl ArrowOperator for BlackholeSinkFunc { "BlackholeSink".to_string() } - async fn process_batch(&mut self, _: RecordBatch, _: &mut ArrowContext) { + async fn process_batch( + &mut self, + _: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { // no-op } } diff --git a/crates/arroyo-connectors/src/confluent/mod.rs b/crates/arroyo-connectors/src/confluent/mod.rs index 27c5cab01..e346230c9 100644 --- a/crates/arroyo-connectors/src/confluent/mod.rs +++ b/crates/arroyo-connectors/src/confluent/mod.rs @@ -4,7 +4,7 @@ use crate::kafka::{ use crate::{kafka, pull_opt}; use anyhow::anyhow; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; @@ -195,7 +195,7 @@ impl Connector for ConfluentConnector { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { KafkaConnector {}.make_operator(profile.into(), table, config) } } diff --git a/crates/arroyo-connectors/src/filesystem/delta.rs b/crates/arroyo-connectors/src/filesystem/delta.rs index 30ebde66b..060f3466b 100644 --- a/crates/arroyo-connectors/src/filesystem/delta.rs +++ b/crates/arroyo-connectors/src/filesystem/delta.rs @@ -14,7 +14,7 @@ use crate::filesystem::{ use crate::EmptyConfig; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use super::sink::{LocalParquetFileSystemSink, ParquetFileSystemSink}; @@ -154,7 +154,7 @@ impl Connector for DeltaLakeConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { let TableType::Sink { write_path, file_settings, @@ -179,13 +179,15 @@ impl Connector for DeltaLakeConnector { let is_local = backend_config.is_local(); match (&format_settings, is_local) { (Some(FormatSettings::Parquet { .. }), true) => { - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( LocalParquetFileSystemSink::new(write_path.to_string(), table, config), ))) } - (Some(FormatSettings::Parquet { .. }), false) => Ok(OperatorNode::from_operator( - Box::new(ParquetFileSystemSink::new(table, config)), - )), + (Some(FormatSettings::Parquet { .. }), false) => { + Ok(ConstructedOperator::from_operator(Box::new( + ParquetFileSystemSink::new(table, config), + ))) + } _ => bail!("Delta Lake sink only supports Parquet format"), } } diff --git a/crates/arroyo-connectors/src/filesystem/mod.rs b/crates/arroyo-connectors/src/filesystem/mod.rs index 0aecefefe..d181e5024 100644 --- a/crates/arroyo-connectors/src/filesystem/mod.rs +++ b/crates/arroyo-connectors/src/filesystem/mod.rs @@ -21,7 +21,7 @@ use crate::{pull_opt, pull_option_to_i64, EmptyConfig}; use crate::filesystem::source::FileSystemSourceFunc; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use self::sink::{ JsonFileSystemSink, LocalJsonFileSystemSink, LocalParquetFileSystemSink, ParquetFileSystemSink, @@ -228,10 +228,10 @@ impl Connector for FileSystemConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> Result { + ) -> Result { match &table.table_type { - TableType::Source { .. } => { - Ok(OperatorNode::from_source(Box::new(FileSystemSourceFunc { + TableType::Source { .. } => Ok(ConstructedOperator::from_source(Box::new( + FileSystemSourceFunc { table: table.table_type.clone(), format: config .format @@ -239,8 +239,8 @@ impl Connector for FileSystemConnector { framing: config.framing.clone(), bad_data: config.bad_data.clone(), file_states: HashMap::new(), - }))) - } + }, + ))), TableType::Sink { file_settings: _, format_settings, @@ -250,23 +250,25 @@ impl Connector for FileSystemConnector { let backend_config = BackendConfig::parse_url(write_path, true)?; match (format_settings, backend_config.is_local()) { (Some(FormatSettings::Parquet { .. }), true) => { - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( LocalParquetFileSystemSink::new(write_path.to_string(), table, config), ))) } (Some(FormatSettings::Parquet { .. }), false) => { - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( ParquetFileSystemSink::new(table, config), ))) } (Some(FormatSettings::Json { .. }), true) => { - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( LocalJsonFileSystemSink::new(write_path.to_string(), table, config), ))) } - (Some(FormatSettings::Json { .. }), false) => Ok(OperatorNode::from_operator( - Box::new(JsonFileSystemSink::new(table, config)), - )), + (Some(FormatSettings::Json { .. }), false) => { + Ok(ConstructedOperator::from_operator(Box::new( + JsonFileSystemSink::new(table, config), + ))) + } (None, _) => bail!("have to have some format settings"), } } diff --git a/crates/arroyo-connectors/src/filesystem/sink/local.rs b/crates/arroyo-connectors/src/filesystem/sink/local.rs index 8b4082cec..96b426547 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/local.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/local.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, fs::create_dir_all, path::Path, sync::Arc, time::SystemTime}; use arrow::record_batch::RecordBatch; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::OperatorContext; use arroyo_rpc::{ df::{ArroyoSchema, ArroyoSchemaRef}, formats::Format, @@ -215,7 +215,7 @@ impl TwoPhaseCommitter for LocalFileSystemWrite async fn init( &mut self, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, data_recovery: Vec, ) -> Result<()> { let mut max_file_index = 0; @@ -255,7 +255,7 @@ impl TwoPhaseCommitter for LocalFileSystemWrite }) } } - self.subtask_id = ctx.task_info.task_index; + self.subtask_id = ctx.task_info.task_index as usize; self.finished_files = recovered_files; self.next_file_index = max_file_index; Ok(()) diff --git a/crates/arroyo-connectors/src/filesystem/sink/mod.rs b/crates/arroyo-connectors/src/filesystem/sink/mod.rs index 9c066c56e..8c0c24249 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/mod.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/mod.rs @@ -15,7 +15,7 @@ use ::arrow::{ util::display::{ArrayFormatter, FormatOptions}, }; use anyhow::{bail, Result}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::OperatorContext; use arroyo_rpc::{df::ArroyoSchemaRef, formats::Format, OperatorConfig, TIMESTAMP_FIELD}; use arroyo_storage::StorageProvider; use async_trait::async_trait; @@ -1550,7 +1550,7 @@ impl TwoPhaseCommitter for FileSystemSink, ) -> Result<()> { self.start(Arc::new(ctx.in_schemas.first().unwrap().clone()))?; @@ -1570,7 +1570,7 @@ impl TwoPhaseCommitter for FileSystemSink TwoPhaseCommitter for FileSystemSink { committer: TPC, @@ -43,7 +44,7 @@ pub trait TwoPhaseCommitter: Send + 'static { fn name(&self) -> String; async fn init( &mut self, - task_info: &mut ArrowContext, + task_info: &mut OperatorContext, data_recovery: Vec, ) -> Result<()>; async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>; @@ -83,7 +84,7 @@ impl TwoPhaseCommitterOperator { &mut self, epoch: u32, mut commit_data: HashMap>>, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) { info!("received commit message"); let pre_commits = match self.committer.commit_strategy() { @@ -107,13 +108,16 @@ impl TwoPhaseCommitterOperator { .commit(&ctx.task_info, pre_commits) .await .expect("committer committed"); + let checkpoint_event = arroyo_rpc::ControlResp::CheckpointEvent(CheckpointEvent { checkpoint_epoch: epoch, + node_id: ctx.task_info.node_id, operator_id: ctx.task_info.operator_id.clone(), - subtask_index: ctx.task_info.task_index as u32, + subtask_index: ctx.task_info.task_index, time: SystemTime::now(), event_type: arroyo_rpc::grpc::rpc::TaskCheckpointEventType::FinishedCommit, }); + ctx.control_tx .send(checkpoint_event) .await @@ -152,7 +156,11 @@ impl ArrowOperator for TwoPhaseCommitterOperator { tables } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + fn is_committing(&self) -> bool { + true + } + + async fn on_start(&mut self, ctx: &mut OperatorContext) { let tracking_key_state: &mut GlobalKeyedView = ctx .table_manager .get_global_keyed_state("r") @@ -176,26 +184,23 @@ impl ArrowOperator for TwoPhaseCommitterOperator { } } - async fn process_batch(&mut self, batch: RecordBatch, _ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.committer .insert_batch(batch) .await .expect("record inserted"); } - async fn on_close(&mut self, _final_mesage: &Option, ctx: &mut ArrowContext) { - if let Some(ControlMessage::Commit { epoch, commit_data }) = ctx.control_rx.recv().await { - self.handle_commit(epoch, commit_data, ctx).await; - } else { - warn!("no commit message received, not committing") - } - } - async fn handle_commit( &mut self, epoch: u32, commit_data: &HashMap>>, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) { self.handle_commit(epoch, commit_data.clone(), ctx).await; } @@ -203,7 +208,8 @@ impl ArrowOperator for TwoPhaseCommitterOperator { async fn handle_checkpoint( &mut self, checkpoint_barrier: arroyo_types::CheckpointBarrier, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + _: &mut dyn Collector, ) { let (recovery_data, pre_commits) = self .committer @@ -224,7 +230,7 @@ impl ArrowOperator for TwoPhaseCommitterOperator { .await .expect("should be able to get table"); recovery_data_state - .insert(ctx.task_info.task_index, recovery_data) + .insert(ctx.task_info.task_index as usize, recovery_data) .await; self.pre_commits.clear(); if pre_commits.is_empty() { diff --git a/crates/arroyo-connectors/src/filesystem/source.rs b/crates/arroyo-connectors/src/filesystem/source.rs index 3bf43665d..f2792f897 100644 --- a/crates/arroyo-connectors/src/filesystem/source.rs +++ b/crates/arroyo-connectors/src/filesystem/source.rs @@ -18,7 +18,7 @@ use futures::StreamExt; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::arrow::ParquetRecordBatchStreamBuilder; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use regex::Regex; use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader}; use tokio::select; @@ -60,8 +60,12 @@ impl SourceOperator for FileSystemSourceFunc { "FileSystem".to_string() } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + match self.run_int(ctx, collector).await { Ok(s) => s, Err(e) => { ctx.report_error(e.name.clone(), e.details.clone()).await; @@ -83,7 +87,11 @@ impl FileSystemSourceFunc { } } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { let (storage_provider, regex_pattern) = match &self.table { TableType::Source { path, @@ -116,7 +124,7 @@ impl FileSystemSourceFunc { )) } }; - ctx.initialize_deserializer( + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), @@ -136,7 +144,7 @@ impl FileSystemSourceFunc { // hash the path and modulo by the number of tasks let mut hasher = DefaultHasher::new(); path.hash(&mut hasher); - if (hasher.finish() as usize) % parallelism != task_index { + if (hasher.finish() as usize) % parallelism as usize != task_index as usize { return ready(false); } @@ -164,7 +172,10 @@ impl FileSystemSourceFunc { continue; } - if let Some(finish_type) = self.read_file(ctx, &storage_provider, &obj_key).await? { + if let Some(finish_type) = self + .read_file(ctx, collector, &storage_provider, &obj_key) + .await? + { return Ok(finish_type); } } @@ -274,7 +285,8 @@ impl FileSystemSourceFunc { async fn read_file( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, storage_provider: &StorageProvider, obj_key: &String, ) -> Result, UserError> { @@ -298,7 +310,7 @@ impl FileSystemSourceFunc { .get_newline_separated_stream(storage_provider, obj_key.to_string()) .await? .skip(records_read); - self.read_line_file(ctx, line_reader, obj_key, records_read) + self.read_line_file(ctx, collector, line_reader, obj_key, records_read) .await } Format::Avro(_) => todo!(), @@ -307,12 +319,12 @@ impl FileSystemSourceFunc { .get_record_batch_stream( storage_provider, obj_key, - ctx.out_schema.as_ref().unwrap().schema.clone(), + ctx.out_schema.schema.clone(), ) .await? .skip(records_read); - self.read_parquet_file(ctx, record_batch_stream, obj_key, records_read) + self.read_parquet_file(ctx, collector, record_batch_stream, obj_key, records_read) .await } Format::RawString(_) => todo!(), @@ -323,7 +335,8 @@ impl FileSystemSourceFunc { async fn read_parquet_file( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, mut record_batch_stream: impl Stream> + Unpin + Send, obj_key: &String, mut records_read: usize, @@ -333,7 +346,7 @@ impl FileSystemSourceFunc { item = record_batch_stream.next() => { match item.transpose()? { Some(batch) => { - ctx.collect(batch).await; + collector.collect(batch).await; records_read += 1; } None => { @@ -346,7 +359,7 @@ impl FileSystemSourceFunc { msg_res = ctx.control_rx.recv() => { if let Some(control_message) = msg_res { self.file_states.insert(obj_key.to_string(), FileReadState::RecordsRead(records_read)); - if let Some(finish_type) = self.process_control_message(ctx, control_message).await { + if let Some(finish_type) = self.process_control_message(ctx, collector, control_message).await { return Ok(Some(finish_type)) } } @@ -357,7 +370,8 @@ impl FileSystemSourceFunc { async fn read_line_file( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, mut line_reader: impl Stream> + Unpin + Send, obj_key: &String, mut records_read: usize, @@ -367,15 +381,15 @@ impl FileSystemSourceFunc { line = line_reader.next() => { match line.transpose()? { Some(line) => { - ctx.deserialize_slice(line.as_bytes(), SystemTime::now(), None).await?; + collector.deserialize_slice(line.as_bytes(), SystemTime::now(), None).await?; records_read += 1; - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } None => { info!("finished reading file {}", obj_key); - ctx.flush_buffer().await?; + collector.flush_buffer().await?; self.file_states.insert(obj_key.to_string(), FileReadState::Finished); return Ok(None); } @@ -384,7 +398,7 @@ impl FileSystemSourceFunc { msg_res = ctx.control_rx.recv() => { if let Some(control_message) = msg_res { self.file_states.insert(obj_key.to_string(), FileReadState::RecordsRead(records_read)); - if let Some(finish_type) = self.process_control_message(ctx, control_message).await { + if let Some(finish_type) = self.process_control_message(ctx, collector, control_message).await { return Ok(Some(finish_type)) } } @@ -395,7 +409,8 @@ impl FileSystemSourceFunc { async fn process_control_message( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, control_message: ControlMessage, ) -> Option { match control_message { @@ -409,7 +424,7 @@ impl FileSystemSourceFunc { .await; } // checkpoint our state - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { Some(SourceFinishType::Immediate) } else { None diff --git a/crates/arroyo-connectors/src/fluvio/mod.rs b/crates/arroyo-connectors/src/fluvio/mod.rs index 9ff76464a..a8834a299 100644 --- a/crates/arroyo-connectors/src/fluvio/mod.rs +++ b/crates/arroyo-connectors/src/fluvio/mod.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, bail}; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ConnectionProfile, ConnectionSchema, TestSourceMessage}; use arroyo_rpc::OperatorConfig; use fluvio::Offset; @@ -173,10 +173,10 @@ impl Connector for FluvioConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { match table.type_ { - TableType::Source { offset } => { - Ok(OperatorNode::from_source(Box::new(FluvioSourceFunc { + TableType::Source { offset } => Ok(ConstructedOperator::from_source(Box::new( + FluvioSourceFunc { topic: table.topic, endpoint: table.endpoint.clone(), offset_mode: offset, @@ -185,18 +185,20 @@ impl Connector for FluvioConnector { .ok_or_else(|| anyhow!("format required for fluvio source"))?, framing: config.framing, bad_data: config.bad_data, - }))) - } - TableType::Sink { .. } => Ok(OperatorNode::from_operator(Box::new(FluvioSinkFunc { - topic: table.topic, - endpoint: table.endpoint, - producer: None, - serializer: ArrowSerializer::new( - config - .format - .ok_or_else(|| anyhow!("format required for fluvio sink"))?, - ), - }))), + }, + ))), + TableType::Sink { .. } => Ok(ConstructedOperator::from_operator(Box::new( + FluvioSinkFunc { + topic: table.topic, + endpoint: table.endpoint, + producer: None, + serializer: ArrowSerializer::new( + config + .format + .ok_or_else(|| anyhow!("format required for fluvio sink"))?, + ), + }, + ))), } } } diff --git a/crates/arroyo-connectors/src/fluvio/sink.rs b/crates/arroyo-connectors/src/fluvio/sink.rs index edf1ad513..7580a0293 100644 --- a/crates/arroyo-connectors/src/fluvio/sink.rs +++ b/crates/arroyo-connectors/src/fluvio/sink.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use arroyo_formats::ser::ArrowSerializer; use tracing::info; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_types::CheckpointBarrier; @@ -32,23 +32,29 @@ impl ArrowOperator for FluvioSinkFunc { format!("fluvio-sink-{}", self.topic) } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { match self.get_producer().await { Ok(producer) => { self.producer = Some(producer); } Err(e) => { - ctx.report_error( - "Failed to construct Fluvio producer".to_string(), - e.to_string(), - ) - .await; + ctx.error_reporter + .report_error( + "Failed to construct Fluvio producer".to_string(), + e.to_string(), + ) + .await; panic!("Failed to construct Fluvio producer: {:?}", e); } } } - async fn process_batch(&mut self, batch: RecordBatch, _: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { let values = self.serializer.serialize(&batch); for v in values { self.producer @@ -60,7 +66,12 @@ impl ArrowOperator for FluvioSinkFunc { } } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, _: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.producer.as_mut().unwrap().flush().await.unwrap(); } } diff --git a/crates/arroyo-connectors/src/fluvio/source.rs b/crates/arroyo-connectors/src/fluvio/source.rs index 4e7468a78..491a82b54 100644 --- a/crates/arroyo-connectors/src/fluvio/source.rs +++ b/crates/arroyo-connectors/src/fluvio/source.rs @@ -1,5 +1,5 @@ use anyhow::anyhow; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::{BadData, Format, Framing}; @@ -48,16 +48,18 @@ impl SourceOperator for FluvioSourceFunc { global_table_config("f", "fluvio source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { - ctx.initialize_deserializer( + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), ); - } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { ctx.report_error(e.name.clone(), e.details.clone()).await; @@ -71,7 +73,7 @@ impl SourceOperator for FluvioSourceFunc { impl FluvioSourceFunc { async fn get_consumer( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, ) -> anyhow::Result>>> { info!("Creating Fluvio consumer for {:?}", self.endpoint); @@ -110,7 +112,9 @@ impl FluvioSourceFunc { let has_state = !state.is_empty(); let parts: Vec<_> = (0..partitions) - .filter(|i| *i % ctx.task_info.parallelism == ctx.task_info.task_index) + .filter(|i| { + *i % ctx.task_info.parallelism as usize == ctx.task_info.task_index as usize + }) .map(|i| { let offset = state .get(&(i as u32)) @@ -141,7 +145,11 @@ impl FluvioSourceFunc { Ok(streams) } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { let mut streams = self .get_consumer(ctx) .await @@ -150,10 +158,9 @@ impl FluvioSourceFunc { if streams.is_empty() { warn!("Fluvio Consumer {}-{} is subscribed to no partitions, as there are more subtasks than partitions... setting idle", ctx.task_info.operator_id, ctx.task_info.task_index); - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; } let mut flush_ticker = tokio::time::interval(Duration::from_millis(50)); @@ -166,10 +173,10 @@ impl FluvioSourceFunc { match message { Some((_, Ok(msg))) => { let timestamp = from_millis(msg.timestamp().max(0) as u64); - ctx.deserialize_slice(msg.value(), timestamp, None).await?; + collector.deserialize_slice(msg.value(), timestamp, None).await?; - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } offsets.insert(msg.partition(), msg.offset()); @@ -183,8 +190,8 @@ impl FluvioSourceFunc { } } _ = flush_ticker.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } control_message = ctx.control_rx.recv() => { @@ -202,7 +209,7 @@ impl FluvioSourceFunc { }).await; } - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } }, diff --git a/crates/arroyo-connectors/src/impulse/mod.rs b/crates/arroyo-connectors/src/impulse/mod.rs index a5e46981a..0e7d83cd5 100644 --- a/crates/arroyo-connectors/src/impulse/mod.rs +++ b/crates/arroyo-connectors/src/impulse/mod.rs @@ -2,7 +2,7 @@ mod operator; use anyhow::{anyhow, bail}; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::FieldType::Primitive; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, PrimitiveType, TestSourceMessage, @@ -185,20 +185,22 @@ impl Connector for ImpulseConnector { _: Self::ProfileT, table: Self::TableT, _: OperatorConfig, - ) -> anyhow::Result { - Ok(OperatorNode::from_source(Box::new(ImpulseSourceFunc { - interval: table - .event_time_interval - .map(|i| Duration::from_nanos(i as u64)), - spec: ImpulseSpec::EventsPerSecond(table.event_rate as f32), - limit: table - .message_count - .map(|n| n as usize) - .unwrap_or(usize::MAX), - state: ImpulseSourceState { - counter: 0, - start_time: SystemTime::now(), + ) -> anyhow::Result { + Ok(ConstructedOperator::from_source(Box::new( + ImpulseSourceFunc { + interval: table + .event_time_interval + .map(|i| Duration::from_nanos(i as u64)), + spec: ImpulseSpec::EventsPerSecond(table.event_rate as f32), + limit: table + .message_count + .map(|n| n as usize) + .unwrap_or(usize::MAX), + state: ImpulseSourceState { + counter: 0, + start_time: SystemTime::now(), + }, }, - }))) + ))) } } diff --git a/crates/arroyo-connectors/src/impulse/operator.rs b/crates/arroyo-connectors/src/impulse/operator.rs index a7eb5bf3e..582a7edb9 100644 --- a/crates/arroyo-connectors/src/impulse/operator.rs +++ b/crates/arroyo-connectors/src/impulse/operator.rs @@ -11,7 +11,7 @@ use bincode::{Decode, Encode}; use datafusion::common::ScalarValue; use std::time::{Duration, SystemTime}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_types::{from_millis, print_time, to_millis, to_nanos}; @@ -74,7 +74,7 @@ impl ImpulseSourceFunc { } } - fn batch_size(&self, ctx: &mut ArrowContext) -> usize { + fn batch_size(&self, ctx: &mut SourceContext) -> usize { let duration_micros = self.delay(ctx).as_micros(); if duration_micros == 0 { return 8192; @@ -83,7 +83,7 @@ impl ImpulseSourceFunc { batch_size.clamp(1, 8192) as usize } - fn delay(&self, ctx: &mut ArrowContext) -> Duration { + fn delay(&self, ctx: &mut SourceContext) -> Duration { match self.spec { ImpulseSpec::Delay(d) => d, ImpulseSpec::EventsPerSecond(eps) => { @@ -92,7 +92,11 @@ impl ImpulseSourceFunc { } } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { let delay = self.delay(ctx); info!( "Starting impulse source with start {} delay {:?} and limit {}", @@ -116,7 +120,7 @@ impl ImpulseSourceFunc { let start_time = SystemTime::now() - delay * self.state.counter as u32; - let schema = ctx.out_schema.as_ref().unwrap().schema.clone(); + let schema = ctx.out_schema.schema.clone(); let batch_size = self.batch_size(ctx); @@ -138,18 +142,19 @@ impl ImpulseSourceFunc { let counter_column = counter_builder.finish(); let task_index_column = task_index_scalar.to_array_of_size(items).unwrap(); let timestamp_column = timestamp_builder.finish(); - ctx.collect( - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(counter_column), - Arc::new(task_index_column), - Arc::new(timestamp_column), - ], + collector + .collect( + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(counter_column), + Arc::new(task_index_column), + Arc::new(timestamp_column), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; items = 0; } @@ -163,18 +168,19 @@ impl ImpulseSourceFunc { let counter_column = counter_builder.finish(); let task_index_column = task_index_scalar.to_array_of_size(items).unwrap(); let timestamp_column = timestamp_builder.finish(); - ctx.collect( - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(counter_column), - Arc::new(task_index_column), - Arc::new(timestamp_column), - ], + collector + .collect( + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(counter_column), + Arc::new(task_index_column), + Arc::new(timestamp_column), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; items = 0; } ctx.table_manager @@ -183,7 +189,7 @@ impl ImpulseSourceFunc { .unwrap() .insert(ctx.task_info.task_index, self.state) .await; - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return SourceFinishType::Immediate; } } @@ -203,7 +209,7 @@ impl ImpulseSourceFunc { unreachable!("sources shouldn't receive commit messages"); } Ok(ControlMessage::LoadCompacted { compacted }) => { - ctx.load_compacted(compacted).await; + ctx.table_manager.load_compacted(&compacted).await.unwrap(); } Ok(ControlMessage::NoOp) => {} Err(_) => { @@ -222,18 +228,19 @@ impl ImpulseSourceFunc { let counter_column = counter_builder.finish(); let task_index_column = task_index_scalar.to_array_of_size(items).unwrap(); let timestamp_column = timestamp_builder.finish(); - ctx.collect( - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(counter_column), - Arc::new(task_index_column), - Arc::new(timestamp_column), - ], + collector + .collect( + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(counter_column), + Arc::new(task_index_column), + Arc::new(timestamp_column), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; } SourceFinishType::Final @@ -250,7 +257,7 @@ impl SourceOperator for ImpulseSourceFunc { arroyo_state::global_table_config("i", "impulse source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut SourceContext) { let s = ctx .table_manager .get_global_keyed_state("i") @@ -262,7 +269,11 @@ impl SourceOperator for ImpulseSourceFunc { } } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - self.run(ctx).await + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + self.run(ctx, collector).await } } diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index 57447927f..9d4e402b4 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -41,7 +41,7 @@ use crate::{pull_opt, send, ConnectionType}; use crate::kafka::sink::KafkaSinkFunc; use crate::kafka::source::KafkaSourceFunc; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; mod sink; mod source; @@ -364,7 +364,7 @@ impl Connector for KafkaConnector { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { match &table.type_ { TableType::Source { group_id, @@ -398,48 +398,52 @@ impl Connector for KafkaConnector { None }; - Ok(OperatorNode::from_source(Box::new(KafkaSourceFunc { - topic: table.topic, - bootstrap_servers: profile.bootstrap_servers.to_string(), - group_id: group_id.clone(), - group_id_prefix: group_id_prefix.clone(), - offset_mode: *offset, - format: config.format.expect("Format must be set for Kafka source"), - framing: config.framing, - schema_resolver, - bad_data: config.bad_data, - client_configs, - context: Context::new(Some(profile.clone())), - messages_per_second: NonZeroU32::new( - config - .rate_limit - .map(|l| l.messages_per_second) - .unwrap_or(u32::MAX), - ) - .unwrap(), - metadata_fields: config.metadata_fields, - }))) + Ok(ConstructedOperator::from_source(Box::new( + KafkaSourceFunc { + topic: table.topic, + bootstrap_servers: profile.bootstrap_servers.to_string(), + group_id: group_id.clone(), + group_id_prefix: group_id_prefix.clone(), + offset_mode: *offset, + format: config.format.expect("Format must be set for Kafka source"), + framing: config.framing, + schema_resolver, + bad_data: config.bad_data, + client_configs, + context: Context::new(Some(profile.clone())), + messages_per_second: NonZeroU32::new( + config + .rate_limit + .map(|l| l.messages_per_second) + .unwrap_or(u32::MAX), + ) + .unwrap(), + metadata_fields: config.metadata_fields, + }, + ))) } TableType::Sink { commit_mode, key_field, timestamp_field, - } => Ok(OperatorNode::from_operator(Box::new(KafkaSinkFunc { - bootstrap_servers: profile.bootstrap_servers.to_string(), - producer: None, - consistency_mode: (*commit_mode).into(), - timestamp_field: timestamp_field.clone(), - timestamp_col: None, - key_field: key_field.clone(), - key_col: None, - write_futures: vec![], - client_config: client_configs(&profile, Some(table.clone()))?, - context: Context::new(Some(profile.clone())), - topic: table.topic, - serializer: ArrowSerializer::new( - config.format.expect("Format must be defined for KafkaSink"), - ), - }))), + } => Ok(ConstructedOperator::from_operator(Box::new( + KafkaSinkFunc { + bootstrap_servers: profile.bootstrap_servers.to_string(), + producer: None, + consistency_mode: (*commit_mode).into(), + timestamp_field: timestamp_field.clone(), + timestamp_col: None, + key_field: key_field.clone(), + key_col: None, + write_futures: vec![], + client_config: client_configs(&profile, Some(table.clone()))?, + context: Context::new(Some(profile.clone())), + topic: table.topic, + serializer: ArrowSerializer::new( + config.format.expect("Format must be defined for KafkaSink"), + ), + }, + ))), } } } diff --git a/crates/arroyo-connectors/src/kafka/sink/mod.rs b/crates/arroyo-connectors/src/kafka/sink/mod.rs index 4b90d4080..ba2123618 100644 --- a/crates/arroyo-connectors/src/kafka/sink/mod.rs +++ b/crates/arroyo-connectors/src/kafka/sink/mod.rs @@ -2,7 +2,7 @@ use anyhow::Result; use std::borrow::Cow; use arroyo_rpc::grpc::rpc::{GlobalKeyedTableConfig, TableConfig, TableEnum}; -use arroyo_rpc::{CheckpointEvent, ControlMessage, ControlResp}; +use arroyo_rpc::{CheckpointEvent, ControlResp}; use arroyo_types::*; use std::collections::HashMap; use std::fmt::{Display, Formatter}; @@ -16,7 +16,7 @@ use rdkafka::ClientConfig; use arrow::array::{Array, AsArray, RecordBatch}; use arrow::datatypes::{DataType, TimeUnit}; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ArrowOperator, AsDisplayable, DisplayableOperator}; use arroyo_rpc::df::ArroyoSchema; use arroyo_types::CheckpointBarrier; @@ -75,10 +75,6 @@ impl From for ConsistencyMode { } impl KafkaSinkFunc { - fn is_committing(&self) -> bool { - matches!(self.consistency_mode, ConsistencyMode::ExactlyOnce { .. }) - } - fn set_timestamp_col(&mut self, schema: &ArroyoSchema) { if let Some(f) = &self.timestamp_field { if let Ok(f) = schema.schema.field_with_name(f) { @@ -163,7 +159,7 @@ impl KafkaSinkFunc { Ok(()) } - async fn flush(&mut self, ctx: &mut ArrowContext) { + async fn flush(&mut self, ctx: &mut OperatorContext) { self.producer .as_ref() .unwrap() @@ -188,7 +184,7 @@ impl KafkaSinkFunc { ts: Option, k: Option>, v: Vec, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) { let mut rec = { let mut rec = FutureRecord::, Vec>::to(&self.topic); @@ -271,7 +267,11 @@ impl ArrowOperator for KafkaSinkFunc { } } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + fn is_committing(&self) -> bool { + matches!(self.consistency_mode, ConsistencyMode::ExactlyOnce { .. }) + } + + async fn on_start(&mut self, ctx: &mut OperatorContext) { self.set_timestamp_col(&ctx.in_schemas[0]); self.set_key_col(&ctx.in_schemas[0]); @@ -279,7 +279,12 @@ impl ArrowOperator for KafkaSinkFunc { .expect("Producer creation failed"); } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let values = self.serializer.serialize(&batch); let timestamps = batch .column( @@ -306,7 +311,12 @@ impl ArrowOperator for KafkaSinkFunc { } } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.flush(ctx).await; if let ConsistencyMode::ExactlyOnce { next_transaction_index, @@ -330,7 +340,7 @@ impl ArrowOperator for KafkaSinkFunc { &mut self, epoch: u32, _commit_data: &HashMap>>, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) { let ConsistencyMode::ExactlyOnce { next_transaction_index: _, @@ -342,8 +352,10 @@ impl ArrowOperator for KafkaSinkFunc { }; let Some(committing_producer) = producer_to_complete.take() else { - unimplemented!("received a commit message without a producer ready to commit. Restoring from commit phase not yet implemented"); + error!("received a commit message without a producer ready to commit. Restoring from commit phase not yet implemented"); + return; }; + let mut commits_attempted = 0; loop { if committing_producer @@ -360,8 +372,9 @@ impl ArrowOperator for KafkaSinkFunc { } let checkpoint_event = ControlResp::CheckpointEvent(CheckpointEvent { checkpoint_epoch: epoch, + node_id: ctx.task_info.node_id, operator_id: ctx.task_info.operator_id.clone(), - subtask_index: ctx.task_info.task_index as u32, + subtask_index: ctx.task_info.task_index, time: SystemTime::now(), event_type: arroyo_rpc::grpc::rpc::TaskCheckpointEventType::FinishedCommit, }); @@ -371,15 +384,12 @@ impl ArrowOperator for KafkaSinkFunc { .expect("sent commit event"); } - async fn on_close(&mut self, _: &Option, ctx: &mut ArrowContext) { + async fn on_close( + &mut self, + _: &Option, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.flush(ctx).await; - if !self.is_committing() { - return; - } - if let Some(ControlMessage::Commit { epoch, commit_data }) = ctx.control_rx.recv().await { - self.handle_commit(epoch, &commit_data, ctx).await; - } else { - warn!("no commit message received, not committing") - } } } diff --git a/crates/arroyo-connectors/src/kafka/sink/test.rs b/crates/arroyo-connectors/src/kafka/sink/test.rs index 6885adeba..30896e80b 100644 --- a/crates/arroyo-connectors/src/kafka/sink/test.rs +++ b/crates/arroyo-connectors/src/kafka/sink/test.rs @@ -8,7 +8,7 @@ use arrow::array::{RecordBatch, UInt32Array}; use arrow::datatypes::Field; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::OperatorContext; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{Format, JsonFormat}; @@ -24,6 +24,7 @@ use tokio::sync::mpsc::channel; use super::{ConsistencyMode, KafkaSinkFunc}; use crate::kafka::Context; +use crate::test::DummyCollector; pub struct KafkaTopicTester { topic: String, @@ -86,21 +87,17 @@ impl KafkaTopicTester { key_col: None, }; - let (_, control_rx) = channel(128); let (command_tx, _) = channel(128); - let task_info = get_test_task_info(); + let task_info = Arc::new(get_test_task_info()); - let mut ctx = ArrowContext::new( + let mut ctx = OperatorContext::new( task_info, None, - control_rx, command_tx, 1, vec![ArroyoSchema::new_unkeyed(schema(), 0)], None, - None, - vec![vec![]], HashMap::new(), ) .await; @@ -138,7 +135,7 @@ async fn get_data(consumer: &mut StreamConsumer) -> String { struct KafkaSinkWithWrites { sink: KafkaSinkFunc, - ctx: ArrowContext, + ctx: OperatorContext, } #[tokio::test] @@ -158,7 +155,7 @@ async fn test_kafka_checkpoint_flushes() { sink_with_writes .sink - .process_batch(batch, &mut sink_with_writes.ctx) + .process_batch(batch, &mut sink_with_writes.ctx, &mut DummyCollector {}) .await; } let barrier = CheckpointBarrier { @@ -169,7 +166,7 @@ async fn test_kafka_checkpoint_flushes() { }; sink_with_writes .sink - .handle_checkpoint(barrier, &mut sink_with_writes.ctx) + .handle_checkpoint(barrier, &mut sink_with_writes.ctx, &mut DummyCollector {}) .await; for message in 1u32..200 { @@ -196,7 +193,7 @@ async fn test_kafka() { sink_with_writes .sink - .process_batch(batch, &mut sink_with_writes.ctx) + .process_batch(batch, &mut sink_with_writes.ctx, &mut DummyCollector {}) .await; sink_with_writes .sink diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 299bf40b6..38c4a1cd2 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -14,13 +14,13 @@ use tokio::time::MissedTickBehavior; use tracing::{debug, error, info, warn}; use arroyo_formats::de::FieldValueType; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::{BadData, Format, Framing}; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::schema_resolver::SchemaResolver; -use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage, ControlResp, MetadataField}; +use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage, MetadataField}; use arroyo_types::*; use super::{Context, SourceOffset, StreamConsumer}; @@ -51,7 +51,7 @@ pub struct KafkaState { } impl KafkaSourceFunc { - async fn get_consumer(&mut self, ctx: &mut ArrowContext) -> anyhow::Result { + async fn get_consumer(&mut self, ctx: &mut SourceContext) -> anyhow::Result { info!("Creating kafka consumer for {}", self.bootstrap_servers); let mut client_config = ClientConfig::new(); @@ -113,7 +113,9 @@ impl KafkaSourceFunc { partitions .iter() .enumerate() - .filter(|(i, _)| i % ctx.task_info.parallelism == ctx.task_info.task_index) + .filter(|(i, _)| { + i % ctx.task_info.parallelism as usize == ctx.task_info.task_index as usize + }) .map(|(_, p)| { let offset = state .get(&p.id()) @@ -145,7 +147,11 @@ impl KafkaSourceFunc { Ok(consumer) } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { let consumer = self .get_consumer(ctx) .await @@ -157,21 +163,20 @@ impl KafkaSourceFunc { if consumer.assignment().unwrap().count() == 0 { warn!("Kafka Consumer {}-{} is subscribed to no partitions, as there are more subtasks than partitions... setting idle", ctx.task_info.operator_id, ctx.task_info.task_index); - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; } if let Some(schema_resolver) = &self.schema_resolver { - ctx.initialize_deserializer_with_resolver( + collector.initialize_deserializer_with_resolver( self.format.clone(), self.framing.clone(), self.bad_data.clone(), schema_resolver.clone(), ); } else { - ctx.initialize_deserializer( + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), @@ -209,11 +214,10 @@ impl KafkaSourceFunc { None }; - ctx.deserialize_slice(v, from_millis(timestamp.max(0) as u64), connector_metadata.as_ref()).await?; - + collector.deserialize_slice(v, from_millis(timestamp.max(0) as u64), connector_metadata.as_ref()).await?; - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } offsets.insert(msg.partition(), msg.offset()); @@ -226,8 +230,8 @@ impl KafkaSourceFunc { } } _ = flush_ticker.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } control_message = ctx.control_rx.recv() => { @@ -251,7 +255,7 @@ impl KafkaSourceFunc { // fails. The actual offset is stored in state. warn!("Failed to commit offset to Kafka {:?}", e); } - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } }, @@ -286,19 +290,15 @@ impl KafkaSourceFunc { #[async_trait] impl SourceOperator for KafkaSourceFunc { - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { - ctx.control_tx - .send(ControlResp::Error { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, - message: e.name.clone(), - details: e.details.clone(), - }) - .await - .unwrap(); + ctx.report_user_error(e.clone()).await; panic!("{}: {}", e.name, e.details); } diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index 1898f1115..3dab460c3 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -13,14 +13,17 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use crate::kafka::SourceOffset; -use arroyo_operator::context::{batch_bounded, ArrowContext, BatchReceiver}; +use arroyo_operator::context::{ + batch_bounded, ArrowCollector, BatchReceiver, OperatorContext, SourceCollector, SourceContext, +}; use arroyo_operator::operator::SourceOperator; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{Format, RawStringFormat}; use arroyo_rpc::grpc::rpc::{CheckpointMetadata, OperatorCheckpointMetadata, OperatorMetadata}; use arroyo_rpc::{CheckpointCompleted, ControlMessage, ControlResp, MetadataField}; use arroyo_types::{ - single_item_hash_map, to_micros, ArrowMessage, CheckpointBarrier, SignalMessage, TaskInfo, + single_item_hash_map, to_micros, ArrowMessage, ChainInfo, CheckpointBarrier, SignalMessage, + TaskInfo, }; use rdkafka::admin::{AdminClient, AdminOptions, NewTopic}; use rdkafka::producer::{BaseProducer, BaseRecord}; @@ -105,32 +108,54 @@ impl KafkaTopicTester { operator_ids: vec![task_info.operator_id.clone()], }); - let mut ctx = ArrowContext::new( - task_info, - checkpoint_metadata, - control_rx, - command_tx, + let out_schema = Some(ArroyoSchema::new_unkeyed( + Arc::new(Schema::new(vec![ + Field::new( + "_timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("value", DataType::Utf8, false), + ])), + 0, + )); + + let task_info = Arc::new(task_info); + + let ctx = OperatorContext::new( + task_info.clone(), + checkpoint_metadata.as_ref(), + command_tx.clone(), 1, vec![], - Some(ArroyoSchema::new_unkeyed( - Arc::new(Schema::new(vec![ - Field::new( - "_timestamp", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new("value", DataType::Utf8, false), - ])), - 0, - )), - None, - vec![vec![data_tx]], + out_schema, kafka.tables(), ) .await; + let chain_info = Arc::new(ChainInfo { + job_id: ctx.task_info.job_id.clone(), + node_id: ctx.task_info.node_id, + description: "kafka source".to_string(), + task_index: ctx.task_info.task_index, + }); + + let mut ctx = SourceContext::from_operator(ctx, chain_info.clone(), control_rx); + let arrow_collector = ArrowCollector::new( + chain_info.clone(), + Some(ctx.out_schema.clone()), + vec![vec![data_tx]], + ); + let mut collector = SourceCollector::new( + ctx.out_schema.clone(), + arrow_collector, + command_tx, + &chain_info, + &task_info, + ); + tokio::spawn(async move { - kafka.run(&mut ctx).await; + kafka.run(&mut ctx, &mut collector).await; }); KafkaSourceWithReads { to_control_tx, @@ -356,6 +381,7 @@ async fn test_kafka_with_metadata_fields() { let mut task_info = arroyo_types::get_test_task_info(); task_info.job_id = format!("kafka-job-{}", random::()); + let task_info = Arc::new(task_info); kafka_topic_tester.create_topic().await; @@ -388,11 +414,10 @@ async fn test_kafka_with_metadata_fields() { let checkpoint_metadata = None; - let mut ctx = ArrowContext::new( + let ctx = OperatorContext::new( task_info.clone(), - checkpoint_metadata, - control_rx, - command_tx, + checkpoint_metadata.as_ref(), + command_tx.clone(), 1, vec![], Some(ArroyoSchema::new_unkeyed( @@ -407,18 +432,37 @@ async fn test_kafka_with_metadata_fields() { ])), 0, )), - None, - vec![vec![data_tx]], kafka.tables(), ) .await; + let chain_info = Arc::new(ChainInfo { + job_id: ctx.task_info.job_id.clone(), + node_id: ctx.task_info.node_id, + description: "kafka source".to_string(), + task_index: ctx.task_info.task_index, + }); + + let mut ctx = SourceContext::from_operator(ctx, chain_info.clone(), control_rx); + let arrow_collector = ArrowCollector::new( + chain_info.clone(), + Some(ctx.out_schema.clone()), + vec![vec![data_tx]], + ); + let mut collector = SourceCollector::new( + ctx.out_schema.clone(), + arrow_collector, + command_tx, + &chain_info, + &task_info, + ); + tokio::spawn(async move { - kafka.run(&mut ctx).await; + kafka.run(&mut ctx, &mut collector).await; }); let mut reader = kafka_topic_tester - .get_source_with_reader(task_info.clone(), None) + .get_source_with_reader((*task_info).clone(), None) .await; let mut producer = kafka_topic_tester.get_producer(); diff --git a/crates/arroyo-connectors/src/kinesis/mod.rs b/crates/arroyo-connectors/src/kinesis/mod.rs index 1ece9279f..245ea3798 100644 --- a/crates/arroyo-connectors/src/kinesis/mod.rs +++ b/crates/arroyo-connectors/src/kinesis/mod.rs @@ -13,7 +13,7 @@ use crate::{pull_opt, pull_option_to_i64, ConnectionSchema, ConnectionType, Empt use crate::kinesis::sink::{FlushConfig, KinesisSinkFunc}; use crate::kinesis::source::KinesisSourceFunc; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; const TABLE_SCHEMA: &str = include_str!("./table.json"); const ICON: &str = include_str!("./kinesis.svg"); @@ -175,10 +175,10 @@ impl Connector for KinesisConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> Result { + ) -> Result { match table.type_ { - TableType::Source { offset } => { - Ok(OperatorNode::from_source(Box::new(KinesisSourceFunc { + TableType::Source { offset } => Ok(ConstructedOperator::from_source(Box::new( + KinesisSourceFunc { stream_name: table.stream_name, kinesis_client: None, aws_region: table.aws_region, @@ -189,8 +189,8 @@ impl Connector for KinesisConnector { .ok_or_else(|| anyhow!("format required for kinesis source"))?, framing: config.framing, bad_data: config.bad_data, - }))) - } + }, + ))), TableType::Sink { batch_flush_interval_millis, batch_max_buffer_size, @@ -201,18 +201,20 @@ impl Connector for KinesisConnector { batch_max_buffer_size, records_per_batch, ); - Ok(OperatorNode::from_operator(Box::new(KinesisSinkFunc { - client: None, - in_progress_batch: None, - aws_region: table.aws_region, - name: table.stream_name, - serializer: ArrowSerializer::new( - config - .format - .ok_or_else(|| anyhow!("Format must be defined for KinesisSink"))?, - ), - flush_config, - }))) + Ok(ConstructedOperator::from_operator(Box::new( + KinesisSinkFunc { + client: None, + in_progress_batch: None, + aws_region: table.aws_region, + name: table.stream_name, + serializer: ArrowSerializer::new( + config + .format + .ok_or_else(|| anyhow!("Format must be defined for KinesisSink"))?, + ), + flush_config, + }, + ))) } } } diff --git a/crates/arroyo-connectors/src/kinesis/sink.rs b/crates/arroyo-connectors/src/kinesis/sink.rs index 5382660e2..27b820770 100644 --- a/crates/arroyo-connectors/src/kinesis/sink.rs +++ b/crates/arroyo-connectors/src/kinesis/sink.rs @@ -5,7 +5,7 @@ use std::time::{Duration, Instant}; use anyhow::{bail, Result}; use arrow::array::RecordBatch; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::retry; use arroyo_types::CheckpointBarrier; @@ -32,7 +32,7 @@ impl ArrowOperator for KinesisSinkFunc { format!("kinesis-producer-{}", self.name) } - async fn on_start(&mut self, _ctx: &mut ArrowContext) { + async fn on_start(&mut self, _ctx: &mut OperatorContext) { let mut loader = aws_config::defaults(BehaviorVersion::v2024_03_28()); if let Some(region) = &self.aws_region { loader = loader.region(Region::new(region.clone())); @@ -44,7 +44,12 @@ impl ArrowOperator for KinesisSinkFunc { self.in_progress_batch = Some(BatchRecordPreparer::new(client, self.name.clone())); } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { for v in self.serializer.serialize(&batch) { self.in_progress_batch .as_mut() @@ -56,7 +61,12 @@ impl ArrowOperator for KinesisSinkFunc { } } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, _: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { retry!( self.in_progress_batch.as_mut().unwrap().flush().await, 30, @@ -67,7 +77,7 @@ impl ArrowOperator for KinesisSinkFunc { .expect("could not flush to Kinesis during checkpointing"); } - async fn handle_tick(&mut self, _: u64, ctx: &mut ArrowContext) { + async fn handle_tick(&mut self, _: u64, ctx: &mut OperatorContext, _: &mut dyn Collector) { self.maybe_flush_with_retries(ctx) .await .expect("failed to flush batch during tick"); @@ -75,7 +85,7 @@ impl ArrowOperator for KinesisSinkFunc { } impl KinesisSinkFunc { - async fn maybe_flush_with_retries(&mut self, ctx: &mut ArrowContext) -> Result<()> { + async fn maybe_flush_with_retries(&mut self, ctx: &mut OperatorContext) -> Result<()> { if !self .flush_config .should_flush(self.in_progress_batch.as_ref().unwrap()) diff --git a/crates/arroyo-connectors/src/kinesis/source.rs b/crates/arroyo-connectors/src/kinesis/source.rs index b7565b85d..ce43a0411 100644 --- a/crates/arroyo-connectors/src/kinesis/source.rs +++ b/crates/arroyo-connectors/src/kinesis/source.rs @@ -7,7 +7,7 @@ use std::{ }; use anyhow::{anyhow, bail, Context as AnyhowContext, Result}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::{BadData, Format, Framing}; @@ -164,16 +164,18 @@ impl SourceOperator for KinesisSourceFunc { global_table_config("k", "kinesis source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { - ctx.initialize_deserializer( + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), ); - } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(UserError { name, details, .. }) => { ctx.report_error(name.clone(), details.clone()).await; @@ -189,7 +191,7 @@ impl KinesisSourceFunc { /// It returns a future for each shard to fetch the next shard iterator id. async fn init_shards( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, ) -> anyhow::Result>>> { let mut futures = Vec::new(); let s: &mut GlobalKeyedView = ctx @@ -204,7 +206,7 @@ impl KinesisSourceFunc { .filter(|(shard_id, _shard_state)| { let mut hasher = DefaultHasher::new(); shard_id.hash(&mut hasher); - let shard_hash = hasher.finish() as usize; + let shard_hash = hasher.finish() as u32; shard_hash % ctx.task_info.parallelism == ctx.task_info.task_index }) { @@ -223,7 +225,7 @@ impl KinesisSourceFunc { &mut self, shard_id: String, async_result: AsyncResult, - ctx: &mut ArrowContext, + collector: &mut SourceCollector, ) -> Result>>, UserError> { match async_result { AsyncResult::ShardIteratorIdUpdate(new_shard_iterator) => { @@ -231,7 +233,8 @@ impl KinesisSourceFunc { .await } AsyncResult::GetRecords(get_records) => { - self.handle_get_records(shard_id, get_records, ctx).await + self.handle_get_records(shard_id, get_records, collector) + .await } AsyncResult::NeedNewIterator => self.handle_need_new_iterator(shard_id).await, } @@ -304,14 +307,14 @@ impl KinesisSourceFunc { &mut self, shard_id: String, get_records: GetRecordsOutput, - ctx: &mut ArrowContext, + collector: &mut SourceCollector, ) -> Result>>, UserError> { let last_sequence_number = get_records .records() .last() .map(|record| record.sequence_number().to_owned()); - let next_shard_iterator = self.process_records(get_records, ctx).await?; + let next_shard_iterator = self.process_records(get_records, collector).await?; let shard_state = self.shards.get_mut(&shard_id).unwrap(); if let Some(last_sequence_number) = last_sequence_number { @@ -351,7 +354,11 @@ impl KinesisSourceFunc { /// * A `FuturesUnordered` tha contains futures for reading off of shards. /// * An interval that periodically polls for new shards, initializing their futures. /// * Polling off of the control queue, to perform checkpointing and stop the operator. - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { self.init_client().await; let starting_futures = self .init_shards(ctx) @@ -368,13 +375,13 @@ impl KinesisSourceFunc { result = futures.select_next_some() => { let shard_id = result.name; if let Some(future) = self.handle_async_result_split(shard_id, - result.result.map_err(|e| UserError::new("Fatal Kinesis error", e.to_string()))?, ctx).await? { + result.result.map_err(|e| UserError::new("Fatal Kinesis error", e.to_string()))?, collector).await? { futures.push(future); } }, _ = shard_poll_interval.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } match self.sync_shards(ctx).await { Err(err) => { @@ -394,7 +401,7 @@ impl KinesisSourceFunc { for (shard_id, shard_state) in &self.shards { s.insert(shard_id.clone(), shard_state.clone()).await; } - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } }, @@ -428,17 +435,18 @@ impl KinesisSourceFunc { async fn process_records( &mut self, get_records_output: GetRecordsOutput, - ctx: &mut ArrowContext, + collector: &mut SourceCollector, ) -> Result, UserError> { let records = get_records_output.records; for record in records { let data = record.data.into_inner(); let timestamp = record.approximate_arrival_timestamp.unwrap(); - ctx.deserialize_slice(&data, from_nanos(timestamp.as_nanos() as u128), None) + collector + .deserialize_slice(&data, from_nanos(timestamp.as_nanos() as u128), None) .await?; - if ctx.should_flush() { - ctx.flush_buffer().await? + if collector.should_flush() { + collector.flush_buffer().await? } } Ok(get_records_output.next_shard_iterator) @@ -446,7 +454,7 @@ impl KinesisSourceFunc { async fn sync_shards( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, ) -> Result>>> { let mut futures = Vec::new(); for shard in self.get_splits().await? { @@ -454,7 +462,7 @@ impl KinesisSourceFunc { let shard_id = shard.shard_id().to_string(); let mut hasher = DefaultHasher::new(); shard_id.hash(&mut hasher); - let shard_hash = hasher.finish() as usize; + let shard_hash = hasher.finish() as u32; if self.shards.contains_key(&shard_id) || shard_hash % ctx.task_info.parallelism != ctx.task_info.task_index diff --git a/crates/arroyo-connectors/src/lib.rs b/crates/arroyo-connectors/src/lib.rs index da5d32ce4..bc4892269 100644 --- a/crates/arroyo-connectors/src/lib.rs +++ b/crates/arroyo-connectors/src/lib.rs @@ -1,14 +1,3 @@ -use crate::confluent::ConfluentConnector; -use crate::filesystem::delta::DeltaLakeConnector; -use crate::filesystem::FileSystemConnector; -use crate::kinesis::KinesisConnector; -use crate::mqtt::MqttConnector; -use crate::polling_http::PollingHTTPConnector; -use crate::preview::PreviewConnector; -use crate::redis::RedisConnector; -use crate::single_file::SingleFileConnector; -use crate::stdout::StdoutConnector; -use crate::webhook::WebhookConnector; use anyhow::{anyhow, bail, Context}; use arroyo_operator::connector::ErasedConnector; use arroyo_rpc::api_types::connections::{ @@ -17,23 +6,13 @@ use arroyo_rpc::api_types::connections::{ use arroyo_rpc::primitive_to_sql; use arroyo_rpc::var_str::VarStr; use arroyo_types::string_to_map; -use blackhole::BlackholeConnector; -use fluvio::FluvioConnector; -use impulse::ImpulseConnector; -use nats::NatsConnector; -use nexmark::NexmarkConnector; -use rabbitmq::RabbitmqConnector; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; use serde::{Deserialize, Serialize}; -use sse::SSEConnector; use std::collections::HashMap; use std::time::Duration; use tokio::sync::mpsc::Sender; use tracing::warn; -use websocket::WebsocketConnector; - -use self::kafka::KafkaConnector; pub mod blackhole; pub mod confluent; @@ -57,26 +36,26 @@ pub mod websocket; pub fn connectors() -> HashMap<&'static str, Box> { let connectors: Vec> = vec![ - Box::new(BlackholeConnector {}), - Box::new(ConfluentConnector {}), - Box::new(DeltaLakeConnector {}), - Box::new(FileSystemConnector {}), - Box::new(FluvioConnector {}), - Box::new(ImpulseConnector {}), - Box::new(KafkaConnector {}), - Box::new(KinesisConnector {}), - Box::new(MqttConnector {}), - Box::new(NatsConnector {}), - Box::new(NexmarkConnector {}), - Box::new(PollingHTTPConnector {}), - Box::new(PreviewConnector {}), - Box::new(RabbitmqConnector {}), - Box::new(RedisConnector {}), - Box::new(SingleFileConnector {}), - Box::new(SSEConnector {}), - Box::new(StdoutConnector {}), - Box::new(WebhookConnector {}), - Box::new(WebsocketConnector {}), + Box::new(blackhole::BlackholeConnector {}), + Box::new(confluent::ConfluentConnector {}), + Box::new(filesystem::delta::DeltaLakeConnector {}), + Box::new(filesystem::FileSystemConnector {}), + Box::new(fluvio::FluvioConnector {}), + Box::new(impulse::ImpulseConnector {}), + Box::new(kafka::KafkaConnector {}), + Box::new(kinesis::KinesisConnector {}), + Box::new(mqtt::MqttConnector {}), + Box::new(nats::NatsConnector {}), + Box::new(nexmark::NexmarkConnector {}), + Box::new(polling_http::PollingHTTPConnector {}), + Box::new(preview::PreviewConnector {}), + Box::new(rabbitmq::RabbitmqConnector {}), + Box::new(redis::RedisConnector {}), + Box::new(single_file::SingleFileConnector {}), + Box::new(sse::SSEConnector {}), + Box::new(stdout::StdoutConnector {}), + Box::new(webhook::WebhookConnector {}), + Box::new(websocket::WebsocketConnector {}), ]; connectors.into_iter().map(|c| (c.name(), c)).collect() @@ -181,3 +160,24 @@ pub fn header_map(headers: Option) -> HashMap { ) .expect("Invalid header map") } + +#[cfg(test)] +mod test { + use arrow::array::RecordBatch; + use arroyo_operator::context::Collector; + use arroyo_types::Watermark; + use async_trait::async_trait; + + pub struct DummyCollector {} + + #[async_trait] + impl Collector for DummyCollector { + async fn collect(&mut self, _: RecordBatch) { + unreachable!() + } + + async fn broadcast_watermark(&mut self, _: Watermark) { + unreachable!() + } + } +} diff --git a/crates/arroyo-connectors/src/mqtt/mod.rs b/crates/arroyo-connectors/src/mqtt/mod.rs index d642b7c5f..8c035b98a 100644 --- a/crates/arroyo-connectors/src/mqtt/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/mod.rs @@ -11,7 +11,7 @@ use anyhow::{anyhow, bail}; use arrow::datatypes::DataType; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::connector::{Connection, Connector, MetadataDef}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; @@ -271,10 +271,10 @@ impl Connector for MqttConnector { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { let qos = table.qos(); Ok(match table.type_ { - TableType::Source {} => OperatorNode::from_source(Box::new(MqttSourceFunc { + TableType::Source {} => ConstructedOperator::from_source(Box::new(MqttSourceFunc { config: profile, topic: table.topic, qos, @@ -293,19 +293,21 @@ impl Connector for MqttConnector { subscribed: Arc::new(AtomicBool::new(false)), metadata_fields: config.metadata_fields, })), - TableType::Sink { retain } => OperatorNode::from_operator(Box::new(MqttSinkFunc { - config: profile, - qos, - topic: table.topic, - retain, - serializer: ArrowSerializer::new( - config - .format - .ok_or_else(|| anyhow!("format is required for mqtt sink"))?, - ), - stopped: Arc::new(AtomicBool::new(false)), - client: None, - })), + TableType::Sink { retain } => { + ConstructedOperator::from_operator(Box::new(MqttSinkFunc { + config: profile, + qos, + topic: table.topic, + retain, + serializer: ArrowSerializer::new( + config + .format + .ok_or_else(|| anyhow!("format is required for mqtt sink"))?, + ), + stopped: Arc::new(AtomicBool::new(false)), + client: None, + })) + } }) } } diff --git a/crates/arroyo-connectors/src/mqtt/sink/mod.rs b/crates/arroyo-connectors/src/mqtt/sink/mod.rs index 3bb1b1dc4..56bddda7a 100644 --- a/crates/arroyo-connectors/src/mqtt/sink/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/sink/mod.rs @@ -6,10 +6,9 @@ use std::time::Duration; use crate::mqtt::MqttConfig; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::formats::Format; -use arroyo_rpc::ControlResp; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::AsyncClient; use rumqttc::v5::ConnectionError; @@ -46,10 +45,10 @@ impl ArrowOperator for MqttSinkFunc { fn name(&self) -> String { format!("mqtt-producer-{}", self.topic) } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let mut attempts = 0; while attempts < 20 { - match super::create_connection(&self.config, ctx.task_info.task_index) { + match super::create_connection(&self.config, ctx.task_info.task_index as usize) { Ok((client, mut eventloop)) => { self.client = Some(client); let stopped = self.stopped.clone(); @@ -91,7 +90,12 @@ impl ArrowOperator for MqttSinkFunc { panic!("Failed to establish connection to mqtt after 20 retries"); } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { for v in self.serializer.serialize(&batch) { match self .client @@ -102,16 +106,8 @@ impl ArrowOperator for MqttSinkFunc { { Ok(_) => (), Err(e) => { - ctx.control_tx - .send(ControlResp::Error { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, - message: "Could not write to mqtt".to_string(), - details: format!("{:?}", e), - }) - .await - .unwrap(); - + ctx.report_error("Could not write to mqtt", format!("{:?}", e)) + .await; panic!("Could not write to mqtt: {:?}", e); } } diff --git a/crates/arroyo-connectors/src/mqtt/sink/test.rs b/crates/arroyo-connectors/src/mqtt/sink/test.rs index 6c2e17678..60c1d03e6 100644 --- a/crates/arroyo-connectors/src/mqtt/sink/test.rs +++ b/crates/arroyo-connectors/src/mqtt/sink/test.rs @@ -2,9 +2,11 @@ use arrow::array::{RecordBatch, StringArray}; use std::collections::HashMap; use std::sync::Arc; +use super::MqttSinkFunc; use crate::mqtt::{create_connection, MqttConfig, Tls}; +use crate::test::DummyCollector; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::OperatorContext; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::{ @@ -20,8 +22,6 @@ use rumqttc::{ use serde::Deserialize; use tokio::sync::mpsc::channel; -use super::MqttSinkFunc; - fn schema() -> SchemaRef { Arc::new(Schema::new(vec![Field::new( "value", @@ -75,21 +75,17 @@ impl MqttTopicTester { Format::Json(JsonFormat::default()), ); - let (_, control_rx) = channel(128); let (command_tx, _) = channel(128); - let task_info = get_test_task_info(); + let task_info = Arc::new(get_test_task_info()); - let mut ctx = ArrowContext::new( + let mut ctx = OperatorContext::new( task_info, None, - control_rx, command_tx, 1, vec![ArroyoSchema::new_unkeyed(schema(), 0)], None, - None, - vec![vec![]], HashMap::new(), ) .await; @@ -102,7 +98,7 @@ impl MqttTopicTester { struct MqttSinkWithWrites { sink: MqttSinkFunc, - ctx: ArrowContext, + ctx: OperatorContext, } #[tokio::test] @@ -145,7 +141,7 @@ async fn test_mqtt() { sink_with_writes .sink - .process_batch(batch, &mut sink_with_writes.ctx) + .process_batch(batch, &mut sink_with_writes.ctx, &mut DummyCollector {}) .await; } diff --git a/crates/arroyo-connectors/src/mqtt/source/mod.rs b/crates/arroyo-connectors/src/mqtt/source/mod.rs index 6f8f9f19d..6c2d51577 100644 --- a/crates/arroyo-connectors/src/mqtt/source/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/source/mod.rs @@ -7,15 +7,15 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use arroyo_rpc::formats::{BadData, Format, Framing}; -use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage, ControlResp, MetadataField}; -use arroyo_types::{ArrowMessage, SignalMessage, UserError, Watermark}; +use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage, MetadataField}; +use arroyo_types::{SignalMessage, UserError, Watermark}; use governor::{Quota, RateLimiter as GovernorRateLimiter}; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::{ConnectionError, Event as MqttEvent, Incoming}; use rumqttc::Outgoing; use crate::mqtt::{create_connection, MqttConfig}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::grpc::rpc::TableConfig; @@ -47,19 +47,15 @@ impl SourceOperator for MqttSourceFunc { arroyo_state::global_table_config("m", "mqtt source state") } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { - ctx.control_tx - .send(ControlResp::Error { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, - message: e.name.clone(), - details: e.details.clone(), - }) - .await - .unwrap(); + ctx.report_error(&e.name, &e.details).await; panic!("{}: {}", e.name, e.details); } @@ -96,8 +92,12 @@ impl MqttSourceFunc { self.subscribed.clone() } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { - ctx.initialize_deserializer( + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), @@ -109,14 +109,13 @@ impl MqttSourceFunc { ctx.task_info.operator_id, ctx.task_info.task_index ); - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; } let (client, mut eventloop) = - match create_connection(&self.config, ctx.task_info.task_index) { + match create_connection(&self.config, ctx.task_info.task_index as usize) { Ok(c) => c, Err(e) => { return Err(UserError { @@ -163,7 +162,7 @@ impl MqttSourceFunc { None }; - ctx.deserialize_slice(&p.payload, SystemTime::now(), connector_metadata.as_ref()).await?; + collector.deserialize_slice(&p.payload, SystemTime::now(), connector_metadata.as_ref()).await?; rate_limiter.until_ready().await; } Ok(MqttEvent::Outgoing(Outgoing::Subscribe(_))) => { @@ -190,15 +189,15 @@ impl MqttSourceFunc { } } _ = flush_ticker.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } control_message = ctx.control_rx.recv() => { match control_message { Some(ControlMessage::Checkpoint(c)) => { tracing::debug!("starting checkpointing {}", ctx.task_info.task_index); - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } }, diff --git a/crates/arroyo-connectors/src/mqtt/source/test.rs b/crates/arroyo-connectors/src/mqtt/source/test.rs index f085a7c5e..4690cdc88 100644 --- a/crates/arroyo-connectors/src/mqtt/source/test.rs +++ b/crates/arroyo-connectors/src/mqtt/source/test.rs @@ -5,13 +5,15 @@ use std::sync::Arc; use crate::mqtt::{create_connection, MqttConfig, Tls}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use arroyo_operator::context::{batch_bounded, ArrowContext, BatchReceiver}; +use arroyo_operator::context::{ + batch_bounded, ArrowCollector, BatchReceiver, OperatorContext, SourceCollector, SourceContext, +}; use arroyo_operator::operator::SourceOperator; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{Format, JsonFormat}; use arroyo_rpc::var_str::VarStr; use arroyo_rpc::{ControlMessage, ControlResp}; -use arroyo_types::{ArrowMessage, TaskInfo}; +use arroyo_types::{ArrowMessage, ChainInfo, TaskInfo}; use rand::random; use rumqttc::v5::mqttbytes::QoS; use serde::{Deserialize, Serialize}; @@ -116,6 +118,7 @@ impl MqttTopicTester { async fn get_source_with_reader(&self, task_info: TaskInfo) -> MqttSourceWithReads { let config = self.get_config(); + let task_info = Arc::new(task_info); let mut mqtt = MqttSourceFunc::new( config, @@ -132,11 +135,10 @@ impl MqttTopicTester { let (command_tx, from_control_rx) = channel(128); let (data_tx, recv) = batch_bounded(128); - let mut ctx = ArrowContext::new( - task_info, - None, - control_rx, - command_tx, + let ctx = OperatorContext::new( + task_info.clone(), + None.as_ref(), + command_tx.clone(), 1, vec![], Some(ArroyoSchema::new_unkeyed( @@ -150,16 +152,35 @@ impl MqttTopicTester { ])), 0, )), - None, - vec![vec![data_tx]], mqtt.tables(), ) .await; + let chain_info = Arc::new(ChainInfo { + job_id: ctx.task_info.job_id.clone(), + node_id: ctx.task_info.node_id, + description: "mqtt source".to_string(), + task_index: ctx.task_info.task_index, + }); + + let mut ctx = SourceContext::from_operator(ctx, chain_info.clone(), control_rx); + let arrow_collector = ArrowCollector::new( + chain_info.clone(), + Some(ctx.out_schema.clone()), + vec![vec![data_tx]], + ); + let mut collector = SourceCollector::new( + ctx.out_schema.clone(), + arrow_collector, + command_tx, + &chain_info, + &task_info, + ); + let subscribed = mqtt.subscribed(); tokio::spawn(async move { mqtt.on_start(&mut ctx).await; - mqtt.run(&mut ctx).await; + mqtt.run(&mut ctx, &mut collector).await; }); MqttSourceWithReads { diff --git a/crates/arroyo-connectors/src/nats/mod.rs b/crates/arroyo-connectors/src/nats/mod.rs index 6b7a98fb7..a21ced01d 100644 --- a/crates/arroyo-connectors/src/nats/mod.rs +++ b/crates/arroyo-connectors/src/nats/mod.rs @@ -5,7 +5,7 @@ use anyhow::anyhow; use anyhow::bail; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; @@ -334,10 +334,10 @@ impl Connector for NatsConnector { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { Ok(match table.connector_type { ConnectorType::Source { ref source_type } => { - OperatorNode::from_source(Box::new(NatsSourceFunc { + ConstructedOperator::from_source(Box::new(NatsSourceFunc { source_type: source_type .clone() .ok_or_else(|| anyhow!("`sourceType` is required"))?, @@ -361,7 +361,7 @@ impl Connector for NatsConnector { })) } ConnectorType::Sink { ref sink_type } => { - OperatorNode::from_operator(Box::new(NatsSinkFunc { + ConstructedOperator::from_operator(Box::new(NatsSinkFunc { sink_type: sink_type .clone() .ok_or_else(|| anyhow!("`sinkType` is required"))?, diff --git a/crates/arroyo-connectors/src/nats/sink/mod.rs b/crates/arroyo-connectors/src/nats/sink/mod.rs index 92255ea85..29ad12eb5 100644 --- a/crates/arroyo-connectors/src/nats/sink/mod.rs +++ b/crates/arroyo-connectors/src/nats/sink/mod.rs @@ -3,15 +3,12 @@ use super::NatsTable; use super::{get_nats_client, SinkType}; use arrow::array::RecordBatch; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::grpc::rpc::TableConfig; -use arroyo_rpc::ControlMessage; -use arroyo_rpc::ControlResp; use arroyo_types::*; use async_trait::async_trait; use std::collections::HashMap; -use tracing::warn; pub struct NatsSinkFunc { pub sink_type: SinkType, @@ -39,7 +36,7 @@ impl ArrowOperator for NatsSinkFunc { HashMap::new() } - async fn on_start(&mut self, _ctx: &mut ArrowContext) { + async fn on_start(&mut self, _ctx: &mut OperatorContext) { match get_nats_client(&self.connection).await { Ok(client) => { self.publisher = Some(client); @@ -50,15 +47,12 @@ impl ArrowOperator for NatsSinkFunc { } } - async fn on_close(&mut self, _: &Option, ctx: &mut ArrowContext) { - if let Some(ControlMessage::Commit { epoch, commit_data }) = ctx.control_rx.recv().await { - self.handle_commit(epoch, &commit_data, ctx).await; - } else { - warn!("No commit message received, not committing") - } - } - - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, _ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { // TODO: Implement checkpointing of in-progress data to avoid depending on // the downstream NATS availability to flush and checkpoint. let publisher = self @@ -74,7 +68,12 @@ impl ArrowOperator for NatsSinkFunc { } } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let SinkType::Subject(s) = &self.sink_type; let nats_subject = async_nats::Subject::from(s.clone()); for msg in self.serializer.serialize(&batch) { @@ -86,15 +85,7 @@ impl ArrowOperator for NatsSinkFunc { match publisher.publish(nats_subject.clone(), msg.into()).await { Ok(_) => {} Err(e) => { - ctx.control_tx - .send(ControlResp::Error { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, - message: e.to_string(), - details: e.to_string(), - }) - .await - .expect("Something went wrong, data will never be received."); + ctx.report_error(e.to_string(), format!("{:?}", e)).await; panic!("Panicked while processing element: {}", e); } } diff --git a/crates/arroyo-connectors/src/nats/source/mod.rs b/crates/arroyo-connectors/src/nats/source/mod.rs index 9a81aee41..7cee2bc1e 100644 --- a/crates/arroyo-connectors/src/nats/source/mod.rs +++ b/crates/arroyo-connectors/src/nats/source/mod.rs @@ -5,7 +5,7 @@ use super::NatsState; use super::NatsTable; use super::ReplayPolicy; use super::{get_nats_client, SourceType}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::BadData; @@ -13,7 +13,6 @@ use arroyo_rpc::formats::{Format, Framing}; use arroyo_rpc::grpc::rpc::StopMode; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::ControlMessage; -use arroyo_rpc::ControlResp; use arroyo_rpc::OperatorConfig; use arroyo_types::UserError; use async_nats::jetstream::consumer; @@ -58,19 +57,15 @@ impl SourceOperator for NatsSourceFunc { arroyo_state::global_table_config("n", "NATS source state") } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + match self.run_int(ctx, collector).await { Ok(res) => res, Err(err) => { - ctx.control_tx - .send(ControlResp::Error { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, - message: err.name.clone(), - details: err.details.clone(), - }) - .await - .unwrap(); + ctx.report_user_error(err.clone()).await; panic!("{}: {}", err.name, err.details); } } @@ -173,7 +168,7 @@ impl NatsSourceFunc { &mut self, stream: &async_nats::jetstream::stream::Stream, sequence_number: u64, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, ) -> consumer::Consumer { match sequence_number { 1 => info!( @@ -308,7 +303,7 @@ impl NatsSourceFunc { consumer } - async fn get_start_sequence_number(&self, ctx: &mut ArrowContext) -> anyhow::Result { + async fn get_start_sequence_number(&self, ctx: &mut SourceContext) -> anyhow::Result { let state: Vec<_> = ctx .table_manager .get_global_keyed_state::("n") @@ -329,8 +324,12 @@ impl NatsSourceFunc { } } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { - ctx.initialize_deserializer( + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), @@ -366,7 +365,7 @@ impl NatsSourceFunc { let payload = msg.payload.as_ref(); let message_info = msg.info().expect("Couldn't get message information"); let timestamp = message_info.published.into() ; - ctx.deserialize_slice(payload, timestamp, None).await?; + collector.deserialize_slice(payload, timestamp, None).await?; debug!("---------------------------------------------->"); debug!( @@ -398,8 +397,8 @@ impl NatsSourceFunc { message_info.delivered ); - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } sequence_numbers.insert( @@ -449,7 +448,7 @@ impl NatsSourceFunc { .ok_or_else(|| UserError::new("No sequence number could be fetched from the state", "") ); - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } @@ -492,9 +491,9 @@ impl NatsSourceFunc { Some(msg) => { let payload = msg.payload.as_ref(); let timestamp = SystemTime::now(); - ctx.deserialize_slice(payload, timestamp, None).await?; - if ctx.should_flush() { - ctx.flush_buffer().await?; + collector.deserialize_slice(payload, timestamp, None).await?; + if collector.should_flush() { + collector.flush_buffer().await?; } }, None => { @@ -508,7 +507,7 @@ impl NatsSourceFunc { Some(ControlMessage::Checkpoint(c)) => { // TODO: Is checkpointing necessary for subjects? debug!("Starting checkpointing {}", ctx.task_info.task_index); - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } } diff --git a/crates/arroyo-connectors/src/nexmark/mod.rs b/crates/arroyo-connectors/src/nexmark/mod.rs index 4c9ef358d..e7b7a859b 100644 --- a/crates/arroyo-connectors/src/nexmark/mod.rs +++ b/crates/arroyo-connectors/src/nexmark/mod.rs @@ -5,7 +5,7 @@ mod test; use anyhow::{anyhow, bail}; use arrow::datatypes::{Field, Schema, TimeUnit}; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; @@ -230,8 +230,8 @@ impl Connector for NexmarkConnector { _: Self::ProfileT, table: Self::TableT, _: OperatorConfig, - ) -> anyhow::Result { - Ok(OperatorNode::from_source(Box::new( + ) -> anyhow::Result { + Ok(ConstructedOperator::from_source(Box::new( NexmarkSourceFunc::from_config(&table), ))) } diff --git a/crates/arroyo-connectors/src/nexmark/operator.rs b/crates/arroyo-connectors/src/nexmark/operator.rs index ecc7ab127..a32ebd8d8 100644 --- a/crates/arroyo-connectors/src/nexmark/operator.rs +++ b/crates/arroyo-connectors/src/nexmark/operator.rs @@ -3,7 +3,7 @@ use arrow::array::{ Int64Builder, RecordBatch, StringBuilder, StructBuilder, TimestampNanosecondBuilder, }; use arroyo_formats::should_flush; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::grpc::rpc::{StopMode, TableConfig}; @@ -210,7 +210,7 @@ impl SourceOperator for NexmarkSourceFunc { arroyo_state::global_table_config("s", "nexmark source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut SourceContext) { // load state self.state = Some({ let ss = ctx @@ -219,12 +219,12 @@ impl SourceOperator for NexmarkSourceFunc { .await .expect("should be able to read state"); let saved_states = ss.get_all().len(); - if saved_states != ctx.task_info.parallelism { + if saved_states != ctx.task_info.parallelism as usize { let config = GeneratorConfig::new( NexmarkConfig::new( self.first_event_rate, self.num_events, - ctx.task_info.parallelism, + ctx.task_info.parallelism as usize, ), SystemTime::now(), 1, @@ -233,16 +233,22 @@ impl SourceOperator for NexmarkSourceFunc { ); let splits = config.split(ctx.task_info.parallelism as u64); NexmarkSourceState { - config: splits[ctx.task_info.task_index].clone(), + config: splits[ctx.task_info.task_index as usize].clone(), event_count: 0, } } else { - ss.get(&ctx.task_info.task_index).unwrap().clone() + ss.get(&(ctx.task_info.task_index as usize)) + .unwrap() + .clone() } }); } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { let state = self.state.as_ref().unwrap().clone(); let mut generator = NexmarkGenerator::from_config(&state.config, state.event_count as u64); @@ -271,19 +277,20 @@ impl SourceOperator for NexmarkSourceFunc { timestamp_builder.append_value(to_nanos(next_event.event_timetamp) as i64); if should_flush(records, flush_time) { - ctx.collect( - RecordBatch::try_new( - ctx.out_schema.as_ref().unwrap().schema.clone(), - vec![ - Arc::new(person_builder.finish()), - Arc::new(auction_builder.finish()), - Arc::new(bid_builder.finish()), - Arc::new(timestamp_builder.finish()), - ], + collector + .collect( + RecordBatch::try_new( + ctx.out_schema.schema.clone(), + vec![ + Arc::new(person_builder.finish()), + Arc::new(auction_builder.finish()), + Arc::new(bid_builder.finish()), + Arc::new(timestamp_builder.finish()), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; records = 0; flush_time = Instant::now(); } @@ -298,7 +305,7 @@ impl SourceOperator for NexmarkSourceFunc { .await .expect("should be able to get nexmark state") .insert( - ctx.task_info.task_index, + ctx.task_info.task_index as usize, NexmarkSourceState { config: state.config.clone(), event_count: generator.events_count_so_far as usize, @@ -306,7 +313,7 @@ impl SourceOperator for NexmarkSourceFunc { ) .await; debug!("starting checkpointing {}", ctx.task_info.task_index); - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return SourceFinishType::Immediate; } } diff --git a/crates/arroyo-connectors/src/polling_http/mod.rs b/crates/arroyo-connectors/src/polling_http/mod.rs index 7db361e04..34c3f5452 100644 --- a/crates/arroyo-connectors/src/polling_http/mod.rs +++ b/crates/arroyo-connectors/src/polling_http/mod.rs @@ -21,7 +21,7 @@ use crate::{construct_http_client, pull_opt, pull_option_to_i64, EmptyConfig}; use crate::polling_http::operator::{PollingHttpSourceFunc, PollingHttpSourceState}; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; const TABLE_SCHEMA: &str = include_str!("./table.json"); const DEFAULT_POLLING_INTERVAL: Duration = Duration::from_secs(1); @@ -241,7 +241,7 @@ impl Connector for PollingHTTPConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { let headers = string_to_map( &table .headers @@ -262,31 +262,33 @@ impl Connector for PollingHTTPConnector { }) .collect(); - Ok(OperatorNode::from_source(Box::new(PollingHttpSourceFunc { - state: PollingHttpSourceState::default(), - client: reqwest::ClientBuilder::new() - .default_headers(headers) - .timeout(Duration::from_secs(5)) - .build() - .expect("could not construct http client"), - endpoint: url::Url::from_str(&table.endpoint).expect("invalid endpoint"), - method: match table.method { - None | Some(Method::Get) => reqwest::Method::GET, - Some(Method::Post) => reqwest::Method::POST, - Some(Method::Put) => reqwest::Method::PUT, - Some(Method::Patch) => reqwest::Method::PATCH, + Ok(ConstructedOperator::from_source(Box::new( + PollingHttpSourceFunc { + state: PollingHttpSourceState::default(), + client: reqwest::ClientBuilder::new() + .default_headers(headers) + .timeout(Duration::from_secs(5)) + .build() + .expect("could not construct http client"), + endpoint: url::Url::from_str(&table.endpoint).expect("invalid endpoint"), + method: match table.method { + None | Some(Method::Get) => reqwest::Method::GET, + Some(Method::Post) => reqwest::Method::POST, + Some(Method::Put) => reqwest::Method::PUT, + Some(Method::Patch) => reqwest::Method::PATCH, + }, + body: table.body.map(|b| b.into()), + polling_interval: table + .poll_interval_ms + .map(|d| Duration::from_millis(d as u64)) + .unwrap_or(DEFAULT_POLLING_INTERVAL), + emit_behavior: table.emit_behavior.unwrap_or(EmitBehavior::All), + format: config + .format + .expect("PollingHTTP source must have a format"), + framing: config.framing, + bad_data: config.bad_data, }, - body: table.body.map(|b| b.into()), - polling_interval: table - .poll_interval_ms - .map(|d| Duration::from_millis(d as u64)) - .unwrap_or(DEFAULT_POLLING_INTERVAL), - emit_behavior: table.emit_behavior.unwrap_or(EmitBehavior::All), - format: config - .format - .expect("PollingHTTP source must have a format"), - framing: config.framing, - bad_data: config.bad_data, - }))) + ))) } } diff --git a/crates/arroyo-connectors/src/polling_http/operator.rs b/crates/arroyo-connectors/src/polling_http/operator.rs index ec150cebd..f6217421c 100644 --- a/crates/arroyo-connectors/src/polling_http/operator.rs +++ b/crates/arroyo-connectors/src/polling_http/operator.rs @@ -8,13 +8,13 @@ use std::time::Duration; use std::time::SystemTime; use arroyo_rpc::ControlMessage; -use arroyo_types::{ArrowMessage, SignalMessage, UserError, Watermark}; +use arroyo_types::{SignalMessage, UserError, Watermark}; use tokio::select; use tokio::time::MissedTickBehavior; use crate::polling_http::EmitBehavior; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::{BadData, Format, Framing}; @@ -52,7 +52,7 @@ impl SourceOperator for PollingHttpSourceFunc { arroyo_state::global_table_config("s", "polling http source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut SourceContext) { let s: &mut GlobalKeyedView<(), PollingHttpSourceState> = ctx .table_manager .get_global_keyed_state("s") @@ -64,8 +64,12 @@ impl SourceOperator for PollingHttpSourceFunc { } } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { ctx.report_error(e.name.clone(), e.details.clone()).await; @@ -79,7 +83,8 @@ impl SourceOperator for PollingHttpSourceFunc { impl PollingHttpSourceFunc { async fn our_handle_control_message( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, msg: Option, ) -> Option { match msg? { @@ -93,7 +98,7 @@ impl PollingHttpSourceFunc { .expect("should be able to get http state"); s.insert((), state).await; - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Some(SourceFinishType::Immediate); } } @@ -194,8 +199,12 @@ impl PollingHttpSourceFunc { } } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { - ctx.initialize_deserializer( + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), @@ -215,10 +224,10 @@ impl PollingHttpSourceFunc { continue; } - ctx.deserialize_slice(&buf, SystemTime::now(), None).await?; + collector.deserialize_slice(&buf, SystemTime::now(), None).await?; - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } self.state.last_message = Some(buf); @@ -229,7 +238,7 @@ impl PollingHttpSourceFunc { } } control_message = ctx.control_rx.recv() => { - if let Some(r) = self.our_handle_control_message(ctx, control_message).await { + if let Some(r) = self.our_handle_control_message(ctx, collector, control_message).await { return Ok(r); } } @@ -237,13 +246,12 @@ impl PollingHttpSourceFunc { } } else { // otherwise set idle and just process control messages - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; loop { let msg = ctx.control_rx.recv().await; - if let Some(r) = self.our_handle_control_message(ctx, msg).await { + if let Some(r) = self.our_handle_control_message(ctx, collector, msg).await { return Ok(r); } } diff --git a/crates/arroyo-connectors/src/preview/mod.rs b/crates/arroyo-connectors/src/preview/mod.rs index 6b3672af4..3f15eb57b 100644 --- a/crates/arroyo-connectors/src/preview/mod.rs +++ b/crates/arroyo-connectors/src/preview/mod.rs @@ -15,7 +15,7 @@ use crate::EmptyConfig; use crate::preview::operator::PreviewSink; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; pub struct PreviewConnector {} @@ -115,7 +115,9 @@ impl Connector for PreviewConnector { _: Self::ProfileT, _: Self::TableT, _: OperatorConfig, - ) -> anyhow::Result { - Ok(OperatorNode::from_operator(Box::::default())) + ) -> anyhow::Result { + Ok(ConstructedOperator::from_operator( + Box::::default(), + )) } } diff --git a/crates/arroyo-connectors/src/preview/operator.rs b/crates/arroyo-connectors/src/preview/operator.rs index 049d6ddae..6634a59c4 100644 --- a/crates/arroyo-connectors/src/preview/operator.rs +++ b/crates/arroyo-connectors/src/preview/operator.rs @@ -3,7 +3,7 @@ use arrow::json::writer::JsonArray; use arrow::json::{Writer, WriterBuilder}; use std::collections::HashMap; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::config::config; use arroyo_rpc::grpc::rpc::controller_grpc_client::ControllerGrpcClient; @@ -31,7 +31,7 @@ impl ArrowOperator for PreviewSink { ) } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let table = ctx.table_manager.get_global_keyed_state("s").await.unwrap(); self.row = *table.get(&ctx.task_info.task_index).unwrap_or(&0); @@ -43,7 +43,12 @@ impl ArrowOperator for PreviewSink { ); } - async fn process_batch(&mut self, mut batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + mut batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let ts = ctx.in_schemas[0].timestamp_index; let timestamps: Vec<_> = batch .column(ts) @@ -73,7 +78,7 @@ impl ArrowOperator for PreviewSink { .send_sink_data(SinkDataReq { job_id: ctx.task_info.job_id.clone(), operator_id: ctx.task_info.operator_id.clone(), - subtask_index: ctx.task_info.task_index as u32, + subtask_index: ctx.task_info.task_index, timestamps, batch: String::from_utf8(buf).unwrap_or_else(|_| String::new()), start_id: self.row as u64, @@ -85,24 +90,34 @@ impl ArrowOperator for PreviewSink { self.row += batch.num_rows(); } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let table = ctx .table_manager - .get_global_keyed_state::("s") + .get_global_keyed_state::("s") .await .unwrap(); table.insert(ctx.task_info.task_index, self.row).await; } - async fn on_close(&mut self, _: &Option, ctx: &mut ArrowContext) { + async fn on_close( + &mut self, + _: &Option, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.client .as_mut() .unwrap() .send_sink_data(SinkDataReq { job_id: ctx.task_info.job_id.clone(), operator_id: ctx.task_info.operator_id.clone(), - subtask_index: ctx.task_info.task_index as u32, + subtask_index: ctx.task_info.task_index, timestamps: vec![], batch: "[]".to_string(), start_id: self.row as u64, diff --git a/crates/arroyo-connectors/src/rabbitmq/mod.rs b/crates/arroyo-connectors/src/rabbitmq/mod.rs index 04373e19e..2331b0547 100644 --- a/crates/arroyo-connectors/src/rabbitmq/mod.rs +++ b/crates/arroyo-connectors/src/rabbitmq/mod.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, bail}; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::{api_types::connections::TestSourceMessage, OperatorConfig}; use rabbitmq_stream_client::types::OffsetSpecification; use rabbitmq_stream_client::{Environment, TlsConfiguration}; @@ -205,9 +205,9 @@ impl Connector for RabbitmqConnector { profile: Self::ProfileT, table: Self::TableT, config: arroyo_rpc::OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { match table.type_ { - TableType::Source { offset } => Ok(OperatorNode::from_source(Box::new( + TableType::Source { offset } => Ok(ConstructedOperator::from_source(Box::new( RabbitmqStreamSourceFunc { config: profile, stream: table.stream, diff --git a/crates/arroyo-connectors/src/rabbitmq/source.rs b/crates/arroyo-connectors/src/rabbitmq/source.rs index c810339ff..05d7a53e5 100644 --- a/crates/arroyo-connectors/src/rabbitmq/source.rs +++ b/crates/arroyo-connectors/src/rabbitmq/source.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use std::time::{Duration, SystemTime}; -use arroyo_operator::{context::ArrowContext, operator::SourceOperator, SourceFinishType}; +use super::{RabbitmqStreamConfig, SourceOffset}; +use arroyo_operator::context::SourceContext; +use arroyo_operator::{context::SourceCollector, operator::SourceOperator, SourceFinishType}; use arroyo_rpc::formats::{BadData, Format, Framing}; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage}; @@ -15,8 +17,6 @@ use tokio::{select, time::MissedTickBehavior}; use tokio_stream::StreamExt; use tracing::{debug, error, info}; -use super::{RabbitmqStreamConfig, SourceOffset}; - pub struct RabbitmqStreamSourceFunc { pub config: RabbitmqStreamConfig, pub stream: String, @@ -41,16 +41,18 @@ impl SourceOperator for RabbitmqStreamSourceFunc { arroyo_state::global_table_config("s", "rabbitmq stream source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { - ctx.initialize_deserializer( + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), ); - } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { ctx.report_error(e.name.clone(), e.details.clone()).await; @@ -62,7 +64,7 @@ impl SourceOperator for RabbitmqStreamSourceFunc { } impl RabbitmqStreamSourceFunc { - async fn get_consumer(&mut self, ctx: &mut ArrowContext) -> anyhow::Result { + async fn get_consumer(&mut self, ctx: &mut SourceContext) -> anyhow::Result { info!( "Creating rabbitmq stream consumer for {}", self.config.host.clone().unwrap_or("localhost".to_string()) @@ -91,7 +93,11 @@ impl RabbitmqStreamSourceFunc { Ok(consumer) } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { let mut consumer = self.get_consumer(ctx).await.map_err(|e| { UserError::new( "Could not create RabbitMQ Stream consumer", @@ -113,11 +119,11 @@ impl RabbitmqStreamSourceFunc { if let Some(data) = message.data() { let timestamp = SystemTime::now(); - ctx.deserialize_slice(data, timestamp, None).await?; + collector.deserialize_slice(data, timestamp, None).await?; } - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } offset = delivery.offset(); @@ -132,8 +138,8 @@ impl RabbitmqStreamSourceFunc { }, _ = flush_ticker.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } }, control_message = ctx.control_rx.recv() => { @@ -149,7 +155,7 @@ impl RabbitmqStreamSourceFunc { offset }).await; - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Ok(SourceFinishType::Immediate); } }, diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index cf9d4e045..7a78e88d0 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -3,7 +3,7 @@ mod operator; use anyhow::{anyhow, bail}; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::connector::{Connection, Connector}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::var_str::VarStr; use redis::aio::ConnectionManager; use redis::cluster::ClusterClient; @@ -397,23 +397,25 @@ impl Connector for RedisConnector { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { let client = RedisClient::new(&profile)?; let (tx, cmd_rx) = tokio::sync::mpsc::channel(128); let (cmd_tx, rx) = tokio::sync::mpsc::channel(128); - Ok(OperatorNode::from_operator(Box::new(RedisSinkFunc { - serializer: ArrowSerializer::new( - config.format.expect("redis table must have a format"), - ), - table, - client, - cmd_q: Some((cmd_tx, cmd_rx)), - tx, - rx, - key_index: None, - hash_index: None, - }))) + Ok(ConstructedOperator::from_operator(Box::new( + RedisSinkFunc { + serializer: ArrowSerializer::new( + config.format.expect("redis table must have a format"), + ), + table, + client, + cmd_q: Some((cmd_tx, cmd_rx)), + tx, + rx, + key_index: None, + hash_index: None, + }, + ))) } } diff --git a/crates/arroyo-connectors/src/redis/operator/sink.rs b/crates/arroyo-connectors/src/redis/operator/sink.rs index 5fe10989a..93d9e76ba 100644 --- a/crates/arroyo-connectors/src/redis/operator/sink.rs +++ b/crates/arroyo-connectors/src/redis/operator/sink.rs @@ -1,7 +1,7 @@ use crate::redis::{ListOperation, RedisClient, RedisTable, TableType, Target}; use arrow::array::{AsArray, RecordBatch}; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::{ArrowContext, ErrorReporter}; +use arroyo_operator::context::{Collector, ErrorReporter, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_types::CheckpointBarrier; use async_trait::async_trait; @@ -228,7 +228,7 @@ impl ArrowOperator for RedisSinkFunc { "RedisSink".to_string() } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { match &self.table.connector_type { TableType::Target(Target::ListTable { list_key_column: Some(key), @@ -321,7 +321,12 @@ impl ArrowOperator for RedisSinkFunc { panic!("Failed to establish connection to redis after 20 retries"); } - async fn process_batch(&mut self, batch: RecordBatch, _: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { for (i, value) in self.serializer.serialize(&batch).enumerate() { match &self.table.connector_type { TableType::Target(target) => match &target { @@ -360,7 +365,12 @@ impl ArrowOperator for RedisSinkFunc { } } - async fn handle_checkpoint(&mut self, checkpoint: CheckpointBarrier, _ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + checkpoint: CheckpointBarrier, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.tx .send(RedisCmd::Flush(checkpoint.epoch)) .await diff --git a/crates/arroyo-connectors/src/single_file/mod.rs b/crates/arroyo-connectors/src/single_file/mod.rs index c5cb0cf89..5d41f1024 100644 --- a/crates/arroyo-connectors/src/single_file/mod.rs +++ b/crates/arroyo-connectors/src/single_file/mod.rs @@ -15,7 +15,7 @@ use crate::{pull_opt, EmptyConfig}; use crate::single_file::sink::SingleFileSink; use crate::single_file::source::SingleFileSourceFunc; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; const TABLE_SCHEMA: &str = include_str!("./table.json"); @@ -155,27 +155,31 @@ impl Connector for SingleFileConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> Result { + ) -> Result { match table.table_type { - TableType::Source => Ok(OperatorNode::from_source(Box::new(SingleFileSourceFunc { - input_file: table.path, - lines_read: 0, - format: config - .format - .expect("Format must be set for Single File Source"), - framing: config.framing, - bad_data: config.bad_data, - wait_for_control: table.wait_for_control.unwrap_or(true), - }))), - TableType::Sink => Ok(OperatorNode::from_operator(Box::new(SingleFileSink { - output_path: table.path, - file: None, - serializer: ArrowSerializer::new( - config + TableType::Source => Ok(ConstructedOperator::from_source(Box::new( + SingleFileSourceFunc { + input_file: table.path, + lines_read: 0, + format: config .format - .expect("Format must be set for Single File Sink"), - ), - }))), + .expect("Format must be set for Single File Source"), + framing: config.framing, + bad_data: config.bad_data, + wait_for_control: table.wait_for_control.unwrap_or(true), + }, + ))), + TableType::Sink => Ok(ConstructedOperator::from_operator(Box::new( + SingleFileSink { + output_path: table.path, + file: None, + serializer: ArrowSerializer::new( + config + .format + .expect("Format must be set for Single File Sink"), + ), + }, + ))), } } } diff --git a/crates/arroyo-connectors/src/single_file/sink.rs b/crates/arroyo-connectors/src/single_file/sink.rs index feaff9f46..8b759d1cd 100644 --- a/crates/arroyo-connectors/src/single_file/sink.rs +++ b/crates/arroyo-connectors/src/single_file/sink.rs @@ -8,7 +8,7 @@ use arroyo_types::{CheckpointBarrier, SignalMessage}; use async_trait::async_trait; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use tokio::{ fs::{self, File, OpenOptions}, @@ -31,7 +31,12 @@ impl ArrowOperator for SingleFileSink { arroyo_state::global_table_config("f", "file_sink") } - async fn process_batch(&mut self, batch: RecordBatch, _ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let values = self.serializer.serialize(&batch); let file = self.file.as_mut().unwrap(); for value in values { @@ -40,7 +45,7 @@ impl ArrowOperator for SingleFileSink { } } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let file_path = Path::new(&self.output_path); let parent = file_path.parent().unwrap(); fs::create_dir_all(&parent).await.unwrap(); @@ -72,13 +77,23 @@ impl ArrowOperator for SingleFileSink { self.file = Some(file); } - async fn on_close(&mut self, final_message: &Option, _ctx: &mut ArrowContext) { + async fn on_close( + &mut self, + final_message: &Option, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { if let Some(SignalMessage::EndOfData) = final_message { self.file.as_mut().unwrap().flush().await.unwrap(); } } - async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _b: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.file.as_mut().unwrap().flush().await.unwrap(); let state = ctx.table_manager.get_global_keyed_state("f").await.unwrap(); state diff --git a/crates/arroyo-connectors/src/single_file/source.rs b/crates/arroyo-connectors/src/single_file/source.rs index 3ea8063a5..800d21fd9 100644 --- a/crates/arroyo-connectors/src/single_file/source.rs +++ b/crates/arroyo-connectors/src/single_file/source.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, time::SystemTime}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::{ @@ -26,73 +26,22 @@ pub struct SingleFileSourceFunc { } impl SingleFileSourceFunc { - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - if ctx.task_info.task_index != 0 { - return SourceFinishType::Final; - } - ctx.initialize_deserializer( - self.format.clone(), - self.framing.clone(), - self.bad_data.clone(), - ); - - let state: &mut arroyo_state::tables::global_keyed_map::GlobalKeyedView = - ctx.table_manager.get_global_keyed_state("f").await.unwrap(); - - self.lines_read = state.get(&self.input_file).copied().unwrap_or_default(); - - let file = File::open(&self.input_file).await.expect(&self.input_file); - let mut lines = BufReader::new(file).lines(); - - let mut i = 0; - - while let Some(s) = lines.next_line().await.unwrap() { - if i < self.lines_read { - i += 1; - continue; - } - ctx.deserialize_slice(s.as_bytes(), SystemTime::now(), None) - .await - .unwrap(); - if ctx.should_flush() { - ctx.flush_buffer().await.unwrap(); - } - - self.lines_read += 1; - i += 1; - - // wait for a control message after each line - let return_type = if self.wait_for_control { - self.handle_control(ctx.control_rx.recv().await, ctx).await - } else { - self.handle_control(ctx.control_rx.try_recv().ok(), ctx) - .await - }; - - if let Some(value) = return_type { - return value; - } - } - ctx.flush_buffer().await.unwrap(); - info!("file source finished"); - SourceFinishType::Final - } - async fn handle_control( &mut self, msg: Option, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, ) -> Option { match msg { Some(ControlMessage::Checkpoint(c)) => { - ctx.flush_buffer().await.unwrap(); + collector.flush_buffer().await.unwrap(); let state: &mut arroyo_state::tables::global_keyed_map::GlobalKeyedView< String, usize, > = ctx.table_manager.get_global_keyed_state("f").await.unwrap(); state.insert(self.input_file.clone(), self.lines_read).await; // checkpoint our state - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Some(SourceFinishType::Immediate); } } @@ -127,7 +76,7 @@ impl SourceOperator for SingleFileSourceFunc { arroyo_state::global_table_config("f", "file_source") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut SourceContext) { let s: &mut arroyo_state::tables::global_keyed_map::GlobalKeyedView = ctx .table_manager .get_global_keyed_state("f") @@ -138,7 +87,63 @@ impl SourceOperator for SingleFileSourceFunc { self.lines_read = *state; } } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - self.run(ctx).await + + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + if ctx.task_info.task_index != 0 { + return SourceFinishType::Final; + } + collector.initialize_deserializer( + self.format.clone(), + self.framing.clone(), + self.bad_data.clone(), + ); + + let state: &mut arroyo_state::tables::global_keyed_map::GlobalKeyedView = + ctx.table_manager.get_global_keyed_state("f").await.unwrap(); + + self.lines_read = state.get(&self.input_file).copied().unwrap_or_default(); + + let file = File::open(&self.input_file).await.expect(&self.input_file); + let mut lines = BufReader::new(file).lines(); + + let mut i = 0; + + while let Some(s) = lines.next_line().await.unwrap() { + if i < self.lines_read { + i += 1; + continue; + } + collector + .deserialize_slice(s.as_bytes(), SystemTime::now(), None) + .await + .unwrap(); + if collector.should_flush() { + collector.flush_buffer().await.unwrap(); + } + + self.lines_read += 1; + i += 1; + + // wait for a control message after each line + let return_type = if self.wait_for_control { + self.handle_control(ctx.control_rx.recv().await, ctx, collector) + .await + } else { + self.handle_control(ctx.control_rx.try_recv().ok(), ctx, collector) + .await + }; + + if let Some(value) = return_type { + return value; + } + } + + collector.flush_buffer().await.unwrap(); + info!("file source finished"); + SourceFinishType::Final } } diff --git a/crates/arroyo-connectors/src/sse/mod.rs b/crates/arroyo-connectors/src/sse/mod.rs index 0e64b5583..ed828e9f3 100644 --- a/crates/arroyo-connectors/src/sse/mod.rs +++ b/crates/arroyo-connectors/src/sse/mod.rs @@ -12,7 +12,7 @@ use tokio::sync::mpsc::Sender; use typify::import_types; use arroyo_operator::connector::Connection; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; @@ -150,7 +150,7 @@ impl Connector for SSEConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { SSESourceFunc::new_operator(table, config) } } diff --git a/crates/arroyo-connectors/src/sse/operator.rs b/crates/arroyo-connectors/src/sse/operator.rs index b78654d8d..df00ee4f9 100644 --- a/crates/arroyo-connectors/src/sse/operator.rs +++ b/crates/arroyo-connectors/src/sse/operator.rs @@ -1,12 +1,12 @@ use crate::sse::SseTable; -use arroyo_operator::context::ArrowContext; -use arroyo_operator::operator::{OperatorNode, SourceOperator}; +use arroyo_operator::context::{SourceCollector, SourceContext}; +use arroyo_operator::operator::{ConstructedOperator, SourceOperator}; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::{BadData, Format, Framing}; use arroyo_rpc::grpc::rpc::{StopMode, TableConfig}; -use arroyo_rpc::{ControlMessage, ControlResp, OperatorConfig}; +use arroyo_rpc::{ControlMessage, OperatorConfig}; use arroyo_state::tables::global_keyed_map::GlobalKeyedView; -use arroyo_types::{string_to_map, ArrowMessage, SignalMessage, UserError, Watermark}; +use arroyo_types::{string_to_map, SignalMessage, UserError, Watermark}; use async_trait::async_trait; use bincode::{Decode, Encode}; use eventsource_client::{Client, Error, SSE}; @@ -33,13 +33,16 @@ pub struct SSESourceFunc { } impl SSESourceFunc { - pub fn new_operator(table: SseTable, config: OperatorConfig) -> anyhow::Result { + pub fn new_operator( + table: SseTable, + config: OperatorConfig, + ) -> anyhow::Result { let headers = table .headers .as_ref() .map(|s| s.sub_env_vars().expect("Failed to substitute env vars")); - Ok(OperatorNode::from_source(Box::new(SSESourceFunc { + Ok(ConstructedOperator::from_source(Box::new(SSESourceFunc { url: table.endpoint, headers: string_to_map(&headers.unwrap_or("".to_string()), ':') .expect("Invalid header map") @@ -67,7 +70,11 @@ impl SourceOperator for SSESourceFunc { arroyo_state::global_table_config("e", "sse source state") } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { let s: &mut GlobalKeyedView<(), SSESourceState> = ctx .table_manager .get_global_keyed_state("e") @@ -78,7 +85,7 @@ impl SourceOperator for SSESourceFunc { self.state = state.clone(); } - match self.run_int(ctx).await { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { ctx.report_error(e.name.clone(), e.details.clone()).await; @@ -92,7 +99,8 @@ impl SourceOperator for SSESourceFunc { impl SSESourceFunc { async fn our_handle_control_message( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, msg: Option, ) -> Option { match msg? { @@ -105,7 +113,7 @@ impl SSESourceFunc { .expect("should be able to get SSE state"); s.insert((), self.state.clone()).await; - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Some(SourceFinishType::Immediate); } } @@ -132,8 +140,12 @@ impl SSESourceFunc { None } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { - ctx.initialize_deserializer( + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), @@ -171,11 +183,11 @@ impl SSESourceFunc { } if events.is_empty() || events.contains(&event.event_type) { - ctx.deserialize_slice( + collector.deserialize_slice( event.data.as_bytes(), SystemTime::now(), None).await?; - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } } @@ -196,13 +208,10 @@ impl SSESourceFunc { last_eof = Instant::now(); } Some(Err(e)) => { - ctx.control_tx.send( - ControlResp::Error { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, - message: "Error while reading from EventSource".to_string(), - details: format!("{:?}", e)} - ).await.unwrap(); + ctx.report_user_error(UserError::new( + "Error while reading from EventSource", + format!("{:?}", e) + )).await; panic!("Error while reading from EventSource: {:?}", e); } None => { @@ -212,27 +221,26 @@ impl SSESourceFunc { } } control_message = ctx.control_rx.recv() => { - if let Some(r) = self.our_handle_control_message(ctx, control_message).await { + if let Some(r) = self.our_handle_control_message(ctx, collector, control_message).await { return Ok(r); } } _ = flush_ticker.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } } } } else { // otherwise set idle and just process control messages - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; loop { let msg = ctx.control_rx.recv().await; - if let Some(r) = self.our_handle_control_message(ctx, msg).await { + if let Some(r) = self.our_handle_control_message(ctx, collector, msg).await { return Ok(r); } } diff --git a/crates/arroyo-connectors/src/stdout/mod.rs b/crates/arroyo-connectors/src/stdout/mod.rs index 7fbe84f2d..f0e855333 100644 --- a/crates/arroyo-connectors/src/stdout/mod.rs +++ b/crates/arroyo-connectors/src/stdout/mod.rs @@ -17,7 +17,7 @@ use crate::EmptyConfig; use crate::stdout::operator::StdoutSink; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::formats::{Format, JsonFormat}; pub struct StdoutConnector {} @@ -126,11 +126,11 @@ impl Connector for StdoutConnector { _: Self::ProfileT, _: Self::TableT, c: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { let format = c .format .unwrap_or_else(|| Format::Json(JsonFormat::default())); - Ok(OperatorNode::from_operator(Box::new(StdoutSink { + Ok(ConstructedOperator::from_operator(Box::new(StdoutSink { stdout: BufWriter::new(tokio::io::stdout()), serializer: ArrowSerializer::new(format), }))) diff --git a/crates/arroyo-connectors/src/stdout/operator.rs b/crates/arroyo-connectors/src/stdout/operator.rs index 422ea72af..cf1632fc5 100644 --- a/crates/arroyo-connectors/src/stdout/operator.rs +++ b/crates/arroyo-connectors/src/stdout/operator.rs @@ -2,7 +2,7 @@ use arrow::array::RecordBatch; use arroyo_formats::ser::ArrowSerializer; use tokio::io::{AsyncWriteExt, BufWriter, Stdout}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_types::SignalMessage; @@ -17,14 +17,25 @@ impl ArrowOperator for StdoutSink { "Stdout".to_string() } - async fn process_batch(&mut self, batch: RecordBatch, _: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { for value in self.serializer.serialize(&batch) { self.stdout.write_all(&value).await.unwrap(); self.stdout.write_u8(b'\n').await.unwrap(); } self.stdout.flush().await.unwrap(); } - async fn on_close(&mut self, _: &Option, _: &mut ArrowContext) { + + async fn on_close( + &mut self, + _: &Option, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.stdout.flush().await.unwrap(); } } diff --git a/crates/arroyo-connectors/src/webhook/mod.rs b/crates/arroyo-connectors/src/webhook/mod.rs index b8224e392..ec6e1ff97 100644 --- a/crates/arroyo-connectors/src/webhook/mod.rs +++ b/crates/arroyo-connectors/src/webhook/mod.rs @@ -24,7 +24,7 @@ use crate::{construct_http_client, pull_opt, EmptyConfig}; use crate::webhook::operator::WebhookSinkFunc; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; const TABLE_SCHEMA: &str = include_str!("./table.json"); @@ -210,25 +210,27 @@ impl Connector for WebhookConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { let url = table.endpoint.sub_env_vars()?; - Ok(OperatorNode::from_operator(Box::new(WebhookSinkFunc { - url: Arc::new(url.clone()), - client: construct_http_client( - &url, - table - .headers - .as_ref() - .map(|s| s.sub_env_vars()) - .transpose()?, - )?, - semaphore: Arc::new(Semaphore::new(MAX_INFLIGHT as usize)), - serializer: ArrowSerializer::new( - config - .format - .expect("No format configured for webhook sink"), - ), - last_reported_error_at: Arc::new(Mutex::new(SystemTime::UNIX_EPOCH)), - }))) + Ok(ConstructedOperator::from_operator(Box::new( + WebhookSinkFunc { + url: Arc::new(url.clone()), + client: construct_http_client( + &url, + table + .headers + .as_ref() + .map(|s| s.sub_env_vars()) + .transpose()?, + )?, + semaphore: Arc::new(Semaphore::new(MAX_INFLIGHT as usize)), + serializer: ArrowSerializer::new( + config + .format + .expect("No format configured for webhook sink"), + ), + last_reported_error_at: Arc::new(Mutex::new(SystemTime::UNIX_EPOCH)), + }, + ))) } } diff --git a/crates/arroyo-connectors/src/webhook/operator.rs b/crates/arroyo-connectors/src/webhook/operator.rs index f7a78fa6a..b512a0295 100644 --- a/crates/arroyo-connectors/src/webhook/operator.rs +++ b/crates/arroyo-connectors/src/webhook/operator.rs @@ -13,7 +13,7 @@ use tracing::warn; use crate::webhook::MAX_INFLIGHT; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::ControlResp; @@ -37,7 +37,12 @@ impl ArrowOperator for WebhookSinkFunc { global_table_config("s", "webhook sink state") } - async fn process_batch(&mut self, record: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + record: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { for body in self.serializer.serialize(&record) { let permit = self .semaphore @@ -55,7 +60,8 @@ impl ArrowOperator for WebhookSinkFunc { // these are just used for (potential) error reporting and we don't need to clone them let operator_id = ctx.task_info.operator_id.clone(); - let task_index = ctx.task_info.task_index; + let task_index = ctx.task_info.task_index as usize; + let node_id = ctx.task_info.node_id; tokio::task::spawn(async move { // move the permit into the task @@ -88,6 +94,7 @@ impl ArrowOperator for WebhookSinkFunc { control_tx .send(ControlResp::Error { + node_id, operator_id: operator_id.clone(), task_index, message: format!("webhook failed (retry {})", retries), @@ -113,7 +120,12 @@ impl ArrowOperator for WebhookSinkFunc { } } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, _ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { // wait to acquire all of the permits (effectively blocking until all inflight requests are done) let _permits = self.semaphore.acquire_many(MAX_INFLIGHT).await.unwrap(); diff --git a/crates/arroyo-connectors/src/websocket/mod.rs b/crates/arroyo-connectors/src/websocket/mod.rs index 9bf8cc967..2271ca517 100644 --- a/crates/arroyo-connectors/src/websocket/mod.rs +++ b/crates/arroyo-connectors/src/websocket/mod.rs @@ -23,7 +23,7 @@ use crate::{header_map, pull_opt, EmptyConfig}; use crate::websocket::operator::{WebsocketSourceFunc, WebsocketSourceState}; use arroyo_operator::connector::Connector; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::operator::ConstructedOperator; mod operator; @@ -313,7 +313,7 @@ impl Connector for WebsocketConnector { _: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result { + ) -> anyhow::Result { // Include subscription_message for backwards compatibility let mut subscription_messages = vec![]; if let Some(message) = table.subscription_message { @@ -331,16 +331,18 @@ impl Connector for WebsocketConnector { .map(|(k, v)| ((&k).into(), (&v).into())) .collect(); - Ok(OperatorNode::from_source(Box::new(WebsocketSourceFunc { - url: table.endpoint, - headers, - subscription_messages, - format: config - .format - .ok_or_else(|| anyhow!("format required for websocket source"))?, - framing: config.framing, - bad_data: config.bad_data, - state: WebsocketSourceState::default(), - }))) + Ok(ConstructedOperator::from_source(Box::new( + WebsocketSourceFunc { + url: table.endpoint, + headers, + subscription_messages, + format: config + .format + .ok_or_else(|| anyhow!("format required for websocket source"))?, + framing: config.framing, + bad_data: config.bad_data, + state: WebsocketSourceState::default(), + }, + ))) } } diff --git a/crates/arroyo-connectors/src/websocket/operator.rs b/crates/arroyo-connectors/src/websocket/operator.rs index c9c41b073..c009dd548 100644 --- a/crates/arroyo-connectors/src/websocket/operator.rs +++ b/crates/arroyo-connectors/src/websocket/operator.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::time::SystemTime; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{SourceCollector, SourceContext}; use arroyo_operator::operator::SourceOperator; use arroyo_operator::SourceFinishType; use arroyo_rpc::formats::{BadData, Format, Framing}; @@ -11,7 +11,7 @@ use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::{grpc::rpc::StopMode, ControlMessage}; use arroyo_state::global_table_config; use arroyo_state::tables::global_keyed_map::GlobalKeyedView; -use arroyo_types::{ArrowMessage, SignalMessage, UserError, Watermark}; +use arroyo_types::{SignalMessage, UserError, Watermark}; use bincode::{Decode, Encode}; use futures::{SinkExt, StreamExt}; use tokio::select; @@ -44,7 +44,7 @@ impl SourceOperator for WebsocketSourceFunc { global_table_config("e", "websocket source state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut SourceContext) { let s: &mut GlobalKeyedView<(), WebsocketSourceState> = ctx .table_manager .get_global_keyed_state("e") @@ -54,16 +54,20 @@ impl SourceOperator for WebsocketSourceFunc { if let Some(state) = s.get(&()) { self.state = state.clone(); } + } - ctx.initialize_deserializer( + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { + collector.initialize_deserializer( self.format.clone(), self.framing.clone(), self.bad_data.clone(), ); - } - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType { - match self.run_int(ctx).await { + match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { ctx.report_error(e.name.clone(), e.details.clone()).await; @@ -77,7 +81,8 @@ impl SourceOperator for WebsocketSourceFunc { impl WebsocketSourceFunc { async fn our_handle_control_message( &mut self, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, msg: Option, ) -> Option { match msg? { @@ -90,7 +95,7 @@ impl WebsocketSourceFunc { .expect("couldn't get state for websocket"); s.insert((), self.state.clone()).await; - if self.start_checkpoint(c, ctx).await { + if self.start_checkpoint(c, ctx, collector).await { return Some(SourceFinishType::Immediate); } } @@ -120,18 +125,24 @@ impl WebsocketSourceFunc { async fn handle_message( &mut self, msg: &[u8], - ctx: &mut ArrowContext, + collector: &mut SourceCollector, ) -> Result<(), UserError> { - ctx.deserialize_slice(msg, SystemTime::now(), None).await?; + collector + .deserialize_slice(msg, SystemTime::now(), None) + .await?; - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } Ok(()) } - async fn run_int(&mut self, ctx: &mut ArrowContext) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { let uri = match Uri::from_str(&self.url.to_string()) { Ok(uri) => uri, Err(e) => { @@ -212,10 +223,10 @@ impl WebsocketSourceFunc { Some(Ok(msg)) => { match msg { tungstenite::Message::Text(t) => { - self.handle_message(t.as_bytes(), ctx).await? + self.handle_message(t.as_bytes(), collector).await? }, tungstenite::Message::Binary(bs) => { - self.handle_message(&bs, ctx).await? + self.handle_message(&bs, collector).await? }, tungstenite::Message::Ping(d) => { tx.send(tungstenite::Message::Pong(d)).await @@ -245,12 +256,12 @@ impl WebsocketSourceFunc { } } _ = flush_ticker.tick() => { - if ctx.should_flush() { - ctx.flush_buffer().await?; + if collector.should_flush() { + collector.flush_buffer().await?; } } control_message = ctx.control_rx.recv() => { - if let Some(r) = self.our_handle_control_message(ctx, control_message).await { + if let Some(r) = self.our_handle_control_message(ctx, collector, control_message).await { return Ok(r); } } @@ -258,13 +269,12 @@ impl WebsocketSourceFunc { } } else { // otherwise set idle and just process control messages - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; loop { let msg = ctx.control_rx.recv().await; - if let Some(r) = self.our_handle_control_message(ctx, msg).await { + if let Some(r) = self.our_handle_control_message(ctx, collector, msg).await { return Ok(r); } } diff --git a/crates/arroyo-controller/Cargo.toml b/crates/arroyo-controller/Cargo.toml index 45fe8c1d8..2a0978f1b 100644 --- a/crates/arroyo-controller/Cargo.toml +++ b/crates/arroyo-controller/Cargo.toml @@ -66,6 +66,7 @@ uuid = "1.3.3" async-stream = "0.3.5" base64 = "0.22" rusqlite = { version = "0.31.0", features = ["serde_json", "time"] } +log = "0.4.22" [build-dependencies] cornucopia = { workspace = true } diff --git a/crates/arroyo-controller/src/job_controller/checkpointer.rs b/crates/arroyo-controller/src/job_controller/checkpointer.rs deleted file mode 100644 index 9fba2501b..000000000 --- a/crates/arroyo-controller/src/job_controller/checkpointer.rs +++ /dev/null @@ -1,16 +0,0 @@ -use arroyo_state::checkpoint_state::CheckpointState; -use arroyo_state::committing_state::CommittingState; - -pub enum CheckpointingOrCommittingState { - Checkpointing(CheckpointState), - Committing(CommittingState), -} - -impl CheckpointingOrCommittingState { - pub(crate) fn done(&self) -> bool { - match self { - CheckpointingOrCommittingState::Checkpointing(checkpointing) => checkpointing.done(), - CheckpointingOrCommittingState::Committing(committing) => committing.done(), - } - } -} diff --git a/crates/arroyo-controller/src/job_controller/job_metrics.rs b/crates/arroyo-controller/src/job_controller/job_metrics.rs index fc7a1cece..5db369a41 100644 --- a/crates/arroyo-controller/src/job_controller/job_metrics.rs +++ b/crates/arroyo-controller/src/job_controller/job_metrics.rs @@ -29,7 +29,7 @@ pub fn get_metric_name(name: &str) -> Option { #[derive(Copy, Clone, Eq, PartialEq, Hash)] pub struct TaskKey { - pub operator_id: u32, + pub node_id: u32, pub subtask_idx: u32, } @@ -46,7 +46,7 @@ impl JobMetrics { for i in 0..program.graph[op].parallelism { tasks.insert( TaskKey { - operator_id: op.index() as u32, + node_id: op.index() as u32, subtask_idx: i as u32, }, TaskMetrics::new(), @@ -60,23 +60,18 @@ impl JobMetrics { } } - pub async fn update( - &self, - operator_id: u32, - subtask_idx: u32, - values: &HashMap, - ) { + pub async fn update(&self, node_id: u32, subtask_idx: u32, values: &HashMap) { let now = SystemTime::now(); let key = TaskKey { - operator_id, + node_id, subtask_idx, }; let mut tasks = self.tasks.write().await; let Some(task) = tasks.get_mut(&key) else { warn!( "Task not found for operator_id: {}, subtask_idx: {}", - operator_id, subtask_idx + node_id, subtask_idx ); return; }; @@ -106,7 +101,7 @@ impl JobMetrics { HashMap::new(); for (k, v) in self.tasks.read().await.iter() { - let op = metric_groups.entry(k.operator_id).or_default(); + let op = metric_groups.entry(k.node_id).or_default(); for (metric, rate) in &v.rates { op.entry(*metric).or_default().push(SubtaskMetrics { @@ -139,15 +134,14 @@ impl JobMetrics { metric_groups .into_iter() .map(|(op_id, metrics)| { - let operator_id = self + let node_id = self .program .graph .node_weight(NodeIndex::new(op_id as usize)) .unwrap() - .operator_id - .clone(); + .node_id; OperatorMetricGroup { - operator_id, + node_id, metric_groups: metrics .into_iter() .map(|(name, subtasks)| MetricGroup { diff --git a/crates/arroyo-controller/src/job_controller/mod.rs b/crates/arroyo-controller/src/job_controller/mod.rs index 110145447..716b2c74e 100644 --- a/crates/arroyo-controller/src/job_controller/mod.rs +++ b/crates/arroyo-controller/src/job_controller/mod.rs @@ -33,15 +33,26 @@ use tokio::{sync::mpsc::Receiver, task::JoinHandle}; use tonic::{transport::Channel, Request}; use tracing::{debug, error, info, warn}; -use self::checkpointer::CheckpointingOrCommittingState; - -mod checkpointer; pub mod job_metrics; const CHECKPOINTS_TO_KEEP: u32 = 4; const CHECKPOINT_ROWS_TO_KEEP: u32 = 100; const COMPACT_EVERY: u32 = 2; +pub enum CheckpointingOrCommittingState { + Checkpointing(CheckpointState), + Committing(CommittingState), +} + +impl CheckpointingOrCommittingState { + pub(crate) fn done(&self) -> bool { + match self { + CheckpointingOrCommittingState::Checkpointing(checkpointing) => checkpointing.done(), + CheckpointingOrCommittingState::Committing(committing) => committing.done(), + } + } +} + #[derive(Debug, PartialEq, Eq)] pub enum WorkerState { Running, @@ -90,8 +101,8 @@ pub struct RunningJobModel { min_epoch: u32, last_checkpoint: Instant, workers: HashMap, - tasks: HashMap<(String, u32), TaskStatus>, - operator_parallelism: HashMap, + tasks: HashMap<(u32, u32), TaskStatus>, + operator_parallelism: HashMap, metrics: JobMetrics, metric_update_task: Option>, last_updated_metrics: Instant, @@ -233,28 +244,28 @@ impl RunningJobModel { RunningMessage::TaskFinished { worker_id: _, time: _, - operator_id, + node_id, subtask_index, } => { - let key = (operator_id, subtask_index); + let key = (node_id, subtask_index); if let Some(status) = self.tasks.get_mut(&key) { status.state = TaskState::Finished; } else { warn!( message = "Received task finished for unknown task", job_id = *self.job_id, - operator_id = key.0, + node_id = key.0, subtask_index ); } } RunningMessage::TaskFailed { - operator_id, + node_id, subtask_index, reason, .. } => { - let key = (operator_id, subtask_index); + let key = (node_id, subtask_index); if let Some(status) = self.tasks.get_mut(&key) { status.state = TaskState::Failed(reason); } else { @@ -382,27 +393,30 @@ impl RunningJobModel { let mut worker_clients: Vec> = self.workers.values().map(|w| w.connect.clone()).collect(); - for operator_id in self.operator_parallelism.keys() { - let compacted_tables = ParquetBackend::compact_operator( - // compact the operator's state and notify the workers to load the new files - self.job_id.clone(), - operator_id.clone(), - self.epoch, - ) - .await?; + for node in self.program.graph.node_weights() { + for (op, _) in node.operator_chain.iter() { + let compacted_tables = ParquetBackend::compact_operator( + // compact the operator's state and notify the workers to load the new files + self.job_id.clone(), + &op.operator_id, + self.epoch, + ) + .await?; - if compacted_tables.is_empty() { - continue; - } + if compacted_tables.is_empty() { + continue; + } - // TODO: these should be put on separate tokio tasks. - for worker_client in &mut worker_clients { - worker_client - .load_compacted_data(LoadCompactedDataReq { - operator_id: operator_id.clone(), - compacted_metadata: compacted_tables.clone(), - }) - .await?; + // TODO: these should be put on separate tokio tasks. + for worker_client in &mut worker_clients { + worker_client + .load_compacted_data(LoadCompactedDataReq { + node_id: node.node_id, + operator_id: op.operator_id.clone(), + compacted_metadata: compacted_tables.clone(), + }) + .await?; + } } } @@ -450,6 +464,7 @@ impl RunningJobModel { DbCheckpointState::committing, ) .await?; + let committing_data = committing_state.committing_data(); self.checkpoint_state = Some(CheckpointingOrCommittingState::Committing(committing_state)); @@ -526,7 +541,7 @@ impl RunningJobModel { let source_tasks = self.program.sources(); self.tasks.iter().any(|((operator, _), t)| { - source_tasks.contains(operator.as_str()) && t.state == TaskState::Finished + source_tasks.contains(operator) && t.state == TaskState::Finished }) } @@ -605,7 +620,7 @@ impl JobController { .flat_map(|node| { (0..node.parallelism).map(|idx| { ( - (node.operator_id.clone(), idx as u32), + (node.node_id, idx as u32), TaskStatus { state: TaskState::Running, }, @@ -613,7 +628,7 @@ impl JobController { }) }) .collect(), - operator_parallelism: program.tasks_per_operator(), + operator_parallelism: program.tasks_per_node(), metrics, metric_update_task: None, last_updated_metrics: Instant::now(), @@ -653,6 +668,13 @@ impl JobController { .map(|(id, w)| (*id, w.connect.clone())) .collect(); let program = self.model.program.clone(); + let operator_indices: Arc> = Arc::new( + program + .graph + .node_indices() + .map(|idx| (program.graph[idx].node_id, idx.index() as u32)) + .collect(), + ); self.model.metric_update_task = Some(tokio::spawn(async move { let mut metrics: HashMap<(u32, u32), HashMap> = HashMap::new(); @@ -679,12 +701,12 @@ impl JobController { .into_iter() .filter_map(|f| Some((get_metric_name(&f.name?)?, f.metric))) .flat_map(|(metric, values)| { - let program = program.clone(); + let operator_indices = operator_indices.clone(); values.into_iter().filter_map(move |m| { let subtask_idx = u32::from_str(find_label(&m.label, "subtask_idx")?).ok()?; - let operator_idx = - program.operator_index(find_label(&m.label, "operator_id")?)?; + let operator_idx = *operator_indices + .get(&u32::from_str(find_label(&m.label, "node_id")?).ok()?)?; let value = m .counter .map(|c| c.value) @@ -861,8 +883,8 @@ impl JobController { } } - pub fn operator_parallelism(&self, op: &str) -> Option { - self.model.operator_parallelism.get(op).cloned() + pub fn operator_parallelism(&self, node_id: u32) -> Option { + self.model.operator_parallelism.get(&node_id).cloned() } fn start_cleanup(&mut self, new_min: u32) -> JoinHandle> { diff --git a/crates/arroyo-controller/src/lib.rs b/crates/arroyo-controller/src/lib.rs index b2cf40929..7754621c0 100644 --- a/crates/arroyo-controller/src/lib.rs +++ b/crates/arroyo-controller/src/lib.rs @@ -28,6 +28,7 @@ use states::{Created, State, StateMachine}; use std::collections::{HashMap, HashSet}; use std::env; use std::net::SocketAddr; +use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; use time::OffsetDateTime; @@ -83,7 +84,7 @@ pub struct JobConfig { stop_mode: StopMode, checkpoint_interval: Duration, ttl: Option, - parallelism_overrides: HashMap, + parallelism_overrides: HashMap, restart_nonce: i32, restart_mode: RestartMode, } @@ -146,12 +147,12 @@ pub enum RunningMessage { TaskFinished { worker_id: WorkerId, time: SystemTime, - operator_id: String, + node_id: u32, subtask_index: u32, }, TaskFailed { worker_id: WorkerId, - operator_id: String, + node_id: u32, subtask_index: u32, reason: String, }, @@ -176,7 +177,7 @@ pub enum JobMessage { }, TaskStarted { worker_id: WorkerId, - operator_id: String, + node_id: u32, operator_subtask: u64, }, RunningMessage(RunningMessage), @@ -249,7 +250,7 @@ impl ControllerGrpc for ControllerServer { &req.job_id, JobMessage::TaskStarted { worker_id: WorkerId(req.worker_id), - operator_id: req.operator_id, + node_id: req.node_id, operator_subtask: req.operator_subtask, }, ) @@ -304,7 +305,7 @@ impl ControllerGrpc for ControllerServer { JobMessage::RunningMessage(RunningMessage::TaskFinished { worker_id: WorkerId(req.worker_id), time: from_micros(req.time), - operator_id: req.operator_id, + node_id: req.node_id, subtask_index: req.operator_subtask as u32, }), ) @@ -323,7 +324,7 @@ impl ControllerGrpc for ControllerServer { &req.job_id, JobMessage::RunningMessage(RunningMessage::TaskFailed { worker_id: WorkerId(req.worker_id), - operator_id: req.operator_id, + node_id: req.node_id, subtask_index: req.operator_subtask as u32, reason: req.error, }), @@ -569,7 +570,9 @@ impl ControllerServer { .as_object() .unwrap() .into_iter() - .map(|(k, v)| (k.clone(), v.as_u64().unwrap() as usize)) + .filter_map(|(k, v)| { + Some((u32::from_str(k).ok()?, v.as_u64()? as usize)) + }) .collect(), restart_nonce: p.config_restart_nonce, restart_mode: p.restart_mode, diff --git a/crates/arroyo-controller/src/states/running.rs b/crates/arroyo-controller/src/states/running.rs index 2de0e1d8a..870ef5bcc 100644 --- a/crates/arroyo-controller/src/states/running.rs +++ b/crates/arroyo-controller/src/states/running.rs @@ -61,8 +61,8 @@ impl State for Running { let job_controller = ctx.job_controller.as_mut().unwrap(); - for (op, p) in &c.parallelism_overrides { - if let Some(actual) = job_controller.operator_parallelism(op){ + for (node_id, p) in &c.parallelism_overrides { + if let Some(actual) = job_controller.operator_parallelism(*node_id){ if actual != *p { return Ok(Transition::next( *self, diff --git a/crates/arroyo-controller/src/states/scheduling.rs b/crates/arroyo-controller/src/states/scheduling.rs index 5b7727a8b..7bd62bdb1 100644 --- a/crates/arroyo-controller/src/states/scheduling.rs +++ b/crates/arroyo-controller/src/states/scheduling.rs @@ -64,8 +64,8 @@ fn compute_assignments( for i in 0..node.parallelism { assignments.push(TaskAssignment { - operator_id: node.operator_id.clone(), - operator_subtask: i as u64, + node_id: node.node_id, + subtask_idx: i as u32, worker_id: workers[worker_idx].id.0, worker_addr: workers[worker_idx].data_address.clone(), }); @@ -363,6 +363,7 @@ impl State for Scheduling { metadata.min_epoch = min_epoch; if needs_commits { let mut commit_subtasks = HashSet::new(); + // (operator_id => (table_name => (subtask => data))) let mut committing_data: HashMap>>> = HashMap::new(); for operator_id in &metadata.operator_ids { @@ -423,7 +424,9 @@ impl State for Scheduling { .program .graph .node_weights() - .find(|node| node.operator_id == *operator_id) + .find(|node| { + node.operator_chain.first().operator_id == *operator_id + }) .unwrap(); for subtask_index in 0..program_node.parallelism { commit_subtasks.insert((operator_id.clone(), subtask_index as u32)); @@ -516,15 +519,15 @@ impl State for Scheduling { v = ctx.rx.recv() => { match v { Some(JobMessage::TaskStarted { - operator_id, + node_id, operator_subtask, .. }) => { - started_tasks.insert((operator_id, operator_subtask)); + started_tasks.insert((node_id, operator_subtask)); } - Some(JobMessage::RunningMessage(RunningMessage::TaskFailed {worker_id, operator_id, subtask_index, reason})) => { + Some(JobMessage::RunningMessage(RunningMessage::TaskFailed {worker_id, node_id, subtask_index, reason})) => { return Err(ctx.retryable(self, "task failed on startup", - anyhow!("task failed on job startup on {:?}: {}:{}: {}", worker_id, operator_id, subtask_index, reason), 10)); + anyhow!("task failed on job startup on {:?}: {}:{}: {}", worker_id, node_id, subtask_index, reason), 10)); } Some(JobMessage::ConfigUpdate(c)) => { stop_if_desired_non_running!(self, &c); diff --git a/crates/arroyo-datastream/Cargo.toml b/crates/arroyo-datastream/Cargo.toml index 1ec49eab8..501e3db57 100644 --- a/crates/arroyo-datastream/Cargo.toml +++ b/crates/arroyo-datastream/Cargo.toml @@ -30,3 +30,4 @@ regex = "1.9.5" serde_json = "1.0.108" strum = { version = "0.25.0", features = ["derive"] } datafusion-proto = { workspace = true } +itertools = "0.13.0" diff --git a/crates/arroyo-datastream/src/lib.rs b/crates/arroyo-datastream/src/lib.rs index 238270571..0d2dec2da 100644 --- a/crates/arroyo-datastream/src/lib.rs +++ b/crates/arroyo-datastream/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::comparison_chain)] pub mod logical; +pub mod optimizers; use arroyo_rpc::config::{config, DefaultSink}; use arroyo_rpc::grpc::api; diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index 6580748c6..f4689c995 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -1,5 +1,7 @@ use datafusion_proto::protobuf::ArrowType; +use itertools::Itertools; +use crate::optimizers::Optimizer; use anyhow::anyhow; use arrow_schema::DataType; use arroyo_rpc::api_types::pipelines::{PipelineEdge, PipelineGraph, PipelineNode}; @@ -18,6 +20,7 @@ use std::collections::hash_map::DefaultHasher; use std::collections::{HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hasher; +use std::str::FromStr; use std::sync::Arc; use strum::{Display, EnumString}; @@ -87,14 +90,17 @@ impl TryFrom for PipelineGraph { .graph .node_weights() .map(|node| Ok(PipelineNode { - node_id: node.operator_id.to_string(), - operator: match node.operator_name { - OperatorName::ConnectorSource | OperatorName::ConnectorSink => { - ConnectorOp::decode(&node.operator_config[..]) - .map_err(|_| anyhow!("invalid graph: could not decode connector configuration for {}", node.operator_id))? + node_id: node.node_id, + operator: match node.operator_chain.operators.first() { + Some(ChainedLogicalOperator { operator_name: OperatorName::ConnectorSource | OperatorName::ConnectorSink, operator_config, .. }) => { + ConnectorOp::decode(&operator_config[..]) + .map_err(|_| anyhow!("invalid graph: could not decode connector configuration for {}", node.node_id))? .connector } - op => op.to_string(), + Some(op) if node.operator_chain.operators.len() == 1 => { + op.operator_id.to_string() + } + _ => "chained_op".to_string(), }, description: node.description.clone(), parallelism: node.parallelism as u32, @@ -108,8 +114,8 @@ impl TryFrom for PipelineGraph { let src = value.graph.node_weight(edge.source()).unwrap(); let target = value.graph.node_weight(edge.target()).unwrap(); PipelineEdge { - src_id: src.operator_id.to_string(), - dest_id: target.operator_id.to_string(), + src_id: src.node_id, + dest_id: target.node_id, key_type: "()".to_string(), value_type: "()".to_string(), edge_type: format!("{:?}", edge.weight().edge_type), @@ -128,40 +134,111 @@ impl TryFrom for PipelineGraph { pub struct LogicalEdge { pub edge_type: LogicalEdgeType, pub schema: ArroyoSchema, - pub projection: Option>, } impl LogicalEdge { - pub fn new( - edge_type: LogicalEdgeType, - schema: ArroyoSchema, - projection: Option>, - ) -> Self { - LogicalEdge { - edge_type, - schema, - projection, - } + pub fn new(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { + LogicalEdge { edge_type, schema } } pub fn project_all(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { - LogicalEdge { - edge_type, - schema, - projection: None, + LogicalEdge { edge_type, schema } + } +} + +#[derive(Clone, Debug)] +pub struct ChainedLogicalOperator { + pub operator_id: String, + pub operator_name: OperatorName, + pub operator_config: Vec, +} + +#[derive(Clone, Debug)] +pub struct OperatorChain { + pub(crate) operators: Vec, + pub(crate) edges: Vec, +} + +impl OperatorChain { + pub fn new(operator: ChainedLogicalOperator) -> Self { + Self { + operators: vec![operator], + edges: vec![], } } + + pub fn iter(&self) -> impl Iterator)> { + self.operators + .iter() + .zip_longest(self.edges.iter()) + .map(|e| e.left_and_right()) + .map(|(l, r)| (l.unwrap(), r)) + } + + pub fn iter_mut( + &mut self, + ) -> impl Iterator)> { + self.operators + .iter_mut() + .zip_longest(self.edges.iter_mut()) + .map(|e| e.left_and_right()) + .map(|(l, r)| (l.unwrap(), r)) + } + + pub fn first(&self) -> &ChainedLogicalOperator { + &self.operators[0] + } + + pub fn len(&self) -> usize { + self.operators.len() + } + + pub fn is_empty(&self) -> bool { + self.operators.is_empty() + } + + pub fn is_source(&self) -> bool { + self.operators[0].operator_name == OperatorName::ConnectorSource + } + + pub fn is_sink(&self) -> bool { + self.operators[0].operator_name == OperatorName::ConnectorSink + } } #[derive(Clone)] pub struct LogicalNode { - pub operator_id: String, + pub node_id: u32, pub description: String, - pub operator_name: OperatorName, - pub operator_config: Vec, + pub operator_chain: OperatorChain, pub parallelism: usize, } +impl LogicalNode { + pub fn single( + id: u32, + operator_id: String, + name: OperatorName, + config: Vec, + description: String, + parallelism: usize, + ) -> Self { + Self { + node_id: id, + description, + operator_chain: OperatorChain { + operators: vec![ChainedLogicalOperator { + operator_id, + operator_name: name, + operator_config: config, + }], + edges: vec![], + }, + parallelism, + } + } +} + impl Display for LogicalNode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.description) @@ -170,7 +247,17 @@ impl Display for LogicalNode { impl Debug for LogicalNode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.operator_id) + write!( + f, + "{}[{}]", + self.operator_chain + .operators + .iter() + .map(|op| op.operator_id.clone()) + .collect::>() + .join(" -> "), + self.parallelism + ) } } @@ -203,26 +290,23 @@ pub struct ProgramConfig { pub struct LogicalProgram { pub graph: LogicalGraph, pub program_config: ProgramConfig, - pub operator_indices: HashMap, } impl LogicalProgram { pub fn new(graph: LogicalGraph, program_config: ProgramConfig) -> Self { - let operator_indices = graph - .node_indices() - .map(|idx| (graph[idx].operator_id.clone(), idx.index() as u32)) - .collect(); - Self { graph, program_config, - operator_indices, } } - pub fn update_parallelism(&mut self, overrides: &HashMap) { + pub fn optimize(&mut self, optimizer: &dyn Optimizer) { + optimizer.optimize(&mut self.graph); + } + + pub fn update_parallelism(&mut self, overrides: &HashMap) { for node in self.graph.node_weights_mut() { - if let Some(p) = overrides.get(&node.operator_id) { + if let Some(p) = overrides.get(&node.node_id) { node.parallelism = *p; } } @@ -237,11 +321,11 @@ impl LogicalProgram { self.graph.node_weights().map(|nw| nw.parallelism).sum() } - pub fn sources(&self) -> HashSet<&str> { + pub fn sources(&self) -> HashSet { // TODO: this can be memoized self.graph .externals(Direction::Incoming) - .map(|t| self.graph.node_weight(t).unwrap().operator_id.as_str()) + .map(|t| self.graph.node_weight(t).unwrap().node_id) .collect() } @@ -261,50 +345,62 @@ impl LogicalProgram { .collect() } - pub fn operator_index(&self, name: &str) -> Option { - self.operator_indices.get(name).cloned() - } - pub fn tasks_per_operator(&self) -> HashMap { let mut tasks_per_operator = HashMap::new(); for node in self.graph.node_weights() { - tasks_per_operator.insert(node.operator_id.clone(), node.parallelism); + for op in &node.operator_chain.operators { + tasks_per_operator.insert(op.operator_id.clone(), node.parallelism); + } } tasks_per_operator } + pub fn tasks_per_node(&self) -> HashMap { + let mut tasks_per_node = HashMap::new(); + for node in self.graph.node_weights() { + tasks_per_node.insert(node.node_id, node.parallelism); + } + tasks_per_node + } + pub fn features(&self) -> HashSet { let mut s = HashSet::new(); - for t in self.graph.node_weights() { - let feature = match &t.operator_name { - OperatorName::AsyncUdf => "async-udf".to_string(), - OperatorName::ExpressionWatermark - | OperatorName::ArrowValue - | OperatorName::ArrowKey => continue, - OperatorName::Join => "join-with-expiration".to_string(), - OperatorName::InstantJoin => "windowed-join".to_string(), - OperatorName::WindowFunction => "sql-window-function".to_string(), - OperatorName::TumblingWindowAggregate => { - "sql-tumbling-window-aggregate".to_string() - } - OperatorName::SlidingWindowAggregate => "sql-sliding-window-aggregate".to_string(), - OperatorName::SessionWindowAggregate => "sql-session-window-aggregate".to_string(), - OperatorName::UpdatingAggregate => "sql-updating-aggregate".to_string(), - OperatorName::ConnectorSource => { - let Ok(connector_op) = ConnectorOp::decode(&t.operator_config[..]) else { - continue; - }; - format!("{}-source", connector_op.connector) - } - OperatorName::ConnectorSink => { - let Ok(connector_op) = ConnectorOp::decode(&t.operator_config[..]) else { - continue; - }; - format!("{}-sink", connector_op.connector) - } - }; - s.insert(feature); + for n in self.graph.node_weights() { + for t in &n.operator_chain.operators { + let feature = match &t.operator_name { + OperatorName::AsyncUdf => "async-udf".to_string(), + OperatorName::ExpressionWatermark + | OperatorName::ArrowValue + | OperatorName::ArrowKey => continue, + OperatorName::Join => "join-with-expiration".to_string(), + OperatorName::InstantJoin => "windowed-join".to_string(), + OperatorName::WindowFunction => "sql-window-function".to_string(), + OperatorName::TumblingWindowAggregate => { + "sql-tumbling-window-aggregate".to_string() + } + OperatorName::SlidingWindowAggregate => { + "sql-sliding-window-aggregate".to_string() + } + OperatorName::SessionWindowAggregate => { + "sql-session-window-aggregate".to_string() + } + OperatorName::UpdatingAggregate => "sql-updating-aggregate".to_string(), + OperatorName::ConnectorSource => { + let Ok(connector_op) = ConnectorOp::decode(&t.operator_config[..]) else { + continue; + }; + format!("{}-source", connector_op.connector) + } + OperatorName::ConnectorSink => { + let Ok(connector_op) = ConnectorOp::decode(&t.operator_config[..]) else { + continue; + }; + format!("{}-sink", connector_op.connector) + } + }; + s.insert(feature); + } } s @@ -323,10 +419,26 @@ impl TryFrom for LogicalProgram { id_map.insert( node.node_index, graph.add_node(LogicalNode { - operator_id: node.node_id, + node_id: node.node_id, description: node.description, - operator_name: OperatorName::try_from(node.operator_name.as_str())?, - operator_config: node.operator_config, + operator_chain: OperatorChain { + operators: node + .operators + .into_iter() + .map(|op| { + Ok(ChainedLogicalOperator { + operator_id: op.operator_id, + operator_name: OperatorName::from_str(&op.operator_name)?, + operator_config: op.operator_config, + }) + }) + .collect::>>()?, + edges: node + .edges + .into_iter() + .map(|e| Ok(e.try_into()?)) + .collect::>>()?, + }, parallelism: node.parallelism as usize, }), ); @@ -343,11 +455,6 @@ impl TryFrom for LogicalProgram { LogicalEdge { edge_type: edge.edge_type().into(), schema: schema.clone().try_into()?, - projection: if edge.projection.is_empty() { - None - } else { - Some(edge.projection.iter().map(|p| *p as usize).collect()) - }, }, ); } @@ -497,11 +604,25 @@ impl From for ArrowProgram { let node = graph.node_weight(idx).unwrap(); api::ArrowNode { node_index: idx.index() as i32, - node_id: node.operator_id.clone(), + node_id: node.node_id, parallelism: node.parallelism as u32, description: node.description.clone(), - operator_name: node.operator_name.to_string(), - operator_config: node.operator_config.clone(), + operators: node + .operator_chain + .operators + .iter() + .map(|op| api::ChainedOperator { + operator_id: op.operator_id.clone(), + operator_name: op.operator_name.to_string(), + operator_config: op.operator_config.clone(), + }) + .collect(), + edges: node + .operator_chain + .edges + .iter() + .map(|edge| edge.clone().into()) + .collect(), } }) .collect(); @@ -518,11 +639,6 @@ impl From for ArrowProgram { target: target.index() as i32, schema: Some(edge.schema.clone().into()), edge_type: edge_type as i32, - projection: edge - .projection - .as_ref() - .map(|p| p.iter().map(|v| *v as u32).collect()) - .unwrap_or_default(), } }) .collect(); diff --git a/crates/arroyo-datastream/src/optimizers.rs b/crates/arroyo-datastream/src/optimizers.rs new file mode 100644 index 000000000..8c46754d8 --- /dev/null +++ b/crates/arroyo-datastream/src/optimizers.rs @@ -0,0 +1,105 @@ +use crate::logical::{LogicalEdgeType, LogicalGraph}; +use petgraph::prelude::*; +use petgraph::visit::NodeRef; +use std::mem; + +pub trait Optimizer { + fn optimize_once(&self, plan: &mut LogicalGraph) -> bool; + + fn optimize(&self, plan: &mut LogicalGraph) { + loop { + if !self.optimize_once(plan) { + break; + } + } + } +} + +pub struct ChainingOptimizer {} + +fn remove_in_place(graph: &mut DiGraph, node: NodeIndex) { + let incoming = graph.edges_directed(node, Incoming).next().unwrap(); + + let parent = incoming.source().id(); + let incoming = incoming.id(); + graph.remove_edge(incoming); + + let outgoing: Vec<_> = graph + .edges_directed(node, Outgoing) + .map(|e| (e.id(), e.target().id())) + .collect(); + + for (edge, target) in outgoing { + let weight = graph.remove_edge(edge).unwrap(); + graph.add_edge(parent, target, weight); + } + + graph.remove_node(node); +} + +impl Optimizer for ChainingOptimizer { + fn optimize_once(&self, plan: &mut LogicalGraph) -> bool { + let node_indices: Vec = plan.node_indices().collect(); + + for &node_idx in &node_indices { + let cur = plan.node_weight(node_idx).unwrap(); + + // sources can't be chained + if cur.operator_chain.is_source() { + continue; + } + + let mut successors = plan.edges_directed(node_idx, Outgoing).collect::>(); + + if successors.len() != 1 { + continue; + } + + let edge = successors.remove(0); + let edge_type = edge.weight().edge_type; + + if edge_type != LogicalEdgeType::Forward { + continue; + } + + let successor_idx = edge.target(); + + let successor_node = plan.node_weight(successor_idx).unwrap(); + + // skip if parallelism doesn't match or successor is a sink + if cur.parallelism != successor_node.parallelism + || successor_node.operator_chain.is_sink() + { + continue; + } + + // skip successors with multiple predecessors + if plan.edges_directed(successor_idx, Incoming).count() > 1 { + continue; + } + + // construct the new node + let mut new_cur = cur.clone(); + + new_cur.description = format!("{} -> {}", cur.description, successor_node.description); + + new_cur + .operator_chain + .operators + .extend(successor_node.operator_chain.operators.clone()); + + new_cur + .operator_chain + .edges + .push(edge.weight().schema.clone()); + + mem::swap(&mut new_cur, plan.node_weight_mut(node_idx).unwrap()); + + // remove the old successor + remove_in_place(plan, successor_idx); + return true; + } + + false + } +} diff --git a/crates/arroyo-metrics/src/lib.rs b/crates/arroyo-metrics/src/lib.rs index ea4388633..54b0d8f43 100644 --- a/crates/arroyo-metrics/src/lib.rs +++ b/crates/arroyo-metrics/src/lib.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; use arroyo_types::{ - TaskInfo, BATCHES_RECV, BATCHES_SENT, BYTES_RECV, BYTES_SENT, DESERIALIZATION_ERRORS, + ChainInfo, BATCHES_RECV, BATCHES_SENT, BYTES_RECV, BYTES_SENT, DESERIALIZATION_ERRORS, MESSAGES_RECV, MESSAGES_SENT, }; use lazy_static::lazy_static; @@ -12,13 +12,13 @@ use prometheus::{ }; pub fn gauge_for_task( - task_info: &TaskInfo, + chain_info: &ChainInfo, name: &'static str, help: &'static str, mut labels: HashMap, ) -> Option { let mut opts = Opts::new(name, help); - labels.extend(task_info.metric_label_map()); + labels.extend(chain_info.metric_label_map()); opts.const_labels = labels; @@ -26,13 +26,13 @@ pub fn gauge_for_task( } pub fn histogram_for_task( - task_info: &TaskInfo, + chain_info: &ChainInfo, name: &'static str, help: &'static str, mut labels: HashMap, buckets: Vec, ) -> Option { - labels.extend(task_info.metric_label_map()); + labels.extend(chain_info.metric_label_map()); let opts = HistogramOpts::new(name, help) .const_labels(labels) .buckets(buckets); @@ -42,7 +42,7 @@ pub fn histogram_for_task( lazy_static! { pub static ref TASK_METRIC_LABELS: Vec<&'static str> = - vec!["operator_id", "subtask_idx", "operator_name"]; + vec!["node_id", "subtask_idx", "operator_name"]; pub static ref MESSAGE_RECV_COUNTER: IntCounterVec = register_int_counter_vec!( MESSAGES_RECV, "Count of messages received by this subtask", @@ -128,25 +128,25 @@ impl TaskCounters { } } - pub fn for_task(&self, task_info: &Arc, f: F) + pub fn for_task(&self, chain_info: &Arc, f: F) where F: Fn(&IntCounter), { - static CACHE: OnceLock), IntCounter>>>> = + static CACHE: OnceLock), IntCounter>>>> = OnceLock::new(); let cache = CACHE.get_or_init(|| Arc::new(RwLock::new(HashMap::new()))); { - if let Some(counter) = cache.read().unwrap().get(&(*self, task_info.clone())) { + if let Some(counter) = cache.read().unwrap().get(&(*self, chain_info.clone())) { f(counter); return; } } let counter = self.metric().with_label_values(&[ - &task_info.operator_id, - &task_info.task_index.to_string(), - &task_info.operator_name, + &chain_info.node_id.to_string(), + &chain_info.task_index.to_string(), + &chain_info.description.to_string(), ]); f(&counter); @@ -154,7 +154,7 @@ impl TaskCounters { cache .write() .unwrap() - .insert((*self, task_info.clone()), counter); + .insert((*self, chain_info.clone()), counter); } } @@ -163,7 +163,7 @@ pub type QueueGauges = Vec>>; pub fn register_queue_gauge( name: &'static str, help: &'static str, - task_info: &TaskInfo, + chain_info: &ChainInfo, out_qs: &[Vec], initial: i64, ) -> QueueGauges { @@ -175,7 +175,7 @@ pub fn register_queue_gauge( .enumerate() .map(|(j, _)| { let mut g = gauge_for_task( - task_info, + chain_info, name, help, labels! { diff --git a/crates/arroyo-operator/src/connector.rs b/crates/arroyo-operator/src/connector.rs index 32745309b..d079879f5 100644 --- a/crates/arroyo-operator/src/connector.rs +++ b/crates/arroyo-operator/src/connector.rs @@ -1,4 +1,4 @@ -use crate::operator::OperatorNode; +use crate::operator::ConstructedOperator; use anyhow::{anyhow, bail}; use arrow::datatypes::{DataType, Field}; use arroyo_rpc::api_types::connections::{ @@ -117,7 +117,7 @@ pub trait Connector: Send { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - ) -> anyhow::Result; + ) -> anyhow::Result; } #[allow(clippy::type_complexity)] #[allow(clippy::wrong_self_convention)] @@ -186,7 +186,7 @@ pub trait ErasedConnector: Send { schema: Option<&ConnectionSchema>, ) -> anyhow::Result; - fn make_operator(&self, config: OperatorConfig) -> anyhow::Result; + fn make_operator(&self, config: OperatorConfig) -> anyhow::Result; } impl ErasedConnector for C { @@ -320,7 +320,7 @@ impl ErasedConnector for C { ) } - fn make_operator(&self, config: OperatorConfig) -> anyhow::Result { + fn make_operator(&self, config: OperatorConfig) -> anyhow::Result { self.make_operator( self.parse_config(&config.connection).map_err(|e| { anyhow!( diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index 33268923f..32fb3610f 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -12,10 +12,11 @@ use arroyo_rpc::grpc::rpc::{CheckpointMetadata, TableConfig, TaskCheckpointEvent use arroyo_rpc::schema_resolver::SchemaResolver; use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp}; use arroyo_state::tables::table_manager::TableManager; -use arroyo_state::{BackingStore, StateBackend}; use arroyo_types::{ - from_micros, ArrowMessage, CheckpointBarrier, SourceError, TaskInfo, UserError, Watermark, + ArrowMessage, ChainInfo, CheckpointBarrier, SignalMessage, SourceError, TaskInfo, UserError, + Watermark, }; +use async_trait::async_trait; use datafusion::common::hash_utils; use rand::Rng; use std::collections::HashMap; @@ -26,7 +27,7 @@ use std::time::{Instant, SystemTime}; use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::sync::Notify; -use tracing::warn; +use tracing::{trace, warn}; pub type QueueItem = ArrowMessage; @@ -233,29 +234,276 @@ impl ContextBuffer { should_flush(self.size(), self.created) } - pub fn finish(self) -> RecordBatch { + pub fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema, - self.buffer.into_iter().map(|mut a| a.finish()).collect(), + self.schema.clone(), + self.buffer.iter_mut().map(|a| a.finish()).collect(), ) .unwrap() } } -pub struct ArrowContext { - pub task_info: Arc, +pub struct SourceContext { + pub out_schema: ArroyoSchema, + pub error_reporter: ErrorReporter, + pub control_tx: Sender, pub control_rx: Receiver, + pub chain_info: Arc, + pub task_info: Arc, + pub table_manager: TableManager, + pub watermarks: WatermarkHolder, +} + +impl SourceContext { + pub fn from_operator( + ctx: OperatorContext, + chain_info: Arc, + control_rx: Receiver, + ) -> Self { + Self { + out_schema: ctx.out_schema.expect("sources must have downstream nodes"), + error_reporter: ErrorReporter { + tx: ctx.control_tx.clone(), + task_info: ctx.task_info.clone(), + }, + control_tx: ctx.control_tx, + control_rx, + chain_info, + task_info: ctx.task_info, + table_manager: ctx.table_manager, + watermarks: ctx.watermarks, + } + } + + pub async fn load_compacted(&mut self, compaction: CompactionResult) { + //TODO: support compaction in the table manager + self.table_manager + .load_compacted(&compaction) + .await + .expect("should be able to load compacted"); + } + + pub async fn report_error(&mut self, message: impl Into, details: impl Into) { + self.error_reporter.report_error(message, details).await; + } + + pub async fn report_user_error(&mut self, error: UserError) { + self.control_tx + .send(ControlResp::Error { + node_id: self.task_info.node_id, + task_index: self.task_info.task_index as usize, + operator_id: self.task_info.operator_id.clone(), + message: error.name, + details: error.details, + }) + .await + .unwrap(); + } +} + +pub struct SourceCollector { + deserializer: Option, + buffer: ContextBuffer, + buffered_error: Option, + error_rate_limiter: RateLimiter, + pub out_schema: ArroyoSchema, + pub(crate) collector: ArrowCollector, + control_tx: Sender, + chain_info: Arc, + task_info: Arc, +} + +impl SourceCollector { + pub fn new( + out_schema: ArroyoSchema, + collector: ArrowCollector, + control_tx: Sender, + chain_info: &Arc, + task_info: &Arc, + ) -> Self { + Self { + buffer: ContextBuffer::new(out_schema.schema.clone()), + out_schema, + collector, + control_tx, + chain_info: chain_info.clone(), + task_info: task_info.clone(), + deserializer: None, + buffered_error: None, + error_rate_limiter: RateLimiter::new(), + } + } + + pub async fn collect(&mut self, record: RecordBatch) { + self.collector.collect(record).await; + } + + pub fn initialize_deserializer_with_resolver( + &mut self, + format: Format, + framing: Option, + bad_data: Option, + schema_resolver: Arc, + ) { + self.deserializer = Some(ArrowDeserializer::with_schema_resolver( + format, + framing, + self.out_schema.clone(), + bad_data.unwrap_or_default(), + schema_resolver, + )); + } + + pub fn initialize_deserializer( + &mut self, + format: Format, + framing: Option, + bad_data: Option, + ) { + if self.deserializer.is_some() { + panic!("Deserialize already initialized"); + } + + self.deserializer = Some(ArrowDeserializer::new( + format, + self.out_schema.clone(), + framing, + bad_data.unwrap_or_default(), + )); + } + + pub fn should_flush(&self) -> bool { + self.buffer.should_flush() + || self + .deserializer + .as_ref() + .map(|d| d.should_flush()) + .unwrap_or(false) + } + + pub async fn deserialize_slice( + &mut self, + msg: &[u8], + time: SystemTime, + additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + ) -> Result<(), UserError> { + let deserializer = self + .deserializer + .as_mut() + .expect("deserializer not initialized!"); + + let errors = deserializer + .deserialize_slice(&mut self.buffer.buffer, msg, time, additional_fields) + .await; + self.collect_source_errors(errors).await?; + + Ok(()) + } + + /// Handling errors and rate limiting error reporting. + /// Considers the `bad_data` option to determine whether to drop or fail on bad data. + async fn collect_source_errors(&mut self, errors: Vec) -> Result<(), UserError> { + let bad_data = self + .deserializer + .as_ref() + .expect("deserializer not initialized") + .bad_data(); + for error in errors { + match error { + SourceError::BadData { details } => match bad_data { + BadData::Drop {} => { + self.error_rate_limiter + .rate_limit(|| async { + warn!("Dropping invalid data: {}", details.clone()); + self.control_tx + .send(ControlResp::Error { + node_id: self.task_info.node_id, + operator_id: self.task_info.operator_id.clone(), + task_index: self.task_info.task_index as usize, + message: "Dropping invalid data".to_string(), + details, + }) + .await + .unwrap(); + }) + .await; + TaskCounters::DeserializationErrors.for_task(&self.chain_info, |c| c.inc()) + } + BadData::Fail {} => { + return Err(UserError::new("Deserialization error", details)); + } + }, + SourceError::Other { name, details } => { + return Err(UserError::new(name, details)); + } + } + } + + Ok(()) + } + + pub async fn flush_buffer(&mut self) -> Result<(), UserError> { + if self.buffer.size() > 0 { + let batch = self.buffer.finish(); + self.collector.collect(batch).await; + } + + if let Some(deserializer) = self.deserializer.as_mut() { + if let Some(buffer) = deserializer.flush_buffer() { + match buffer { + Ok(batch) => { + self.collector.collect(batch).await; + } + Err(e) => { + self.collect_source_errors(vec![e]).await?; + } + } + } + } + + if let Some(error) = self.buffered_error.take() { + return Err(error); + } + + Ok(()) + } + + pub async fn broadcast(&mut self, message: SignalMessage) { + if let Err(e) = self.flush_buffer().await { + self.buffered_error.replace(e); + } + self.collector.broadcast(message).await; + } +} + +pub async fn send_checkpoint_event( + tx: &Sender, + info: &TaskInfo, + barrier: CheckpointBarrier, + event_type: TaskCheckpointEventType, +) { + // These messages are received by the engine control thread, + // which then sends a TaskCheckpointEventReq to the controller. + tx.send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent { + checkpoint_epoch: barrier.epoch, + node_id: info.node_id, + operator_id: info.operator_id.clone(), + subtask_index: info.task_index, + time: SystemTime::now(), + event_type, + })) + .await + .unwrap(); +} + +pub struct OperatorContext { + pub task_info: Arc, pub control_tx: Sender, - pub error_reporter: ErrorReporter, pub watermarks: WatermarkHolder, pub in_schemas: Vec, pub out_schema: Option, - pub collector: ArrowCollector, - buffer: Option, - buffered_error: Option, - error_rate_limiter: RateLimiter, - deserializer: Option, pub table_manager: TableManager, + pub error_reporter: ErrorReporter, } #[derive(Clone)] @@ -268,8 +516,9 @@ impl ErrorReporter { pub async fn report_error(&mut self, message: impl Into, details: impl Into) { self.tx .send(ControlResp::Error { + node_id: self.task_info.node_id, operator_id: self.task_info.operator_id.clone(), - task_index: self.task_info.task_index, + task_index: self.task_info.task_index as usize, message: message.into(), details: details.into(), }) @@ -278,11 +527,16 @@ impl ErrorReporter { } } +#[async_trait] +pub trait Collector: Send { + async fn collect(&mut self, batch: RecordBatch); + async fn broadcast_watermark(&mut self, watermark: Watermark); +} + #[derive(Clone)] pub struct ArrowCollector { - task_info: Arc, + pub chain_info: Arc, out_schema: Option, - projection: Option>, out_qs: Vec>, tx_queue_rem_gauges: QueueGauges, tx_queue_size_gauges: QueueGauges, @@ -345,36 +599,26 @@ fn repartition<'a>( } } -impl ArrowCollector { - pub async fn collect(&mut self, record: RecordBatch) { +#[async_trait] +impl Collector for ArrowCollector { + async fn collect(&mut self, record: RecordBatch) { TaskCounters::MessagesSent - .for_task(&self.task_info, |c| c.inc_by(record.num_rows() as u64)); - TaskCounters::BatchesSent.for_task(&self.task_info, |c| c.inc()); - TaskCounters::BytesSent.for_task(&self.task_info, |c| { + .for_task(&self.chain_info, |c| c.inc_by(record.num_rows() as u64)); + TaskCounters::BatchesSent.for_task(&self.chain_info, |c| c.inc()); + TaskCounters::BytesSent.for_task(&self.chain_info, |c| { c.inc_by(record.get_array_memory_size() as u64) }); let out_schema = self .out_schema .as_ref() - .unwrap_or_else(|| panic!("No out-schema in {}!", self.task_info.operator_name)); - - let record = if let Some(projection) = &self.projection { - record.project(projection).unwrap_or_else(|e| { - panic!( - "failed to project for operator {}: {}", - self.task_info.operator_id, e - ) - }) - } else { - record - }; + .unwrap_or_else(|| panic!("No out-schema in {}!", self.chain_info)); let record = RecordBatch::try_new(out_schema.schema.clone(), record.columns().to_vec()) .unwrap_or_else(|e| { panic!( "Data does not match expected schema for {}: {:?}. expected schema:\n{:#?}\n, actual schema:\n{:#?}", - self.task_info.operator_id, e, out_schema.schema, record.schema() + self.chain_info, e, out_schema.schema, record.schema() ); }); @@ -402,64 +646,21 @@ impl ArrowCollector { } } - pub async fn broadcast(&mut self, message: ArrowMessage) { - for out_node in &self.out_qs { - for q in out_node { - q.send(message.clone()).await.unwrap_or_else(|e| { - panic!( - "failed to broadcast message <{:?}> for operator {}: {}", - message, self.task_info.operator_id, e - ) - }); - } - } + async fn broadcast_watermark(&mut self, watermark: Watermark) { + self.broadcast(SignalMessage::Watermark(watermark)).await; } } -impl ArrowContext { - #[allow(clippy::too_many_arguments)] - pub async fn new( - task_info: TaskInfo, - restore_from: Option, - control_rx: Receiver, - control_tx: Sender, - input_partitions: usize, - in_schemas: Vec, +impl ArrowCollector { + pub fn new( + chain_info: Arc, out_schema: Option, - projection: Option>, out_qs: Vec>, - tables: HashMap, ) -> Self { - let (watermark, metadata) = if let Some(metadata) = restore_from { - let (watermark, operator_metadata) = { - let metadata = StateBackend::load_operator_metadata( - &task_info.job_id, - &task_info.operator_id, - metadata.epoch, - ) - .await - .expect("lookup should succeed") - .expect("require metadata"); - ( - metadata - .operator_metadata - .as_ref() - .unwrap() - .min_watermark - .map(from_micros), - metadata, - ) - }; - - (watermark, Some(operator_metadata)) - } else { - (None, None) - }; - let tx_queue_size_gauges = register_queue_gauge( "arroyo_worker_tx_queue_size", "Size of a tx queue", - &task_info, + &chain_info, &out_qs, config().worker.queue_size as i64, ); @@ -467,7 +668,7 @@ impl ArrowContext { let tx_queue_rem_gauges = register_queue_gauge( "arroyo_worker_tx_queue_rem", "Remaining space in a tx queue", - &task_info, + &chain_info, &out_qs, config().worker.queue_size as i64, ); @@ -475,27 +676,61 @@ impl ArrowContext { let tx_queue_bytes_gauges = register_queue_gauge( "arroyo_worker_tx_bytes", "Number of bytes queued in a tx queue", - &task_info, + &chain_info, &out_qs, 0, ); - let task_info = Arc::new(task_info); - // initialize counters so that tasks that never produce data still report 0 for m in TaskCounters::variants() { - // just initialize it - m.for_task(&task_info, |_| {}); + m.for_task(&chain_info, |_| {}); + } + + Self { + chain_info, + out_schema, + out_qs, + tx_queue_rem_gauges, + tx_queue_size_gauges, + tx_queue_bytes_gauges, + } + } + + pub async fn broadcast(&mut self, message: SignalMessage) { + trace!("[{}] Broadcast {:?}", self.chain_info, message); + for out_node in &self.out_qs { + for q in out_node { + q.send(ArrowMessage::Signal(message.clone())) + .await + .unwrap_or_else(|e| { + panic!( + "failed to broadcast message <{:?}> for operator {}: {}", + message, self.chain_info, e + ) + }); + } } + } +} - let table_manager = - TableManager::new(task_info.clone(), tables, control_tx.clone(), metadata) +impl OperatorContext { + #[allow(clippy::too_many_arguments)] + pub async fn new( + task_info: Arc, + restore_from: Option<&CheckpointMetadata>, + control_tx: Sender, + input_partitions: usize, + in_schemas: Vec, + out_schema: Option, + tables: HashMap, + ) -> Self { + let (table_manager, watermark) = + TableManager::load(task_info.clone(), tables, control_tx.clone(), restore_from) .await .expect("should be able to create TableManager"); Self { task_info: task_info.clone(), - control_rx, control_tx: control_tx.clone(), watermarks: WatermarkHolder::new(vec![ watermark.map(Watermark::EventTime); @@ -503,24 +738,11 @@ impl ArrowContext { ]), in_schemas, out_schema: out_schema.clone(), - collector: ArrowCollector { - task_info: task_info.clone(), - out_qs, - tx_queue_rem_gauges, - tx_queue_size_gauges, - tx_queue_bytes_gauges, - out_schema: out_schema.clone(), - projection, - }, + table_manager, error_reporter: ErrorReporter { tx: control_tx, task_info, }, - buffer: None, - error_rate_limiter: RateLimiter::new(), - deserializer: None, - buffered_error: None, - table_manager, } } @@ -532,99 +754,7 @@ impl ArrowContext { self.watermarks.last_present_watermark() } - pub async fn flush_buffer(&mut self) -> Result<(), UserError> { - if self.buffer.is_none() { - return Ok(()); - } - - if self.buffer.as_ref().unwrap().size() > 0 { - let buffer = self.buffer.take().unwrap(); - let batch = buffer.finish(); - self.collector.collect(batch).await; - self.buffer = Some(ContextBuffer::new( - self.out_schema.as_ref().map(|t| t.schema.clone()).unwrap(), - )); - } - - if let Some(deserializer) = self.deserializer.as_mut() { - if let Some(buffer) = deserializer.flush_buffer() { - match buffer { - Ok(batch) => { - self.collector.collect(batch).await; - } - Err(e) => { - self.collect_source_errors(vec![e]).await?; - } - } - } - } - - if let Some(error) = self.buffered_error.take() { - return Err(error); - } - - Ok(()) - } - - pub async fn collect(&mut self, record: RecordBatch) { - self.collector.collect(record).await; - } - - pub fn should_flush(&self) -> bool { - self.buffer - .as_ref() - .map(|b| b.should_flush()) - .unwrap_or(false) - || self - .deserializer - .as_ref() - .map(|d| d.should_flush()) - .unwrap_or(false) - } - - pub async fn broadcast(&mut self, message: ArrowMessage) { - if let Err(e) = self.flush_buffer().await { - self.buffered_error.replace(e); - } - self.collector.broadcast(message).await; - } - - pub async fn report_error(&mut self, message: impl Into, details: impl Into) { - self.error_reporter.report_error(message, details).await; - } - - pub async fn report_user_error(&mut self, error: UserError) { - self.control_tx - .send(ControlResp::Error { - operator_id: self.task_info.operator_id.clone(), - task_index: self.task_info.task_index, - message: error.name, - details: error.details, - }) - .await - .unwrap(); - } - - pub async fn send_checkpoint_event( - &mut self, - barrier: CheckpointBarrier, - event_type: TaskCheckpointEventType, - ) { - // These messages are received by the engine control thread, - // which then sends a TaskCheckpointEventReq to the controller. - self.control_tx - .send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent { - checkpoint_epoch: barrier.epoch, - operator_id: self.task_info.operator_id.clone(), - subtask_index: self.task_info.task_index as u32, - time: SystemTime::now(), - event_type, - })) - .await - .unwrap(); - } - - pub async fn load_compacted(&mut self, compaction: CompactionResult) { + pub async fn load_compacted(&mut self, compaction: &CompactionResult) { //TODO: support compaction in the table manager self.table_manager .load_compacted(compaction) @@ -632,110 +762,8 @@ impl ArrowContext { .expect("should be able to load compacted"); } - pub fn initialize_deserializer( - &mut self, - format: Format, - framing: Option, - bad_data: Option, - ) { - if self.deserializer.is_some() { - panic!("Deserialize already initialized"); - } - - self.deserializer = Some(ArrowDeserializer::new( - format, - self.out_schema.as_ref().expect("no out schema").clone(), - framing, - bad_data.unwrap_or_default(), - )); - } - - pub fn initialize_deserializer_with_resolver( - &mut self, - format: Format, - framing: Option, - bad_data: Option, - schema_resolver: Arc, - ) { - self.deserializer = Some(ArrowDeserializer::with_schema_resolver( - format, - framing, - self.out_schema.as_ref().expect("no out schema").clone(), - bad_data.unwrap_or_default(), - schema_resolver, - )); - } - - pub async fn deserialize_slice( - &mut self, - msg: &[u8], - time: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, - ) -> Result<(), UserError> { - let deserializer = self - .deserializer - .as_mut() - .expect("deserializer not initialized!"); - - if self.buffer.is_none() { - self.buffer = self - .out_schema - .as_ref() - .map(|t| ContextBuffer::new(t.schema.clone())); - } - - let errors = deserializer - .deserialize_slice( - &mut self.buffer.as_mut().expect("no out schema").buffer, - msg, - time, - additional_fields, - ) - .await; - self.collect_source_errors(errors).await?; - - Ok(()) - } - - /// Handling errors and rate limiting error reporting. - /// Considers the `bad_data` option to determine whether to drop or fail on bad data. - async fn collect_source_errors(&mut self, errors: Vec) -> Result<(), UserError> { - let bad_data = self - .deserializer - .as_ref() - .expect("deserializer not initialized") - .bad_data(); - for error in errors { - match error { - SourceError::BadData { details } => match bad_data { - BadData::Drop {} => { - self.error_rate_limiter - .rate_limit(|| async { - warn!("Dropping invalid data: {}", details.clone()); - self.control_tx - .send(ControlResp::Error { - operator_id: self.task_info.operator_id.clone(), - task_index: self.task_info.task_index, - message: "Dropping invalid data".to_string(), - details, - }) - .await - .unwrap(); - }) - .await; - TaskCounters::DeserializationErrors.for_task(&self.task_info, |c| c.inc()) - } - BadData::Fail {} => { - return Err(UserError::new("Deserialization error", details)); - } - }, - SourceError::Other { name, details } => { - return Err(UserError::new(name, details)); - } - } - } - - Ok(()) + pub async fn report_error(&mut self, message: impl Into, details: impl Into) { + self.error_reporter.report_error(message, details).await; } } @@ -804,13 +832,11 @@ mod tests { let record = RecordBatch::try_new(schema.clone(), columns).unwrap(); - let task_info = Arc::new(TaskInfo { + let chain_info = Arc::new(ChainInfo { job_id: "test-job".to_string(), - operator_name: "test-operator".to_string(), - operator_id: "test-operator-1".to_string(), + node_id: 1, + description: "test-operator".to_string(), task_index: 0, - parallelism: 1, - key_range: 0..=1, }); let out_qs = vec![vec![tx1, tx2]]; @@ -818,7 +844,7 @@ mod tests { let tx_queue_size_gauges = register_queue_gauge( "arroyo_worker_tx_queue_size", "Size of a tx queue", - &task_info, + &chain_info, &out_qs, 0, ); @@ -826,7 +852,7 @@ mod tests { let tx_queue_rem_gauges = register_queue_gauge( "arroyo_worker_tx_queue_rem", "Remaining space in a tx queue", - &task_info, + &chain_info, &out_qs, 0, ); @@ -834,15 +860,14 @@ mod tests { let tx_queue_bytes_gauges = register_queue_gauge( "arroyo_worker_tx_bytes", "Number of bytes queued in a tx queue", - &task_info, + &chain_info, &out_qs, 0, ); let mut collector = ArrowCollector { - task_info, + chain_info, out_schema: Some(ArroyoSchema::new_keyed(schema, 1, vec![0])), - projection: None, out_qs, tx_queue_rem_gauges, tx_queue_size_gauges, diff --git a/crates/arroyo-operator/src/lib.rs b/crates/arroyo-operator/src/lib.rs index ded2a1496..602414aef 100644 --- a/crates/arroyo-operator/src/lib.rs +++ b/crates/arroyo-operator/src/lib.rs @@ -9,12 +9,12 @@ use crate::inq_reader::InQReader; use arrow::array::types::{TimestampNanosecondType, UInt64Type}; use arrow::array::{Array, PrimitiveArray, RecordBatch, UInt64Array}; use arrow::compute::kernels::numeric::{div, rem}; -use arroyo_types::{ArrowMessage, CheckpointBarrier, Data, SignalMessage, TaskInfoRef}; +use arroyo_types::{ArrowMessage, CheckpointBarrier, Data, SignalMessage, TaskInfo}; use bincode::{Decode, Encode}; -use crate::context::ArrowContext; -use crate::operator::Registry; -use operator::{OperatorConstructor, OperatorNode}; +use crate::context::OperatorContext; +use crate::operator::{ConstructedOperator, Registry}; +use operator::OperatorConstructor; use tokio_stream::Stream; pub mod connector; @@ -64,6 +64,7 @@ pub enum ControlOutcome { Stop, StopAndSendStop, Finish, + StopAfterCommit, } #[derive(Debug)] @@ -113,7 +114,7 @@ impl CheckpointCounter { #[allow(unused)] pub struct RunContext + Send + Sync> { - pub task_info: TaskInfoRef, + pub task_info: Arc, pub name: String, pub counter: CheckpointCounter, pub closed: HashSet, @@ -132,26 +133,26 @@ pub struct ArrowTimerValue { } pub trait ErasedConstructor: Send { - fn with_config(&self, config: Vec, registry: Arc) - -> anyhow::Result; + fn with_config( + &self, + config: &[u8], + registry: Arc, + ) -> anyhow::Result; } impl ErasedConstructor for T { fn with_config( &self, - config: Vec, + config: &[u8], registry: Arc, - ) -> anyhow::Result { - self.with_config( - prost::Message::decode(&mut config.as_slice()).unwrap(), - registry, - ) + ) -> anyhow::Result { + self.with_config(prost::Message::decode(config).unwrap(), registry) } } pub fn get_timestamp_col<'a>( batch: &'a RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) -> &'a PrimitiveArray { batch .column(ctx.out_schema.as_ref().unwrap().timestamp_index) diff --git a/crates/arroyo-operator/src/operator.rs b/crates/arroyo-operator/src/operator.rs index c5d5498e9..179d8dcc9 100644 --- a/crates/arroyo-operator/src/operator.rs +++ b/crates/arroyo-operator/src/operator.rs @@ -1,4 +1,7 @@ -use crate::context::{ArrowContext, BatchReceiver}; +use crate::context::{ + send_checkpoint_event, ArrowCollector, BatchReceiver, BatchSender, Collector, OperatorContext, + SourceCollector, SourceContext, +}; use crate::inq_reader::InQReader; use crate::udfs::{ArroyoUdaf, UdafArg}; use crate::{CheckpointCounter, ControlOutcome, SourceFinishType}; @@ -8,10 +11,14 @@ use arrow::datatypes::DataType; use arrow::datatypes::Schema; use arroyo_datastream::logical::{DylibUdfConfig, PythonUdfConfig}; use arroyo_metrics::TaskCounters; +use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::rpc::{TableConfig, TaskCheckpointEventType}; use arroyo_rpc::{ControlMessage, ControlResp}; +use arroyo_state::tables::table_manager::TableManager; use arroyo_storage::StorageProvider; -use arroyo_types::{ArrowMessage, CheckpointBarrier, SignalMessage, Watermark}; +use arroyo_types::{ + ArrowMessage, ChainInfo, CheckpointBarrier, SignalMessage, TaskInfo, Watermark, +}; use arroyo_udf_host::parse::inner_type; use arroyo_udf_host::{ContainerOrLocal, LocalUdf, SyncUdfDylib, UdfDylib, UdfInterface}; use arroyo_udf_python::PythonUDF; @@ -26,6 +33,7 @@ use datafusion::logical_expr::{ use datafusion::physical_plan::{displayable, ExecutionPlan}; use dlopen2::wrapper::Container; use futures::future::OptionFuture; +use futures::stream::FuturesUnordered; use std::any::Any; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; @@ -35,9 +43,11 @@ use std::io::ErrorKind; use std::path::Path; use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, SystemTime}; use tokio::fs::OpenOptions; use tokio::io::AsyncWriteExt; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::Barrier; use tokio_stream::StreamExt; use tracing::{debug, error, info, trace, warn, Instrument}; @@ -48,133 +58,236 @@ pub trait OperatorConstructor: Send { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result; + ) -> anyhow::Result; } -pub enum OperatorNode { +pub struct SourceNode { + pub operator: Box, + pub context: OperatorContext, +} + +pub enum ConstructedOperator { Source(Box), Operator(Box), } -// TODO: this is required currently because the FileSystemSink code isn't sync -unsafe impl Sync for OperatorNode {} - -impl OperatorNode { +impl ConstructedOperator { pub fn from_source(source: Box) -> Self { - OperatorNode::Source(source) + Self::Source(source) } pub fn from_operator(operator: Box) -> Self { - OperatorNode::Operator(operator) + Self::Operator(operator) } pub fn name(&self) -> String { match self { - OperatorNode::Source(s) => s.name(), - OperatorNode::Operator(s) => s.name(), + Self::Source(s) => s.name(), + Self::Operator(s) => s.name(), } } pub fn display(&self) -> DisplayableOperator { match self { - OperatorNode::Source(_) => DisplayableOperator { + Self::Source(_) => DisplayableOperator { name: self.name().into(), fields: vec![], }, - OperatorNode::Operator(op) => op.display(), + Self::Operator(op) => op.display(), + } + } +} + +pub enum OperatorNode { + Source(SourceNode), + Chained(ChainedOperator), +} + +// TODO: this is required currently because the FileSystemSink code isn't sync +unsafe impl Sync for OperatorNode {} + +impl OperatorNode { + pub fn name(&self) -> String { + match self { + OperatorNode::Source(s) => s.operator.name(), + OperatorNode::Chained(s) => { + format!( + "[{}]", + s.iter() + .map(|(o, _)| o.name()) + .collect::>() + .join(" -> ") + ) + } } } - pub fn tables(&self) -> HashMap { + pub fn task_info(&self) -> &Arc { match self { - OperatorNode::Source(s) => s.tables(), - OperatorNode::Operator(s) => s.tables(), + OperatorNode::Source(s) => &s.context.task_info, + OperatorNode::Chained(s) => &s.context.task_info, + } + } + + pub fn operator_ids(&self) -> Vec { + match self { + OperatorNode::Source(s) => vec![s.context.task_info.operator_id.clone()], + OperatorNode::Chained(s) => s + .iter() + .map(|(_, ctx)| ctx.task_info.operator_id.clone()) + .collect(), } } async fn run_behavior( - &mut self, - ctx: &mut ArrowContext, + self, + chain_info: &Arc, + control_tx: Sender, + control_rx: Receiver, in_qs: &mut [BatchReceiver], ready: Arc, - ) -> Option { + mut collector: ArrowCollector, + ) { match self { - OperatorNode::Source(s) => { - s.on_start(ctx).await; + OperatorNode::Source(mut s) => { + let mut source_context = + SourceContext::from_operator(s.context, chain_info.clone(), control_rx); + + let mut collector = SourceCollector::new( + source_context.out_schema.clone(), + collector, + control_tx.clone(), + &source_context.chain_info, + &source_context.task_info, + ); + + s.operator.on_start(&mut source_context).await; ready.wait().await; info!( "Running source {}-{}", - ctx.task_info.operator_name, ctx.task_info.task_index + source_context.task_info.operator_name, source_context.task_info.operator_name ); - ctx.control_tx + source_context + .control_tx .send(ControlResp::TaskStarted { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, + node_id: source_context.task_info.node_id, + task_index: source_context.task_info.task_index as usize, start_time: SystemTime::now(), }) .await .unwrap(); - let result = s.run(ctx).await; + let result = s.operator.run(&mut source_context, &mut collector).await; - s.on_close(ctx).await; + s.operator + .on_close(&mut source_context, &mut collector) + .await; - result.into() + if let Some(final_message) = result.into() { + collector.broadcast(final_message).await; + } + } + OperatorNode::Chained(mut o) => { + let result = operator_run_behavior( + &mut o, + in_qs, + control_tx, + control_rx, + &mut collector, + ready, + ) + .await; + if let Some(final_message) = result { + collector.broadcast(final_message).await; + } } - OperatorNode::Operator(o) => operator_run_behavior(o, ctx, in_qs, ready).await, + } + } + + fn node_id(&self) -> u32 { + match self { + OperatorNode::Source(s) => s.context.task_info.node_id, + OperatorNode::Chained(s) => s.context.task_info.node_id, + } + } + + fn task_index(&self) -> u32 { + match self { + OperatorNode::Source(s) => s.context.task_info.task_index, + OperatorNode::Chained(s) => s.context.task_info.task_index, } } pub async fn start( - mut self: Box, - mut ctx: ArrowContext, + self: Box, + control_tx: Sender, + control_rx: Receiver, mut in_qs: Vec, + out_qs: Vec>, + out_schema: Option, ready: Arc, ) { info!( - "Starting task {}-{}", - ctx.task_info.operator_name, ctx.task_info.task_index + "Starting node {}-{} ({})", + self.node_id(), + self.task_index(), + self.name() ); - let final_message = self.run_behavior(&mut ctx, &mut in_qs, ready).await; - - if let Some(final_message) = final_message { - ctx.broadcast(ArrowMessage::Signal(final_message)).await; - } + let chain_info = Arc::new(ChainInfo { + job_id: self.task_info().job_id.clone(), + node_id: self.node_id(), + description: self.name(), + task_index: self.task_index(), + }); + + let collector = ArrowCollector::new(chain_info.clone(), out_schema, out_qs); + + self.run_behavior( + &chain_info, + control_tx.clone(), + control_rx, + &mut in_qs, + ready, + collector, + ) + .await; info!( - "Task finished {}-{}", - ctx.task_info.operator_name, ctx.task_info.task_index + "Task finished {}-{} ({})", + chain_info.node_id, chain_info.task_index, chain_info.description ); - ctx.control_tx + control_tx .send(ControlResp::TaskFinished { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, + node_id: chain_info.node_id, + task_index: chain_info.task_index as usize, }) .await .expect("control response unwrap"); } } -async fn run_checkpoint(checkpoint_barrier: CheckpointBarrier, ctx: &mut ArrowContext) -> bool { - let watermark = ctx.watermarks.last_present_watermark(); - - ctx.table_manager +async fn run_checkpoint( + checkpoint_barrier: CheckpointBarrier, + task_info: &TaskInfo, + watermark: Option, + table_manager: &mut TableManager, + control_tx: &Sender, +) { + table_manager .checkpoint(checkpoint_barrier, watermark) .await; - ctx.send_checkpoint_event(checkpoint_barrier, TaskCheckpointEventType::FinishedSync) - .await; - - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Barrier( + send_checkpoint_event( + control_tx, + task_info, checkpoint_barrier, - ))) + TaskCheckpointEventType::FinishedSync, + ) .await; - - checkpoint_barrier.then_stop } #[async_trait] @@ -186,52 +299,592 @@ pub trait SourceOperator: Send + 'static { } #[allow(unused_variables)] - async fn on_start(&mut self, ctx: &mut ArrowContext) {} + async fn on_start(&mut self, ctx: &mut SourceContext) {} - async fn run(&mut self, ctx: &mut ArrowContext) -> SourceFinishType; + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType; #[allow(unused_variables)] - async fn on_close(&mut self, ctx: &mut ArrowContext) {} + async fn on_close(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) {} async fn start_checkpoint( &mut self, checkpoint_barrier: CheckpointBarrier, - ctx: &mut ArrowContext, + ctx: &mut SourceContext, + collector: &mut SourceCollector, ) -> bool { - ctx.send_checkpoint_event( + send_checkpoint_event( + &ctx.control_tx, + &ctx.task_info, + checkpoint_barrier, + TaskCheckpointEventType::StartedCheckpointing, + ) + .await; + + run_checkpoint( checkpoint_barrier, + &ctx.task_info, + ctx.watermarks.last_present_watermark(), + &mut ctx.table_manager, + &ctx.control_tx, + ) + .await; + + collector + .broadcast(SignalMessage::Barrier(checkpoint_barrier)) + .await; + + checkpoint_barrier.then_stop + } +} + +macro_rules! call_with_collector { + ($self:expr, $final_collector:expr, $name:ident, $arg:expr) => { + match &mut $self.next { + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + index: 0, + in_partitions: 1, + final_collector: $final_collector, + }; + + $self + .operator + .$name($arg, &mut $self.context, &mut collector) + .await; + } + None => { + $self + .operator + .$name($arg, &mut $self.context, $final_collector) + .await; + } + } + }; +} + +pub struct ChainedCollector<'a, 'b> { + cur: &'a mut ChainedOperator, + final_collector: &'b mut ArrowCollector, + index: usize, + in_partitions: usize, +} + +#[async_trait] +impl<'a, 'b> Collector for ChainedCollector<'a, 'b> +where + 'b: 'a, + 'a: 'b, +{ + async fn collect(&mut self, batch: RecordBatch) { + if let Some(next) = &mut self.cur.next { + let mut collector = ChainedCollector { + cur: next, + final_collector: self.final_collector, + // all chained operators (other than the first one) must have a single input + index: 0, + in_partitions: 1, + }; + + self.cur + .operator + .process_batch_index( + self.index, + self.in_partitions, + batch, + &mut self.cur.context, + &mut collector, + ) + .await; + } else { + self.cur + .operator + .process_batch_index( + self.index, + self.in_partitions, + batch, + &mut self.cur.context, + self.final_collector, + ) + .await; + }; + } + + async fn broadcast_watermark(&mut self, watermark: Watermark) { + self.cur + .handle_watermark(watermark, self.index, self.final_collector) + .await; + } +} + +pub struct ChainedOperator { + pub operator: Box, + pub context: OperatorContext, + pub next: Option>, +} + +impl ChainedOperator { + pub fn new(operator: Box, context: OperatorContext) -> Self { + Self { + operator, + context, + next: None, + } + } +} + +pub struct ChainIteratorMut<'a> { + current: Option<&'a mut ChainedOperator>, +} + +impl<'a> Iterator for ChainIteratorMut<'a> { + type Item = (&'a mut dyn ArrowOperator, &'a mut OperatorContext); + + fn next(&mut self) -> Option { + if let Some(current) = self.current.take() { + let next = current.next.as_deref_mut(); + self.current = next; + Some((current.operator.as_mut(), &mut current.context)) + } else { + None + } + } +} + +pub struct ChainIterator<'a> { + current: Option<&'a ChainedOperator>, +} + +impl<'a> Iterator for ChainIterator<'a> { + type Item = (&'a dyn ArrowOperator, &'a OperatorContext); + + fn next(&mut self) -> Option { + if let Some(current) = self.current.take() { + let next = current.next.as_deref(); + self.current = next; + Some((current.operator.as_ref(), ¤t.context)) + } else { + None + } + } +} + +struct IndexedFuture { + f: Pin> + Send>>, + i: usize, +} + +impl Future for IndexedFuture { + type Output = (usize, Box); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.f.as_mut().poll(cx) { + Poll::Ready(r) => Poll::Ready((self.i, r)), + Poll::Pending => Poll::Pending, + } + } +} + +impl ChainedOperator { + async fn handle_controller_message( + &mut self, + control_message: &ControlMessage, + shutdown_after_commit: bool, + ) -> bool { + match control_message { + ControlMessage::Checkpoint(_) => { + error!("shouldn't receive checkpoint") + } + ControlMessage::Stop { .. } => { + error!("shouldn't receive stop") + } + ControlMessage::Commit { epoch, commit_data } => { + assert!( + self.next.is_none(), + "can only commit sinks, which cannot be chained" + ); + self.operator + .handle_commit(*epoch, commit_data, &mut self.context) + .await; + return shutdown_after_commit; + } + ControlMessage::LoadCompacted { compacted } => { + self.iter_mut() + .find(|(_, ctx)| ctx.task_info.operator_id == compacted.operator_id) + .unwrap_or_else(|| { + panic!( + "could not load compacted data for unknown operator '{}'", + compacted.operator_id + ) + }) + .1 + .load_compacted(compacted) + .await; + } + ControlMessage::NoOp => {} + } + + false + } + + pub fn iter(&self) -> ChainIterator { + ChainIterator { + current: Some(self), + } + } + + pub fn iter_mut(&mut self) -> ChainIteratorMut { + ChainIteratorMut { + current: Some(self), + } + } + + fn name(&self) -> String { + self.iter() + .map(|(op, _)| op.name()) + .collect::>() + .join(" -> ") + } + + fn tick_interval(&self) -> Option { + self.iter().filter_map(|(op, _)| op.tick_interval()).min() + } + + async fn on_start(&mut self) { + for (op, ctx) in self.iter_mut() { + op.on_start(ctx).await; + } + } + + async fn process_batch_index<'a, 'b>( + &'a mut self, + index: usize, + in_partitions: usize, + batch: RecordBatch, + final_collector: &'b mut ArrowCollector, + ) where + 'a: 'b, + { + let mut collector = ChainedCollector { + cur: self, + index, + in_partitions, + final_collector, + }; + + collector.collect(batch).await; + } + + #[allow(clippy::type_complexity)] + fn future_to_poll( + &mut self, + ) -> Option)> + Send>>> { + let futures = self + .iter_mut() + .enumerate() + .filter_map(|(i, (op, _))| { + Some(IndexedFuture { + f: op.future_to_poll()?, + i, + }) + }) + .collect::>(); + + match futures.len() { + 0 => None, + 1 => Some(Box::pin(futures.into_iter().next().unwrap())), + _ => { + Some(Box::pin(async move { + let mut futures = FuturesUnordered::from_iter(futures); + // we've guaranteed that the future unordered has at least one element so this + // will never unwrap + futures.next().await.unwrap() + })) + } + } + } + + #[allow(clippy::too_many_arguments)] + async fn handle_control_message( + &mut self, + idx: usize, + message: &SignalMessage, + counter: &mut CheckpointCounter, + closed: &mut HashSet, + in_partitions: usize, + control_tx: &Sender, + chain_info: &ChainInfo, + collector: &mut ArrowCollector, + ) -> ControlOutcome { + match message { + SignalMessage::Barrier(t) => { + debug!("received barrier in {}[{}]", chain_info, idx); + + if counter.all_clear() { + control_tx + .send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent { + checkpoint_epoch: t.epoch, + node_id: chain_info.node_id, + operator_id: self.context.task_info.operator_id.clone(), + subtask_index: chain_info.task_index, + time: SystemTime::now(), + event_type: TaskCheckpointEventType::StartedAlignment, + })) + .await + .unwrap(); + } + + if counter.mark(idx, t) { + debug!("Checkpointing {chain_info}"); + + self.run_checkpoint(t, control_tx, collector).await; + + collector.broadcast(SignalMessage::Barrier(*t)).await; + + if t.then_stop { + // if this is a committing operator, we need to wait for the commit message + // before shutting down; otherwise we just stop + return if self.operator.is_committing() { + ControlOutcome::StopAfterCommit + } else { + return ControlOutcome::Stop; + }; + } + } + } + SignalMessage::Watermark(watermark) => { + debug!("received watermark {:?} in {}", watermark, chain_info,); + + self.handle_watermark(*watermark, idx, collector).await; + } + SignalMessage::Stop => { + closed.insert(idx); + if closed.len() == in_partitions { + return ControlOutcome::StopAndSendStop; + } + } + SignalMessage::EndOfData => { + closed.insert(idx); + if closed.len() == in_partitions { + return ControlOutcome::Finish; + } + } + } + ControlOutcome::Continue + } + + async fn handle_watermark( + &mut self, + watermark: Watermark, + index: usize, + final_collector: &mut ArrowCollector, + ) { + trace!( + "handling watermark {:?} for {}", + watermark, + self.context.task_info, + ); + + let watermark = self + .context + .watermarks + .set(index, watermark) + .expect("watermark index is too big"); + + let Some(watermark) = watermark else { + return; + }; + + if let Watermark::EventTime(_t) = watermark { + // TOOD: pass to table_manager + } + + match &mut self.next { + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + index: 0, + in_partitions: 1, + final_collector, + }; + + let watermark = self + .operator + .handle_watermark(watermark, &mut self.context, &mut collector) + .await; + + if let Some(watermark) = watermark { + Box::pin(next.handle_watermark(watermark, 0, final_collector)).await; + } + } + None => { + let watermark = self + .operator + .handle_watermark(watermark, &mut self.context, final_collector) + .await; + if let Some(watermark) = watermark { + final_collector + .broadcast(SignalMessage::Watermark(watermark)) + .await; + } + } + } + } + + async fn handle_future_result( + &mut self, + op_index: usize, + result: Box, + final_collector: &mut ArrowCollector, + ) { + let mut op = self; + for _ in 0..op_index { + op = op + .next + .as_mut() + .expect("Future produced from operator index larger than chain size"); + } + + match &mut op.next { + None => { + op.operator + .handle_future_result(result, &mut op.context, final_collector) + .await; + } + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + final_collector, + index: 0, + in_partitions: 1, + }; + op.operator + .handle_future_result(result, &mut op.context, &mut collector) + .await; + } + } + } + + async fn run_checkpoint( + &mut self, + t: &CheckpointBarrier, + control_tx: &Sender, + final_collector: &mut ArrowCollector, + ) { + send_checkpoint_event( + control_tx, + &self.context.task_info, + *t, TaskCheckpointEventType::StartedCheckpointing, ) .await; - run_checkpoint(checkpoint_barrier, ctx).await + call_with_collector!(self, final_collector, handle_checkpoint, *t); + + send_checkpoint_event( + control_tx, + &self.context.task_info, + *t, + TaskCheckpointEventType::FinishedOperatorSetup, + ) + .await; + + let last_watermark = self.context.watermarks.last_present_watermark(); + + run_checkpoint( + *t, + &self.context.task_info, + last_watermark, + &mut self.context.table_manager, + control_tx, + ) + .await; + + if let Some(next) = &mut self.next { + Box::pin(next.run_checkpoint(t, control_tx, final_collector)).await; + } + } + + async fn handle_tick(&mut self, tick: u64, final_collector: &mut ArrowCollector) { + match &mut self.next { + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + index: 0, + in_partitions: 1, + final_collector, + }; + self.operator + .handle_tick(tick, &mut self.context, &mut collector) + .await; + Box::pin(next.handle_tick(tick, final_collector)).await; + } + None => { + self.operator + .handle_tick(tick, &mut self.context, final_collector) + .await; + } + } + } + + async fn on_close( + &mut self, + final_message: &Option, + final_collector: &mut ArrowCollector, + ) { + match &mut self.next { + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + index: 0, + in_partitions: 1, + final_collector, + }; + + self.operator + .on_close(final_message, &mut self.context, &mut collector) + .await; + + Box::pin(next.on_close(final_message, final_collector)).await; + } + None => { + self.operator + .on_close(final_message, &mut self.context, final_collector) + .await; + } + } } } async fn operator_run_behavior( - this: &mut Box, - ctx: &mut ArrowContext, + this: &mut ChainedOperator, in_qs: &mut [BatchReceiver], + control_tx: Sender, + mut control_rx: Receiver, + collector: &mut ArrowCollector, ready: Arc, ) -> Option { - this.on_start(ctx).await; + this.on_start().await; + + let chain_info = &mut collector.chain_info.clone(); ready.wait().await; - info!( - "Running operator {}-{}", - ctx.task_info.operator_name, ctx.task_info.task_index - ); - ctx.control_tx + info!("Running node {}", chain_info); + + control_tx .send(ControlResp::TaskStarted { - operator_id: ctx.task_info.operator_id.clone(), - task_index: ctx.task_info.task_index, + node_id: chain_info.node_id, + task_index: chain_info.task_index as usize, start_time: SystemTime::now(), }) .await .unwrap(); - let task_info = ctx.task_info.clone(); let name = this.name(); let mut counter = CheckpointCounter::new(in_qs.len()); let mut closed: HashSet = HashSet::new(); @@ -241,7 +894,7 @@ async fn operator_run_behavior( for (i, q) in in_qs.iter_mut().enumerate() { let stream = async_stream::stream! { while let Some(item) = q.recv().await { - yield(i,item); + yield(i, item); } }; sel.push(Box::pin(stream)); @@ -252,13 +905,18 @@ async fn operator_run_behavior( let mut ticks = 0u64; let mut interval = tokio::time::interval(this.tick_interval().unwrap_or(Duration::from_secs(60))); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + let mut shutdown_after_commit = false; + loop { let operator_future: OptionFuture<_> = this.future_to_poll().into(); tokio::select! { - Some(control_message) = ctx.control_rx.recv() => { - this.handle_controller_message(control_message, ctx).await; + Some(control_message) = control_rx.recv() => { + if this.handle_controller_message(&control_message, shutdown_after_commit).await { + break; + } } p = sel.next() => { @@ -267,28 +925,32 @@ async fn operator_run_behavior( let local_idx = idx; trace!("[{}] Handling message {}-{}, {:?}", - ctx.task_info.operator_name, 0, local_idx, message); + chain_info.node_id, 0, local_idx, message); match message { ArrowMessage::Data(record) => { - TaskCounters::BatchesReceived.for_task(&ctx.task_info, |c| c.inc()); - TaskCounters::MessagesReceived.for_task(&ctx.task_info, |c| c.inc_by(record.num_rows() as u64)); - TaskCounters::BytesReceived.for_task(&ctx.task_info, |c| c.inc_by(record.get_array_memory_size() as u64)); - this.process_batch_index(idx, in_partitions, record, ctx) + TaskCounters::BatchesReceived.for_task(chain_info, |c| c.inc()); + TaskCounters::MessagesReceived.for_task(chain_info, |c| c.inc_by(record.num_rows() as u64)); + TaskCounters::BytesReceived.for_task(chain_info, |c| c.inc_by(record.get_array_memory_size() as u64)); + this.process_batch_index(idx, in_partitions, record, collector) .instrument(tracing::trace_span!("handle_fn", name, - operator_id = task_info.operator_id, - subtask_idx = task_info.task_index) + node_id = chain_info.node_id, + subtask_idx = chain_info.task_index) ).await; } ArrowMessage::Signal(signal) => { - match this.handle_control_message(idx, &signal, &mut counter, &mut closed, in_partitions, ctx).await { + match this.handle_control_message(idx,&signal, &mut counter, &mut closed, in_partitions, + &control_tx, chain_info, collector).await { ControlOutcome::Continue => {} ControlOutcome::Stop => { // just stop; the stop will have already been broadcast for example by // a final checkpoint break; } + ControlOutcome::StopAfterCommit => { + shutdown_after_commit = true; + } ControlOutcome::Finish => { final_message = Some(SignalMessage::EndOfData); break; @@ -313,21 +975,23 @@ async fn operator_run_behavior( } } None => { - info!("[{}] Stream completed",ctx.task_info.operator_name); - break; + info!("[{}] Stream completed", chain_info); + if !shutdown_after_commit { + break; + } } } } Some(val) = operator_future => { - this.handle_future_result(val, ctx).await; + this.handle_future_result(val.0, val.1, collector).await; } _ = interval.tick() => { - this.handle_tick(ticks, ctx).await; + this.handle_tick(ticks, collector).await; ticks += 1; } } } - this.on_close(&final_message, ctx).await; + this.on_close(&final_message, collector).await; final_message } @@ -400,147 +1064,16 @@ pub struct DisplayableOperator<'a> { #[async_trait::async_trait] pub trait ArrowOperator: Send + 'static { - async fn handle_watermark_int(&mut self, watermark: Watermark, ctx: &mut ArrowContext) { - // process timers - tracing::trace!( - "handling watermark {:?} for {}-{}", - watermark, - ctx.task_info.operator_name, - ctx.task_info.task_index - ); - - if let Watermark::EventTime(_t) = watermark { - // let finished = ProcessFnUtils::finished_timers(t, ctx).await; - // - // for (k, tv) in finished { - // self.handle_timer(k, tv.data, ctx).await; - // } - } - - if let Some(watermark) = self.handle_watermark(watermark, ctx).await { - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark(watermark))) - .await; - } - } - - async fn handle_controller_message( - &mut self, - control_message: ControlMessage, - ctx: &mut ArrowContext, - ) { - match control_message { - ControlMessage::Checkpoint(_) => { - error!("shouldn't receive checkpoint") - } - ControlMessage::Stop { .. } => { - error!("shouldn't receive stop") - } - ControlMessage::Commit { epoch, commit_data } => { - self.handle_commit(epoch, &commit_data, ctx).await; - } - ControlMessage::LoadCompacted { compacted } => { - ctx.load_compacted(compacted).await; - } - ControlMessage::NoOp => {} - } - } - - async fn handle_control_message( - &mut self, - idx: usize, - message: &SignalMessage, - counter: &mut CheckpointCounter, - closed: &mut HashSet, - in_partitions: usize, - ctx: &mut ArrowContext, - ) -> ControlOutcome { - match message { - SignalMessage::Barrier(t) => { - debug!( - "received barrier in {}-{}-{}-{}", - self.name(), - ctx.task_info.operator_id, - ctx.task_info.task_index, - idx - ); - - if counter.all_clear() { - ctx.control_tx - .send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent { - checkpoint_epoch: t.epoch, - operator_id: ctx.task_info.operator_id.clone(), - subtask_index: ctx.task_info.task_index as u32, - time: SystemTime::now(), - event_type: TaskCheckpointEventType::StartedAlignment, - })) - .await - .unwrap(); - } - - if counter.mark(idx, t) { - debug!( - "Checkpointing {}-{}-{}", - self.name(), - ctx.task_info.operator_id, - ctx.task_info.task_index - ); - - ctx.send_checkpoint_event(*t, TaskCheckpointEventType::StartedCheckpointing) - .await; - - self.handle_checkpoint(*t, ctx).await; - - ctx.send_checkpoint_event(*t, TaskCheckpointEventType::FinishedOperatorSetup) - .await; - - if run_checkpoint(*t, ctx).await { - return ControlOutcome::Stop; - } - } - } - SignalMessage::Watermark(watermark) => { - debug!( - "received watermark {:?} in {}-{}", - watermark, - self.name(), - ctx.task_info.task_index - ); - - let watermark = ctx - .watermarks - .set(idx, *watermark) - .expect("watermark index is too big"); - - if let Some(watermark) = watermark { - if let Watermark::EventTime(_t) = watermark { - // TOOD: pass to table_manager - } - - self.handle_watermark_int(watermark, ctx).await; - } - } - SignalMessage::Stop => { - closed.insert(idx); - if closed.len() == in_partitions { - return ControlOutcome::StopAndSendStop; - } - } - SignalMessage::EndOfData => { - closed.insert(idx); - if closed.len() == in_partitions { - return ControlOutcome::Finish; - } - } - } - ControlOutcome::Continue - } - fn name(&self) -> String; fn tables(&self) -> HashMap { HashMap::new() } + fn is_committing(&self) -> bool { + false + } + fn tick_interval(&self) -> Option { None } @@ -553,19 +1086,26 @@ pub trait ArrowOperator: Send + 'static { } #[allow(unused_variables)] - async fn on_start(&mut self, ctx: &mut ArrowContext) {} + async fn on_start(&mut self, ctx: &mut OperatorContext) {} + #[allow(unused_variables)] async fn process_batch_index( &mut self, - _index: usize, - _in_partitions: usize, + index: usize, + in_partitions: usize, batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) { - self.process_batch(batch, ctx).await + self.process_batch(batch, ctx, collector).await } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext); + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ); #[allow(clippy::type_complexity)] fn future_to_poll( @@ -575,37 +1115,63 @@ pub trait ArrowOperator: Send + 'static { } #[allow(unused_variables)] - async fn handle_future_result(&mut self, result: Box, ctx: &mut ArrowContext) {} + async fn handle_future_result( + &mut self, + result: Box, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + } #[allow(unused_variables)] - async fn handle_timer(&mut self, key: Vec, value: Vec, ctx: &mut ArrowContext) {} + async fn handle_timer(&mut self, key: Vec, value: Vec, ctx: &mut OperatorContext) {} + #[allow(unused_variables)] async fn handle_watermark( &mut self, watermark: Watermark, - _ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { Some(watermark) } #[allow(unused_variables)] - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut ArrowContext) {} + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + } #[allow(unused_variables)] async fn handle_commit( &mut self, epoch: u32, commit_data: &HashMap>>, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) { warn!("default handling of commit with epoch {:?}", epoch); } #[allow(unused_variables)] - async fn handle_tick(&mut self, tick: u64, ctx: &mut ArrowContext) {} + async fn handle_tick( + &mut self, + tick: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + } #[allow(unused_variables)] - async fn on_close(&mut self, final_message: &Option, ctx: &mut ArrowContext) {} + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + } } #[derive(Default)] diff --git a/crates/arroyo-planner/src/extension/aggregate.rs b/crates/arroyo-planner/src/extension/aggregate.rs index 8840d8ded..2454a9939 100644 --- a/crates/arroyo-planner/src/extension/aggregate.rs +++ b/crates/arroyo-planner/src/extension/aggregate.rs @@ -105,13 +105,14 @@ impl AggregateExtension { final_projection: Some(final_physical_plan_node.encode_to_vec()), }; - Ok(LogicalNode { - operator_id: format!("tumbling_{}", index), - operator_name: OperatorName::TumblingWindowAggregate, - operator_config: config.encode_to_vec(), - description: format!("TumblingWindow<{}>", config.name), - parallelism: 1, - }) + Ok(LogicalNode::single( + index as u32, + format!("tumbling_{}", index), + OperatorName::TumblingWindowAggregate, + config.encode_to_vec(), + format!("TumblingWindow<{}>", config.name), + 1, + )) } pub fn sliding_window_config( @@ -154,13 +155,15 @@ impl AggregateExtension { final_projection: final_physical_plan_node.encode_to_vec(), // TODO add final aggregation. }; - Ok(LogicalNode { - operator_id: format!("sliding_window_{}", index), - description: "sliding window".to_string(), - operator_name: OperatorName::SlidingWindowAggregate, - operator_config: config.encode_to_vec(), - parallelism: 1, - }) + + Ok(LogicalNode::single( + index as u32, + format!("sliding_window_{}", index), + OperatorName::SlidingWindowAggregate, + config.encode_to_vec(), + "sliding window".to_string(), + 1, + )) } pub fn session_window_config( @@ -216,13 +219,14 @@ impl AggregateExtension { final_aggregation_plan: physical_plan_node.encode_to_vec(), }; - Ok(LogicalNode { - operator_id: config.name.clone(), - description: format!("SessionWindow<{:?}>", gap), - operator_name: OperatorName::SessionWindowAggregate, - operator_config: config.encode_to_vec(), - parallelism: 1, - }) + Ok(LogicalNode::single( + index as u32, + format!("SessionWindow<{:?}>", gap), + OperatorName::SessionWindowAggregate, + config.encode_to_vec(), + config.name.clone(), + 1, + )) } pub fn instant_window_config( @@ -273,13 +277,14 @@ impl AggregateExtension { final_projection, }; - Ok(LogicalNode { - operator_id: format!("instant_window_{}", index), - description: "instant window".to_string(), - operator_name: OperatorName::TumblingWindowAggregate, - operator_config: config.encode_to_vec(), - parallelism: 1, - }) + Ok(LogicalNode::single( + index as u32, + format!("instant_window_{}", index), + OperatorName::TumblingWindowAggregate, + config.encode_to_vec(), + "instant window".to_string(), + 1, + )) } // projection assuming that _timestamp has been populated with the start of the bin. diff --git a/crates/arroyo-planner/src/extension/join.rs b/crates/arroyo-planner/src/extension/join.rs index 3b8582eec..5695bcc26 100644 --- a/crates/arroyo-planner/src/extension/join.rs +++ b/crates/arroyo-planner/src/extension/join.rs @@ -60,13 +60,14 @@ impl ArroyoExtension for JoinExtension { ttl_micros: self.ttl.map(|t| t.as_micros() as u64), }; - let logical_node = LogicalNode { - operator_id: format!("join_{}", index), - description: "join".to_string(), + let logical_node = LogicalNode::single( + index as u32, + format!("join_{}", index), operator_name, - operator_config: config.encode_to_vec(), - parallelism: 1, - }; + config.encode_to_vec(), + "join".to_string(), + 1, + ); let left_edge = LogicalEdge::project_all(LogicalEdgeType::LeftJoin, left_schema.as_ref().clone()); diff --git a/crates/arroyo-planner/src/extension/key_calculation.rs b/crates/arroyo-planner/src/extension/key_calculation.rs index dda3e4dd5..a12b7c34f 100644 --- a/crates/arroyo-planner/src/extension/key_calculation.rs +++ b/crates/arroyo-planner/src/extension/key_calculation.rs @@ -97,13 +97,14 @@ impl ArroyoExtension for KeyCalculationExtension { physical_plan: physical_plan_node.encode_to_vec(), key_fields: self.keys.iter().map(|k: &usize| *k as u64).collect(), }; - let node = LogicalNode { - operator_id: format!("key_{}", index), - operator_name: OperatorName::ArrowKey, - operator_config: config.encode_to_vec(), - description: format!("ArrowKey<{}>", config.name), - parallelism: 1, - }; + let node = LogicalNode::single( + index as u32, + format!("key_{}", index), + OperatorName::ArrowKey, + config.encode_to_vec(), + format!("ArrowKey<{}>", config.name), + 1, + ); let edge = LogicalEdge::project_all(LogicalEdgeType::Forward, (*input_schema).clone()); Ok(NodeWithIncomingEdges { node, diff --git a/crates/arroyo-planner/src/extension/mod.rs b/crates/arroyo-planner/src/extension/mod.rs index cd2a27ca7..7b576b2a0 100644 --- a/crates/arroyo-planner/src/extension/mod.rs +++ b/crates/arroyo-planner/src/extension/mod.rs @@ -250,13 +250,14 @@ impl ArroyoExtension for AsyncUDFExtension { timeout_micros: self.timeout.as_micros() as u64, }; - let node = LogicalNode { - operator_id: format!("async_udf_{}", index), - description: format!("async_udf<{}>", self.name), - operator_name: OperatorName::AsyncUdf, - operator_config: config.encode_to_vec(), - parallelism: 1, - }; + let node = LogicalNode::single( + index as u32, + format!("async_udf_{}", index), + OperatorName::AsyncUdf, + config.encode_to_vec(), + format!("async_udf<{}>", self.name), + 1, + ); let incoming_edge = LogicalEdge::project_all(LogicalEdgeType::Forward, input_schemas[0].as_ref().clone()); diff --git a/crates/arroyo-planner/src/extension/remote_table.rs b/crates/arroyo-planner/src/extension/remote_table.rs index 97d817a76..3fe283ed3 100644 --- a/crates/arroyo-planner/src/extension/remote_table.rs +++ b/crates/arroyo-planner/src/extension/remote_table.rs @@ -73,13 +73,15 @@ impl ArroyoExtension for RemoteTableExtension { name: format!("value_calculation({})", self.name), physical_plan: physical_plan_node.encode_to_vec(), }; - let node = LogicalNode { - operator_id: format!("value_{}", index), - description: self.name.to_string(), - operator_name: OperatorName::ArrowValue, - parallelism: 1, - operator_config: config.encode_to_vec(), - }; + let node = LogicalNode::single( + index as u32, + format!("value_{}", index), + OperatorName::ArrowValue, + config.encode_to_vec(), + self.name.to_string(), + 1, + ); + let edges = input_schemas .into_iter() .map(|schema| LogicalEdge::project_all(LogicalEdgeType::Forward, (*schema).clone())) diff --git a/crates/arroyo-planner/src/extension/sink.rs b/crates/arroyo-planner/src/extension/sink.rs index a7e20047b..0e559d175 100644 --- a/crates/arroyo-planner/src/extension/sink.rs +++ b/crates/arroyo-planner/src/extension/sink.rs @@ -157,13 +157,16 @@ impl ArroyoExtension for SinkExtension { .connector_op() .map_err(|e| e.context("connector op"))?) .encode_to_vec(); - let node = LogicalNode { - operator_id: format!("sink_{}_{}", self.name, index), - description: self.table.connector_op().unwrap().description.clone(), - operator_name: OperatorName::ConnectorSink, - parallelism: 1, + + let node = LogicalNode::single( + index as u32, + format!("sink_{}_{}", self.name, index), + OperatorName::ConnectorSink, operator_config, - }; + self.table.connector_op().unwrap().description.clone(), + 1, + ); + let edges = input_schemas .into_iter() .map(|input_schema| { diff --git a/crates/arroyo-planner/src/extension/table_source.rs b/crates/arroyo-planner/src/extension/table_source.rs index ef8dd62d9..cf4a0334b 100644 --- a/crates/arroyo-planner/src/extension/table_source.rs +++ b/crates/arroyo-planner/src/extension/table_source.rs @@ -101,13 +101,14 @@ impl ArroyoExtension for TableSourceExtension { return plan_err!("TableSourceExtension should not have inputs"); } let sql_source = self.table.as_sql_source()?; - let node = LogicalNode { - operator_id: format!("source_{}_{}", self.name, index), - description: sql_source.source.config.description.clone(), - operator_name: OperatorName::ConnectorSource, - operator_config: sql_source.source.config.encode_to_vec(), - parallelism: 1, - }; + let node = LogicalNode::single( + index as u32, + format!("source_{}_{}", self.name, index), + OperatorName::ConnectorSource, + sql_source.source.config.encode_to_vec(), + sql_source.source.config.description.clone(), + 1, + ); Ok(NodeWithIncomingEdges { node, edges: vec![], diff --git a/crates/arroyo-planner/src/extension/updating_aggregate.rs b/crates/arroyo-planner/src/extension/updating_aggregate.rs index 1d127a1fd..e109fa2e9 100644 --- a/crates/arroyo-planner/src/extension/updating_aggregate.rs +++ b/crates/arroyo-planner/src/extension/updating_aggregate.rs @@ -241,13 +241,14 @@ impl ArroyoExtension for UpdatingAggregateExtension { ttl_micros: self.ttl.as_micros() as u64, }; - let node = LogicalNode { - operator_id: format!("updating_aggregate_{}", index), - description: "UpdatingAggregate".to_string(), - operator_name: OperatorName::UpdatingAggregate, - operator_config: config.encode_to_vec(), - parallelism: 1, - }; + let node = LogicalNode::single( + index as u32, + format!("updating_aggregate_{}", index), + OperatorName::UpdatingAggregate, + config.encode_to_vec(), + "UpdatingAggregate".to_string(), + 1, + ); let edge = LogicalEdge::project_all(LogicalEdgeType::Shuffle, (*input_schema).clone()); diff --git a/crates/arroyo-planner/src/extension/watermark_node.rs b/crates/arroyo-planner/src/extension/watermark_node.rs index aa44acfa8..6d4ea98da 100644 --- a/crates/arroyo-planner/src/extension/watermark_node.rs +++ b/crates/arroyo-planner/src/extension/watermark_node.rs @@ -89,19 +89,21 @@ impl ArroyoExtension for WatermarkNode { ) -> Result { let expression = planner.create_physical_expr(&self.watermark_expression, &self.schema)?; let expression = serialize_physical_expr(&expression, &DefaultPhysicalExtensionCodec {})?; - let node = LogicalNode { - operator_id: format!("watermark_{}", index), - description: "watermark".to_string(), - operator_name: OperatorName::ExpressionWatermark, - parallelism: 1, - operator_config: ExpressionWatermarkConfig { + let node = LogicalNode::single( + index as u32, + format!("watermark_{}", index), + OperatorName::ExpressionWatermark, + ExpressionWatermarkConfig { period_micros: 1_000_000, idle_time_micros: None, expression: expression.encode_to_vec(), input_schema: Some(self.arroyo_schema().into()), } .encode_to_vec(), - }; + "watermark".to_string(), + 1, + ); + let incoming_edge = LogicalEdge::project_all(LogicalEdgeType::Forward, input_schemas[0].as_ref().clone()); Ok(NodeWithIncomingEdges { diff --git a/crates/arroyo-planner/src/extension/window_fn.rs b/crates/arroyo-planner/src/extension/window_fn.rs index 4c8cab17e..ebc6a3fc6 100644 --- a/crates/arroyo-planner/src/extension/window_fn.rs +++ b/crates/arroyo-planner/src/extension/window_fn.rs @@ -96,13 +96,14 @@ impl ArroyoExtension for WindowFunctionExtension { window_function_plan: window_plan_proto.encode_to_vec(), }; - let logical_node = LogicalNode { - operator_id: format!("window_function_{}", index), - description: "window function".to_string(), - operator_name: OperatorName::WindowFunction, - operator_config: config.encode_to_vec(), - parallelism: 1, - }; + let logical_node = LogicalNode::single( + index as u32, + format!("window_function_{}", index), + OperatorName::WindowFunction, + config.encode_to_vec(), + "window function".to_string(), + 1, + ); let edge = arroyo_datastream::logical::LogicalEdge::project_all( // TODO: detect when this shuffle is unnecessary diff --git a/crates/arroyo-planner/src/lib.rs b/crates/arroyo-planner/src/lib.rs index a1fa3379d..980cd2880 100644 --- a/crates/arroyo-planner/src/lib.rs +++ b/crates/arroyo-planner/src/lib.rs @@ -62,6 +62,7 @@ use crate::rewriters::{SourceMetadataVisitor, TimeWindowUdfChecker, UnnestRewrit use crate::udafs::EmptyUdaf; use arrow::compute::kernels::cast_utils::parse_interval_day_time; use arroyo_datastream::logical::LogicalProgram; +use arroyo_datastream::optimizers::ChainingOptimizer; use arroyo_operator::connector::Connection; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::TIMESTAMP_FIELD; @@ -851,7 +852,7 @@ pub async fn parse_and_get_arrow_program( } let graph = plan_to_graph_visitor.into_graph(); - let program = LogicalProgram::new( + let mut program = LogicalProgram::new( graph, ProgramConfig { udf_dylibs: schema_provider.dylib_udfs.clone(), @@ -859,6 +860,10 @@ pub async fn parse_and_get_arrow_program( }, ); + if arroyo_rpc::config::config().pipeline.chaining.enabled { + program.optimize(&ChainingOptimizer {}); + } + Ok(CompiledSql { program, connection_ids: used_connections.into_iter().collect(), diff --git a/crates/arroyo-rpc/default.toml b/crates/arroyo-rpc/default.toml index 13c81b537..f13ea7d79 100644 --- a/crates/arroyo-rpc/default.toml +++ b/crates/arroyo-rpc/default.toml @@ -10,6 +10,7 @@ worker-heartbeat-timeout = "30s" healthy-duration = "2m" worker-startup-time = "10m" task-startup-time = "2m" +chaining.enabled = false [pipeline.compaction] enabled = false diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index a20a2c4cd..24c120c11 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -243,7 +243,6 @@ message OperatorCheckpointDetail { } - message DylibUdfConfig { string dylib_path = 1; repeated bytes arg_types = 2; @@ -279,13 +278,19 @@ message ArroyoSchema { bool has_keys = 4; } +message ChainedOperator { + string operator_id = 1; + string operator_name = 2; + bytes operator_config = 3; +} + message ArrowNode { int32 node_index = 1; - string node_id = 2; + uint32 node_id = 2; uint32 parallelism = 3; string description = 4; - string operator_name = 5; - bytes operator_config = 6; + repeated ChainedOperator operators = 5; + repeated ArroyoSchema edges = 6; } message ArrowEdge { @@ -293,5 +298,4 @@ message ArrowEdge { int32 target = 2; ArroyoSchema schema = 4; EdgeType edge_type = 5; - repeated uint32 projection = 6; } diff --git a/crates/arroyo-rpc/proto/rpc.proto b/crates/arroyo-rpc/proto/rpc.proto index 2f8fdcb3d..a921d43f0 100644 --- a/crates/arroyo-rpc/proto/rpc.proto +++ b/crates/arroyo-rpc/proto/rpc.proto @@ -49,6 +49,7 @@ message TaskCheckpointEventReq { uint64 worker_id = 1; uint64 time = 2; string job_id = 3; + uint32 node_id = 8; string operator_id = 4; uint32 subtask_index = 5; uint32 epoch = 6; @@ -62,6 +63,7 @@ message TaskCheckpointCompletedReq { uint64 worker_id = 1; uint64 time = 2; string job_id = 3; + uint32 node_id = 8; string operator_id = 4; uint32 epoch = 5; SubtaskCheckpointMetadata metadata = 6; @@ -75,7 +77,7 @@ message TaskFinishedReq { uint64 worker_id = 1; uint64 time = 2; string job_id = 3; - string operator_id = 4; + uint32 node_id = 4; uint64 operator_subtask = 5; } @@ -86,7 +88,7 @@ message TaskFailedReq { uint64 worker_id = 1; uint64 time = 2; string job_id = 3; - string operator_id = 4; + uint32 node_id = 4; uint64 operator_subtask = 5; string error = 6; } @@ -99,7 +101,7 @@ message TaskStartedReq { uint64 worker_id = 1; uint64 time = 2; string job_id = 3; - string operator_id = 4; + uint32 node_id = 4; uint64 operator_subtask = 5; } @@ -161,6 +163,7 @@ message OutputData { message WorkerErrorReq { string job_id = 1; + uint32 node_id = 6; string operator_id = 2; uint32 task_index = 3; string message = 4; @@ -319,8 +322,8 @@ message ArroyoSchema { // Worker message TaskAssignment { - string operator_id = 1; - uint64 operator_subtask = 2; + uint32 node_id = 1; + uint32 subtask_idx = 2; uint64 worker_id = 4; string worker_addr = 5; } @@ -365,6 +368,7 @@ message TableCommitData { } message LoadCompactedDataReq { + uint32 node_id = 3; string operator_id = 1; map compacted_metadata = 2; } diff --git a/crates/arroyo-rpc/src/api_types/metrics.rs b/crates/arroyo-rpc/src/api_types/metrics.rs index cef688d8c..9bbe21fb7 100644 --- a/crates/arroyo-rpc/src/api_types/metrics.rs +++ b/crates/arroyo-rpc/src/api_types/metrics.rs @@ -41,6 +41,6 @@ pub struct MetricGroup { #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] #[serde(rename_all = "camelCase")] pub struct OperatorMetricGroup { - pub operator_id: String, + pub node_id: u32, pub metric_groups: Vec, } diff --git a/crates/arroyo-rpc/src/api_types/pipelines.rs b/crates/arroyo-rpc/src/api_types/pipelines.rs index daf820413..5a43ec4ce 100644 --- a/crates/arroyo-rpc/src/api_types/pipelines.rs +++ b/crates/arroyo-rpc/src/api_types/pipelines.rs @@ -77,7 +77,7 @@ pub struct PipelineGraph { #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] #[serde(rename_all = "camelCase")] pub struct PipelineNode { - pub node_id: String, + pub node_id: u32, pub operator: String, pub description: String, pub parallelism: u32, @@ -86,8 +86,8 @@ pub struct PipelineNode { #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] #[serde(rename_all = "camelCase")] pub struct PipelineEdge { - pub src_id: String, - pub dest_id: String, + pub src_id: u32, + pub dest_id: u32, pub key_type: String, pub value_type: String, pub edge_type: String, diff --git a/crates/arroyo-rpc/src/config.rs b/crates/arroyo-rpc/src/config.rs index 9dc0442b2..0a3d43c2a 100644 --- a/crates/arroyo-rpc/src/config.rs +++ b/crates/arroyo-rpc/src/config.rs @@ -427,9 +427,18 @@ pub struct PipelineConfig { #[serde(default)] pub default_sink: DefaultSink, + pub chaining: ChainingConfig, + pub compaction: CompactionConfig, } +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +pub struct ChainingConfig { + /// Whether to enable operator chaining + pub enabled: bool, +} + #[derive(Debug, Deserialize, Serialize, Eq, PartialEq, Clone)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub enum DatabaseType { diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index da27c32d9..5bdb01bf2 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -91,6 +91,7 @@ impl From for CompactionResult { #[derive(Debug, Clone)] pub struct CheckpointCompleted { pub checkpoint_epoch: u32, + pub node_id: u32, pub operator_id: String, pub subtask_metadata: SubtaskCheckpointMetadata, } @@ -98,6 +99,7 @@ pub struct CheckpointCompleted { #[derive(Debug, Clone)] pub struct CheckpointEvent { pub checkpoint_epoch: u32, + pub node_id: u32, pub operator_id: String, pub subtask_index: u32, pub time: SystemTime, @@ -109,20 +111,21 @@ pub enum ControlResp { CheckpointEvent(CheckpointEvent), CheckpointCompleted(CheckpointCompleted), TaskStarted { - operator_id: String, + node_id: u32, task_index: usize, start_time: SystemTime, }, TaskFinished { - operator_id: String, + node_id: u32, task_index: usize, }, TaskFailed { - operator_id: String, + node_id: u32, task_index: usize, error: String, }, Error { + node_id: u32, operator_id: String, task_index: usize, message: String, diff --git a/crates/arroyo-sql-testing/src/smoke_tests.rs b/crates/arroyo-sql-testing/src/smoke_tests.rs index b1db2f309..82fdc2f5e 100644 --- a/crates/arroyo-sql-testing/src/smoke_tests.rs +++ b/crates/arroyo-sql-testing/src/smoke_tests.rs @@ -12,7 +12,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; use std::{env, time::SystemTime}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::{channel, Receiver}; use crate::udfs::get_udfs; use arroyo_rpc::config; @@ -21,7 +21,7 @@ use arroyo_rpc::{CompactionResult, ControlMessage, ControlResp}; use arroyo_state::checkpoint_state::CheckpointState; use arroyo_types::{to_micros, CheckpointBarrier}; use arroyo_udf_host::LocalUdf; -use arroyo_worker::engine::{Engine, StreamConfig}; +use arroyo_worker::engine::Engine; use arroyo_worker::engine::{Program, RunningEngine}; use petgraph::{Direction, Graph}; use serde_json::Value; @@ -151,6 +151,7 @@ async fn checkpoint(ctx: &mut SmokeTestContext<'_>, epoch: u32) { worker_id: 1, time: to_micros(c.time), job_id: (*ctx.job_id).clone(), + node_id: c.node_id, operator_id: c.operator_id, subtask_index: c.subtask_index, epoch: c.checkpoint_epoch, @@ -163,6 +164,7 @@ async fn checkpoint(ctx: &mut SmokeTestContext<'_>, epoch: u32) { worker_id: 1, time: c.subtask_metadata.finish_time, job_id: (*ctx.job_id).clone(), + node_id: c.node_id, operator_id: c.operator_id, epoch: c.checkpoint_epoch, needs_commit: false, @@ -185,12 +187,14 @@ async fn compact( tasks_per_operator: HashMap, epoch: u32, ) { + let operator_to_node = running_engine.operator_to_node(); let operator_controls = running_engine.operator_controls(); for (operator, _) in tasks_per_operator { if let Ok(compacted) = - ParquetBackend::compact_operator(job_id.clone(), operator.clone(), epoch).await + ParquetBackend::compact_operator(job_id.clone(), &operator, epoch).await { - let operator_controls = operator_controls.get(&operator).unwrap(); + let node_id = operator_to_node.get(&operator).unwrap(); + let operator_controls = operator_controls.get(node_id).unwrap(); for s in operator_controls { s.send(ControlMessage::LoadCompacted { compacted: CompactionResult { @@ -229,31 +233,44 @@ fn set_internal_parallelism(graph: &mut Graph, paralle let watermark_nodes: HashSet<_> = graph .node_indices() .filter(|index| { - let operator_name = graph.node_weight(*index).unwrap().operator_name; - matches!(operator_name, OperatorName::ExpressionWatermark) + graph + .node_weight(*index) + .unwrap() + .operator_chain + .iter() + .any(|(c, _)| c.operator_name == OperatorName::ExpressionWatermark) }) .collect(); + let indices: Vec<_> = graph .node_indices() - .filter( - |index| match graph.node_weight(*index).unwrap().operator_name { - OperatorName::ExpressionWatermark - | OperatorName::ConnectorSource - | OperatorName::ConnectorSink => false, - _ => { - for watermark_node in watermark_nodes.iter() { - if has_path_connecting(&graph.clone(), *watermark_node, *index, None) { - return true; + .filter(|index| { + !watermark_nodes.contains(index) + && graph + .node_weight(*index) + .unwrap() + .operator_chain + .iter() + .any(|(c, _)| match c.operator_name { + OperatorName::ExpressionWatermark + | OperatorName::ConnectorSource + | OperatorName::ConnectorSink => false, + _ => { + for watermark_node in watermark_nodes.iter() { + if has_path_connecting(&*graph, *watermark_node, *index, None) { + return true; + } + } + false } - } - false - } - }, - ) + }) + }) .collect(); + for node in indices { graph.node_weight_mut(node).unwrap().parallelism = parallelism; } + if parallelism > 1 { let mut edges_to_make_shuffle = vec![]; for node in graph.externals(Direction::Outgoing) { @@ -262,7 +279,13 @@ fn set_internal_parallelism(graph: &mut Graph, paralle } } for node in graph.node_indices() { - if graph.node_weight(node).unwrap().operator_name == OperatorName::ExpressionWatermark { + if graph + .node_weight(node) + .unwrap() + .operator_chain + .iter() + .any(|(c, _)| c.operator_name == OperatorName::ExpressionWatermark) + { for edge in graph.edges_directed(node, Direction::Outgoing) { edges_to_make_shuffle.push(edge.id()); } @@ -274,14 +297,15 @@ fn set_internal_parallelism(graph: &mut Graph, paralle } } -async fn run_and_checkpoint(job_id: Arc, program: Program, checkpoint_interval: i32) { - let tasks_per_operator = program.tasks_per_operator(); +async fn run_and_checkpoint( + job_id: Arc, + program: Program, + tasks_per_operator: HashMap, + control_rx: &mut Receiver, + checkpoint_interval: i32, +) { let engine = Engine::for_local(program, job_id.to_string()); - let (running_engine, mut control_rx) = engine - .start(StreamConfig { - restore_epoch: None, - }) - .await; + let running_engine = engine.start().await; info!("Smoke test checkpointing enabled"); env::set_var( "ARROYO__CONTROLLER__COMPACTION__CHECKPOINTS_TO_COMPACT", @@ -291,7 +315,7 @@ async fn run_and_checkpoint(job_id: Arc, program: Program, checkpoint_in let ctx = &mut SmokeTestContext { job_id: job_id.clone(), engine: &running_engine, - control_rx: &mut control_rx, + control_rx, tasks_per_operator: tasks_per_operator.clone(), }; @@ -316,21 +340,33 @@ async fn run_and_checkpoint(job_id: Arc, program: Program, checkpoint_in .await .unwrap(); } - run_until_finished(&running_engine, &mut control_rx).await; + run_until_finished(&running_engine, control_rx).await; } -async fn finish_from_checkpoint(job_id: &str, program: Program) { +async fn finish_from_checkpoint( + job_id: &str, + program: Program, + control_rx: &mut Receiver, +) { let engine = Engine::for_local(program, job_id.to_string()); - let (running_engine, mut control_rx) = engine - .start(StreamConfig { - restore_epoch: Some(3), - }) - .await; + let running_engine = engine.start().await; info!("Restored engine, running until finished"); - run_until_finished(&running_engine, &mut control_rx).await; + run_until_finished(&running_engine, control_rx).await; +} + +fn tasks_per_operator(graph: &LogicalGraph) -> HashMap { + graph + .node_weights() + .flat_map(|node| { + node.operator_chain + .iter() + .map(|(op, _)| (op.operator_id.clone(), node.parallelism)) + }) + .collect() } +#[allow(clippy::too_many_arguments)] async fn run_pipeline_and_assert_outputs( job_id: &str, mut graph: LogicalGraph, @@ -345,15 +381,16 @@ async fn run_pipeline_and_assert_outputs( std::fs::remove_file(&output_location).unwrap(); } - let get_program = - |graph: &LogicalGraph| Program::local_from_logical(job_id.to_string(), graph, udfs); + println!("Running completely"); + let (control_tx, mut control_rx) = channel(128); run_completely( job_id, - get_program(&graph), + Program::local_from_logical(job_id.to_string(), &graph, udfs, None, control_tx).await, output_location.clone(), golden_output_location.clone(), primary_keys, + &mut control_rx, ) .await; @@ -363,9 +400,14 @@ async fn run_pipeline_and_assert_outputs( set_internal_parallelism(&mut graph, 2); } + let (control_tx, mut control_rx) = channel(128); + + println!("Run and checkpoint"); run_and_checkpoint( Arc::new(job_id.to_string()), - get_program(&graph), + Program::local_from_logical(job_id.to_string(), &graph, udfs, None, control_tx).await, + tasks_per_operator(&graph), + &mut control_rx, checkpoint_interval, ) .await; @@ -374,7 +416,15 @@ async fn run_pipeline_and_assert_outputs( set_internal_parallelism(&mut graph, 3); } - finish_from_checkpoint(job_id, get_program(&graph)).await; + let (control_tx, mut control_rx) = channel(128); + + println!("Finish from checkpoint"); + finish_from_checkpoint( + job_id, + Program::local_from_logical(job_id.to_string(), &graph, udfs, Some(3), control_tx).await, + &mut control_rx, + ) + .await; check_output_files( "resuming from checkpointing", @@ -391,15 +441,12 @@ async fn run_completely( output_location: String, golden_output_location: String, primary_keys: Option<&[&str]>, + control_rx: &mut Receiver, ) { let engine = Engine::for_local(program, job_id.to_string()); - let (running_engine, mut control_rx) = engine - .start(StreamConfig { - restore_epoch: None, - }) - .await; + let running_engine = engine.start().await; - run_until_finished(&running_engine, &mut control_rx).await; + run_until_finished(&running_engine, control_rx).await; check_output_files( "initial run", diff --git a/crates/arroyo-state/src/committing_state.rs b/crates/arroyo-state/src/committing_state.rs index ef09b80f0..12e5138c1 100644 --- a/crates/arroyo-state/src/committing_state.rs +++ b/crates/arroyo-state/src/committing_state.rs @@ -40,12 +40,13 @@ impl CommittingState { .iter() .map(|(operator_id, _subtask_id)| operator_id.clone()) .collect(); + operators_to_commit .into_iter() - .map(|operator_id| { + .map(|node_id| { let committing_data = self .committing_data - .get(&operator_id) + .get(&node_id) .map(|table_map| { table_map .iter() @@ -60,7 +61,7 @@ impl CommittingState { .collect() }) .unwrap_or_default(); - (operator_id, OperatorCommitData { committing_data }) + (node_id, OperatorCommitData { committing_data }) }) .collect() } diff --git a/crates/arroyo-state/src/metrics.rs b/crates/arroyo-state/src/metrics.rs index 0816d9e6e..793dca9eb 100644 --- a/crates/arroyo-state/src/metrics.rs +++ b/crates/arroyo-state/src/metrics.rs @@ -2,15 +2,14 @@ use lazy_static::lazy_static; use prometheus::{register_gauge_vec, GaugeVec}; lazy_static! { - pub static ref WORKER_LABELS_NAMES: Vec<&'static str> = vec!["operator_id", "task_id"]; + pub static ref WORKER_LABELS_NAMES: Vec<&'static str> = vec!["node_id", "task_id"]; pub static ref CURRENT_FILES_GAUGE: GaugeVec = register_gauge_vec!( "arroyo_worker_current_files", "Number of parquet files in the checkpoint", &WORKER_LABELS_NAMES ) .unwrap(); - pub static ref TABLE_LABELS_NAMES: Vec<&'static str> = - vec!["operator_id", "task_id", "table_char"]; + pub static ref TABLE_LABELS_NAMES: Vec<&'static str> = vec!["node_id", "task_id", "table_char"]; pub static ref TABLE_SIZE_GAUGE: GaugeVec = register_gauge_vec!( "arroyo_worker_table_size_keys", "Number of keys in the table", diff --git a/crates/arroyo-state/src/parquet.rs b/crates/arroyo-state/src/parquet.rs index ca772dd01..d829bb252 100644 --- a/crates/arroyo-state/src/parquet.rs +++ b/crates/arroyo-state/src/parquet.rs @@ -158,13 +158,13 @@ impl ParquetBackend { /// Called after a checkpoint is committed pub async fn compact_operator( job_id: Arc, - operator_id: String, + operator_id: &str, epoch: u32, ) -> Result> { let min_files_to_compact = config().pipeline.compaction.checkpoints_to_compact as usize; let operator_checkpoint_metadata = - Self::load_operator_metadata(&job_id, &operator_id, epoch) + Self::load_operator_metadata(&job_id, operator_id, epoch) .await? .expect("expect operator metadata to still be present"); let storage_provider = get_storage_provider().await?; diff --git a/crates/arroyo-state/src/tables/expiring_time_key_map.rs b/crates/arroyo-state/src/tables/expiring_time_key_map.rs index 70476de66..f9f8fd422 100644 --- a/crates/arroyo-state/src/tables/expiring_time_key_map.rs +++ b/crates/arroyo-state/src/tables/expiring_time_key_map.rs @@ -24,7 +24,7 @@ use arroyo_rpc::{ }; use arroyo_storage::StorageProviderRef; use arroyo_types::{ - from_micros, from_nanos, print_time, server_for_hash, to_micros, to_nanos, TaskInfoRef, + from_micros, from_nanos, print_time, server_for_hash, to_micros, to_nanos, TaskInfo, }; use datafusion::parquet::arrow::async_reader::ParquetObjectReader; @@ -50,7 +50,7 @@ use super::{table_checkpoint_path, CompactionConfig, Table, TableEpochCheckpoint #[derive(Debug, Clone)] pub struct ExpiringTimeKeyTable { table_name: String, - task_info: TaskInfoRef, + task_info: Arc, schema: SchemaWithHashAndOperation, retention: Duration, storage_provider: StorageProviderRef, @@ -258,7 +258,7 @@ impl Table for ExpiringTimeKeyTable { fn from_config( config: Self::ConfigMessage, - task_info: arroyo_types::TaskInfoRef, + task_info: Arc, storage_provider: arroyo_storage::StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result { @@ -343,7 +343,7 @@ impl Table for ExpiringTimeKeyTable { table_metadata: Self::TableCheckpointMessage, ) -> anyhow::Result> { Ok(Some(ExpiringKeyedTimeSubtaskCheckpointMetadata { - subtask_index: self.task_info.task_index as u32, + subtask_index: self.task_info.task_index, watermark: None, files: table_metadata.files, })) @@ -353,7 +353,7 @@ impl Table for ExpiringTimeKeyTable { TableEnum::ExpiringKeyedTimeTable } - fn task_info(&self) -> TaskInfoRef { + fn task_info(&self) -> Arc { self.task_info.clone() } @@ -641,7 +641,7 @@ impl ExpiringTimeKeyTableCheckpointer { &parent.task_info.job_id, &parent.task_info.operator_id, &parent.table_name, - parent.task_info.task_index, + parent.task_info.task_index as usize, epoch, false, ); @@ -736,7 +736,7 @@ impl TableEpochCheckpointer for ExpiringTimeKeyTableCheckpointer { } else { Ok(Some(( ExpiringKeyedTimeSubtaskCheckpointMetadata { - subtask_index: self.parent.task_info.task_index as u32, + subtask_index: self.parent.task_info.task_index, watermark: checkpoint.watermark.map(to_micros), files, }, @@ -750,7 +750,7 @@ impl TableEpochCheckpointer for ExpiringTimeKeyTableCheckpointer { } fn subtask_index(&self) -> u32 { - self.parent.task_info.task_index as u32 + self.parent.task_info.task_index } } diff --git a/crates/arroyo-state/src/tables/global_keyed_map.rs b/crates/arroyo-state/src/tables/global_keyed_map.rs index 275fd841a..2f359b0d9 100644 --- a/crates/arroyo-state/src/tables/global_keyed_map.rs +++ b/crates/arroyo-state/src/tables/global_keyed_map.rs @@ -7,7 +7,7 @@ use arroyo_rpc::grpc::rpc::{ OperatorMetadata, TableEnum, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::{to_micros, Data, Key, TaskInfoRef}; +use arroyo_types::{to_micros, Data, Key, TaskInfo}; use bincode::config; use once_cell::sync::Lazy; @@ -41,7 +41,7 @@ static GLOBAL_KEY_VALUE_SCHEMA: Lazy> = Lazy::new(|| { #[derive(Debug, Clone)] pub struct GlobalKeyedTable { table_name: String, - pub task_info: TaskInfoRef, + pub task_info: Arc, storage_provider: StorageProviderRef, pub files: Vec, } @@ -125,7 +125,7 @@ impl Table for GlobalKeyedTable { fn from_config( config: Self::ConfigMessage, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result { @@ -184,7 +184,7 @@ impl Table for GlobalKeyedTable { TableEnum::GlobalKeyValue } - fn task_info(&self) -> TaskInfoRef { + fn task_info(&self) -> Arc { self.task_info.clone() } @@ -194,12 +194,13 @@ impl Table for GlobalKeyedTable { ) -> Result> { Ok(checkpoint.files.into_iter().collect()) } + fn committing_data( config: Self::ConfigMessage, table_metadata: Self::TableCheckpointMessage, ) -> Option>> { if config.uses_two_phase_commit { - Some(table_metadata.commit_data_by_subtask.clone()) + Some(table_metadata.commit_data_by_subtask) } else { None } @@ -227,7 +228,7 @@ impl Table for GlobalKeyedTable { pub struct GlobalKeyedCheckpointer { table_name: String, epoch: u32, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, latest_values: BTreeMap, Vec>, commit_data: Option>, @@ -288,7 +289,7 @@ impl TableEpochCheckpointer for GlobalKeyedCheckpointer { &self.task_info.job_id, &self.task_info.operator_id, &self.table_name, - self.task_info.task_index, + self.task_info.task_index as usize, self.epoch, false, ); @@ -298,7 +299,7 @@ impl TableEpochCheckpointer for GlobalKeyedCheckpointer { let _finish_time = to_micros(SystemTime::now()); Ok(Some(( GlobalKeyedTableSubtaskCheckpointMetadata { - subtask_index: self.task_info.task_index as u32, + subtask_index: self.task_info.task_index, commit_data: self.commit_data, file: Some(path), }, @@ -311,7 +312,7 @@ impl TableEpochCheckpointer for GlobalKeyedCheckpointer { } fn subtask_index(&self) -> u32 { - self.task_info.task_index as u32 + self.task_info.task_index } } diff --git a/crates/arroyo-state/src/tables/mod.rs b/crates/arroyo-state/src/tables/mod.rs index 6777b6f89..68185be61 100644 --- a/crates/arroyo-state/src/tables/mod.rs +++ b/crates/arroyo-state/src/tables/mod.rs @@ -5,10 +5,11 @@ use arroyo_rpc::grpc::rpc::{ TableSubtaskCheckpointMetadata, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::TaskInfoRef; +use arroyo_types::TaskInfo; use prost::Message; use std::any::Any; use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use std::time::SystemTime; use tracing::debug; @@ -79,7 +80,7 @@ pub(crate) trait Table: Send + Sync + 'static + Clone { // * checkpoint_message: If restoring from a checkpoint, the checkpoint data for that checkpoint's epoch. fn from_config( config: Self::ConfigMessage, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result @@ -117,7 +118,7 @@ pub(crate) trait Table: Send + Sync + 'static + Clone { fn table_type() -> TableEnum; - fn task_info(&self) -> TaskInfoRef; + fn task_info(&self) -> Arc; fn files_to_keep( config: Self::ConfigMessage, @@ -155,7 +156,7 @@ pub trait ErasedTable: Send + Sync + 'static { // * checkpoint_message: If restoring from a checkpoint, the checkpoint data for that checkpoint's epoch. fn from_config( config: TableConfig, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result @@ -241,7 +242,7 @@ pub trait ErasedTable: Send + Sync + 'static { impl ErasedTable for T { fn from_config( config: TableConfig, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result @@ -302,7 +303,7 @@ impl ErasedTable for T { let subtask_metadata = self.subtask_metadata_from_table(table_metadata)?; Ok( subtask_metadata.map(|metadata| TableSubtaskCheckpointMetadata { - subtask_index: self.task_info().task_index as u32, + subtask_index: self.task_info().task_index, table_type: T::table_type().into(), data: metadata.encode_to_vec(), }), @@ -324,7 +325,7 @@ impl ErasedTable for T { let result = self.apply_compacted_checkpoint(epoch, compacted_checkpoint, subtask_metadata)?; Ok(TableSubtaskCheckpointMetadata { - subtask_index: self.task_info().task_index as u32, + subtask_index: self.task_info().task_index, table_type: T::table_type().into(), data: result.encode_to_vec(), }) diff --git a/crates/arroyo-state/src/tables/table_manager.rs b/crates/arroyo-state/src/tables/table_manager.rs index 58614ef5b..41eddaaf4 100644 --- a/crates/arroyo-state/src/tables/table_manager.rs +++ b/crates/arroyo-state/src/tables/table_manager.rs @@ -6,22 +6,24 @@ use anyhow::{anyhow, bail, Result}; use arroyo_rpc::CompactionResult; use arroyo_rpc::{ grpc::rpc::{ - OperatorCheckpointMetadata, SubtaskCheckpointMetadata, TableConfig, TableEnum, - TableSubtaskCheckpointMetadata, + SubtaskCheckpointMetadata, TableConfig, TableEnum, TableSubtaskCheckpointMetadata, }, CheckpointCompleted, ControlResp, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::{to_micros, CheckpointBarrier, Data, Key, TaskInfoRef}; +use arroyo_types::{from_micros, to_micros, CheckpointBarrier, Data, Key, TaskInfo}; use tokio::sync::{ mpsc::{self, Receiver, Sender}, oneshot, }; -use tracing::{debug, error, info, warn}; - -use crate::{get_storage_provider, tables::global_keyed_map::GlobalKeyedTable, StateMessage}; +use crate::{ + get_storage_provider, tables::global_keyed_map::GlobalKeyedTable, BackingStore, StateBackend, + StateMessage, +}; use crate::{CheckpointMessage, TableData}; +use arroyo_rpc::grpc::rpc::CheckpointMetadata; +use tracing::{debug, error, info, warn}; use super::expiring_time_key_map::{ ExpiringTimeKeyTable, ExpiringTimeKeyView, KeyTimeView, LastKeyValueView, @@ -34,9 +36,9 @@ pub struct TableManager { epoch: u32, min_epoch: u32, // ordered by table, then epoch. - tables: HashMap>>, + tables: HashMap>, writer: BackendWriter, - task_info: TaskInfoRef, + task_info: Arc, storage: StorageProviderRef, caches: HashMap>, } @@ -53,8 +55,8 @@ pub struct BackendFlusher { storage: StorageProviderRef, control_tx: Sender, finish_tx: Option>, - task_info: TaskInfoRef, - tables: HashMap>>, + task_info: Arc, + tables: HashMap>, table_configs: HashMap, table_checkpointers: HashMap>, current_epoch: u32, @@ -75,8 +77,8 @@ impl BackendFlusher { error!("Failed to flush state file: {:?}", err); self.control_tx .send(ControlResp::TaskFailed { - operator_id: self.task_info.operator_id.clone(), - task_index: self.task_info.task_index, + node_id: self.task_info.node_id, + task_index: self.task_info.task_index as usize, error: err.to_string(), }) .await @@ -163,7 +165,7 @@ impl BackendFlusher { // send controller the subtask metadata let subtask_metadata = SubtaskCheckpointMetadata { - subtask_index: self.task_info.task_index as u32, + subtask_index: self.task_info.task_index, start_time: to_micros(cp.time), finish_time: to_micros(SystemTime::now()), watermark: cp.watermark.map(to_micros), @@ -174,6 +176,7 @@ impl BackendFlusher { self.control_tx .send(ControlResp::CheckpointCompleted(CheckpointCompleted { checkpoint_epoch: cp.epoch, + node_id: self.task_info.node_id, operator_id: self.task_info.operator_id.clone(), subtask_metadata, })) @@ -192,10 +195,10 @@ impl BackendFlusher { impl BackendWriter { fn new( - task_info: TaskInfoRef, + task_info: Arc, control_tx: Sender, table_configs: HashMap, - tables: HashMap>>, + tables: HashMap>, storage: StorageProviderRef, current_epoch: u32, last_epoch_checkpoints: HashMap, @@ -225,12 +228,38 @@ impl BackendWriter { } impl TableManager { - pub async fn new( - task_info: TaskInfoRef, + pub async fn load( + task_info: Arc, table_configs: HashMap, tx: Sender, - checkpoint_metadata: Option, - ) -> Result { + restore_from: Option<&CheckpointMetadata>, + ) -> Result<(Self, Option)> { + let (watermark, checkpoint_metadata) = if let Some(metadata) = restore_from { + let (watermark, operator_metadata) = { + let metadata = StateBackend::load_operator_metadata( + &task_info.job_id, + &task_info.operator_id, + metadata.epoch, + ) + .await + .expect("lookup should succeed") + .expect("require metadata"); + ( + metadata + .operator_metadata + .as_ref() + .unwrap() + .min_watermark + .map(from_micros), + metadata, + ) + }; + + (watermark, Some(operator_metadata)) + } else { + (None, None) + }; + let storage = get_storage_provider().await?; let tables = table_configs @@ -242,23 +271,23 @@ impl TableManager { let erased_table = match table_config.table_type() { TableEnum::MissingTableType => bail!("should have table type"), TableEnum::GlobalKeyValue => { - Box::new(::from_config( + Arc::new(::from_config( table_config.clone(), task_info.clone(), storage.clone(), table_restore_from, - )?) as Box + )?) as Arc } TableEnum::ExpiringKeyedTimeTable => { - Box::new(::from_config( + Arc::new(::from_config( table_config.clone(), task_info.clone(), storage.clone(), table_restore_from, - )?) as Box + )?) as Arc } }; - Ok((table_name.to_string(), Arc::new(erased_table))) + Ok((table_name.to_string(), erased_table)) }) .collect::>>()?; @@ -299,15 +328,18 @@ impl TableManager { epoch, last_epoch_checkpoints, ); - Ok(Self { - epoch, - min_epoch, - tables, - writer, - task_info, - storage: Arc::clone(storage), - caches: HashMap::new(), - }) + Ok(( + Self { + epoch, + min_epoch, + tables, + writer, + task_info, + storage: Arc::clone(storage), + caches: HashMap::new(), + }, + watermark, + )) } pub async fn checkpoint(&mut self, barrier: CheckpointBarrier, watermark: Option) { @@ -330,13 +362,13 @@ impl TableManager { } } - pub async fn load_compacted(&mut self, compacted: CompactionResult) -> Result<()> { + pub async fn load_compacted(&mut self, compacted: &CompactionResult) -> Result<()> { if compacted.operator_id != self.task_info.operator_id { bail!("shouldn't be loading compaction for other operator"); } self.writer .sender - .send(StateMessage::Compaction(compacted.compacted_tables)) + .send(StateMessage::Compaction(compacted.compacted_tables.clone())) .await?; Ok(()) } diff --git a/crates/arroyo-types/src/lib.rs b/crates/arroyo-types/src/lib.rs index 95b37bb36..bcc5df76d 100644 --- a/crates/arroyo-types/src/lib.rs +++ b/crates/arroyo-types/src/lib.rs @@ -8,7 +8,6 @@ use std::convert::TryFrom; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use std::ops::{Range, RangeInclusive}; -use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[derive(Copy, Hash, Debug, Clone, Eq, PartialEq, Encode, Decode, PartialOrd, Ord, Deserialize)] @@ -372,23 +371,60 @@ pub trait RecordBatchBuilder: Default + Debug + Sync + Send + 'static { fn schema(&self) -> SchemaRef; } -/// A reference-counted reference to a [TaskInfo]. -pub type TaskInfoRef = Arc; - #[derive(Eq, PartialEq, Hash, Debug, Clone, Encode, Decode)] pub struct TaskInfo { pub job_id: String, + pub node_id: u32, pub operator_name: String, pub operator_id: String, - pub task_index: usize, - pub parallelism: usize, + pub task_index: u32, + pub parallelism: u32, pub key_range: RangeInclusive, } +impl Display for TaskInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Task_{}-{}/{}", + self.operator_id, self.task_index, self.parallelism + ) + } +} + +#[derive(Eq, PartialEq, Hash, Debug, Clone, Encode, Decode)] +pub struct ChainInfo { + pub job_id: String, + pub node_id: u32, + pub description: String, + pub task_index: u32, +} + +impl Display for ChainInfo { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "TaskChain{}-{} ({})", + self.node_id, self.task_index, self.description + ) + } +} + +impl ChainInfo { + pub fn metric_label_map(&self) -> HashMap { + let mut labels = HashMap::new(); + labels.insert("node_id".to_string(), self.node_id.to_string()); + labels.insert("subtask_idx".to_string(), self.task_index.to_string()); + labels.insert("node_description".to_string(), self.description.to_string()); + labels + } +} + impl TaskInfo { pub fn for_test(job_id: &str, operator_id: &str) -> Self { Self { job_id: job_id.to_string(), + node_id: 1, operator_name: "op".to_string(), operator_id: operator_id.to_string(), task_index: 0, @@ -396,19 +432,12 @@ impl TaskInfo { key_range: 0..=u64::MAX, } } - - pub fn metric_label_map(&self) -> HashMap { - let mut labels = HashMap::new(); - labels.insert("operator_id".to_string(), self.operator_id.clone()); - labels.insert("subtask_idx".to_string(), format!("{}", self.task_index)); - labels.insert("operator_name".to_string(), self.operator_name.clone()); - labels - } } pub fn get_test_task_info() -> TaskInfo { TaskInfo { job_id: "instance-1".to_string(), + node_id: 1, operator_name: "test-operator".to_string(), operator_id: "test-operator-1".to_string(), task_index: 0, diff --git a/crates/arroyo-worker/src/arrow/async_udf.rs b/crates/arroyo-worker/src/arrow/async_udf.rs index 5220a0d34..6ad69f1c7 100644 --- a/crates/arroyo-worker/src/arrow/async_udf.rs +++ b/crates/arroyo-worker/src/arrow/async_udf.rs @@ -4,14 +4,15 @@ use arrow_array::{make_array, Array, RecordBatch, UInt64Array}; use arrow_schema::{Field, Schema}; use arroyo_datastream::logical::DylibUdfConfig; use arroyo_df::ASYNC_RESULT_FIELD; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, OperatorNode, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::grpc::api; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_state::global_table_config; -use arroyo_types::{ArrowMessage, CheckpointBarrier, SignalMessage, Watermark}; +use arroyo_types::{CheckpointBarrier, SignalMessage, Watermark}; use arroyo_udf_host::AsyncUdfDylib; use async_trait::async_trait; use bincode::{Decode, Encode}; @@ -64,7 +65,7 @@ impl OperatorConstructor for AsyncUdfConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let udf_config: DylibUdfConfig = config .udf .clone() @@ -83,24 +84,26 @@ impl OperatorConstructor for AsyncUdfConstructor { ) })?; - Ok(OperatorNode::from_operator(Box::new(AsyncUdfOperator { - name: config.name.clone(), - udf: (&*udf).try_into()?, - ordered, - allowed_in_flight: config.max_concurrency, - timeout: Duration::from_micros(config.timeout_micros), - config, - registry, - input_exprs: vec![], - final_exprs: vec![], - next_id: 0, - inputs: BTreeMap::new(), - outputs: BTreeMap::new(), - watermarks: VecDeque::new(), - input_row_converter: RowConverter::new(vec![]).unwrap(), - output_row_converter: RowConverter::new(vec![]).unwrap(), - input_schema: None, - }))) + Ok(ConstructedOperator::from_operator(Box::new( + AsyncUdfOperator { + name: config.name.clone(), + udf: (&*udf).try_into()?, + ordered, + allowed_in_flight: config.max_concurrency, + timeout: Duration::from_micros(config.timeout_micros), + config, + registry, + input_exprs: vec![], + final_exprs: vec![], + next_id: 0, + inputs: BTreeMap::new(), + outputs: BTreeMap::new(), + watermarks: VecDeque::new(), + input_row_converter: RowConverter::new(vec![]).unwrap(), + output_row_converter: RowConverter::new(vec![]).unwrap(), + input_schema: None, + }, + ))) } } @@ -147,7 +150,7 @@ impl ArrowOperator for AsyncUdfOperator { global_table_config("a", "AsyncMapOperator state") } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { info!("Starting async UDF with timeout {:?}", self.timeout); self.input_row_converter = RowConverter::new( ctx.in_schemas[0] @@ -234,7 +237,8 @@ impl ArrowOperator for AsyncUdfOperator { gs.get_all() .iter() .filter(|(task_index, _)| { - **task_index % ctx.task_info.parallelism == ctx.task_info.task_index + **task_index % ctx.task_info.parallelism as usize + == ctx.task_info.task_index as usize }) .for_each(|(_, state)| { for (k, v) in &state.inputs { @@ -290,7 +294,12 @@ impl ArrowOperator for AsyncUdfOperator { Some(Duration::from_millis(50)) } - async fn process_batch(&mut self, batch: RecordBatch, _: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { let arg_batch: Vec<_> = self .input_exprs .iter() @@ -325,7 +334,12 @@ impl ArrowOperator for AsyncUdfOperator { } } - async fn handle_tick(&mut self, _: u64, ctx: &mut ArrowContext) { + async fn handle_tick( + &mut self, + _: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let Some((ids, results)) = self .udf .drain_results() @@ -368,19 +382,25 @@ impl ArrowOperator for AsyncUdfOperator { self.inputs.remove(id); } - self.flush_output(ctx).await; + self.flush_output(ctx, collector).await; } async fn handle_watermark( &mut self, watermark: Watermark, - _ctx: &mut ArrowContext, + _: &mut OperatorContext, + _: &mut dyn Collector, ) -> Option { self.watermarks.push_back((self.next_id, watermark)); None } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let gs = ctx.table_manager.get_global_keyed_state("a").await.unwrap(); let state = AsyncUdfState { @@ -400,10 +420,15 @@ impl ArrowOperator for AsyncUdfOperator { gs.insert(ctx.task_info.task_index, state).await; } - async fn on_close(&mut self, final_message: &Option, ctx: &mut ArrowContext) { + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(SignalMessage::EndOfData) = final_message { while !self.inputs.is_empty() && !self.outputs.is_empty() { - self.handle_tick(0, ctx).await; + self.handle_tick(0, ctx, collector).await; tokio::time::sleep(Duration::from_millis(50)).await; } } @@ -411,7 +436,7 @@ impl ArrowOperator for AsyncUdfOperator { } impl AsyncUdfOperator { - async fn flush_output(&mut self, ctx: &mut ArrowContext) { + async fn flush_output(&mut self, ctx: &mut OperatorContext, collector: &mut dyn Collector) { // check if we can emit any records -- these are ones received before our most recent // watermark -- once all records from before a watermark have been processed, we can // remove and emit the watermark @@ -451,7 +476,7 @@ impl AsyncUdfOperator { let batch = RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), cols) .expect("failed to construct record batch"); - ctx.collect(batch).await; + collector.collect(batch).await; let Some(watermark) = watermark else { break; @@ -461,8 +486,7 @@ impl AsyncUdfOperator { if watermark_id <= oldest_unprocessed { // we've processed everything before this watermark, we can emit and drop it - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark(watermark))) - .await; + collector.broadcast_watermark(watermark).await; } else { // we still have messages preceding this watermark to work on self.watermarks.push_front((watermark_id, watermark)); diff --git a/crates/arroyo-worker/src/arrow/instant_join.rs b/crates/arroyo-worker/src/arrow/instant_join.rs index a111a1a15..45dc474b7 100644 --- a/crates/arroyo-worker/src/arrow/instant_join.rs +++ b/crates/arroyo-worker/src/arrow/instant_join.rs @@ -3,9 +3,9 @@ use anyhow::Result; use arrow::compute::{max, min, partition, sort_to_indices, take}; use arrow_array::{RecordBatch, TimestampNanosecondArray}; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, DisplayableOperator, OperatorConstructor, OperatorNode, Registry, + ArrowOperator, ConstructedOperator, DisplayableOperator, OperatorConstructor, Registry, }; use arroyo_rpc::{ df::{ArroyoSchema, ArroyoSchemaRef}, @@ -110,7 +110,7 @@ impl InstantJoin { &mut self, side: Side, batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) -> Result<()> { let table = ctx .table_manager @@ -171,7 +171,7 @@ impl InstantJoin { async fn process_left( &mut self, record_batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) -> Result<()> { self.process_side(Side::Left, record_batch, ctx).await } @@ -179,7 +179,7 @@ impl InstantJoin { async fn process_right( &mut self, right_batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) -> Result<()> { self.process_side(Side::Right, right_batch, ctx).await } @@ -200,7 +200,7 @@ impl ArrowOperator for InstantJoin { } } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let watermark = ctx.last_present_watermark(); let left_table = ctx .table_manager @@ -232,15 +232,22 @@ impl ArrowOperator for InstantJoin { } } - async fn process_batch(&mut self, _record_batch: RecordBatch, _ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + _: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { unreachable!(); } + async fn process_batch_index( &mut self, index: usize, total_inputs: usize, record_batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + _: &mut dyn Collector, ) { match index / (total_inputs / 2) { 0 => self @@ -256,11 +263,12 @@ impl ArrowOperator for InstantJoin { } async fn handle_watermark( &mut self, - int_watermark: Watermark, - ctx: &mut ArrowContext, + watermark: Watermark, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { let Some(watermark) = ctx.last_present_watermark() else { - return Some(int_watermark); + return Some(watermark); }; let futures_to_drain = { let mut futures_to_drain = vec![]; @@ -278,7 +286,7 @@ impl ArrowOperator for InstantJoin { while let (_time, Some((batch, new_exec))) = future.await { match batch { Ok(batch) => { - ctx.collect(batch).await; + collector.collect(batch).await; } Err(err) => { panic!("error in future: {:?}", err); @@ -287,10 +295,15 @@ impl ArrowOperator for InstantJoin { future = new_exec; } } - Some(int_watermark) + Some(Watermark::EventTime(watermark)) } - async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let watermark = ctx.last_present_watermark(); ctx.table_manager .get_expiring_time_key_table("left", watermark) @@ -346,7 +359,12 @@ impl ArrowOperator for InstantJoin { })) } - async fn handle_future_result(&mut self, result: Box, ctx: &mut ArrowContext) { + async fn handle_future_result( + &mut self, + result: Box, + _: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let data: Box> = result.downcast().expect("invalid data in future"); if let Some((bin, batch_option)) = *data { match batch_option { @@ -356,7 +374,8 @@ impl ArrowOperator for InstantJoin { Some((batch, future)) => match self.execs.get_mut(&bin) { Some(exec) => { exec.active_exec = future.clone(); - ctx.collect(batch.expect("should compute batch in future")) + collector + .collect(batch.expect("should compute batch in future")) .await; self.futures.lock().await.push(future); } @@ -376,7 +395,7 @@ impl OperatorConstructor for InstantJoinConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; let left_input_schema: Arc = @@ -399,7 +418,7 @@ impl OperatorConstructor for InstantJoinConstructor { &codec, )?; - Ok(OperatorNode::from_operator(Box::new(InstantJoin { + Ok(ConstructedOperator::from_operator(Box::new(InstantJoin { left_input_schema, right_input_schema, execs: BTreeMap::new(), diff --git a/crates/arroyo-worker/src/arrow/join_with_expiration.rs b/crates/arroyo-worker/src/arrow/join_with_expiration.rs index b8e26ee7d..14e4b338a 100644 --- a/crates/arroyo-worker/src/arrow/join_with_expiration.rs +++ b/crates/arroyo-worker/src/arrow/join_with_expiration.rs @@ -2,9 +2,10 @@ use anyhow::Result; use arrow::compute::concat_batches; use arrow_array::RecordBatch; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, OperatorNode, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::{ df::ArroyoSchema, @@ -41,7 +42,8 @@ impl JoinWithExpiration { async fn process_left( &mut self, record_batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Result<()> { let left_table = ctx .table_manager @@ -70,7 +72,7 @@ impl JoinWithExpiration { self.compute_pair( self.left_input_schema.unkeyed_batch(&record_batch)?, right_batch, - ctx, + collector, ) .await; Ok(()) @@ -79,7 +81,8 @@ impl JoinWithExpiration { async fn process_right( &mut self, right_batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Result<()> { let right_table = ctx .table_manager @@ -108,7 +111,7 @@ impl JoinWithExpiration { self.compute_pair( left_batch, self.right_input_schema.unkeyed_batch(&right_batch)?, - ctx, + collector, ) .await; Ok(()) @@ -118,7 +121,7 @@ impl JoinWithExpiration { &mut self, left: RecordBatch, right: RecordBatch, - ctx: &mut ArrowContext, + collector: &mut dyn Collector, ) { { self.right_passer.write().unwrap().replace(right); @@ -131,7 +134,7 @@ impl JoinWithExpiration { .expect("successfully computed?"); while let Some(batch) = records.next().await { let batch = batch.expect("should be able to compute batch"); - ctx.collect(batch).await; + collector.collect(batch).await; } } } @@ -162,7 +165,12 @@ impl ArrowOperator for JoinWithExpiration { } } - async fn process_batch(&mut self, _record_batch: RecordBatch, _ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + _record_batch: RecordBatch, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { unreachable!(); } async fn process_batch_index( @@ -170,15 +178,16 @@ impl ArrowOperator for JoinWithExpiration { index: usize, total_inputs: usize, record_batch: RecordBatch, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) { match index / (total_inputs / 2) { 0 => self - .process_left(record_batch, ctx) + .process_left(record_batch, ctx, collector) .await .expect("should process left"), 1 => self - .process_right(record_batch, ctx) + .process_right(record_batch, ctx, collector) .await .expect("should process right"), _ => unreachable!(), @@ -218,7 +227,7 @@ impl OperatorConstructor for JoinWithExpirationConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let left_passer = Arc::new(RwLock::new(None)); let right_passer = Arc::new(RwLock::new(None)); @@ -251,16 +260,18 @@ impl OperatorConstructor for JoinWithExpirationConstructor { ttl = Duration::from_secs(24 * 60 * 60); } - Ok(OperatorNode::from_operator(Box::new(JoinWithExpiration { - left_expiration: ttl, - right_expiration: ttl, - left_input_schema, - right_input_schema, - left_schema, - right_schema, - left_passer, - right_passer, - join_execution_plan, - }))) + Ok(ConstructedOperator::from_operator(Box::new( + JoinWithExpiration { + left_expiration: ttl, + right_expiration: ttl, + left_input_schema, + right_input_schema, + left_schema, + right_schema, + left_passer, + right_passer, + join_execution_plan, + }, + ))) } } diff --git a/crates/arroyo-worker/src/arrow/mod.rs b/crates/arroyo-worker/src/arrow/mod.rs index 381bf5f45..0ea7f57e6 100644 --- a/crates/arroyo-worker/src/arrow/mod.rs +++ b/crates/arroyo-worker/src/arrow/mod.rs @@ -2,9 +2,10 @@ use arrow::datatypes::SchemaRef; use arrow_array::RecordBatch; use arroyo_df::physical::ArroyoPhysicalExtensionCodec; use arroyo_df::physical::DecodingContext; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, OperatorNode, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::grpc::api; use datafusion::common::DataFusionError; @@ -47,9 +48,9 @@ impl OperatorConstructor for ValueExecutionConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let executor = StatelessPhysicalExecutor::new(&config.physical_plan, ®istry)?; - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( ValueExecutionOperator { name: config.name, executor, @@ -71,11 +72,16 @@ impl ArrowOperator for ValueExecutionOperator { } } - async fn process_batch(&mut self, record_batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + record_batch: RecordBatch, + _: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let mut records = self.executor.process_batch(record_batch).await; while let Some(batch) = records.next().await { let batch = batch.expect("should be able to compute batch"); - ctx.collect(batch).await; + collector.collect(batch).await; } } } @@ -163,10 +169,10 @@ impl OperatorConstructor for KeyExecutionConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let executor = StatelessPhysicalExecutor::new(&config.physical_plan, ®istry)?; - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( KeyExecutionOperator { name: config.name, executor, @@ -196,13 +202,18 @@ impl ArrowOperator for KeyExecutionOperator { } } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let mut records = self.executor.process_batch(batch).await; while let Some(batch) = records.next().await { let batch = batch.expect("should be able to compute batch"); //TODO: sort by the key //info!("batch {:?}", batch); - ctx.collect(batch).await; + collector.collect(batch).await; } } } diff --git a/crates/arroyo-worker/src/arrow/session_aggregating_window.rs b/crates/arroyo-worker/src/arrow/session_aggregating_window.rs index 68f19b3a2..5d14387f5 100644 --- a/crates/arroyo-worker/src/arrow/session_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/session_aggregating_window.rs @@ -19,8 +19,8 @@ use arrow_array::{ use arrow_schema::{DataType, Field, FieldRef}; use arroyo_df::schemas::window_arrow_struct; use arroyo_operator::{ - context::ArrowContext, - operator::{ArrowOperator, OperatorConstructor, OperatorNode}, + context::OperatorContext, + operator::{ArrowOperator, ConstructedOperator, OperatorConstructor}, }; use arroyo_rpc::{ grpc::{api, rpc::TableConfig}, @@ -33,6 +33,7 @@ use arroyo_types::{from_nanos, print_time, to_nanos, CheckpointBarrier, Watermar use datafusion::{execution::context::SessionContext, physical_plan::ExecutionPlan}; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::context::Collector; use arroyo_operator::operator::Registry; use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; use datafusion::execution::{ @@ -45,7 +46,6 @@ use std::time::Duration; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::StreamExt; use tracing::{debug, warn}; - // TODO: advance futures outside of method calls. pub struct SessionAggregatingWindowFunc { @@ -71,7 +71,11 @@ impl SessionAggregatingWindowFunc { result } - async fn advance(&mut self, ctx: &mut ArrowContext) -> Result<()> { + async fn advance( + &mut self, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) -> Result<()> { let Some(watermark) = ctx.last_present_watermark() else { debug!("no watermark, not advancing"); return Ok(()); @@ -85,7 +89,7 @@ impl SessionAggregatingWindowFunc { .to_record_batch(results, ctx) .context("should convert to record batch")?; debug!("emitting session batch of size {}", result_batch.num_rows()); - ctx.collect(result_batch).await; + collector.collect(result_batch).await; } Ok(()) @@ -294,7 +298,7 @@ impl SessionAggregatingWindowFunc { fn to_record_batch( &self, results: Vec<(OwnedRow, Vec)>, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, ) -> Result { debug!("first result is {:#?}", results[0]); let (rows, results): (Vec<_>, Vec<_>) = results @@ -684,7 +688,7 @@ impl OperatorConstructor for SessionAggregatingWindowConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let window_field = Arc::new(Field::new( config.window_field_name, window_arrow_struct(), @@ -735,7 +739,7 @@ impl OperatorConstructor for SessionAggregatingWindowConstructor { receiver, }; - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( SessionAggregatingWindowFunc { config: Arc::new(config), keys_by_next_watermark_action: BTreeMap::new(), @@ -754,8 +758,8 @@ impl ArrowOperator for SessionAggregatingWindowFunc { "session_window".to_string() } - async fn on_start(&mut self, ctx: &mut ArrowContext) { - let start_times_map: &mut GlobalKeyedView> = + async fn on_start(&mut self, ctx: &mut OperatorContext) { + let start_times_map: &mut GlobalKeyedView> = ctx.table_manager.get_global_keyed_state("e").await.unwrap(); let start_time = start_times_map .get_all() @@ -809,7 +813,12 @@ impl ArrowOperator for SessionAggregatingWindowFunc { } // TODO: filter out late data - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { debug!("received batch {:?}", batch); let current_watermark = ctx.last_present_watermark(); let batch = if let Some(watermark) = current_watermark { @@ -857,13 +866,19 @@ impl ArrowOperator for SessionAggregatingWindowFunc { async fn handle_watermark( &mut self, watermark: Watermark, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { - self.advance(ctx).await.unwrap(); + self.advance(ctx, collector).await.unwrap(); Some(watermark) } - async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let watermark = ctx.last_present_watermark(); let table = ctx .table_manager diff --git a/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs b/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs index c0bc2a204..0c97fdfbb 100644 --- a/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs @@ -3,8 +3,8 @@ use arrow::compute::{partition, sort_to_indices, take}; use arrow_array::{types::TimestampNanosecondType, Array, PrimitiveArray, RecordBatch}; use arrow_schema::SchemaRef; use arroyo_operator::{ - context::ArrowContext, - operator::{ArrowOperator, OperatorConstructor, OperatorNode}, + context::OperatorContext, + operator::{ArrowOperator, ConstructedOperator, OperatorConstructor}, }; use arroyo_rpc::grpc::{api, rpc::TableConfig}; use arroyo_state::timestamp_table_config; @@ -21,7 +21,9 @@ use std::{ use futures::stream::FuturesUnordered; +use super::sync::streams::KeyedCloneableStreamFuture; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::context::Collector; use arroyo_operator::operator::{AsDisplayable, DisplayableOperator, Registry}; use arroyo_rpc::df::ArroyoSchema; use datafusion::execution::{ @@ -40,8 +42,6 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::StreamExt; use tracing::info; -use super::sync::streams::KeyedCloneableStreamFuture; - pub struct SlidingAggregatingWindowFunc { slide: Duration, width: Duration, @@ -113,7 +113,11 @@ impl SlidingAggregatingWindowFunc { } } - async fn advance(&mut self, ctx: &mut ArrowContext) -> Result<()> { + async fn advance( + &mut self, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) -> Result<()> { let bin_start = match self.state { SlidingWindowState::NoData => unreachable!(), SlidingWindowState::OnlyBufferedData { earliest_bin_time } => earliest_bin_time, @@ -204,7 +208,7 @@ impl SlidingAggregatingWindowFunc { .execute(0, SessionContext::new().task_ctx())?; while let Some(batch) = final_projection_exec.next().await { let batch = batch.expect("should be able to compute batch"); - ctx.collector.collect(batch).await; + collector.collect(batch).await; } Ok(()) @@ -453,7 +457,7 @@ impl OperatorConstructor for SlidingAggregatingWindowConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let width = Duration::from_micros(config.width_micros); let input_schema: ArroyoSchema = config .input_schema @@ -505,7 +509,7 @@ impl OperatorConstructor for SlidingAggregatingWindowConstructor { &final_codec, )?; - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( SlidingAggregatingWindowFunc { slide, width, @@ -554,7 +558,7 @@ impl ArrowOperator for SlidingAggregatingWindowFunc { } } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let watermark = ctx.last_present_watermark(); let table = ctx .table_manager @@ -596,7 +600,12 @@ impl ArrowOperator for SlidingAggregatingWindowFunc { } // TODO: filter out late data - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let bin = self .binning_function .evaluate(&batch) @@ -671,18 +680,24 @@ impl ArrowOperator for SlidingAggregatingWindowFunc { async fn handle_watermark( &mut self, watermark: Watermark, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { let last_watermark = ctx.last_present_watermark()?; while self.should_advance(last_watermark) { - self.advance(ctx).await.unwrap(); + self.advance(ctx, collector).await.unwrap(); } Some(watermark) } - async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let watermark = ctx .watermark() .and_then(|watermark: Watermark| match watermark { diff --git a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs index 0131e40e0..366adde3f 100644 --- a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs @@ -3,9 +3,10 @@ use arrow::compute::{partition, sort_to_indices, take}; use arrow_array::{types::TimestampNanosecondType, Array, PrimitiveArray, RecordBatch}; use arrow_schema::SchemaRef; use arroyo_df::schemas::add_timestamp_field_arrow; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, OperatorNode, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::grpc::{api, rpc::TableConfig}; use arroyo_state::timestamp_table_config; @@ -40,7 +41,7 @@ use prost::Message; use std::time::Duration; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::Mutex; -use tracing::{debug, warn}; +use tracing::warn; use super::sync::streams::KeyedCloneableStreamFuture; type NextBatchFuture = KeyedCloneableStreamFuture; @@ -114,7 +115,7 @@ impl OperatorConstructor for TumblingAggregateWindowConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let width = Duration::from_micros(config.width_micros); let input_schema: ArroyoSchema = config .input_schema @@ -182,7 +183,7 @@ impl OperatorConstructor for TumblingAggregateWindowConstructor { let aggregate_with_timestamp_schema = add_timestamp_field_arrow(finish_execution_plan.schema()); - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( TumblingAggregatingWindowFunc { width, binning_function, @@ -230,7 +231,7 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { } } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let watermark = ctx.last_present_watermark(); let table = ctx .table_manager @@ -246,7 +247,12 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { } } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let bin = self .binning_function .evaluate(&batch) @@ -312,7 +318,8 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { async fn handle_watermark( &mut self, watermark: Watermark, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { if let Some(watermark) = ctx.last_present_watermark() { let bin = self.bin_start(watermark); @@ -359,7 +366,7 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { if self.final_projection.is_some() { aggregate_results.push(with_timestamp); } else { - ctx.collect(with_timestamp).await; + collector.collect(with_timestamp).await; } } if let Some(final_projection) = self.final_projection.as_ref() { @@ -373,7 +380,7 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { .unwrap(); while let Some(batch) = final_projection_exec.next().await { let batch = batch.expect("should be able to compute batch"); - ctx.collect(batch).await; + collector.collect(batch).await; } } } else { @@ -399,28 +406,33 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { })) } - async fn handle_future_result(&mut self, result: Box, _: &mut ArrowContext) { + async fn handle_future_result( + &mut self, + result: Box, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { let data: Box> = result.downcast().expect("invalid data in future"); - if let Some((bin, batch_option)) = *data { - match batch_option { + if let Some((bin, Some((batch, future)))) = *data { + match self.execs.get_mut(&bin) { + Some(exec) => { + exec.finished_batches + .push(batch.expect("should've been able to compute a batch")); + self.futures.lock().await.push(future); + } None => { - debug!("future for {} was finished elsewhere", print_time(bin)); + unreachable!("FuturesUnordered returned a batch, but we can't find the exec") } - Some((batch, future)) => match self.execs.get_mut(&bin) { - Some(exec) => { - exec.finished_batches - .push(batch.expect("should've been able to compute a batch")); - self.futures.lock().await.push(future); - } - None => unreachable!( - "FuturesUnordered returned a batch, but we can't find the exec" - ), - }, } } } - async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let watermark = ctx .watermark() .and_then(|watermark: Watermark| match watermark { diff --git a/crates/arroyo-worker/src/arrow/updating_aggregator.rs b/crates/arroyo-worker/src/arrow/updating_aggregator.rs index fa64c4cc5..b51e1d85c 100644 --- a/crates/arroyo-worker/src/arrow/updating_aggregator.rs +++ b/crates/arroyo-worker/src/arrow/updating_aggregator.rs @@ -13,11 +13,12 @@ use arrow_array::{Array, BooleanArray, RecordBatch, StructArray}; use arrow_array::cast::AsArray; use arrow_schema::SchemaRef; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::context::Collector; use arroyo_operator::{ - context::ArrowContext, + context::OperatorContext, operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, OperatorNode, - Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, + OperatorConstructor, Registry, }, }; use arroyo_rpc::df::ArroyoSchemaRef; @@ -57,7 +58,11 @@ pub struct UpdatingAggregatingFunc { } impl UpdatingAggregatingFunc { - async fn flush(&mut self, ctx: &mut ArrowContext) -> Result<()> { + async fn flush( + &mut self, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) -> Result<()> { if self.sender.is_none() { return Ok(()); } @@ -158,11 +163,12 @@ impl UpdatingAggregatingFunc { } if !batches_to_write.is_empty() { - ctx.collect(concat_batches( - &batches_to_write[0].schema(), - batches_to_write.iter(), - )?) - .await; + collector + .collect(concat_batches( + &batches_to_write[0].schema(), + batches_to_write.iter(), + )?) + .await; } Ok(()) @@ -242,15 +248,25 @@ impl ArrowOperator for UpdatingAggregatingFunc { } } - async fn process_batch(&mut self, batch: RecordBatch, _ctx: &mut ArrowContext) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { if self.sender.is_none() { self.init_exec(); } self.sender.as_ref().unwrap().send(batch).unwrap(); } - async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { - self.flush(ctx).await.unwrap(); + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + self.flush(ctx, collector).await.unwrap(); } fn tables(&self) -> HashMap { @@ -283,14 +299,20 @@ impl ArrowOperator for UpdatingAggregatingFunc { Some(self.flush_interval) } - async fn handle_tick(&mut self, _tick: u64, ctx: &mut ArrowContext) { - self.flush(ctx).await.unwrap(); + async fn handle_tick( + &mut self, + _tick: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + self.flush(ctx, collector).await.unwrap(); } async fn handle_watermark( &mut self, watermark: Watermark, - ctx: &mut ArrowContext, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { let last_watermark = ctx.last_present_watermark(); let partial_table = ctx @@ -299,7 +321,7 @@ impl ArrowOperator for UpdatingAggregatingFunc { .await .expect("should have partial table"); if partial_table.would_expire(last_watermark) { - self.flush(ctx).await.unwrap(); + self.flush(ctx, collector).await.unwrap(); } let partial_table = ctx .table_manager @@ -331,17 +353,18 @@ impl ArrowOperator for UpdatingAggregatingFunc { })) } - async fn handle_future_result(&mut self, _result: Box, _: &mut ArrowContext) { - //unreachable!("should not have future result") - } - - async fn on_close(&mut self, final_mesage: &Option, ctx: &mut ArrowContext) { - if let Some(SignalMessage::EndOfData) = final_mesage { - self.flush(ctx).await.unwrap(); + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + if let Some(SignalMessage::EndOfData) = final_message { + self.flush(ctx, collector).await.unwrap(); } } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { // fetch the tables so they are ready to be queried. ctx.table_manager .get_last_key_value_table("f", ctx.last_present_watermark()) @@ -363,7 +386,7 @@ impl OperatorConstructor for UpdatingAggregatingConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let receiver = Arc::new(RwLock::new(None)); let codec = ArroyoPhysicalExtensionCodec { @@ -407,7 +430,7 @@ impl OperatorConstructor for UpdatingAggregatingConstructor { config.ttl_micros }; - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( UpdatingAggregatingFunc { partial_aggregation_plan, partial_schema: Arc::new(partial_schema), diff --git a/crates/arroyo-worker/src/arrow/watermark_generator.rs b/crates/arroyo-worker/src/arrow/watermark_generator.rs index 8eefd3343..3ea91b41c 100644 --- a/crates/arroyo-worker/src/arrow/watermark_generator.rs +++ b/crates/arroyo-worker/src/arrow/watermark_generator.rs @@ -1,17 +1,16 @@ use arrow::compute::kernels; use arrow_array::RecordBatch; -use arroyo_operator::context::ArrowContext; +use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::get_timestamp_col; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, OperatorNode, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::api::ExpressionWatermarkConfig; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_state::global_table_config; -use arroyo_types::{ - from_nanos, to_millis, ArrowMessage, CheckpointBarrier, SignalMessage, Watermark, -}; +use arroyo_types::{from_nanos, to_millis, CheckpointBarrier, SignalMessage, Watermark}; use async_trait::async_trait; use bincode::{Decode, Encode}; use datafusion::physical_expr::PhysicalExpr; @@ -68,7 +67,7 @@ impl OperatorConstructor for WatermarkGeneratorConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let input_schema: ArroyoSchema = config.input_schema.unwrap().try_into()?; let expression = PhysicalExprNode::decode(&mut config.expression.as_slice())?; let expression = parse_physical_expr( @@ -78,7 +77,7 @@ impl OperatorConstructor for WatermarkGeneratorConstructor { &DefaultPhysicalExtensionCodec {}, )?; - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( WatermarkGenerator::expression( Duration::from_micros(config.period_micros), config.idle_time_micros.map(Duration::from_micros), @@ -113,7 +112,7 @@ impl ArrowOperator for WatermarkGenerator { Some(Duration::from_secs(1)) } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + async fn on_start(&mut self, ctx: &mut OperatorContext) { let gs = ctx .table_manager .get_global_keyed_state("s") @@ -131,21 +130,31 @@ impl ArrowOperator for WatermarkGenerator { self.state_cache = state; } - async fn on_close(&mut self, final_message: &Option, ctx: &mut ArrowContext) { + async fn on_close( + &mut self, + final_message: &Option, + _: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(SignalMessage::EndOfData) = final_message { // send final watermark on close - ctx.collector - .broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - // this is in the year 2554, far enough out be close to inifinity, + collector + .broadcast_watermark( + // this is in the year 2554, far enough out be close to infinity, // but can still be formatted. Watermark::EventTime(from_nanos(u64::MAX as u128)), - ))) + ) .await; } } - async fn process_batch(&mut self, record: RecordBatch, ctx: &mut ArrowContext) { - ctx.collector.collect(record.clone()).await; + async fn process_batch( + &mut self, + record: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + collector.collect(record.clone()).await; self.last_event = SystemTime::now(); let timestamp_column = get_timestamp_col(&record, ctx); @@ -181,17 +190,20 @@ impl ArrowOperator for WatermarkGenerator { ctx.task_info.task_index, to_millis(watermark) ); - ctx.collector - .broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::EventTime(watermark), - ))) + collector + .broadcast_watermark(Watermark::EventTime(watermark)) .await; self.state_cache.last_watermark_emitted_at = max_timestamp; self.idle = false; } } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let gs = ctx .table_manager .get_global_keyed_state("s") @@ -201,17 +213,19 @@ impl ArrowOperator for WatermarkGenerator { gs.insert(ctx.task_info.task_index, self.state_cache).await; } - async fn handle_tick(&mut self, _: u64, ctx: &mut ArrowContext) { + async fn handle_tick( + &mut self, + _: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(idle_time) = self.idle_time { if self.last_event.elapsed().unwrap_or(Duration::ZERO) > idle_time && !self.idle { info!( "Setting partition {} to idle after {:?}", ctx.task_info.task_index, idle_time ); - ctx.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector.broadcast_watermark(Watermark::Idle).await; self.idle = true; } } diff --git a/crates/arroyo-worker/src/arrow/window_fn.rs b/crates/arroyo-worker/src/arrow/window_fn.rs index 600e1b8b6..0462634b3 100644 --- a/crates/arroyo-worker/src/arrow/window_fn.rs +++ b/crates/arroyo-worker/src/arrow/window_fn.rs @@ -7,8 +7,10 @@ use anyhow::{anyhow, Result}; use arrow::compute::{max, min}; use arrow_array::RecordBatch; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; -use arroyo_operator::context::ArrowContext; -use arroyo_operator::operator::{ArrowOperator, OperatorConstructor, OperatorNode, Registry}; +use arroyo_operator::context::{Collector, OperatorContext}; +use arroyo_operator::operator::{ + ArrowOperator, ConstructedOperator, OperatorConstructor, Registry, +}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::{df::ArroyoSchemaRef, grpc::api}; @@ -123,7 +125,8 @@ impl ArrowOperator for WindowFunctionOperator { fn name(&self) -> String { "WindowFunction".to_string() } - async fn on_start(&mut self, ctx: &mut ArrowContext) { + + async fn on_start(&mut self, ctx: &mut OperatorContext) { let watermark = ctx.last_present_watermark(); let table = ctx .table_manager @@ -137,7 +140,13 @@ impl ArrowOperator for WindowFunctionOperator { } } } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut ArrowContext) { + + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let current_watermark = ctx.last_present_watermark(); let table = ctx .table_manager @@ -156,11 +165,12 @@ impl ArrowOperator for WindowFunctionOperator { async fn handle_watermark( &mut self, - watermark_message: Watermark, - ctx: &mut ArrowContext, + watermark: Watermark, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, ) -> Option { let Some(watermark) = ctx.last_present_watermark() else { - return Some(watermark_message); + return Some(watermark); }; loop { let finished = { @@ -180,13 +190,18 @@ impl ArrowOperator for WindowFunctionOperator { while let (_timestamp, Some((batch, new_exec))) = active_exec.await { active_exec = new_exec; let batch = batch.expect("batch should be computable"); - ctx.collect(batch).await; + collector.collect(batch).await; } } - Some(watermark_message) + Some(Watermark::EventTime(watermark)) } - async fn handle_checkpoint(&mut self, _cb: CheckpointBarrier, ctx: &mut ArrowContext) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let watermark = ctx.last_present_watermark(); ctx.table_manager .get_expiring_time_key_table("input", watermark) @@ -220,7 +235,7 @@ impl OperatorConstructor for WindowFunctionConstructor { &self, config: Self::ConfigT, registry: Arc, - ) -> anyhow::Result { + ) -> anyhow::Result { let window_exec = PhysicalPlanNode::decode(&mut config.window_function_plan.as_slice())?; let input_schema = Arc::new(ArroyoSchema::try_from( config @@ -239,7 +254,7 @@ impl OperatorConstructor for WindowFunctionConstructor { let input_schema_unkeyed = Arc::new(ArroyoSchema::from_schema_unkeyed( input_schema.schema.clone(), )?); - Ok(OperatorNode::from_operator(Box::new( + Ok(ConstructedOperator::from_operator(Box::new( WindowFunctionOperator { input_schema, input_schema_unkeyed, diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index 99a86e4e3..5bc60d3a4 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -1,17 +1,3 @@ -use std::collections::{BTreeMap, HashMap}; -use std::fmt::{Debug, Formatter}; -use std::mem; -use std::sync::{Arc, RwLock}; - -use std::time::SystemTime; - -use arroyo_connectors::connectors; -use arroyo_rpc::df::ArroyoSchema; -use bincode::{Decode, Encode}; -use futures::stream::FuturesUnordered; -use futures::StreamExt; -use tracing::{info, warn}; - use crate::arrow::async_udf::AsyncUdfConstructor; use crate::arrow::instant_join::InstantJoinConstructor; use crate::arrow::join_with_expiration::JoinWithExpirationConstructor; @@ -23,15 +9,17 @@ use crate::arrow::watermark_generator::WatermarkGeneratorConstructor; use crate::arrow::window_fn::WindowFunctionConstructor; use crate::arrow::{KeyExecutionConstructor, ValueExecutionConstructor}; use crate::network_manager::{NetworkManager, Quad, Senders}; +use arroyo_connectors::connectors; use arroyo_datastream::logical::{ - LogicalEdge, LogicalEdgeType, LogicalGraph, LogicalNode, OperatorName, + LogicalEdge, LogicalEdgeType, LogicalGraph, LogicalNode, OperatorChain, OperatorName, }; use arroyo_df::physical::new_registry; -use arroyo_operator::context::{batch_bounded, ArrowContext, BatchReceiver, BatchSender}; -use arroyo_operator::operator::OperatorNode; +use arroyo_operator::context::{batch_bounded, BatchReceiver, BatchSender, OperatorContext}; use arroyo_operator::operator::Registry; +use arroyo_operator::operator::{ChainedOperator, ConstructedOperator, OperatorNode, SourceNode}; use arroyo_operator::ErasedConstructor; use arroyo_rpc::config::config; +use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::{ api, rpc::{CheckpointMetadata, TaskAssignment}, @@ -40,11 +28,20 @@ use arroyo_rpc::{ControlMessage, ControlResp}; use arroyo_state::{BackingStore, StateBackend}; use arroyo_types::{range_for_server, Key, TaskInfo, WorkerId}; use arroyo_udf_host::LocalUdf; +use bincode::{Decode, Encode}; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::EdgeRef; -use petgraph::Direction; +use petgraph::{dot, Direction}; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::{Debug, Formatter}; +use std::mem; +use std::sync::{Arc, RwLock}; +use std::time::SystemTime; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::Barrier; +use tracing::{debug, info, warn}; #[derive(Encode, Decode, Clone, Debug, PartialEq, Eq)] pub struct TimerValue { @@ -54,23 +51,29 @@ pub struct TimerValue { } pub struct SubtaskNode { - pub id: String, + pub node_id: u32, pub subtask_idx: usize, pub parallelism: usize, pub in_schemas: Vec, pub out_schema: Option, - pub projection: Option>, pub node: OperatorNode, } impl Debug for SubtaskNode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}-{}-{}", self.node.name(), self.id, self.subtask_idx) + write!( + f, + "{}-{}-{}", + self.node.name(), + self.node_id, + self.subtask_idx + ) } } pub struct QueueNode { - task_info: TaskInfo, + task_info: Arc, + operator_ids: Vec, tx: Sender, } @@ -111,20 +114,14 @@ impl Debug for PhysicalGraphEdge { } impl SubtaskOrQueueNode { - pub fn take_subtask(&mut self, job_id: String) -> (SubtaskNode, Receiver) { + pub fn take_subtask(&mut self) -> (SubtaskNode, Receiver) { let (mut qn, rx) = match self { SubtaskOrQueueNode::SubtaskNode(sn) => { let (tx, rx) = channel(16); let n = SubtaskOrQueueNode::QueueNode(QueueNode { - task_info: TaskInfo { - job_id, - operator_name: sn.node.name(), - operator_id: sn.id.clone(), - task_index: sn.subtask_idx, - parallelism: sn.parallelism, - key_range: range_for_server(sn.subtask_idx, sn.parallelism), - }, + task_info: sn.node.task_info().clone(), + operator_ids: sn.node.operator_ids(), tx, }); @@ -138,17 +135,17 @@ impl SubtaskOrQueueNode { (qn.unwrap_subtask(), rx) } - pub fn id(&self) -> &str { + pub fn id(&self) -> u32 { match self { - SubtaskOrQueueNode::SubtaskNode(n) => &n.id, - SubtaskOrQueueNode::QueueNode(n) => &n.task_info.operator_id, + SubtaskOrQueueNode::SubtaskNode(n) => n.node_id, + SubtaskOrQueueNode::QueueNode(n) => n.task_info.node_id, } } pub fn subtask_idx(&self) -> usize { match self { SubtaskOrQueueNode::SubtaskNode(n) => n.subtask_idx, - SubtaskOrQueueNode::QueueNode(n) => n.task_info.task_index, + SubtaskOrQueueNode::QueueNode(n) => n.task_info.task_index as usize, } } @@ -170,6 +167,7 @@ impl SubtaskOrQueueNode { pub struct Program { pub name: String, pub graph: Arc>>, + pub control_tx: Option>, } impl Program { @@ -177,17 +175,19 @@ impl Program { self.graph.read().unwrap().node_count() } - pub fn local_from_logical( - name: String, + pub async fn local_from_logical( + job_id: String, logical: &DiGraph, udfs: &[LocalUdf], + restore_epoch: Option, + control_tx: Sender, ) -> Self { let assignments = logical .node_weights() .flat_map(|weight| { (0..weight.parallelism).map(|index| TaskAssignment { - operator_id: weight.operator_id.clone(), - operator_subtask: index as u64, + node_id: weight.node_id, + subtask_idx: index as u32, worker_id: 0, worker_addr: "".into(), }) @@ -198,22 +198,49 @@ impl Program { for udf in udfs { registry.add_local_udf(udf); } - Self::from_logical(name, logical, &assignments, registry) + Self::from_logical( + "local".to_string(), + &job_id, + logical, + &assignments, + registry, + restore_epoch, + control_tx, + ) + .await } - pub fn from_logical( + pub async fn from_logical( name: String, + job_id: &str, logical: &LogicalGraph, assignments: &Vec, registry: Registry, + restore_epoch: Option, + control_tx: Sender, ) -> Program { let mut physical = DiGraph::new(); + let checkpoint_metadata = if let Some(epoch) = restore_epoch { + info!("Restoring checkpoint {} for job {}", epoch, job_id); + Some( + StateBackend::load_checkpoint_metadata(job_id, epoch) + .await + .unwrap_or_else(|_| { + panic!("failed to load checkpoint metadata for epoch {}", epoch) + }), + ) + } else { + None + }; + + debug!("Logical graph\n{:?}", dot::Dot::new(logical)); + let registry = Arc::new(registry); let mut parallelism_map = HashMap::new(); for task in assignments { - *(parallelism_map.entry(&task.operator_id).or_insert(0usize)) += 1; + *(parallelism_map.entry(&task.node_id).or_insert(0usize)) += 1; } for idx in logical.node_indices() { @@ -227,30 +254,44 @@ impl Program { .map(|edge| edge.weight().schema.clone()) .next(); - let projection = logical - .edges_directed(idx, Direction::Outgoing) - .map(|edge| edge.weight().projection.clone()) - .next() - .unwrap_or_default(); + let in_queue_count = logical + .edges_directed(idx, Direction::Incoming) + .map(|edge| match edge.weight().edge_type { + LogicalEdgeType::Forward => 1, + LogicalEdgeType::Shuffle + | LogicalEdgeType::LeftJoin + | LogicalEdgeType::RightJoin => { + logical.node_weight(edge.source()).unwrap().parallelism as u32 + } + }) + .sum(); let node = logical.node_weight(idx).unwrap(); - let parallelism = *parallelism_map.get(&node.operator_id).unwrap_or_else(|| { - warn!("no assignments for operator {}", node.operator_id); + let parallelism = *parallelism_map.get(&node.node_id).unwrap_or_else(|| { + warn!("no assignments for node {}", node.node_id); &node.parallelism }); for i in 0..parallelism { physical.add_node(SubtaskOrQueueNode::SubtaskNode(SubtaskNode { - id: node.operator_id.clone(), + node_id: node.node_id, subtask_idx: i, parallelism, in_schemas: in_schemas.clone(), out_schema: out_schema.clone(), - node: construct_operator( - node.operator_name, - node.operator_config.clone(), + node: construct_node( + node.operator_chain.clone(), + job_id, + node.node_id, + i as u32, + parallelism as u32, + in_queue_count, + in_schemas.clone(), + out_schema.clone(), + checkpoint_metadata.as_ref(), + control_tx.clone(), registry.clone(), - ), - projection: projection.clone(), + ) + .await, })); } } @@ -265,12 +306,12 @@ impl Program { let from_nodes: Vec<_> = physical .node_indices() - .filter(|n| physical.node_weight(*n).unwrap().id() == logical_in_node.operator_id) + .filter(|n| physical.node_weight(*n).unwrap().id() == logical_in_node.node_id) .collect(); assert_ne!(from_nodes.len(), 0, "failed to find from nodes"); let to_nodes: Vec<_> = physical .node_indices() - .filter(|n| physical.node_weight(*n).unwrap().id() == logical_out_node.operator_id) + .filter(|n| physical.node_weight(*n).unwrap().id() == logical_out_node.node_id) .collect(); assert_ne!(from_nodes.len(), 0, "failed to find to nodes"); @@ -318,17 +359,9 @@ impl Program { Program { name, graph: Arc::new(RwLock::new(physical)), + control_tx: Some(control_tx), } } - - pub fn tasks_per_operator(&self) -> HashMap { - let mut tasks_per_operator = HashMap::new(); - for node in self.graph.read().unwrap().node_weights() { - let entry = tasks_per_operator.entry(node.id().to_string()).or_insert(0); - *entry += 1; - } - tasks_per_operator - } } pub struct Engine { @@ -338,7 +371,7 @@ pub struct Engine { run_id: String, job_id: String, network_manager: NetworkManager, - assignments: HashMap<(String, usize), TaskAssignment>, + assignments: HashMap<(u32, usize), TaskAssignment>, } pub struct StreamConfig { @@ -347,7 +380,7 @@ pub struct StreamConfig { pub struct RunningEngine { program: Program, - assignments: HashMap<(String, usize), TaskAssignment>, + assignments: HashMap<(u32, usize), TaskAssignment>, worker_id: WorkerId, } @@ -359,7 +392,7 @@ impl RunningEngine { .filter(|idx| { let w = graph.node_weight(*idx).unwrap(); self.assignments - .get(&(w.id().to_string(), w.subtask_idx())) + .get(&(w.id(), w.subtask_idx())) .unwrap() .worker_id == self.worker_id.0 @@ -375,7 +408,7 @@ impl RunningEngine { .filter(|idx| { let w = graph.node_weight(*idx).unwrap(); self.assignments - .get(&(w.id().to_string(), w.subtask_idx())) + .get(&(w.id(), w.subtask_idx())) .unwrap() .worker_id == self.worker_id.0 @@ -384,7 +417,7 @@ impl RunningEngine { .collect() } - pub fn operator_controls(&self) -> HashMap>> { + pub fn operator_controls(&self) -> HashMap>> { let mut controls = HashMap::new(); let graph = self.program.graph.read().unwrap(); @@ -393,26 +426,34 @@ impl RunningEngine { .filter(|idx| { let w = graph.node_weight(*idx).unwrap(); self.assignments - .get(&(w.id().to_string(), w.subtask_idx())) + .get(&(w.id(), w.subtask_idx())) .unwrap() .worker_id == self.worker_id.0 }) .for_each(|idx| { let w = graph.node_weight(idx).unwrap(); - let assignment = self - .assignments - .get(&(w.id().to_string(), w.subtask_idx())) - .unwrap(); + let assignment = self.assignments.get(&(w.id(), w.subtask_idx())).unwrap(); let tx = graph.node_weight(idx).unwrap().as_queue().tx.clone(); controls - .entry(assignment.operator_id.clone()) + .entry(assignment.node_id) .or_insert(vec![]) .push(tx); }); controls } + + pub fn operator_to_node(&self) -> HashMap { + let program = self.program.graph.read().unwrap(); + let mut result = HashMap::new(); + for n in program.node_weights() { + for id in &n.as_queue().operator_ids { + result.insert(id.clone(), n.as_queue().task_info.node_id); + } + } + result + } } impl Engine { @@ -426,7 +467,7 @@ impl Engine { ) -> Self { let assignments = assignments .into_iter() - .map(|a| ((a.operator_id.to_string(), a.operator_subtask as usize), a)) + .map(|a| ((a.node_id, a.subtask_idx as usize), a)) .collect(); Self { @@ -455,10 +496,10 @@ impl Engine { .node_weights() .map(|n| { ( - (n.id().to_string(), n.subtask_idx()), + (n.id(), n.subtask_idx()), TaskAssignment { - operator_id: n.id().to_string(), - operator_subtask: n.subtask_idx() as u64, + node_id: n.id(), + subtask_idx: n.subtask_idx() as u32, worker_id: worker_id.0, worker_addr: "locahost:0".to_string(), }, @@ -476,25 +517,11 @@ impl Engine { } } - pub async fn start(mut self, config: StreamConfig) -> (RunningEngine, Receiver) { + pub async fn start(mut self) -> RunningEngine { info!("Starting job {}", self.job_id); - let checkpoint_metadata = if let Some(epoch) = config.restore_epoch { - info!("Restoring checkpoint {} for job {}", epoch, self.job_id); - Some( - StateBackend::load_checkpoint_metadata(&self.job_id, epoch) - .await - .unwrap_or_else(|_| { - panic!("failed to load checkpoint metadata for epoch {}", epoch) - }), - ) - } else { - None - }; - let node_indexes: Vec<_> = self.program.graph.read().unwrap().node_indices().collect(); - let (control_tx, control_rx) = channel(128); let worker_id = self.worker_id; let mut senders = Senders::new(); @@ -505,8 +532,7 @@ impl Engine { for idx in node_indexes { futures.push(self.schedule_node( - &checkpoint_metadata, - &control_tx, + self.program.control_tx.as_ref().unwrap(), idx, ready.clone(), )); @@ -524,19 +550,17 @@ impl Engine { n.tx = None; } - ( - RunningEngine { - program: self.program, - assignments: self.assignments, - worker_id, - }, - control_rx, - ) + self.program.control_tx = None; + + RunningEngine { + program: self.program, + assignments: self.assignments, + worker_id, + } } async fn schedule_node( &self, - checkpoint_metadata: &Option, control_tx: &Sender, idx: NodeIndex, ready: Arc, @@ -548,37 +572,29 @@ impl Engine { .unwrap() .node_weight_mut(idx) .unwrap() - .take_subtask(self.job_id.clone()); + .take_subtask(); let assignment = &self .assignments - .get(&(node.id.clone().to_string(), node.subtask_idx)) + .get(&(node.node_id, node.subtask_idx)) .cloned() .unwrap_or_else(|| { panic!( "Could not find assignment for node {}-{}", - node.id.clone(), - node.subtask_idx + node.node_id, node.subtask_idx ) }); let mut senders = Senders::new(); if assignment.worker_id == self.worker_id.0 { - self.run_locally( - checkpoint_metadata, - control_tx, - idx, - node, - control_rx, - ready, - ) - .await; + self.run_locally(control_tx, idx, node, control_rx, ready) + .await; } else { self.connect_to_remote_task( &mut senders, idx, - node.id.clone(), + node.node_id, node.subtask_idx, assignment, ) @@ -592,7 +608,7 @@ impl Engine { &self, senders: &mut Senders, idx: NodeIndex, - node_id: String, + node_id: u32, node_subtask_idx: usize, assignment: &TaskAssignment, ) { @@ -656,7 +672,6 @@ impl Engine { pub async fn run_locally( &self, - checkpoint_metadata: &Option, control_tx: &Sender, idx: NodeIndex, node: SubtaskNode, @@ -667,7 +682,7 @@ impl Engine { "[{:?}] Scheduling {}-{}-{} ({}/{})", self.worker_id, node.node.name(), - node.id, + node.node_id, node.subtask_idx, node.subtask_idx + 1, node.parallelism @@ -692,7 +707,7 @@ impl Engine { let _local = { let target = graph.node_weight(edge.target()).unwrap(); self.assignments - .get(&(target.id().to_string(), target.subtask_idx())) + .get(&(target.id(), target.subtask_idx())) .unwrap() .worker_id == self.worker_id.0 @@ -708,41 +723,39 @@ impl Engine { graph.node_weight(idx).unwrap().as_queue().task_info.clone() }; - let operator_id = task_info.operator_id.clone(); let task_index = task_info.task_index; - let tables = node.node.tables(); let in_qs: Vec<_> = in_qs_map.into_values().flatten().collect(); - let ctx = ArrowContext::new( - task_info, - checkpoint_metadata.clone(), - control_rx, - control_tx.clone(), - in_qs.len(), - node.in_schemas, - node.out_schema, - node.projection, - out_qs_map - .into_values() - .map(|v| v.into_values().collect()) - .collect(), - tables, - ) - .await; + let out_qs = out_qs_map + .into_values() + .map(|v| v.into_values().collect()) + .collect(); let operator = Box::new(node.node); - let join_task = tokio::spawn(async move { - operator.start(ctx, in_qs, ready).await; - }); + let join_task = { + let control_tx = control_tx.clone(); + tokio::spawn(async move { + operator + .start( + control_tx.clone(), + control_rx, + in_qs, + out_qs, + node.out_schema, + ready, + ) + .await; + }) + }; let send_copy = control_tx.clone(); tokio::spawn(async move { if let Err(error) = join_task.await { send_copy .send(ControlResp::TaskFailed { - operator_id, - task_index, + node_id: node.node_id, + task_index: task_index as usize, error: error.to_string(), }) .await @@ -752,11 +765,104 @@ impl Engine { } } +#[allow(clippy::too_many_arguments)] +pub async fn construct_node( + chain: OperatorChain, + job_id: &str, + node_id: u32, + subtask_idx: u32, + parallelism: u32, + input_partitions: u32, + in_schemas: Vec, + out_schema: Option, + restore_from: Option<&CheckpointMetadata>, + control_tx: Sender, + registry: Arc, +) -> OperatorNode { + if chain.is_source() { + let (head, _) = chain.iter().next().unwrap(); + let ConstructedOperator::Source(operator) = + construct_operator(head.operator_name, &head.operator_config, registry) + else { + unreachable!(); + }; + + let task_info = Arc::new(TaskInfo { + job_id: job_id.to_string(), + node_id, + operator_name: head.operator_name.to_string(), + operator_id: head.operator_id.clone(), + task_index: subtask_idx, + parallelism, + key_range: range_for_server(subtask_idx as usize, parallelism as usize), + }); + + OperatorNode::Source(SourceNode { + context: OperatorContext::new( + task_info, + restore_from, + control_tx, + 1, + vec![], + out_schema, + operator.tables(), + ) + .await, + operator, + }) + } else { + let mut head = None; + let mut cur: Option<&mut ChainedOperator> = None; + let mut input_partitions = input_partitions as usize; + for (node, edge) in chain.iter() { + let ConstructedOperator::Operator(op) = + construct_operator(node.operator_name, &node.operator_config, registry.clone()) + else { + unreachable!("sources must be the first node in a chain"); + }; + + let ctx = OperatorContext::new( + Arc::new(TaskInfo { + job_id: job_id.to_string(), + node_id, + operator_name: node.operator_name.to_string(), + operator_id: node.operator_id.clone(), + task_index: subtask_idx, + parallelism, + key_range: range_for_server(subtask_idx as usize, parallelism as usize), + }), + restore_from, + control_tx.clone(), + input_partitions, + if let Some(cur) = &mut cur { + vec![cur.context.out_schema.clone().unwrap()] + } else { + in_schemas.clone() + }, + edge.cloned().or(out_schema.clone()), + op.tables(), + ) + .await; + + if cur.is_none() { + head = Some(ChainedOperator::new(op, ctx)); + cur = head.as_mut(); + input_partitions = 1; + } else { + cur.as_mut().unwrap().next = Some(Box::new(ChainedOperator::new(op, ctx))); + cur = Some(cur.unwrap().next.as_mut().unwrap().as_mut()); + } + } + + OperatorNode::Chained(head.unwrap()) + } +} + pub fn construct_operator( operator: OperatorName, - config: Vec, + config: &[u8], registry: Arc, -) -> OperatorNode { +) -> ConstructedOperator { let ctor: Box = match operator { OperatorName::ArrowValue => Box::new(ValueExecutionConstructor), OperatorName::ArrowKey => Box::new(KeyExecutionConstructor), @@ -770,7 +876,7 @@ pub fn construct_operator( OperatorName::InstantJoin => Box::new(InstantJoinConstructor), OperatorName::WindowFunction => Box::new(WindowFunctionConstructor), OperatorName::ConnectorSource | OperatorName::ConnectorSink => { - let op: api::ConnectorOp = prost::Message::decode(&mut config.as_slice()).unwrap(); + let op: api::ConnectorOp = prost::Message::decode(config).unwrap(); return connectors() .get(op.connector.as_str()) .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) diff --git a/crates/arroyo-worker/src/lib.rs b/crates/arroyo-worker/src/lib.rs index f7cb684eb..87c013663 100644 --- a/crates/arroyo-worker/src/lib.rs +++ b/crates/arroyo-worker/src/lib.rs @@ -1,7 +1,7 @@ // TODO: factor out complex types #![allow(clippy::type_complexity)] -use crate::engine::{Engine, Program, StreamConfig, SubtaskNode}; +use crate::engine::{Engine, Program, SubtaskNode}; use crate::network_manager::NetworkManager; use anyhow::Result; @@ -28,7 +28,7 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; use tokio::net::TcpListener; use tokio::select; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio_stream::wrappers::TcpListenerStream; use tonic::{Request, Response, Status}; use tracing::{debug, error, info, warn}; @@ -93,40 +93,42 @@ impl Debug for LogicalNode { struct EngineState { sources: Vec>, sinks: Vec>, - operator_controls: HashMap>>, // operator_id -> vec of control tx + operator_to_node: HashMap, + operator_controls: HashMap>>, // node_id -> vec of control tx shutdown_guard: ShutdownGuard, } pub struct LocalRunner { program: Program, + control_rx: Receiver, } impl LocalRunner { - pub fn new(program: Program) -> Self { - Self { program } + pub fn new(program: Program, control_rx: Receiver) -> Self { + Self { + program, + control_rx, + } } - pub async fn run(self) { + pub async fn run(mut self) { let name = format!("{}-0", self.program.name); let total_nodes = self.program.total_nodes(); let engine = Engine::for_local(self.program, name); - let (_running_engine, mut control_rx) = engine - .start(StreamConfig { - restore_epoch: None, - }) - .await; + + let _running_engine = engine.start().await; let mut finished_nodes = HashSet::new(); loop { - while let Some(control_message) = control_rx.recv().await { + while let Some(control_message) = self.control_rx.recv().await { debug!("received {:?}", control_message); if let ControlResp::TaskFinished { - operator_id, + node_id, task_index, } = control_message { - finished_nodes.insert((operator_id, task_index)); + finished_nodes.insert((node_id, task_index)); if finished_nodes.len() == total_nodes { return; } @@ -295,6 +297,7 @@ impl WorkerServer { worker_id: worker_id.0, time: to_micros(c.time), job_id: job_id.clone(), + node_id: c.node_id, operator_id: c.operator_id, subtask_index: c.subtask_index, epoch: c.checkpoint_epoch, @@ -308,6 +311,7 @@ impl WorkerServer { worker_id: worker_id.0, time: c.subtask_metadata.finish_time, job_id: job_id.clone(), + node_id: c.node_id, operator_id: c.operator_id, epoch: c.checkpoint_epoch, needs_commit: false, @@ -315,34 +319,35 @@ impl WorkerServer { } )).await.err() } - Some(ControlResp::TaskFinished { operator_id, task_index }) => { - info!(message = "Task finished", operator_id, task_index); + Some(ControlResp::TaskFinished { node_id, task_index }) => { + info!(message = "Task finished", node_id, task_index); controller.task_finished(Request::new( TaskFinishedReq { worker_id: worker_id.0, job_id: job_id.clone(), time: to_micros(SystemTime::now()), - operator_id: operator_id.to_string(), + node_id, operator_subtask: task_index as u64, } )).await.err() } - Some(ControlResp::TaskFailed { operator_id, task_index, error }) => { + Some(ControlResp::TaskFailed { node_id, task_index, error }) => { controller.task_failed(Request::new( TaskFailedReq { worker_id: worker_id.0, job_id: job_id.clone(), time: to_micros(SystemTime::now()), - operator_id: operator_id.to_string(), + node_id, operator_subtask: task_index as u64, error, } )).await.err() } - Some(ControlResp::Error { operator_id, task_index, message, details}) => { + Some(ControlResp::Error { node_id, operator_id, task_index, message, details}) => { controller.worker_error(Request::new( WorkerErrorReq { job_id: job_id.clone(), + node_id, operator_id, task_index: task_index as u32, message, @@ -350,13 +355,13 @@ impl WorkerServer { } )).await.err() } - Some(ControlResp::TaskStarted {operator_id, task_index, start_time}) => { + Some(ControlResp::TaskStarted {node_id, task_index, start_time}) => { controller.task_started(Request::new( TaskStartedReq { worker_id: worker_id.0, job_id: job_id.clone(), time: to_micros(start_time), - operator_id: operator_id.to_string(), + node_id, operator_subtask: task_index as u64, } )).await.err() @@ -412,9 +417,12 @@ impl WorkerGrpc for WorkerServer { let logical = LogicalProgram::try_from(req.program.expect("Program is None")) .expect("Failed to create LogicalProgram"); - if let Ok(v) = to_d2(&logical).await { - debug!("Starting execution for graph\n{}", v); - } + debug!( + "Starting execution for graph\n{}", + to_d2(&logical) + .await + .unwrap_or_else(|e| format!("Failed to generate pipeline visualization: {e}")) + ); for (udf_name, dylib_config) in &logical.program_config.udf_dylibs { info!("Loading UDF {}", udf_name); @@ -441,11 +449,21 @@ impl WorkerGrpc for WorkerServer { })?; } - let (engine, control_rx) = { + let (control_tx, control_rx) = channel(128); + + let engine = { let network = { self.network.lock().unwrap().take().unwrap() }; - let program = - Program::from_logical(self.name.to_string(), &logical.graph, &req.tasks, registry); + let program = Program::from_logical( + self.name.to_string(), + &self.job_id, + &logical.graph, + &req.tasks, + registry, + req.restore_epoch, + control_tx.clone(), + ) + .await; let engine = Engine::new( program, @@ -455,11 +473,7 @@ impl WorkerGrpc for WorkerServer { network, req.tasks, ); - engine - .start(StreamConfig { - restore_epoch: req.restore_epoch, - }) - .await + engine.start().await }; self.shutdown_guard @@ -469,11 +483,13 @@ impl WorkerGrpc for WorkerServer { let sources = engine.source_controls(); let sinks = engine.sink_controls(); let operator_controls = engine.operator_controls(); + let operator_to_node = engine.operator_to_node(); let mut state = self.state.lock().unwrap(); *state = Some(EngineState { sources, sinks, + operator_to_node, operator_controls, shutdown_guard: self.shutdown_guard.child("engine-state"), }); @@ -542,6 +558,7 @@ impl WorkerGrpc for WorkerServer { async fn commit(&self, request: Request) -> Result, Status> { let req = request.into_inner(); debug!("received commit request {:?}", req); + let sender_commit_map_pairs = { let state_mutex = self.state.lock().unwrap(); let Some(state) = state_mutex.as_ref() else { @@ -549,9 +566,13 @@ impl WorkerGrpc for WorkerServer { "Worker has not yet started execution", )); }; + let mut sender_commit_map_pairs = vec![]; for (operator_id, commit_operator) in req.committing_data { - let nodes = state.operator_controls.get(&operator_id).unwrap().clone(); + let node_id = state.operator_to_node.get(&operator_id).unwrap_or_else(|| { + panic!("Could not find node for operator id {}", operator_id) + }); + let nodes = state.operator_controls.get(node_id).unwrap().clone(); let commit_map: HashMap<_, _> = commit_operator .committing_data .into_iter() @@ -561,6 +582,7 @@ impl WorkerGrpc for WorkerServer { } sender_commit_map_pairs }; + for (senders, commit_map) in sender_commit_map_pairs { for sender in senders { sender @@ -584,7 +606,7 @@ impl WorkerGrpc for WorkerServer { let nodes = { let state = self.state.lock().unwrap(); let s = state.as_ref().unwrap(); - s.operator_controls.get(&req.operator_id).unwrap().clone() + s.operator_controls.get(&req.node_id).unwrap().clone() }; let compacted: CompactionResult = req.into(); @@ -603,7 +625,7 @@ impl WorkerGrpc for WorkerServer { } } - return Ok(Response::new(LoadCompactedDataRes {})); + Ok(Response::new(LoadCompactedDataRes {})) } async fn stop_execution( diff --git a/crates/arroyo-worker/src/utils.rs b/crates/arroyo-worker/src/utils.rs index 90385cfa2..357400553 100644 --- a/crates/arroyo-worker/src/utils.rs +++ b/crates/arroyo-worker/src/utils.rs @@ -1,7 +1,8 @@ use crate::engine::construct_operator; use arrow_schema::Schema; -use arroyo_datastream::logical::LogicalProgram; +use arroyo_datastream::logical::{ChainedLogicalOperator, LogicalEdgeType, LogicalProgram}; use arroyo_df::physical::new_registry; +use arroyo_operator::operator::Registry; use std::fmt::Write; use std::sync::Arc; @@ -13,6 +14,68 @@ fn format_arrow_schema_fields(schema: &Schema) -> Vec<(String, String)> { .collect() } +fn write_op(d2: &mut String, registry: &Arc, idx: usize, el: &ChainedLogicalOperator) { + let operator = construct_operator(el.operator_name, &el.operator_config, registry.clone()); + let display = operator.display(); + + let mut label = format!("### {} ({})", operator.name(), &display.name); + for (field, value) in display.fields { + label.push_str(&format!("\n## {}\n\n{}", field, value)); + } + + writeln!( + d2, + "{}: {{ + label: |markdown +{} + | + shape: rectangle +}}", + idx, label + ) + .unwrap(); +} + +fn write_edge( + d2: &mut String, + from: usize, + to: usize, + edge_idx: usize, + edge_type: &LogicalEdgeType, + schema: &Schema, +) { + let edge_label = format!("{}", edge_type); + + let schema_node_name = format!("schema_{}", edge_idx); + let schema_fields = format_arrow_schema_fields(schema); + + writeln!(d2, "{}: {{", schema_node_name).unwrap(); + writeln!(d2, " shape: sql_table").unwrap(); + + for (field_name, field_type) in schema_fields { + writeln!( + d2, + " \"{}\": \"{}\"", + field_name.replace("\"", "\\\""), + field_type.replace("\"", "\\\"") + ) + .unwrap(); + } + + writeln!(d2, "}}").unwrap(); + + writeln!( + d2, + "{} -> {}: \"{}\"", + from, + schema_node_name, + edge_label.replace("\"", "\\\"") + ) + .unwrap(); + + writeln!(d2, "{} -> {}", schema_node_name, to).unwrap(); +} + pub async fn to_d2(logical: &LogicalProgram) -> anyhow::Result { let mut registry = new_registry(); @@ -30,66 +93,40 @@ pub async fn to_d2(logical: &LogicalProgram) -> anyhow::Result { for idx in logical.graph.node_indices() { let node = logical.graph.node_weight(idx).unwrap(); - let operator = construct_operator( - node.operator_name, - node.operator_config.clone(), - registry.clone(), - ); - let display = operator.display(); - let mut label = format!("### {} ({})", operator.name(), &display.name); - for (field, value) in display.fields { - label.push_str(&format!("\n## {}\n\n{}", field, value)); + if node.operator_chain.len() == 1 { + let el = node.operator_chain.first(); + write_op(&mut d2, ®istry, idx.index(), el); + } else { + writeln!(d2, "{}: {{", idx.index()).unwrap(); + for (i, (el, edge)) in node.operator_chain.iter().enumerate() { + write_op(&mut d2, ®istry, i, el); + if let Some(edge) = edge { + write_edge( + &mut d2, + i, + i + 1, + i, + &LogicalEdgeType::Forward, + &edge.schema, + ); + } + } + writeln!(d2, "}}").unwrap(); } - - writeln!( - &mut d2, - "{}: {{ - label: |markdown -{} - | - shape: rectangle -}}", - idx.index(), - label - ) - .unwrap(); } for idx in logical.graph.edge_indices() { let edge = logical.graph.edge_weight(idx).unwrap(); let (from, to) = logical.graph.edge_endpoints(idx).unwrap(); - - let edge_label = format!("{}", edge.edge_type); - - let schema_node_name = format!("schema_{}", idx.index()); - let schema_fields = format_arrow_schema_fields(&edge.schema.schema); - - writeln!(&mut d2, "{}: {{", schema_node_name).unwrap(); - writeln!(&mut d2, " shape: sql_table").unwrap(); - - for (field_name, field_type) in schema_fields { - writeln!( - &mut d2, - " \"{}\": \"{}\"", - field_name.replace("\"", "\\\""), - field_type.replace("\"", "\\\"") - ) - .unwrap(); - } - - writeln!(&mut d2, "}}").unwrap(); - - writeln!( + write_edge( &mut d2, - "{} -> {}: \"{}\"", from.index(), - schema_node_name, - edge_label.replace("\"", "\\\"") - ) - .unwrap(); - - writeln!(&mut d2, "{} -> {}", schema_node_name, to.index()).unwrap(); + to.index(), + idx.index(), + &edge.edge_type, + &edge.schema.schema, + ); } Ok(d2) diff --git a/crates/integ/tests/api_tests.rs b/crates/integ/tests/api_tests.rs index b5391fefb..2700f5fa9 100644 --- a/crates/integ/tests/api_tests.rs +++ b/crates/integ/tests/api_tests.rs @@ -230,6 +230,16 @@ async fn basic_pipeline() { let (pipeline_id, job_id, _) = start_and_monitor(test_id, &query, &[], 10).await.unwrap(); + let sink_id = valid + .graph + .as_ref() + .unwrap() + .nodes + .iter() + .find(|n| n.description.contains("sink")) + .unwrap() + .node_id; + // get error messages let errors = api_client .get_job_errors() @@ -254,7 +264,7 @@ async fn basic_pipeline() { && metrics .data .iter() - .filter(|op| !op.operator_id.contains("sink")) + .filter(|op| !op.node_id == sink_id) .map(|op| { op.metric_groups .iter() diff --git a/webui/package.json b/webui/package.json index efec76f4f..b8382e68a 100644 --- a/webui/package.json +++ b/webui/package.json @@ -9,7 +9,7 @@ "preview": "vite preview", "format": "npx prettier --write src/ && npx eslint --fix --ext .js,.jsx,.ts,.tsx src", "check": "npx prettier --check src/ && npx eslint --ext .js,.jsx,.ts,.tsx src", - "openapi": "cargo build --package arroyo-openapi && npx openapi-typescript $(pwd)/../target/api-spec.json --output $(pwd)/src/gen/api-types.ts" + "openapi": "cargo build --package arroyo-openapi && pnpm exec openapi-typescript $(pwd)/../target/api-spec.json --output $(pwd)/src/gen/api-types.ts" }, "dependencies": { "@babel/core": "^7.26.0", @@ -47,7 +47,7 @@ "monaco-editor": "^0.34.1", "monaco-sql-languages": "^0.9.5", "openapi-fetch": "^0.6.2", - "openapi-typescript": "^6.7.6", + "openapi-typescript": "=6.2.8", "prop-types": "^15.8.1", "react": "^18.3.1", "react-dom": "^18.3.1", diff --git a/webui/pnpm-lock.yaml b/webui/pnpm-lock.yaml index df8c069f7..6d399b5aa 100644 --- a/webui/pnpm-lock.yaml +++ b/webui/pnpm-lock.yaml @@ -114,8 +114,8 @@ importers: specifier: ^0.6.2 version: 0.6.2 openapi-typescript: - specifier: ^6.7.6 - version: 6.7.6 + specifier: '=6.2.8' + version: 6.2.8 prop-types: specifier: ^15.8.1 version: 15.8.1 @@ -2830,8 +2830,8 @@ packages: openapi-fetch@0.6.2: resolution: {integrity: sha512-Faj29Kzh7oCbt1bz6vAGNKtRJlV/GolOQTx87eYUnfCK7eVXdN9jQVojroc7tcJ5OQgyhbeOqD7LS/8UtGBnMQ==} - openapi-typescript@6.7.6: - resolution: {integrity: sha512-c/hfooPx+RBIOPM09GSxABOZhYPblDoyaGhqBkD/59vtpN21jEuWKDlM0KYTvqJVlSYjKs0tBcIdeXKChlSPtw==} + openapi-typescript@6.2.8: + resolution: {integrity: sha512-yA+y5MHiu6cjmtsGfNLavzVuvGCKzjL3H+exgHDPK6bnp6ZVFibtAiafenNSRDWL0x+7Sw/VPv5SbaqiPLW46w==} hasBin: true optionator@0.9.4: @@ -6712,7 +6712,7 @@ snapshots: openapi-fetch@0.6.2: {} - openapi-typescript@6.7.6: + openapi-typescript@6.2.8: dependencies: ansi-colors: 4.1.3 fast-glob: 3.3.2 diff --git a/webui/post b/webui/post deleted file mode 100644 index a185aa4db..000000000 --- a/webui/post +++ /dev/null @@ -1,6 +0,0 @@ -Projection: COUNT(UInt8(1)) AS count, SUM(price) AS price_sum, AVG(price) AS avg_price, MIN(price) AS min_price, MAX(price) AS max_price, SUM(price / price) AS has_price, SUM(price * auction), MIN(price - auction), MAX(price % auction), AVG(price >> auction) - Aggregate: groupBy=[[hop(IntervalDayTime("2000"), IntervalDayTime("10000"))]], aggr=[[COUNT(UInt8(1)), SUM(price), AVG(price), MIN(price), MAX(price), SUM(price / price), SUM(price * auction), MIN(price - auction), MAX(price % auction), AVG(price >> auction)]] - Projection: auction, price - Projection: (nexmark_thousand.bid)[auction] AS auction, (nexmark_thousand.bid)[price] AS price - TableScan: nexmark_thousand - diff --git a/webui/pre b/webui/pre deleted file mode 100644 index a185aa4db..000000000 --- a/webui/pre +++ /dev/null @@ -1,6 +0,0 @@ -Projection: COUNT(UInt8(1)) AS count, SUM(price) AS price_sum, AVG(price) AS avg_price, MIN(price) AS min_price, MAX(price) AS max_price, SUM(price / price) AS has_price, SUM(price * auction), MIN(price - auction), MAX(price % auction), AVG(price >> auction) - Aggregate: groupBy=[[hop(IntervalDayTime("2000"), IntervalDayTime("10000"))]], aggr=[[COUNT(UInt8(1)), SUM(price), AVG(price), MIN(price), MAX(price), SUM(price / price), SUM(price * auction), MIN(price - auction), MAX(price % auction), AVG(price >> auction)]] - Projection: auction, price - Projection: (nexmark_thousand.bid)[auction] AS auction, (nexmark_thousand.bid)[price] AS price - TableScan: nexmark_thousand - diff --git a/webui/src/components/OperatorDetail.tsx b/webui/src/components/OperatorDetail.tsx index 8e2cd7052..c288848e8 100644 --- a/webui/src/components/OperatorDetail.tsx +++ b/webui/src/components/OperatorDetail.tsx @@ -18,10 +18,10 @@ import { components } from '../gen/api-types'; export interface OperatorDetailProps { pipelineId: string; jobId: string; - operatorId: string; + nodeId: number; } -const OperatorDetail: React.FC = ({ pipelineId, jobId, operatorId }) => { +const OperatorDetail: React.FC = ({ pipelineId, jobId, nodeId }) => { const { pipeline } = usePipeline(pipelineId); const { operatorMetricGroups, operatorMetricGroupsLoading, operatorMetricGroupsError } = useJobMetrics(pipelineId, jobId); @@ -39,8 +39,8 @@ const OperatorDetail: React.FC = ({ pipelineId, jobId, oper return ; } - const node = pipeline.graph.nodes.find(n => n.nodeId == operatorId); - const operatorMetricGroup = operatorMetricGroups.find(o => o.operatorId == operatorId); + const node = pipeline.graph.nodes.find(n => n.nodeId == nodeId); + const operatorMetricGroup = operatorMetricGroups.find(o => o.nodeId == nodeId); if (!operatorMetricGroup) { return ; @@ -107,7 +107,20 @@ const OperatorDetail: React.FC = ({ pipelineId, jobId, oper Backpressure: {backpressureBadge} - {node?.operator} + + + {node?.nodeId} + + {node?.description} + {Math.round(msgRecv)} eps rx {Math.round(msgSent)} eps tx diff --git a/webui/src/gen/api-types.ts b/webui/src/gen/api-types.ts index 114cc2622..b7eb96f3e 100644 --- a/webui/src/gen/api-types.ts +++ b/webui/src/gen/api-types.ts @@ -272,7 +272,7 @@ export interface components { createdAt: number; definition: string; description?: string | null; - dylibUrl: string; + dylibUrl?: string | null; id: string; language: components["schemas"]["UdfLanguage"]; name: string; @@ -355,7 +355,8 @@ export interface components { }; OperatorMetricGroup: { metricGroups: (components["schemas"]["MetricGroup"])[]; - operatorId: string; + /** Format: int32 */ + nodeId: number; }; OperatorMetricGroupCollection: { data: (components["schemas"]["OperatorMetricGroup"])[]; @@ -396,10 +397,12 @@ export interface components { hasMore: boolean; }; PipelineEdge: { - destId: string; + /** Format: int32 */ + destId: number; edgeType: string; keyType: string; - srcId: string; + /** Format: int32 */ + srcId: number; valueType: string; }; PipelineGraph: { @@ -408,7 +411,8 @@ export interface components { }; PipelineNode: { description: string; - nodeId: string; + /** Format: int32 */ + nodeId: number; operator: string; /** Format: int32 */ parallelism: number; @@ -469,6 +473,7 @@ export interface components { SourceField: { fieldName: string; fieldType: components["schemas"]["SourceFieldType"]; + metadataKey?: string | null; nullable: boolean; }; SourceFieldType: { diff --git a/webui/src/routes/connections/ConfluentSchemaEditor.tsx b/webui/src/routes/connections/ConfluentSchemaEditor.tsx index 70f751f8e..23ff528f9 100644 --- a/webui/src/routes/connections/ConfluentSchemaEditor.tsx +++ b/webui/src/routes/connections/ConfluentSchemaEditor.tsx @@ -22,6 +22,7 @@ export function ConfluentSchemaEditor({ }) { let formatEl = null; + // @ts-ignore if (state.schema!.format!.protobuf !== undefined) { formatEl = ( diff --git a/webui/src/routes/pipelines/PipelineDetails.tsx b/webui/src/routes/pipelines/PipelineDetails.tsx index 470293af2..368e47615 100644 --- a/webui/src/routes/pipelines/PipelineDetails.tsx +++ b/webui/src/routes/pipelines/PipelineDetails.tsx @@ -53,7 +53,7 @@ import { vs2015 } from 'react-syntax-highlighter/dist/esm/styles/hljs'; import { useNavbar } from '../../App'; export function PipelineDetails() { - const [activeOperator, setActiveOperator] = useState(undefined); + const [activeOperator, setActiveOperator] = useState(undefined); const { isOpen: configModalOpen, onOpen: onConfigModalOpen, @@ -103,9 +103,9 @@ export function PipelineDetails() { } let operatorDetail = undefined; - if (activeOperator) { + if (activeOperator != undefined) { operatorDetail = ( - + ); } diff --git a/webui/src/routes/pipelines/PipelineGraph.tsx b/webui/src/routes/pipelines/PipelineGraph.tsx index aa12cd408..a566b5332 100644 --- a/webui/src/routes/pipelines/PipelineGraph.tsx +++ b/webui/src/routes/pipelines/PipelineGraph.tsx @@ -10,7 +10,7 @@ function PipelineGraphNode({ }: { data: { node: PipelineNode; - setActiveOperator: (op: string) => void; + setActiveOperator: (op: number) => void; isActive: boolean; operatorBackpressure: number; }; @@ -47,15 +47,15 @@ export function PipelineGraphViewer({ }: { graph: PipelineGraph; operatorMetricGroups?: OperatorMetricGroup[]; - setActiveOperator: (op: string) => void; - activeOperator?: string; + setActiveOperator: (node: number) => void; + activeOperator?: number; }) { const nodeTypes = useMemo(() => ({ pipelineNode: PipelineGraphNode }), []); const nodes = graph.nodes.map(node => { let backpressure = 0; if (operatorMetricGroups && operatorMetricGroups.length > 0) { - const operatorMetricGroup = operatorMetricGroups.find(o => o.operatorId == node.nodeId); + const operatorMetricGroup = operatorMetricGroups.find(o => o.nodeId == node.nodeId); if (operatorMetricGroup) { const metricGroups = operatorMetricGroup.metricGroups; const backpressureMetrics = metricGroups.find(m => m.name == 'backpressure'); @@ -64,12 +64,15 @@ export function PipelineGraphViewer({ } return { - id: node.nodeId, + id: String(node.nodeId), type: 'pipelineNode', data: { label: node.description, node: node, - setActiveOperator: setActiveOperator, + setActiveOperator: () => { + console.log(node); + return setActiveOperator(node.nodeId); + }, isActive: node.nodeId == activeOperator, operatorBackpressure: backpressure, }, @@ -87,8 +90,8 @@ export function PipelineGraphViewer({ const edges = graph.edges.map(edge => { return { id: `${edge.srcId}-${edge.destId}`, - source: edge.srcId, - target: edge.destId, + source: String(edge.srcId), + target: String(edge.destId), type: 'step', }; }); @@ -99,8 +102,8 @@ export function PipelineGraphViewer({ return {}; }); - nodes.forEach(node => g.setNode(node.id, node)); - edges.forEach(edge => g.setEdge(edge.source, edge.target)); + nodes.forEach(node => g.setNode(String(node.id), node)); + edges.forEach(edge => g.setEdge(String(edge.source), String(edge.target))); dagre.layout(g);