From 15fb25e4e62fd6aa23e6dd8cef7fc65808bf29ca Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Thu, 11 Apr 2024 11:36:35 +0200 Subject: [PATCH 1/3] Update to Rust 2021, and auto-apply clippy lints --- Cargo.toml | 2 +- src/client.rs | 30 ++++++----- src/connection.rs | 34 ++++++------- src/lib.rs | 2 +- src/protocol/ascii.rs | 39 +++++++-------- tests/tests.rs | 113 +++++++++++++++++++----------------------- 6 files changed, 103 insertions(+), 117 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9556522..b80dfe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ readme = "README.md" license = "MIT" description = "memcached client for rust" keywords = ["memcache", "memcached", "driver", "cache", "database"] -edition = "2018" +edition = "2021" [features] default = ["tls"] diff --git a/src/client.rs b/src/client.rs index 38274d0..cd05913 100644 --- a/src/client.rs +++ b/src/client.rs @@ -102,7 +102,7 @@ impl Client { let parsed = Url::parse(url.as_str())?; let timeout = parsed .query_pairs() - .find(|&(ref k, ref _v)| k == "connect_timeout") + .find(|(k, _v)| k == "connect_timeout") .and_then(|(ref _k, ref v)| v.parse::().ok()) .map(Duration::from_secs_f64); let builder = r2d2::Pool::builder().max_size(size); @@ -256,7 +256,7 @@ impl Client { for key in keys { let connection_index = (self.hash_function)(key) as usize % connections_count; - let array = con_keys.entry(connection_index).or_insert_with(Vec::new); + let array = con_keys.entry(connection_index).or_default(); array.push(key); } for (&connection_index, keys) in con_keys.iter() { @@ -467,6 +467,12 @@ pub struct ClientBuilder { hash_function: fn(&str) -> u64, } +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + impl ClientBuilder { /// Create an empty client builder. pub fn new() -> Self { @@ -486,7 +492,7 @@ impl ClientBuilder { pub fn add_server(mut self, target: C) -> Result { let targets = target.get_urls(); - if targets.len() == 0 { + if targets.is_empty() { return Err(MemcacheError::BadURL("No servers specified".to_string())); } @@ -540,7 +546,7 @@ impl ClientBuilder { pub fn build(self) -> Result { let urls = self.targets; - if urls.len() == 0 { + if urls.is_empty() { return Err(MemcacheError::BadURL("No servers specified".to_string())); } @@ -572,7 +578,7 @@ impl ClientBuilder { let connection = builder .build(ConnectionManager::new(url)) - .map_err(|e| MemcacheError::PoolError(e))?; + .map_err(MemcacheError::PoolError)?; connections.push(connection); } @@ -600,7 +606,7 @@ mod tests { .unwrap() .build() .unwrap(); - assert!(client.version().unwrap()[0].1 != ""); + assert!(!client.version().unwrap()[0].1.is_empty()); } #[test] @@ -710,14 +716,14 @@ mod tests { #[test] fn unix() { let client = super::Client::connect("memcache:///tmp/memcached.sock").unwrap(); - assert!(client.version().unwrap()[0].1 != ""); + assert!(!client.version().unwrap()[0].1.is_empty()); } #[cfg(feature = "tls")] #[test] fn ssl_noverify() { let client = super::Client::connect("memcache+tls://localhost:12350?verify_mode=none").unwrap(); - assert!(client.version().unwrap()[0].1 != ""); + assert!(!client.version().unwrap()[0].1.is_empty()); } #[cfg(feature = "tls")] @@ -726,22 +732,22 @@ mod tests { let client = super::Client::connect("memcache+tls://localhost:12350?ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt") .unwrap(); - assert!(client.version().unwrap()[0].1 != ""); + assert!(!client.version().unwrap()[0].1.is_empty()); } #[cfg(feature = "tls")] #[test] fn ssl_client_certs() { let client = super::Client::connect("memcache+tls://localhost:12351?key_path=tests/assets/client.key&cert_path=tests/assets/client.crt&ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt").unwrap(); - assert!(client.version().unwrap()[0].1 != ""); + assert!(!client.version().unwrap()[0].1.is_empty()); } #[test] fn delete() { let client = super::Client::connect("memcache://localhost:12345").unwrap(); client.set("an_exists_key", "value", 0).unwrap(); - assert_eq!(client.delete("an_exists_key").unwrap(), true); - assert_eq!(client.delete("a_not_exists_key").unwrap(), false); + assert!(client.delete("an_exists_key").unwrap()); + assert!(!client.delete("a_not_exists_key").unwrap()); } #[test] diff --git a/src/connection.rs b/src/connection.rs index 200fe02..5f1dcba 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -96,16 +96,13 @@ struct TcpOptions { #[cfg(feature = "tls")] fn get_param(url: &Url, key: &str) -> Option { - return url - .query_pairs() - .find(|&(ref k, ref _v)| k == key) - .map(|(_k, v)| v.to_string()); + url.query_pairs().find(|(k, _v)| k == key).map(|(_k, v)| v.to_string()) } #[cfg(feature = "tls")] impl TlsOptions { fn from_url(url: &Url) -> Result { - let verify_mode = match get_param(url, "verify_mode").as_ref().map(String::as_str) { + let verify_mode = match get_param(url, "verify_mode").as_deref() { Some("none") => SslVerifyMode::NONE, Some("peer") => SslVerifyMode::PEER, Some(_) => { @@ -132,10 +129,10 @@ impl TlsOptions { Ok(TlsOptions { tcp_options: TcpOptions::from_url(url), - ca_path: ca_path, - key_path: key_path, - cert_path: cert_path, - verify_mode: verify_mode, + ca_path, + key_path, + cert_path, + verify_mode, }) } } @@ -147,21 +144,18 @@ impl TcpOptions { .any(|(ref k, ref v)| k == "tcp_nodelay" && v == "false"); let timeout = url .query_pairs() - .find(|&(ref k, ref _v)| k == "timeout") + .find(|(k, _v)| k == "timeout") .and_then(|(ref _k, ref v)| v.parse::().ok()) .map(Duration::from_secs_f64); - TcpOptions { - nodelay: nodelay, - timeout: timeout, - } + TcpOptions { nodelay, timeout } } } impl Transport { fn from_url(url: &Url) -> Result { - let mut parts = url.scheme().splitn(2, "+"); + let mut parts = url.scheme().splitn(2, '+'); match parts.next() { - Some(part) if part == "memcache" => (), + Some("memcache") => (), _ => { return Err(MemcacheError::BadURL( "memcache URL's scheme should start with 'memcache'".into(), @@ -191,7 +185,7 @@ impl Transport { #[cfg(unix)] { - if url.host().is_none() && url.port() == None { + if url.host().is_none() && url.port().is_none() { return Ok(Transport::Unix); } } @@ -254,12 +248,12 @@ impl Connection { let protocol = if is_ascii { Protocol::Ascii(AsciiProtocol::new(stream)) } else { - Protocol::Binary(BinaryProtocol { stream: stream }) + Protocol::Binary(BinaryProtocol { stream }) }; Ok(Connection { url: Arc::new(url.to_string()), - protocol: protocol, + protocol, }) } } @@ -273,7 +267,7 @@ mod tests { use url::Url; match Transport::from_url(&Url::parse("memcache:///tmp/memcached.sock").unwrap()).unwrap() { Transport::Unix => (), - _ => assert!(false, "transport is not unix"), + _ => panic!("transport is not unix"), } } } diff --git a/src/lib.rs b/src/lib.rs index 69e8473..fc69192 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,7 +63,7 @@ assert_eq!(answer, 42); ``` !*/ -#![cfg_attr(feature = "cargo-clippy", allow(clippy::needless_return))] +#![allow(clippy::needless_return)] extern crate byteorder; extern crate enum_dispatch; diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index f92d55d..ad84a8a 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -13,7 +13,6 @@ use std::borrow::Cow; pub struct Options { pub noreply: bool, pub exptime: u32, - pub flags: u32, pub cas: Option, } @@ -27,7 +26,7 @@ enum StoreCommand { Prepend, } -const END: &'static str = "END\r\n"; +const END: &str = "END\r\n"; impl fmt::Display for StoreCommand { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -45,15 +44,13 @@ impl fmt::Display for StoreCommand { struct CappedLineReader { inner: C, filled: usize, - buf: [u8; 2048], + buf: Vec, } fn get_line(buf: &[u8]) -> Option { for (i, r) in buf.iter().enumerate() { - if *r == b'\r' { - if buf.get(i + 1) == Some(&b'\n') { - return Some(i + 2); - } + if *r == b'\r' && buf.get(i + 1) == Some(&b'\n') { + return Some(i + 2); } } None @@ -64,7 +61,7 @@ impl CappedLineReader { Self { inner, filled: 0, - buf: [0x0; 2048], + buf: vec![0x0; 2048], } } @@ -77,7 +74,7 @@ impl CappedLineReader { let (to_fill, rest) = buf.split_at_mut(min); to_fill.copy_from_slice(&self.buf[..min]); self.consume(min); - if rest.len() != 0 { + if !rest.is_empty() { self.inner.read_exact(&mut rest[..])?; } Ok(()) @@ -98,7 +95,7 @@ impl CappedLineReader { } loop { let (_filled, buf) = self.buf.split_at_mut(self.filled); - if buf.len() == 0 { + if buf.is_empty() { return Err(ClientError::Error(Cow::Borrowed("Ascii protocol response too long")))?; } let read = self.inner.read(&mut buf[..])?; @@ -131,7 +128,7 @@ impl ProtocolTrait for AsciiProtocol { } fn version(&mut self) -> Result { - self.reader.get_mut().write(b"version\r\n")?; + self.reader.get_mut().write_all(b"version\r\n")?; self.reader.get_mut().flush()?; self.reader.read_line(|response| { let response = MemcacheError::try_from(response)?; @@ -294,7 +291,7 @@ impl ProtocolTrait for AsciiProtocol { } fn stats(&mut self) -> Result { - self.reader.get_mut().write(b"stats\r\n")?; + self.reader.get_mut().write_all(b"stats\r\n")?; self.reader.get_mut().flush()?; enum Loop { @@ -312,7 +309,7 @@ impl ProtocolTrait for AsciiProtocol { if !s.starts_with("STAT") { return Err(ServerError::BadResponse(Cow::Owned(s.into())))?; } - let stat: Vec<_> = s.trim_end_matches("\r\n").split(" ").collect(); + let stat: Vec<_> = s.trim_end_matches("\r\n").split(' ').collect(); if stat.len() < 3 { return Err(ServerError::BadResponse(Cow::Owned(s.into())).into()); } @@ -348,12 +345,10 @@ impl AsciiProtocol { value: V, options: &Options, ) -> Result { - if command == StoreCommand::Cas { - if options.cas.is_none() { - Err(ClientError::Error(Cow::Borrowed( - "cas_id should be present when using cas command", - )))?; - } + if command == StoreCommand::Cas && options.cas.is_none() { + Err(ClientError::Error(Cow::Borrowed( + "cas_id should be present when using cas command", + )))?; } let noreply = if options.noreply { " noreply" } else { "" }; if options.cas.is_some() { @@ -382,7 +377,7 @@ impl AsciiProtocol { } value.write_to(self.reader.get_mut())?; - self.reader.get_mut().write(b"\r\n")?; + self.reader.get_mut().write_all(b"\r\n")?; self.reader.get_mut().flush()?; if options.noreply { @@ -424,7 +419,7 @@ impl AsciiProtocol { if !buf.starts_with("VALUE") { return Err(ServerError::BadResponse(Cow::Owned(buf.into())))?; } - let mut header = buf.trim_end_matches("\r\n").split(" "); + let mut header = buf.trim_end_matches("\r\n").split(' '); let mut next_or_err = || { header .next() @@ -481,7 +476,7 @@ mod tests { match self.reads.pop_front() { Some(range) => { let range = &self.data[range]; - (&mut buf[0..range.len()]).copy_from_slice(&range); + buf[0..range.len()].copy_from_slice(range); Ok(range.len()) } None => Err(std::io::ErrorKind::WouldBlock.into()), diff --git a/tests/tests.rs b/tests/tests.rs index 947084b..bff9d07 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -5,7 +5,6 @@ use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use std::iter; use std::thread; -use std::thread::JoinHandle; use std::time; fn gen_random_key() -> String { @@ -13,7 +12,7 @@ fn gen_random_key() -> String { .map(|()| thread_rng().sample(Alphanumeric)) .take(10) .collect::>(); - return String::from_utf8(bs).unwrap(); + String::from_utf8(bs).unwrap() } #[test] @@ -112,7 +111,7 @@ fn udp_test() { client.set("foo", "bar", 0).unwrap(); let value = client.add("foo", "baz", 0); - assert_eq!(value.is_err(), true); + assert!(value.is_err()); client.delete("foo").unwrap(); let value: Option = client.get("foo").unwrap(); @@ -143,12 +142,12 @@ fn udp_test() { let value: Option = client.get("fooo").unwrap(); assert_eq!(value, Some(String::from("0"))); - assert_eq!(client.touch("foooo", 123).unwrap(), false); - assert_eq!(client.touch("fooo", 12345).unwrap(), true); + assert!(!client.touch("foooo", 123).unwrap()); + assert!(client.touch("fooo", 12345).unwrap()); // gets is not supported for udp let value: Result, _> = client.gets(&["foo", "fooo"]); - assert_eq!(value.is_ok(), false); + assert!(value.is_err()); let mut keys: Vec = Vec::new(); for _ in 0..1000 { @@ -164,46 +163,47 @@ fn udp_test() { } // test with multiple udp connections - let mut handles: Vec>> = Vec::new(); - for i in 0..10 { - handles.push(Some(thread::spawn(move || { - let key = format!("key{}", i); - let value = format!("value{}", i); - let client = memcache::Client::connect("memcache://localhost:22345?udp=true").unwrap(); - for j in 0..50 { - let value = format!("{}{}", value, j); - client.set(key.as_str(), &value, 0).unwrap(); - let result: Option = client.get(key.as_str()).unwrap(); - assert_eq!(result.as_ref(), Some(&value)); - - let result = client.add(key.as_str(), &value, 0); - assert_eq!(result.is_err(), true); - - client.delete(key.as_str()).unwrap(); - let result: Option = client.get(key.as_str()).unwrap(); - assert_eq!(result, None); - - client.add(key.as_str(), &value, 0).unwrap(); - let result: Option = client.get(key.as_str()).unwrap(); - assert_eq!(result.as_ref(), Some(&value)); - - client.replace(key.as_str(), &value, 0).unwrap(); - let result: Option = client.get(key.as_str()).unwrap(); - assert_eq!(result.as_ref(), Some(&value)); - - client.append(key.as_str(), &value).unwrap(); - let result: Option = client.get(key.as_str()).unwrap(); - assert_eq!(result, Some(format!("{}{}", value, value))); - - client.prepend(key.as_str(), &value).unwrap(); - let result: Option = client.get(key.as_str()).unwrap(); - assert_eq!(result, Some(format!("{}{}{}", value, value, value))); - } - }))); - } - - for i in 0..10 { - handles[i].take().unwrap().join().unwrap(); + let handles: Vec<_> = (0..10) + .map(|i| { + thread::spawn(move || { + let key = format!("key{}", i); + let value = format!("value{}", i); + let client = memcache::Client::connect("memcache://localhost:22345?udp=true").unwrap(); + for j in 0..50 { + let value = format!("{}{}", value, j); + client.set(key.as_str(), &value, 0).unwrap(); + let result: Option = client.get(key.as_str()).unwrap(); + assert_eq!(result.as_ref(), Some(&value)); + + let result = client.add(key.as_str(), &value, 0); + assert!(result.is_err()); + + client.delete(key.as_str()).unwrap(); + let result: Option = client.get(key.as_str()).unwrap(); + assert_eq!(result, None); + + client.add(key.as_str(), &value, 0).unwrap(); + let result: Option = client.get(key.as_str()).unwrap(); + assert_eq!(result.as_ref(), Some(&value)); + + client.replace(key.as_str(), &value, 0).unwrap(); + let result: Option = client.get(key.as_str()).unwrap(); + assert_eq!(result.as_ref(), Some(&value)); + + client.append(key.as_str(), &value).unwrap(); + let result: Option = client.get(key.as_str()).unwrap(); + assert_eq!(result, Some(format!("{}{}", value, value))); + + client.prepend(key.as_str(), &value).unwrap(); + let result: Option = client.get(key.as_str()).unwrap(); + assert_eq!(result, Some(format!("{}{}{}", value, value, value))); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); } } @@ -232,21 +232,12 @@ fn test_cas() { assert!(ascii_foo_value.2.is_some()); assert!(ascii_baz_value.2.is_some()); - assert_eq!( - true, - client.cas("ascii_foo", "bar2", 0, ascii_foo_value.2.unwrap()).unwrap() - ); - assert_eq!( - false, - client.cas("ascii_foo", "bar3", 0, ascii_foo_value.2.unwrap()).unwrap() - ); - - assert_eq!( - false, - client - .cas("not_exists_key", "bar", 0, ascii_foo_value.2.unwrap()) - .unwrap() - ); + assert!(client.cas("ascii_foo", "bar2", 0, ascii_foo_value.2.unwrap()).unwrap()); + assert!(!client.cas("ascii_foo", "bar3", 0, ascii_foo_value.2.unwrap()).unwrap()); + + assert!(!client + .cas("not_exists_key", "bar", 0, ascii_foo_value.2.unwrap()) + .unwrap()); client.flush().unwrap(); } } From 40baf2124521bdd2ac6de0161f3f23065746d202 Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Thu, 11 Apr 2024 11:56:03 +0200 Subject: [PATCH 2/3] Use more idiomatic Rust Avoids a bunch of `return` statements in favor of shorthands. Removes a useless `Arc` indirection, and avoids `Arc::clone` within `get_connection`. --- src/client.rs | 46 +++++++++++++++++++++++----------------------- src/connection.rs | 7 +++---- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/client.rs b/src/client.rs index cd05913..446b86d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -75,7 +75,7 @@ unsafe impl Send for Client {} fn default_hash_function(key: &str) -> u64 { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); - return hasher.finish(); + hasher.finish() } pub(crate) fn check_key_len(key: &str) -> Result<(), MemcacheError> { @@ -88,7 +88,7 @@ pub(crate) fn check_key_len(key: &str) -> Result<(), MemcacheError> { impl Client { #[deprecated(since = "0.10.0", note = "please use `connect` instead")] pub fn new(target: C) -> Result { - return Self::connect(target); + Self::connect(target) } pub fn builder() -> ClientBuilder { @@ -131,9 +131,9 @@ impl Client { Self::builder().add_server(target)?.build() } - fn get_connection(&self, key: &str) -> Pool { + fn get_connection(&self, key: &str) -> &Pool { let connections_count = self.connections.len(); - return self.connections[(self.hash_function)(key) as usize % connections_count].clone(); + &self.connections[(self.hash_function)(key) as usize % connections_count] } /// Set the socket read timeout for TCP connections. @@ -145,7 +145,7 @@ impl Client { /// client.set_read_timeout(Some(::std::time::Duration::from_secs(3))).unwrap(); /// ``` pub fn set_read_timeout(&self, timeout: Option) -> Result<(), MemcacheError> { - for conn in self.connections.iter() { + for conn in &self.connections { let mut conn = conn.get()?; match **conn { Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?, @@ -164,7 +164,7 @@ impl Client { /// client.set_write_timeout(Some(::std::time::Duration::from_secs(3))).unwrap(); /// ``` pub fn set_write_timeout(&self, timeout: Option) -> Result<(), MemcacheError> { - for conn in self.connections.iter() { + for conn in &self.connections { let mut conn = conn.get()?; match **conn { Protocol::Ascii(ref mut protocol) => protocol.stream().set_write_timeout(timeout)?, @@ -184,7 +184,7 @@ impl Client { /// ``` pub fn version(&self) -> Result, MemcacheError> { let mut result = Vec::with_capacity(self.connections.len()); - for connection in self.connections.iter() { + for connection in &self.connections { let mut connection = connection.get()?; let url = connection.get_url(); result.push((url, connection.version()?)); @@ -201,10 +201,10 @@ impl Client { /// client.flush().unwrap(); /// ``` pub fn flush(&self) -> Result<(), MemcacheError> { - for connection in self.connections.iter() { + for connection in &self.connections { connection.get()?.flush()?; } - return Ok(()); + Ok(()) } /// Flush all cache on memcached server with a delay seconds. @@ -216,10 +216,10 @@ impl Client { /// client.flush_with_delay(10).unwrap(); /// ``` pub fn flush_with_delay(&self, delay: u32) -> Result<(), MemcacheError> { - for connection in self.connections.iter() { + for connection in &self.connections { connection.get()?.flush_with_delay(delay)?; } - return Ok(()); + Ok(()) } /// Get a key from memcached server. @@ -232,7 +232,7 @@ impl Client { /// ``` pub fn get(&self, key: &str) -> Result, MemcacheError> { check_key_len(key)?; - return self.get_connection(key).get()?.get(key); + self.get_connection(key).get()?.get(key) } /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce network workloads. @@ -263,7 +263,7 @@ impl Client { let connection = self.connections[connection_index].clone(); result.extend(connection.get()?.gets(keys)?); } - return Ok(result); + Ok(result) } /// Set a key with associate value into memcached server with expiration seconds. @@ -277,7 +277,7 @@ impl Client { /// ``` pub fn set>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { check_key_len(key)?; - return self.get_connection(key).get()?.set(key, value, expiration); + self.get_connection(key).get()?.set(key, value, expiration) } /// Compare and swap a key with the associate value into memcached server with expiration seconds. @@ -319,7 +319,7 @@ impl Client { /// ``` pub fn add>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { check_key_len(key)?; - return self.get_connection(key).get()?.add(key, value, expiration); + self.get_connection(key).get()?.add(key, value, expiration) } /// Replace a key with associate value into memcached server with expiration seconds. @@ -340,7 +340,7 @@ impl Client { expiration: u32, ) -> Result<(), MemcacheError> { check_key_len(key)?; - return self.get_connection(key).get()?.replace(key, value, expiration); + self.get_connection(key).get()?.replace(key, value, expiration) } /// Append value to the key. @@ -358,7 +358,7 @@ impl Client { /// ``` pub fn append>(&self, key: &str, value: V) -> Result<(), MemcacheError> { check_key_len(key)?; - return self.get_connection(key).get()?.append(key, value); + self.get_connection(key).get()?.append(key, value) } /// Prepend value to the key. @@ -376,7 +376,7 @@ impl Client { /// ``` pub fn prepend>(&self, key: &str, value: V) -> Result<(), MemcacheError> { check_key_len(key)?; - return self.get_connection(key).get()?.prepend(key, value); + self.get_connection(key).get()?.prepend(key, value) } /// Delete a key from memcached server. @@ -390,7 +390,7 @@ impl Client { /// ``` pub fn delete(&self, key: &str) -> Result { check_key_len(key)?; - return self.get_connection(key).get()?.delete(key); + self.get_connection(key).get()?.delete(key) } /// Increment the value with amount. @@ -404,7 +404,7 @@ impl Client { /// ``` pub fn increment(&self, key: &str, amount: u64) -> Result { check_key_len(key)?; - return self.get_connection(key).get()?.increment(key, amount); + self.get_connection(key).get()?.increment(key, amount) } /// Decrement the value with amount. @@ -418,7 +418,7 @@ impl Client { /// ``` pub fn decrement(&self, key: &str, amount: u64) -> Result { check_key_len(key)?; - return self.get_connection(key).get()?.decrement(key, amount); + self.get_connection(key).get()?.decrement(key, amount) } /// Set a new expiration time for a exist key. @@ -434,7 +434,7 @@ impl Client { /// ``` pub fn touch(&self, key: &str, expiration: u32) -> Result { check_key_len(key)?; - return self.get_connection(key).get()?.touch(key, expiration); + self.get_connection(key).get()?.touch(key, expiration) } /// Get all servers' statistics. @@ -452,7 +452,7 @@ impl Client { let url = connection.get_url(); result.push((url, stats_info)); } - return Ok(result); + Ok(result) } } diff --git a/src/connection.rs b/src/connection.rs index 5f1dcba..b99b007 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -2,7 +2,6 @@ use std::net::TcpStream; use std::ops::{Deref, DerefMut}; #[cfg(unix)] use std::os::unix::net::UnixStream; -use std::sync::Arc; use std::time::Duration; use url::Url; @@ -18,7 +17,7 @@ use r2d2::ManageConnection; /// A connection to the memcached server pub struct Connection { pub protocol: Protocol, - pub url: Arc, + pub url: String, } impl DerefMut for Connection { @@ -206,7 +205,7 @@ fn tcp_stream(url: &Url, opts: &TcpOptions) -> Result impl Connection { pub(crate) fn get_url(&self) -> String { - self.url.to_string() + self.url.clone() } pub(crate) fn connect(url: &Url) -> Result { @@ -252,7 +251,7 @@ impl Connection { }; Ok(Connection { - url: Arc::new(url.to_string()), + url: url.to_string(), protocol, }) } From cb8fe1614cc40a0d44bfc04a1a0a3321e7f3f075 Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Thu, 11 Apr 2024 12:16:55 +0200 Subject: [PATCH 3/3] Avoid intermediate `HashMap` in `gets` This uses a `Vec` instead, plus iterator chaining, and avoids a `Clone` of the `Connection`. --- src/client.rs | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/client.rs b/src/client.rs index 446b86d..f8e63ed 100644 --- a/src/client.rs +++ b/src/client.rs @@ -136,6 +136,26 @@ impl Client { &self.connections[(self.hash_function)(key) as usize % connections_count] } + /// Distributes the input `keys` to the available `connections`. + /// + /// This uses the `hash_function` internally, and the returned [`Vec`] matches + /// the available `connections`. + fn distribute_keys<'a>(&self, keys: &[&'a str]) -> Result>, MemcacheError> { + for key in keys { + check_key_len(key)?; + } + + let connections_count = self.connections.len(); + let mut con_keys = Vec::new(); + con_keys.resize_with(connections_count, Vec::new); + for key in keys { + let connection_index = (self.hash_function)(key) as usize % connections_count; + con_keys[connection_index].push(*key); + } + + Ok(con_keys) + } + /// Set the socket read timeout for TCP connections. /// /// Example: @@ -247,21 +267,13 @@ impl Client { /// assert_eq!(result["foo"], "42"); /// ``` pub fn gets(&self, keys: &[&str]) -> Result, MemcacheError> { - for key in keys { - check_key_len(key)?; - } - let mut con_keys: HashMap> = HashMap::new(); - let mut result: HashMap = HashMap::new(); - let connections_count = self.connections.len(); + let distributed_keys = self.distribute_keys(keys)?; - for key in keys { - let connection_index = (self.hash_function)(key) as usize % connections_count; - let array = con_keys.entry(connection_index).or_default(); - array.push(key); - } - for (&connection_index, keys) in con_keys.iter() { - let connection = self.connections[connection_index].clone(); - result.extend(connection.get()?.gets(keys)?); + let mut result: HashMap = HashMap::new(); + for (connection, keys) in self.connections.iter().zip(distributed_keys) { + if !keys.is_empty() { + result.extend(connection.get()?.gets(&keys)?); + } } Ok(result) }