Skip to content

Commit

Permalink
Change LoadBalancingStrategy trait to always accept the list of serve…
Browse files Browse the repository at this point in the history
…rs from the outside
  • Loading branch information
madchicken committed Jan 3, 2025
1 parent 73b9591 commit 7867098
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 25 deletions.
13 changes: 11 additions & 2 deletions lib/src/routing/connection_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ impl ConnectionRegistry {
pub fn mark_unavailable(&self, server: &BoltServer) {
self.connections.remove(server);
}

pub fn servers(&self) -> Vec<BoltServer> {
self.connections.iter().map(|entry| entry.key().clone()).collect()
}
}

#[cfg(test)]
Expand Down Expand Up @@ -176,17 +180,22 @@ mod tests {
.unwrap();
assert_eq!(registry.connections.len(), 5);
let strategy = RoundRobinStrategy::new(&cluster_routing_table.resolve());
let router = strategy.select_router().unwrap();
let router = strategy.select_router(&registry.servers()).unwrap();
assert_eq!(
format!("{}:{}", router.address, router.port),
routers[0].addresses[0]
);
registry.mark_unavailable(BoltServer::resolve(&writers[0]).first().unwrap());
assert_eq!(registry.connections.len(), 4);
let writer = strategy.select_writer().unwrap();
let writer = strategy.select_writer(&registry.servers()).unwrap();
assert_eq!(
format!("{}:{}", writer.address, writer.port),
writers[1].addresses[0]
);

registry.mark_unavailable(BoltServer::resolve(&writers[1]).first().unwrap());
assert_eq!(registry.connections.len(), 3);
let writer = strategy.select_writer(&registry.servers());
assert!(writer.is_none());
}
}
6 changes: 3 additions & 3 deletions lib/src/routing/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub(crate) mod round_robin_strategy;
use crate::routing::connection_registry::BoltServer;

pub trait LoadBalancingStrategy: Sync + Send {
fn select_reader(&self) -> Option<BoltServer>;
fn select_writer(&self) -> Option<BoltServer>;
fn select_router(&self) -> Option<BoltServer>;
fn select_reader(&self, servers: &[BoltServer]) -> Option<BoltServer>;
fn select_writer(&self, servers: &[BoltServer]) -> Option<BoltServer>;
fn select_router(&self, servers: &[BoltServer]) -> Option<BoltServer>;
}
41 changes: 25 additions & 16 deletions lib/src/routing/load_balancing/round_robin_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ use crate::routing::load_balancing::LoadBalancingStrategy;
use std::sync::atomic::AtomicUsize;

pub struct RoundRobinStrategy {
readers: Vec<BoltServer>,
writers: Vec<BoltServer>,
routers: Vec<BoltServer>,
reader_index: AtomicUsize,
writer_index: AtomicUsize,
router_index: AtomicUsize,
Expand All @@ -32,9 +29,6 @@ impl RoundRobinStrategy {
let writer_index = AtomicUsize::new(writers.len());
let router_index = AtomicUsize::new(routers.len());
RoundRobinStrategy {
readers,
writers,
routers,
reader_index,
writer_index,
router_index,
Expand Down Expand Up @@ -66,16 +60,31 @@ impl RoundRobinStrategy {
}

impl LoadBalancingStrategy for RoundRobinStrategy {
fn select_reader(&self) -> Option<BoltServer> {
Self::select(&self.readers, &self.reader_index)
fn select_reader(&self, servers: &[BoltServer]) -> Option<BoltServer> {
let readers = servers
.iter()
.filter(|s| s.role == "READ")
.cloned()
.collect::<Vec<BoltServer>>();
Self::select(&readers, &self.reader_index)
}

fn select_writer(&self) -> Option<BoltServer> {
Self::select(&self.writers, &self.writer_index)
fn select_writer(&self, servers: &[BoltServer]) -> Option<BoltServer> {
let writers = servers
.iter()
.filter(|s| s.role == "WRITE")
.cloned()
.collect::<Vec<BoltServer>>();
Self::select(&writers, &self.writer_index)
}

fn select_router(&self) -> Option<BoltServer> {
Self::select(&self.routers, &self.router_index)
fn select_router(&self, servers: &[BoltServer]) -> Option<BoltServer> {
let routers = servers
.iter()
.filter(|s| s.role == "ROUTE")
.cloned()
.collect::<Vec<BoltServer>>();
Self::select(&routers, &self.router_index)
}
}

Expand Down Expand Up @@ -116,22 +125,22 @@ mod tests {
assert_eq!(all_servers.len(), 4);
let strategy = RoundRobinStrategy::new(&cluster_routing_table.resolve());

let reader = strategy.select_reader().unwrap();
let reader = strategy.select_reader(&all_servers).unwrap();
assert_eq!(
format!("{}:{}", reader.address, reader.port),
readers[0].addresses[1]
);
let reader = strategy.select_reader().unwrap();
let reader = strategy.select_reader(&all_servers).unwrap();
assert_eq!(
format!("{}:{}", reader.address, reader.port),
readers[0].addresses[0]
);
let reader = strategy.select_reader().unwrap();
let reader = strategy.select_reader(&all_servers).unwrap();
assert_eq!(
format!("{}:{}", reader.address, reader.port),
readers[0].addresses[1]
);
let writer = strategy.select_writer().unwrap();
let writer = strategy.select_writer(&all_servers).unwrap();
assert_eq!(
format!("{}:{}", writer.address, writer.port),
writers[0].addresses[0]
Expand Down
20 changes: 16 additions & 4 deletions lib/src/routing/routed_connection_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::pool::ManagedConnection;
use crate::routing::connection_registry::ConnectionRegistry;
use crate::routing::connection_registry::{BoltServer, ConnectionRegistry};
use crate::routing::load_balancing::LoadBalancingStrategy;
use crate::{Config, Error, Operation};
use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
Expand Down Expand Up @@ -48,7 +48,7 @@ impl RoutedConnectionManager {
}

pub async fn refresh_routing_table(&self) -> Result<RoutingTable, Error> {
while let Some(router) = self.load_balancing_strategy.select_router() {
while let Some(router) = self.select_router() {
if let Some(pool) = self.registry.get_pool(&router) {
if let Ok(mut connection) = pool.get().await {
info!("Refreshing routing table from router {}", router.address);
Expand Down Expand Up @@ -99,8 +99,8 @@ impl RoutedConnectionManager {

let op = operation.unwrap_or(Operation::Write);
while let Some(server) = match op {
Operation::Write => self.load_balancing_strategy.select_writer(),
_ => self.load_balancing_strategy.select_reader(),
Operation::Write => self.select_writer(),
_ => self.select_reader(),
} {
debug!("requesting connection for server: {:?}", server);
if let Some(pool) = self.registry.get_pool(&server) {
Expand Down Expand Up @@ -132,4 +132,16 @@ impl RoutedConnectionManager {
pub(crate) fn backoff(&self) -> ExponentialBackoff {
self.backoff.as_ref().clone()
}

fn select_reader(&self) -> Option<BoltServer> {
self.load_balancing_strategy.select_reader(&self.registry.servers())
}

fn select_writer(&self) -> Option<BoltServer> {
self.load_balancing_strategy.select_writer(&self.registry.servers())
}

fn select_router(&self) -> Option<BoltServer> {
self.load_balancing_strategy.select_router(&self.registry.servers())
}
}

0 comments on commit 7867098

Please sign in to comment.