diff --git a/Cargo.toml b/Cargo.toml index 0d5c14d..97a9fee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Utility Library for Rust" repository = "https://github.com/tinrab/bomboni" @@ -38,9 +38,9 @@ tokio = ["bomboni_common/tokio"] tonic = ["bomboni_proto/tonic", "bomboni_request/tonic"] [dependencies] -bomboni_common = { path = "bomboni_common", version = "0.1.34" } +bomboni_common = { path = "bomboni_common", version = "0.1.39" } -bomboni_prost = { path = "bomboni_prost", version = "0.1.34", optional = true } -bomboni_proto = { path = "bomboni_proto", version = "0.1.34", optional = true } -bomboni_request = { path = "bomboni_request", version = "0.1.34", optional = true } -bomboni_template = { path = "bomboni_template", version = "0.1.34", optional = true } +bomboni_prost = { path = "bomboni_prost", version = "0.1.39", optional = true } +bomboni_proto = { path = "bomboni_proto", version = "0.1.39", optional = true } +bomboni_request = { path = "bomboni_request", version = "0.1.39", optional = true } +bomboni_template = { path = "bomboni_template", version = "0.1.39", optional = true } diff --git a/bomboni_common/Cargo.toml b/bomboni_common/Cargo.toml index b938928..0b593e3 100644 --- a/bomboni_common/Cargo.toml +++ b/bomboni_common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_common" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Common things for Bomboni library." repository = "https://github.com/tinrab/bomboni" diff --git a/bomboni_prost/Cargo.toml b/bomboni_prost/Cargo.toml index 883cd1e..c4d74c4 100644 --- a/bomboni_prost/Cargo.toml +++ b/bomboni_prost/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_prost" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Utilities for working with prost. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" diff --git a/bomboni_proto/Cargo.toml b/bomboni_proto/Cargo.toml index a65f6c0..90c02e0 100644 --- a/bomboni_proto/Cargo.toml +++ b/bomboni_proto/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_proto" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Utilities for working with Protobuf/gRPC. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" @@ -36,5 +36,5 @@ serde_json = { version = "1.0.108", optional = true } serde_json = "1.0.108" [build-dependencies] -bomboni_prost = { path = "../bomboni_prost", version = "0.1.34" } +bomboni_prost = { path = "../bomboni_prost", version = "0.1.39" } prost-build = "0.12.3" diff --git a/bomboni_request/Cargo.toml b/bomboni_request/Cargo.toml index 0716018..72160d9 100644 --- a/bomboni_request/Cargo.toml +++ b/bomboni_request/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_request" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Utilities for working with API requests. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" @@ -19,8 +19,8 @@ testing = [] tonic = ["bomboni_proto/tonic", "dep:tonic"] [dependencies] -bomboni_common = { path = "../bomboni_common", version = "0.1.34" } -bomboni_proto = { path = "../bomboni_proto", version = "0.1.34" } +bomboni_common = { path = "../bomboni_common", version = "0.1.39" } +bomboni_proto = { path = "../bomboni_proto", version = "0.1.39" } thiserror = "1.0.50" itertools = "0.12.0" time = { version = "0.3.30", features = ["formatting", "parsing"] } @@ -35,4 +35,8 @@ rand = "0.8.5" regex = "1.10.2" tonic = { version = "0.10.2", optional = true } -bomboni_request_derive = { path = "../bomboni_request_derive", version = "0.1.34", optional = true } +bomboni_request_derive = { path = "../bomboni_request_derive", version = "0.1.39", optional = true } + +[dev-dependencies] +serde = { version = "1.0.193", features = ["derive"] } +serde_json = "1.0.108" diff --git a/bomboni_request/src/error.rs b/bomboni_request/src/error.rs index ead9c44..7b31bd6 100644 --- a/bomboni_request/src/error.rs +++ b/bomboni_request/src/error.rs @@ -38,8 +38,6 @@ pub struct FieldError { #[derive(Error, Debug, Clone, PartialEq, Eq)] pub enum CommonError { - #[error(transparent)] - Query(#[from] QueryError), #[error("requested entity was not found")] ResourceNotFound, #[error("unauthorized")] @@ -193,7 +191,21 @@ impl RequestError { pub fn wrap_request(self, name: &str) -> Self { match self { Self::Field(error) => Self::bad_request(name, [(error.field, error.error)]), - err => err, + Self::Domain(error) => { + if let Some(error) = error.as_any().downcast_ref::() { + #[allow(trivial_casts)] + Self::bad_request( + name, + [( + error.get_violating_field_name(), + Box::new(error.clone()) as DomainErrorBox, + )], + ) + } else { + RequestError::Domain(error) + } + } + error => error, } } @@ -215,21 +227,39 @@ impl RequestError { error.error, )], ), - err => err, + Self::Domain(error) => { + if let Some(error) = error.as_any().downcast_ref::() { + #[allow(trivial_casts)] + Self::bad_request( + name, + [( + format!( + "{}.{}", + root_path.into_iter().map(|step| step.to_string()).join("."), + error.get_violating_field_name() + ), + Box::new(error.clone()) as DomainErrorBox, + )], + ) + } else { + RequestError::Domain(error) + } + } + error => error, } } pub fn downcast_domain_ref(&self) -> Option<&T> { - if let Self::Domain(err) = self { - err.as_any().downcast_ref::() + if let Self::Domain(error) = self { + error.as_any().downcast_ref::() } else { None } } pub fn downcast_domain(&self) -> Option { - if let Self::Domain(err) = self { - err.as_any().downcast_ref::().cloned() + if let Self::Domain(error) = self { + error.as_any().downcast_ref::().cloned() } else { None } @@ -369,4 +399,32 @@ mod tests { } ); } + + #[test] + fn query_error_metadata() { + assert_eq!( + serde_json::to_value(Status::from( + RequestError::from(QueryError::InvalidPageSize).wrap_request("List"), + )) + .unwrap(), + serde_json::from_str::( + r#"{ + "code": "INVALID_ARGUMENT", + "message": "invalid `List` request", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.BadRequest", + "fieldViolations": [ + { + "field": "page_size", + "description": "page size specified is invalid" + } + ] + } + ] + }"# + ) + .unwrap() + ); + } } diff --git a/bomboni_request/src/ordering/mod.rs b/bomboni_request/src/ordering/mod.rs index b7ac08e..dacb6d0 100644 --- a/bomboni_request/src/ordering/mod.rs +++ b/bomboni_request/src/ordering/mod.rs @@ -2,6 +2,7 @@ use std::{ cmp, collections::BTreeSet, fmt::{self, Display, Formatter}, + ops::{Deref, DerefMut}, }; use itertools::Itertools; @@ -15,9 +16,7 @@ use super::schema::SchemaMapped; pub mod error; #[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct Ordering { - pub terms: Vec, -} +pub struct Ordering(Vec); #[derive(Debug, Clone, PartialEq, Eq)] pub struct OrderingTerm { @@ -33,7 +32,7 @@ pub enum OrderingDirection { impl Ordering { pub fn new(terms: Vec) -> Self { - Self { terms } + Self(terms) } pub fn parse(source: &str) -> OrderingResult { @@ -65,14 +64,14 @@ impl Ordering { terms.push(OrderingTerm { name, direction }); } - Ok(Self { terms }) + Ok(Self(terms)) } pub fn evaluate(&self, lhs: &T, rhs: &T) -> Option where T: SchemaMapped, { - for term in &self.terms { + for term in self.iter() { let a = lhs.get_field(&term.name); let b = rhs.get_field(&term.name); match a.partial_cmp(&b)? { @@ -95,7 +94,7 @@ impl Ordering { } pub fn is_valid(&self, schema: &Schema) -> bool { - for term in &self.terms { + for term in self.iter() { if let Some(field) = schema.get_field(&term.name) { if !field.ordered { return false; @@ -108,9 +107,23 @@ impl Ordering { } } +impl Deref for Ordering { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Ordering { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl Display for Ordering { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.write_str(&self.terms.iter().map(ToString::to_string).join(", ")) + f.write_str(&self.iter().map(ToString::to_string).join(", ")) } } @@ -141,18 +154,16 @@ mod tests { let ordering = Ordering::parse(" , user.displayName, task .userId desc").unwrap(); assert_eq!( ordering, - Ordering { - terms: vec![ - OrderingTerm { - name: "user.displayName".into(), - direction: Ascending, - }, - OrderingTerm { - name: "task.userId".into(), - direction: Descending, - }, - ] - } + Ordering(vec![ + OrderingTerm { + name: "user.displayName".into(), + direction: Ascending, + }, + OrderingTerm { + name: "task.userId".into(), + direction: Descending, + }, + ]) ); } diff --git a/bomboni_request/src/parse/helpers.rs b/bomboni_request/src/parse/helpers.rs index d731c9f..4e00408 100644 --- a/bomboni_request/src/parse/helpers.rs +++ b/bomboni_request/src/parse/helpers.rs @@ -14,7 +14,10 @@ pub mod parse_id { #[cfg(test)] mod tests { - use crate::{error::RequestError, parse::RequestParse}; + use crate::{ + error::RequestError, + parse::{RequestParse, RequestResult}, + }; use bomboni_common::id::Id; use bomboni_request_derive::Parse; diff --git a/bomboni_request/src/parse/mod.rs b/bomboni_request/src/parse/mod.rs index 5cb3684..34d2b88 100644 --- a/bomboni_request/src/parse/mod.rs +++ b/bomboni_request/src/parse/mod.rs @@ -1,26 +1,22 @@ +use crate::error::RequestResult; use bomboni_common::id::Id; use time::OffsetDateTime; + pub mod helpers; pub trait RequestParse: Sized { - type Error; - - fn parse(value: T) -> Result; + fn parse(value: T) -> RequestResult; } pub trait RequestParseInto: Sized { - type Error; - - fn parse_into(self) -> Result; + fn parse_into(self) -> RequestResult; } -impl RequestParseInto for T +impl RequestParseInto for T where - U: RequestParse, + U: RequestParse, { - type Error = U::Error; - - fn parse_into(self) -> Result { + fn parse_into(self) -> RequestResult { U::parse(self) } } @@ -41,14 +37,25 @@ pub struct ParsedResource { mod tests { use std::collections::{BTreeMap, HashMap}; + use crate::ordering::Ordering; + use crate::query::page_token::PageTokenBuilder; + use crate::query::search::{SearchQuery, SearchQueryBuilder, SearchQueryConfig}; + use crate::{ + error::{CommonError, FieldError, RequestError, RequestResult}, + filter::Filter, + ordering::{OrderingDirection, OrderingTerm}, + query::{ + list::{ListQuery, ListQueryBuilder, ListQueryConfig}, + page_token::{plain::PlainPageTokenBuilder, FilterPageToken}, + }, + testing::schema::UserItem, + }; use bomboni_common::{btree_map, btree_map_into, hash_map_into}; use bomboni_proto::google::protobuf::{ Int32Value, Int64Value, StringValue, Timestamp, UInt32Value, }; use bomboni_request_derive::{impl_parse_into_map, parse_resource_name, Parse}; - use crate::error::{CommonError, FieldError, RequestError, RequestResult}; - use super::*; #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -1253,4 +1260,547 @@ mod tests { } ); } + + #[test] + fn parse_optional_vec() { + use values_parse::ItemValues; + + #[derive(Debug, Clone, PartialEq, Default)] + struct Item { + values: Option, + } + + #[derive(Debug, Clone, PartialEq, Default, Parse)] + #[parse(source = Item, write)] + struct ParsedItem { + #[parse(with = values_parse)] + values: Option>, + } + + mod values_parse { + use super::*; + + #[derive(Debug, Clone, PartialEq, Default)] + pub struct ItemValues { + pub values: Vec, + } + + #[allow(clippy::unnecessary_wraps)] + pub fn parse(values: ItemValues) -> RequestResult> { + Ok(values.values) + } + + pub fn write(values: Vec) -> ItemValues { + ItemValues { values } + } + } + + assert_eq!( + ParsedItem::parse(Item { + values: Some(ItemValues { + values: vec![1, 2, 3], + }), + }) + .unwrap(), + ParsedItem { + values: Some(vec![1, 2, 3]), + } + ); + assert_eq!( + Item::from(ParsedItem { + values: Some(vec![1, 2, 3]), + }), + Item { + values: Some(ItemValues { + values: vec![1, 2, 3], + }), + } + ); + } + + #[test] + fn parse_generics() { + #[derive(Debug, Clone, PartialEq, Default)] + struct Item { + value: i32, + } + + #[derive(Debug, Clone, PartialEq, Default, Parse)] + #[parse(source = Item, write)] + struct ParsedItem + where + T: Clone + Default + RequestParse + Into, + S: ToString + Default, + { + value: T, + #[parse(skip)] + skipped: S, + } + + #[derive(Debug, Clone, PartialEq)] + struct Union { + kind: Option, + } + + #[derive(Debug, Clone, PartialEq)] + enum UnionKind { + String(String), + Generic(i32), + } + + impl UnionKind { + pub fn get_variant_name(&self) -> &'static str { + match self { + Self::String(_) => "string", + Self::Generic(_) => "generic", + } + } + } + + #[derive(Debug, Clone, PartialEq, Parse)] + #[parse(source = UnionKind, write)] + enum ParsedUnionKind + where + T: Clone + Default + RequestParse + Into, + { + String(String), + Generic(T), + } + + #[derive(Debug, Clone, PartialEq, Parse)] + #[parse(source = Union, tagged_union { oneof = UnionKind, field = kind }, write)] + enum ParsedTaggedUnionKind + where + T: Clone + Default + RequestParse + Into, + { + String(String), + Generic(T), + } + + impl RequestParse for i32 { + fn parse(value: i32) -> RequestResult { + Ok(value) + } + } + + assert_eq!( + ParsedItem::::parse(Item { value: 42 }).unwrap(), + ParsedItem:: { + value: 42, + skipped: String::new(), + } + ); + assert_eq!( + Item::from(ParsedItem:: { + value: 42, + skipped: String::new(), + }), + Item { value: 42 } + ); + + assert_eq!( + ParsedUnionKind::::parse(UnionKind::Generic(42)).unwrap(), + ParsedUnionKind::::Generic(42) + ); + assert_eq!( + UnionKind::from(ParsedUnionKind::::Generic(42)), + UnionKind::Generic(42) + ); + + assert_eq!( + ParsedTaggedUnionKind::::parse(Union { + kind: Some(UnionKind::Generic(42)), + }) + .unwrap(), + ParsedTaggedUnionKind::::Generic(42) + ); + assert_eq!( + Union::from(ParsedTaggedUnionKind::::Generic(42)), + Union { + kind: Some(UnionKind::Generic(42)), + } + ); + } + + #[test] + fn parse_query() { + #[derive(Debug, PartialEq, Default, Clone)] + struct Item { + query: String, + page_size: Option, + page_token: Option, + filter: Option, + order_by: Option, + order: Option, + } + + #[derive(Parse, Debug, PartialEq)] + #[parse(source = Item, list_query { field = list_query }, write)] + struct ParsedListQuery { + list_query: ListQuery, + } + + #[derive(Parse, Debug, PartialEq)] + #[parse(source = Item, list_query { filter = false }, write)] + struct ParsedNoFilter { + query: ListQuery, + } + + #[derive(Parse, Debug, PartialEq)] + #[parse(source = Item, list_query, write)] + struct ParsedCustomToken { + query: ListQuery, + } + + #[derive(Parse, Debug, PartialEq)] + #[parse(source = Item, search_query { field = search_query }, write)] + struct ParsedSearchQuery { + search_query: SearchQuery, + } + + fn get_list_query_builder() -> &'static ListQueryBuilder { + use std::sync::OnceLock; + static SINGLETON: OnceLock> = OnceLock::new(); + SINGLETON.get_or_init(|| { + ListQueryBuilder::::new( + UserItem::get_schema(), + ListQueryConfig { + max_page_size: Some(20), + default_page_size: 10, + primary_ordering_term: Some(OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Descending, + }), + max_filter_length: Some(50), + max_ordering_length: Some(50), + }, + PlainPageTokenBuilder {}, + ) + }) + } + + fn get_search_query_builder() -> &'static SearchQueryBuilder { + use std::sync::OnceLock; + static SINGLETON: OnceLock> = OnceLock::new(); + SINGLETON.get_or_init(|| { + SearchQueryBuilder::::new( + UserItem::get_schema(), + SearchQueryConfig { + max_query_length: Some(50), + max_page_size: Some(20), + default_page_size: 10, + primary_ordering_term: Some(OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Descending, + }), + max_filter_length: Some(50), + max_ordering_length: Some(50), + }, + PlainPageTokenBuilder {}, + ) + }) + } + + struct CustomPageTokenBuilder {} + + impl PageTokenBuilder for CustomPageTokenBuilder { + type PageToken = u64; + + fn parse( + &self, + _filter: &Filter, + _ordering: &Ordering, + page_token: &str, + ) -> crate::query::error::QueryResult { + Ok(page_token.parse().unwrap()) + } + + fn build_next( + &self, + _filter: &Filter, + _ordering: &Ordering, + _next_item: &T, + ) -> crate::query::error::QueryResult { + Ok("24".into()) + } + } + + let item = Item { + query: "hello".into(), + page_size: Some(42), + page_token: Some("true".into()), + filter: Some("true".into()), + order_by: Some("id".into()), + order: Some("id desc".into()), + }; + + assert_eq!( + ParsedListQuery::parse_list_query(item.clone(), get_list_query_builder()).unwrap(), + ParsedListQuery { + list_query: ListQuery { + page_size: 20, + page_token: Some(FilterPageToken { + filter: Filter::parse("true").unwrap(), + }), + filter: Filter::parse("true").unwrap(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }, + ); + assert_eq!( + Item::from(ParsedListQuery { + list_query: ListQuery { + page_size: 20, + page_token: Some(FilterPageToken { + filter: Filter::parse("true").unwrap(), + }), + filter: Filter::parse("true").unwrap(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }), + Item { + query: String::new(), + page_size: Some(20), + page_token: Some("true".into()), + filter: Some("true".into()), + order_by: Some("id asc".into()), + order: None, + }, + ); + + assert_eq!( + ParsedNoFilter::parse_list_query(item.clone(), get_list_query_builder()).unwrap(), + ParsedNoFilter { + query: ListQuery { + page_size: 20, + page_token: Some(FilterPageToken { + filter: Filter::parse("true").unwrap(), + }), + filter: Filter::default(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }, + ); + assert_eq!( + Item::from(ParsedNoFilter { + query: ListQuery { + page_size: 20, + page_token: Some(FilterPageToken { + filter: Filter::parse("true").unwrap(), + }), + filter: Filter::default(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }), + Item { + query: String::new(), + page_size: Some(20), + page_token: Some("true".into()), + filter: None, + order_by: Some("id asc".into()), + order: None, + }, + ); + + assert_eq!( + ParsedCustomToken::parse_list_query( + Item { + page_token: Some("42".into()), + ..item.clone() + }, + &ListQueryBuilder::::new( + UserItem::get_schema(), + ListQueryConfig { + max_page_size: Some(20), + default_page_size: 10, + primary_ordering_term: Some(OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Descending, + }), + max_filter_length: Some(50), + max_ordering_length: Some(50), + }, + CustomPageTokenBuilder {}, + ) + ) + .unwrap(), + ParsedCustomToken { + query: ListQuery { + page_size: 20, + page_token: Some(42), + filter: Filter::parse("true").unwrap(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }, + ); + assert_eq!( + Item::from(ParsedCustomToken { + query: ListQuery { + page_size: 20, + page_token: Some(42), + filter: Filter::parse("true").unwrap(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }), + Item { + query: String::new(), + page_size: Some(20), + page_token: Some("42".into()), + filter: Some("true".into()), + order_by: Some("id asc".into()), + order: None, + }, + ); + + assert_eq!( + ParsedSearchQuery::parse_search_query(item.clone(), get_search_query_builder()) + .unwrap(), + ParsedSearchQuery { + search_query: SearchQuery { + query: "hello".into(), + page_size: 20, + page_token: Some(FilterPageToken { + filter: Filter::parse("true").unwrap(), + }), + filter: Filter::parse("true").unwrap(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }, + ); + assert_eq!( + Item::from(ParsedSearchQuery { + search_query: SearchQuery { + query: "hello".into(), + page_size: 20, + page_token: Some(FilterPageToken { + filter: Filter::parse("true").unwrap(), + }), + filter: Filter::parse("true").unwrap(), + ordering: Ordering::new(vec![OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Ascending, + }]) + }, + }), + Item { + query: "hello".into(), + page_size: Some(20), + page_token: Some("true".into()), + filter: Some("true".into()), + order_by: Some("id asc".into()), + order: None, + }, + ); + } + + #[test] + fn parse_source_nested() { + #[derive(Debug, Clone, PartialEq, Default)] + struct Item { + name: String, + item: Option, + } + + #[derive(Debug, Clone, PartialEq, Default)] + struct NestedItem { + nested_value: Option, + } + + #[derive(Debug, Clone, PartialEq, Default)] + struct NestedValue { + value: i32, + default_value: Option, + } + + #[derive(Debug, Clone, PartialEq, Default, Parse)] + #[parse(source = Item, write)] + struct ParsedItem { + #[parse(keep)] + name: String, + #[parse(source_name = "item.nested_value.value")] + value: i32, + #[parse(source_option, source_name = "item.nested_value.default_value")] + default_value: i32, + } + + assert_eq!( + ParsedItem::parse(Item { + name: String::new(), + item: Some(NestedItem { + nested_value: Some(NestedValue { + value: 42, + default_value: Some(42) + }), + }), + }) + .unwrap(), + ParsedItem { + name: String::new(), + value: 42, + default_value: 42 + } + ); + assert_eq!( + Item::from(ParsedItem { + name: String::new(), + value: 42, + default_value: 42 + }), + Item { + name: String::new(), + item: Some(NestedItem { + nested_value: Some(NestedValue { + value: 42, + default_value: Some(42) + }), + }), + } + ); + + assert!(matches!( + ParsedItem::parse(Item { + name: String::new(), + item: None, + }).unwrap_err(), + RequestError::Field(FieldError { + error, field, .. + }) if matches!( + error.as_any().downcast_ref::().unwrap(), + CommonError::RequiredFieldMissing { .. } + ) && field == "item" + )); + assert!(matches!( + ParsedItem::parse(Item { + name: String::new(), + item: Some(NestedItem{ nested_value: None }), + }).unwrap_err(), + RequestError::Field(FieldError { + error, field, .. + }) if matches!( + error.as_any().downcast_ref::().unwrap(), + CommonError::RequiredFieldMissing { .. } + ) && field == "item.nested_value" + )); + } } diff --git a/bomboni_request/src/query/error.rs b/bomboni_request/src/query/error.rs index 90f7310..6947d80 100644 --- a/bomboni_request/src/query/error.rs +++ b/bomboni_request/src/query/error.rs @@ -1,6 +1,7 @@ +use bomboni_proto::google::rpc::Code; use thiserror::Error; -use crate::{filter::error::FilterError, ordering::error::OrderingError}; +use crate::{error::DomainError, filter::error::FilterError, ordering::error::OrderingError}; #[derive(Error, Debug, Clone, PartialEq, Eq)] pub enum QueryError { @@ -28,6 +29,20 @@ pub enum QueryError { pub type QueryResult = Result; +impl QueryError { + pub fn get_violating_field_name(&self) -> &'static str { + match self { + Self::FilterError(_) | Self::FilterTooLong | Self::FilterSchemaMismatch => "filter", + Self::OrderingError(_) | Self::OrderingTooLong | Self::OrderingSchemaMismatch => { + "order_by" + } + Self::QueryTooLong => "query", + Self::InvalidPageToken | Self::PageTokenFailure => "page_token", + Self::InvalidPageSize => "page_size", + } + } +} + impl From for QueryError { fn from(err: FilterError) -> Self { Self::FilterError(err) @@ -39,3 +54,13 @@ impl From for QueryError { Self::OrderingError(err) } } + +impl DomainError for QueryError { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn code(&self) -> Code { + Code::InvalidArgument + } +} diff --git a/bomboni_request/src/query/list.rs b/bomboni_request/src/query/list.rs index 70d31a7..565ca71 100644 --- a/bomboni_request/src/query/list.rs +++ b/bomboni_request/src/query/list.rs @@ -18,12 +18,12 @@ use super::{ /// Represents a list query. /// List queries list paged, filtered and ordered items. -#[derive(Debug, Clone)] -pub struct ListQuery { - pub filter: Filter, - pub ordering: Ordering, +#[derive(Debug, Clone, PartialEq)] +pub struct ListQuery { pub page_size: i32, pub page_token: Option, + pub filter: Filter, + pub ordering: Ordering, } /// Config for list query builder. @@ -83,11 +83,10 @@ impl ListQueryBuilder

{ // This is needed for page tokens to work. if let Some(primary_ordering_term) = self.options.primary_ordering_term.as_ref() { if ordering - .terms .iter() .all(|term| term.name != primary_ordering_term.name) { - ordering.terms.insert(0, primary_ordering_term.clone()); + ordering.insert(0, primary_ordering_term.clone()); } } @@ -114,10 +113,10 @@ impl ListQueryBuilder

{ }; Ok(ListQuery { - filter, - ordering, page_size, page_token, + filter, + ordering, }) } } diff --git a/bomboni_request/src/query/page_token/mod.rs b/bomboni_request/src/query/page_token/mod.rs index ee6bcc6..a681564 100644 --- a/bomboni_request/src/query/page_token/mod.rs +++ b/bomboni_request/src/query/page_token/mod.rs @@ -5,6 +5,8 @@ //! To ensure that a valid token is used, we can encrypt it along with the query parameters and decrypt it before use. //! Encryption is also desirable to prevent users from guessing the next page of results, or to hide sensitive information. +use std::fmt::{self, Display, Formatter}; + use crate::{filter::Filter, ordering::Ordering, schema::SchemaMapped}; use super::error::QueryResult; @@ -15,13 +17,13 @@ pub mod rsa; mod utility; /// A page token containing a filter. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct FilterPageToken { pub filter: Filter, } pub trait PageTokenBuilder { - type PageToken: Clone; + type PageToken: Clone + ToString; /// Parse a page token. /// [`QueryError::InvalidPageToken`] is returned if the page token is invalid for any reason. @@ -44,3 +46,9 @@ pub trait PageTokenBuilder { next_item: &T, ) -> QueryResult; } + +impl Display for FilterPageToken { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.filter) + } +} diff --git a/bomboni_request/src/query/page_token/utility.rs b/bomboni_request/src/query/page_token/utility.rs index c411297..ca418c8 100644 --- a/bomboni_request/src/query/page_token/utility.rs +++ b/bomboni_request/src/query/page_token/utility.rs @@ -14,7 +14,7 @@ use crate::ordering::Ordering; pub fn get_page_filter(ordering: &Ordering, next_item: &T) -> Filter { let mut filters = Vec::new(); - for term in &ordering.terms { + for term in ordering.iter() { let term_argument = match next_item.get_field(&term.name) { Value::Integer(value) => Filter::Value(value.into()), Value::Float(value) => Filter::Value(value.into()), diff --git a/bomboni_request/src/query/search.rs b/bomboni_request/src/query/search.rs index bcf8bcf..8326d2a 100644 --- a/bomboni_request/src/query/search.rs +++ b/bomboni_request/src/query/search.rs @@ -14,13 +14,13 @@ use super::{ utility::{parse_query_filter, parse_query_ordering}, }; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct SearchQuery { pub query: String, - pub filter: Filter, - pub ordering: Ordering, pub page_size: i32, pub page_token: Option, + pub filter: Filter, + pub ordering: Ordering, } /// Config for search query builder. @@ -28,10 +28,10 @@ pub struct SearchQuery { /// `primary_ordering_term` should probably never be `None`. #[derive(Debug, Clone)] pub struct SearchQueryConfig { + pub max_query_length: Option, pub max_page_size: Option, pub default_page_size: i32, pub primary_ordering_term: Option, - pub max_query_length: Option, pub max_filter_length: Option, pub max_ordering_length: Option, } @@ -45,10 +45,10 @@ pub struct SearchQueryBuilder { impl Default for SearchQueryConfig { fn default() -> Self { Self { + max_query_length: None, max_page_size: None, default_page_size: 20, primary_ordering_term: None, - max_query_length: None, max_filter_length: None, max_ordering_length: None, } @@ -84,11 +84,10 @@ impl SearchQueryBuilder

{ // This is needed for page tokens to work. if let Some(primary_ordering_term) = self.options.primary_ordering_term.as_ref() { if ordering - .terms .iter() .all(|term| term.name != primary_ordering_term.name) { - ordering.terms.insert(0, primary_ordering_term.clone()); + ordering.insert(0, primary_ordering_term.clone()); } } diff --git a/bomboni_request_derive/Cargo.toml b/bomboni_request_derive/Cargo.toml index d37ff57..d77bc9b 100644 --- a/bomboni_request_derive/Cargo.toml +++ b/bomboni_request_derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_request_derive" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Provides derive implementations for Bomboni library." repository = "https://github.com/tinrab/bomboni" diff --git a/bomboni_request_derive/src/parse/message/parse.rs b/bomboni_request_derive/src/parse/message/parse.rs index 3f023ca..06ebbc7 100644 --- a/bomboni_request_derive/src/parse/message/parse.rs +++ b/bomboni_request_derive/src/parse/message/parse.rs @@ -3,12 +3,16 @@ use itertools::Itertools; use proc_macro2::{Ident, Literal, TokenStream}; use quote::{quote, ToTokens}; -use crate::parse::{DeriveOptions, ParseField, ParseOptions}; -use crate::utility::{get_proto_type_info, ProtoTypeInfo}; +use crate::parse::{DeriveOptions, ParseField, ParseOptions, QueryOptions}; +use crate::utility::{get_proto_type_info, get_query_field_token_type, ProtoTypeInfo}; pub fn expand(options: &ParseOptions, fields: &[ParseField]) -> syn::Result { - let source = &options.source; - let ident = &options.ident; + if options.list_query.is_some() && options.search_query.is_some() { + return Err(syn::Error::new_spanned( + &options.ident, + "list and search query cannot be used together", + )); + } let mut parse_fields = quote!(); // Set default for skipped fields @@ -18,9 +22,10 @@ pub fn expand(options: &ParseOptions, fields: &[ParseField]) -> syn::Result syn::Result syn::Result for #ident { - type Error = RequestError; + let source = &options.source; + let ident = &options.ident; + let type_params = { + let type_params = options.generics.type_params().map(|param| ¶m.ident); + quote! { + <#(#type_params),*> + } + }; + let where_clause = if let Some(where_clause) = &options.generics.where_clause { + quote! { #where_clause } + } else { + quote!() + }; + + if let Some(query_options) = options + .list_query + .as_ref() + .or(options.search_query.as_ref()) + { + let query_field_ident = &query_options.field; + let parse_query = expand_parse_query(query_options, options.search_query.is_some()); + + let query_field = fields + .iter() + .find(|field| field.ident.as_ref().unwrap() == query_field_ident) + .unwrap(); + let query_token_type = if let Some(token_type) = get_query_field_token_type(&query_field.ty) + { + quote! { + + } + } else { + quote! { + + } + }; + + return Ok(if options.search_query.is_some() { + quote! { + impl #ident #type_params #where_clause { + #[allow(clippy::ignored_unit_patterns)] + fn parse_search_query( + source: #source, + query_builder: &SearchQueryBuilder

+ ) -> Result { + Ok(Self { + #query_field_ident: { + #parse_query + query + }, + #parse_fields + #skipped_fields + }) + } + } + } + } else { + quote! { + impl #ident #type_params #where_clause { + #[allow(clippy::ignored_unit_patterns)] + fn parse_list_query( + source: #source, + query_builder: &ListQueryBuilder

+ ) -> Result { + Ok(Self { + #query_field_ident: { + #parse_query + query + }, + #parse_fields + #skipped_fields + }) + } + } + } + }); + } + Ok(quote! { + impl #type_params RequestParse<#source> for #ident #type_params #where_clause { #[allow(clippy::ignored_unit_patterns)] - fn parse(source: #source) -> Result { + fn parse(source: #source) -> RequestResult { Ok(Self { #parse_fields #skipped_fields @@ -59,7 +165,7 @@ pub fn expand(options: &ParseOptions, fields: &[ParseField]) -> syn::Result syn::Result { +fn expand_parse_field(options: &ParseOptions, field: &ParseField) -> syn::Result { let target_ident = field.ident.as_ref().unwrap(); if let Some(DeriveOptions { func, source_field }) = field.derive.as_ref() { @@ -82,22 +188,6 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { }); } - let field_name = if let Some(name) = field.name.as_ref() { - quote! { #name } - } else { - field - .ident - .as_ref() - .unwrap() - .to_string() - .into_token_stream() - }; - let source_ident = if let Some(name) = field.source_name.as_ref() { - Ident::from_string(name).unwrap() - } else { - field.ident.clone().unwrap() - }; - let field_type = &field.ty; let ProtoTypeInfo { is_option, @@ -105,9 +195,10 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { is_string, is_box, is_vec, + is_generic, map_ident, .. - } = get_proto_type_info(field_type); + } = get_proto_type_info(options, field_type); if (field.with.is_some() || field.parse_with.is_some()) && (field.enumeration || field.oneof || field.regex.is_some()) @@ -153,6 +244,18 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { )); } + let field_name = if let Some(name) = field.name.as_ref() { + quote! { #name } + } else { + field + .ident + .as_ref() + .unwrap() + .to_string() + .into_token_stream() + }; + let custom_parse = field.with.is_some() || field.parse_with.is_some(); + let mut parse_source = if field.keep { if is_box || field.source_box { quote! { @@ -161,7 +264,7 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { } else { quote!() } - } else if field.with.is_some() || field.parse_with.is_some() { + } else if custom_parse { let parse_with = if let Some(with) = field.with.as_ref() { quote! { #with::parse @@ -197,7 +300,7 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { let target = target.try_into() .map_err(|_| RequestError::field_index(#field_name, i, CommonError::InvalidEnumValue))?; }); - } else if is_nested { + } else if is_nested || is_generic { parse_item.extend(quote! { let target = target.parse_into() .map_err(|err: RequestError| err.wrap_index(#field_name, i))?; @@ -235,7 +338,7 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { let target = target.try_into() .map_err(|_| RequestError::field(#field_name, CommonError::InvalidEnumValue))?; }); - } else if is_nested { + } else if is_nested || is_generic { parse_item.extend(quote! { let target = target.parse_into() .map_err(|err: RequestError| err.wrap(#field_name))?; @@ -274,7 +377,7 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { #parse_source let target = target.parse_into()?; } - } else if is_nested { + } else if is_nested || is_generic { let parse_source = if is_box || field.source_box { quote! { let target = *target; @@ -326,9 +429,8 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { }); } - let mut parse = quote! { - let target = source.#source_ident; - }; + let mut parse = expand_extract_source_field(field); + if field.wrapper { match field_type.to_token_stream().to_string().as_str() { "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize" => { @@ -360,6 +462,7 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { let source_option = field.source_option || is_option || (is_nested + && !is_generic && (field.with.is_none() && field.parse_with.is_none()) && !is_vec && map_ident.is_none() @@ -367,36 +470,36 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { if is_option { if source_option { - if is_vec || is_string { - parse.extend(quote! { + parse.extend(if (is_vec || is_string) && !custom_parse { + quote! { let target = if let Some(target) = target.filter(|target| !target.is_empty()) { #parse_source Some(target) } else { None }; - }); - } else if field.enumeration { - parse.extend(quote! { + } + } else if field.enumeration && !custom_parse { + quote! { let target = if let Some(target) = target.filter(|e| *e != 0) { #parse_source Some(target) } else { None }; - }); + } } else { - parse.extend(quote! { + quote! { let target = if let Some(target) = target { #parse_source Some(target) } else { None }; - }); - } + } + }); } else { - parse.extend(if is_vec || is_string { + parse.extend(if (is_vec || is_string) && !custom_parse { quote! { let target = if target.is_empty() { None @@ -405,7 +508,7 @@ fn expand_parse_field(field: &ParseField) -> syn::Result { Some(target) }; } - } else if !is_vec && field.enumeration { + } else if (!is_vec && field.enumeration) && !custom_parse { quote! { let target = if target == 0 { None @@ -467,6 +570,10 @@ fn expand_parse_resource_field(field: &ParseField) -> syn::Result { || field.oneof || field.regex.is_some() || field.source_try_from.is_some() + || field.with.is_some() + || field.parse_with.is_some() + || field.write_with.is_some() + || field.keep { return Err(syn::Error::new_spanned( &field.ident, @@ -549,3 +656,99 @@ fn expand_parse_resource_field(field: &ParseField) -> syn::Result { }, }) } + +fn expand_parse_query(options: &QueryOptions, is_search: bool) -> TokenStream { + let mut parse = quote! { + let page_size: Option = None; + let page_token: Option<&str> = None; + let filter: Option<&str> = None; + let ordering: Option<&str> = None; + }; + if options.query.parse && is_search { + let query_source_name = &options.query.source_name; + parse.extend(quote! { + let query_string = &source.#query_source_name; + }); + } + if options.page_size.parse { + let source_name = &options.page_size.source_name; + parse.extend(quote! { + let page_size = source.#source_name.map(|i| i as i32); + }); + } + if options.page_token.parse { + let source_name = &options.page_token.source_name; + parse.extend(quote! { + let page_token = source.#source_name.as_ref().map(|s| s.as_str()); + }); + } + if options.filter.parse { + let source_name = &options.filter.source_name; + parse.extend(quote! { + let filter = source.#source_name.as_ref().map(|s| s.as_str()); + }); + } + if options.ordering.parse { + let source_name = &options.ordering.source_name; + parse.extend(quote! { + let ordering = source.#source_name.as_ref().map(|s| s.as_str()); + }); + } + + if is_search { + quote! { + #parse + let query = query_builder.build(query_string, page_size, page_token, filter, ordering)?; + } + } else { + quote! { + #parse + let query = query_builder.build(page_size, page_token, filter, ordering)?; + } + } +} + +fn expand_extract_source_field(field: &ParseField) -> TokenStream { + if let Some(source_name) = field.source_name.as_ref() { + if source_name.contains('.') { + let parts = source_name.split('.').collect::>(); + + let mut extract = quote!(); + for (i, part) in parts.iter().enumerate() { + let part_ident = Ident::from_string(part).unwrap(); + let part_literal = Literal::string(&parts.iter().take(i + 1).join(".")); + + extract.extend(if i < parts.len() - 1 { + quote! { + .#part_ident.ok_or_else(|| { + RequestError::field( + #part_literal, + CommonError::RequiredFieldMissing, + ) + })? + } + } else { + quote! { + .#part_ident + } + }); + } + + // Purposefully clone source on each parse. + // Could be optimized in the future. + quote! { + let target = source.clone() #extract; + } + } else { + let source_ident = Ident::from_string(source_name).unwrap(); + quote! { + let target = source.#source_ident; + } + } + } else { + let source_ident = field.ident.clone().unwrap(); + quote! { + let target = source.#source_ident; + } + } +} diff --git a/bomboni_request_derive/src/parse/message/write.rs b/bomboni_request_derive/src/parse/message/write.rs index 048229b..0d6b6c0 100644 --- a/bomboni_request_derive/src/parse/message/write.rs +++ b/bomboni_request_derive/src/parse/message/write.rs @@ -2,13 +2,10 @@ use darling::FromMeta; use proc_macro2::{Ident, TokenStream}; use quote::{quote, ToTokens}; -use crate::parse::{ParseField, ParseOptions}; +use crate::parse::{ParseField, ParseOptions, QueryOptions}; use crate::utility::{get_proto_type_info, ProtoTypeInfo}; pub fn expand(options: &ParseOptions, fields: &[ParseField]) -> TokenStream { - let source = &options.source; - let ident = &options.ident; - let mut write_fields = quote!(); for field in fields { @@ -16,43 +13,72 @@ pub fn expand(options: &ParseOptions, fields: &[ParseField]) -> TokenStream { continue; } + // Skip query fields + if let Some(list_query) = options.list_query.as_ref() { + if &list_query.field == field.ident.as_ref().unwrap() { + continue; + } + } else if let Some(search_query) = options.search_query.as_ref() { + if &search_query.field == field.ident.as_ref().unwrap() { + continue; + } + } + if field.resource.is_some() { write_fields.extend(expand_write_resource(field)); } else { - write_fields.extend(expand_write_field(field)); + write_fields.extend(expand_write_field(options, field)); } } + if let Some(query_options) = options + .list_query + .as_ref() + .or(options.search_query.as_ref()) + { + write_fields.extend(expand_query_resource( + query_options, + options.search_query.is_some(), + )); + } + + let source = &options.source; + let ident = &options.ident; + let type_params = { + let type_params = options.generics.type_params().map(|param| ¶m.ident); + quote! { + <#(#type_params),*> + } + }; + let where_clause = if let Some(where_clause) = &options.generics.where_clause { + quote! { #where_clause } + } else { + quote!() + }; + quote! { - impl From<#ident> for #source { + impl #type_params From<#ident #type_params> for #source #where_clause { #[allow(clippy::needless_update)] - fn from(value: #ident) -> Self { - #source { - #write_fields - ..Default::default() - } + fn from(value: #ident #type_params) -> Self { + let mut source: #source = Default::default(); + #write_fields + source } } } } -fn expand_write_field(field: &ParseField) -> TokenStream { - let target_ident = field.ident.as_ref().unwrap(); - let source_ident = if let Some(name) = field.source_name.as_ref() { - Ident::from_string(name).unwrap() - } else { - field.ident.clone().unwrap() - }; - +fn expand_write_field(options: &ParseOptions, field: &ParseField) -> TokenStream { let field_type = &field.ty; let ProtoTypeInfo { is_option, is_nested, is_vec, is_box, + is_generic, map_ident, .. - } = get_proto_type_info(field_type); + } = get_proto_type_info(options, field_type); let mut write_target = if field.keep { if is_box { @@ -79,7 +105,7 @@ fn expand_write_field(field: &ParseField) -> TokenStream { write_item.extend(quote! { let source = source as i32; }); - } else if is_nested { + } else if is_nested || is_generic { write_item.extend(quote! { let source = source.into(); }); @@ -98,7 +124,7 @@ fn expand_write_field(field: &ParseField) -> TokenStream { write_item.extend(quote! { let source = source as i32; }); - } else if is_nested { + } else if is_nested || is_generic { write_item.extend(quote! { let source = source.into(); }); @@ -115,7 +141,7 @@ fn expand_write_field(field: &ParseField) -> TokenStream { quote! { let source = source as i32; } - } else if field.oneof || is_nested { + } else if field.oneof || is_nested || is_generic { let write_target = if is_box { quote! { let source = *source; @@ -136,13 +162,15 @@ fn expand_write_field(field: &ParseField) -> TokenStream { }; if let Some(source_try_from) = field.source_try_from.as_ref() { - let err_literal = format!("failed to convert `{source_ident}` to `{source_try_from}`"); + let field_ident = field.ident.as_ref().unwrap(); + let err_literal = format!("failed to convert `{field_ident}` to `{source_try_from}`"); write_target.extend(quote! { let source: #source_try_from = source.try_into() .expect(#err_literal); }); } + let target_ident = field.ident.as_ref().unwrap(); let mut write = quote! { let source = value.#target_ident; }; @@ -156,6 +184,7 @@ fn expand_write_field(field: &ParseField) -> TokenStream { let source_option = field.source_option || is_option || (is_nested + && !is_generic && (field.with.is_none() && field.parse_with.is_none()) && !is_vec && map_ident.is_none() @@ -207,11 +236,43 @@ fn expand_write_field(field: &ParseField) -> TokenStream { }); } + let source = if let Some(source_name) = field.source_name.as_ref() { + if source_name.contains('.') { + let parts: Vec<_> = source_name.split('.').collect(); + let mut inject = quote!(); + for (i, part) in parts.iter().enumerate() { + let part_ident = Ident::from_string(part).unwrap(); + inject.extend(if i < parts.len() - 1 { + quote! { + .#part_ident + .get_or_insert(Default::default()) + } + } else { + quote! { + .#part_ident + } + }); + } + quote! { + source #inject + } + } else { + let source_ident = Ident::from_string(source_name).unwrap(); + quote! { + source.#source_ident + } + } + } else { + let source_ident = field.ident.clone().unwrap(); + quote! { + source.#source_ident + } + }; quote! { - #source_ident: { + #source = { #write source - }, + }; } } @@ -223,32 +284,78 @@ fn expand_write_resource(field: &ParseField) -> TokenStream { if options.fields.name { result.extend(quote! { - name: value.#ident.name, + source.name = value.#ident.name; }); } if options.fields.create_time { result.extend(quote! { - create_time: value.#ident.create_time.map(Into::into), + source.create_time = value.#ident.create_time.map(Into::into); }); } if options.fields.update_time { result.extend(quote! { - update_time: value.#ident.update_time.map(Into::into), + source.update_time = value.#ident.update_time.map(Into::into); }); } if options.fields.delete_time { result.extend(quote! { - delete_time: value.#ident.delete_time.map(Into::into), + source.delete_time = value.#ident.delete_time.map(Into::into); }); } if options.fields.deleted { result.extend(quote! { - deleted: value.#ident.deleted, + source.deleted = value.#ident.deleted; }); } if options.fields.etag { result.extend(quote! { - etag: value.#ident.etag, + source.etag = value.#ident.etag; + }); + } + + result +} + +fn expand_query_resource(options: &QueryOptions, is_search: bool) -> TokenStream { + let ident = &options.field; + let mut result = quote!(); + + if options.query.parse && is_search { + let source_name = &options.query.source_name; + result.extend(quote! { + source.#source_name = value.#ident.query; + }); + } + if options.page_size.parse { + let source_name = &options.page_size.source_name; + result.extend(quote! { + source.#source_name = Some(value.#ident.page_size.try_into().unwrap()); + }); + } + if options.page_token.parse { + let source_name = &options.page_token.source_name; + result.extend(quote! { + source.#source_name = value.#ident.page_token.map(|page_token| page_token.to_string()); + }); + } + if options.filter.parse { + let source_name = &options.filter.source_name; + result.extend(quote! { + source.#source_name = if value.#ident.filter.is_empty() { + None + } else { + Some(value.#ident.filter.to_string()) + }; + }); + } + if options.ordering.parse { + let source_name = &options.ordering.source_name; + result.extend(quote! { + source.#source_name = if value.#ident.ordering.is_empty() { + None + } else { + Some(value.#ident.ordering.to_string()) + }; }); } diff --git a/bomboni_request_derive/src/parse/mod.rs b/bomboni_request_derive/src/parse/mod.rs index b578623..e577a3f 100644 --- a/bomboni_request_derive/src/parse/mod.rs +++ b/bomboni_request_derive/src/parse/mod.rs @@ -2,7 +2,10 @@ use darling::util::parse_expr; use darling::{ast, FromDeriveInput, FromField, FromMeta, FromVariant}; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::{self, DeriveInput, Expr, ExprArray, ExprPath, Meta, MetaNameValue, Path, Type}; +use syn::{ + self, DeriveInput, Expr, ExprArray, ExprPath, Generics, Meta, MetaNameValue, Path, Type, + WhereClause, +}; mod message; mod oneof; @@ -13,6 +16,8 @@ pub mod parse_resource_name; #[darling(attributes(parse))] pub struct ParseOptions { pub ident: Ident, + pub generics: Generics, + pub where_clause: Option, pub data: ast::Data, /// Source proto type. pub source: Path, @@ -22,6 +27,12 @@ pub struct ParseOptions { /// Used to create tagged unions. #[darling(default)] pub tagged_union: Option, + /// Parse list query fields. + #[darling(default)] + pub list_query: Option, + /// Parse search query fields. + #[darling(default)] + pub search_query: Option, } #[derive(FromMeta, Debug)] @@ -40,6 +51,7 @@ pub struct ParseField { #[darling(with = parse_expr::parse_str_literal, map = Some)] pub name: Option, /// Source field name. + /// Can be a path to a nested field. pub source_name: Option, /// Skip parsing field. #[darling(default)] @@ -156,6 +168,22 @@ pub struct ResourceFields { pub etag: bool, } +#[derive(Debug)] +pub struct QueryOptions { + pub field: Ident, + pub query: QueryFieldOptions, + pub page_size: QueryFieldOptions, + pub page_token: QueryFieldOptions, + pub filter: QueryFieldOptions, + pub ordering: QueryFieldOptions, +} + +#[derive(Debug)] +pub struct QueryFieldOptions { + pub parse: bool, + pub source_name: Ident, +} + #[derive(Debug)] pub struct DeriveOptions { /// The function must have the signature `fn(source: &Source) -> RequestResult`. @@ -303,3 +331,94 @@ impl FromMeta for ResourceFields { Ok(fields) } } + +impl Default for QueryOptions { + fn default() -> Self { + Self { + field: Ident::from_string("query").unwrap(), + query: QueryFieldOptions { + parse: true, + source_name: Ident::from_string("query").unwrap(), + }, + page_size: QueryFieldOptions { + parse: true, + source_name: Ident::from_string("page_size").unwrap(), + }, + page_token: QueryFieldOptions { + parse: true, + source_name: Ident::from_string("page_token").unwrap(), + }, + filter: QueryFieldOptions { + parse: true, + source_name: Ident::from_string("filter").unwrap(), + }, + ordering: QueryFieldOptions { + parse: true, + source_name: Ident::from_string("order_by").unwrap(), + }, + } + } +} + +impl FromMeta for QueryOptions { + fn from_list(items: &[ast::NestedMeta]) -> darling::Result { + let mut options = Self::default(); + + macro_rules! impl_field_option { + ($ident:ident, $meta:ident) => { + if let Ok(parse) = bool::from_meta($meta) { + options.$ident.parse = parse; + } else if let Ok(source_name) = Ident::from_meta($meta) { + options.$ident.source_name = source_name; + } else { + return Err(darling::Error::custom(format!( + "invalid query `{}` option value", + stringify!($ident) + )) + .with_span($meta)); + } + }; + } + + for item in items { + match item { + ast::NestedMeta::Meta(meta) => { + let ident = meta.path().get_ident().unwrap(); + match ident.to_string().as_str() { + "field" => { + options.field = Ident::from_meta(meta)?; + } + "query" => { + if let Ok(source_name) = Ident::from_meta(meta) { + options.query.source_name = source_name; + } else { + return Err(darling::Error::custom( + "invalid query `query` option value", + ) + .with_span(meta)); + } + } + "page_size" => impl_field_option!(page_size, meta), + "page_token" => impl_field_option!(page_token, meta), + "filter" => impl_field_option!(filter, meta), + "ordering" => impl_field_option!(ordering, meta), + _ => { + return Err( + darling::Error::custom("unknown query option").with_span(ident) + ); + } + } + } + ast::NestedMeta::Lit(lit) => { + return Err(darling::Error::custom("unexpected literal").with_span(lit)); + } + } + } + + Ok(options) + } + + fn from_word() -> darling::Result { + Ok(Self::default()) + } +} diff --git a/bomboni_request_derive/src/parse/oneof/mod.rs b/bomboni_request_derive/src/parse/oneof/mod.rs index ad17fa3..83c70e9 100644 --- a/bomboni_request_derive/src/parse/oneof/mod.rs +++ b/bomboni_request_derive/src/parse/oneof/mod.rs @@ -6,9 +6,17 @@ mod parse; mod write; pub fn expand(options: &ParseOptions, variants: &[ParseVariant]) -> syn::Result { + if options.list_query.is_some() || options.search_query.is_some() { + return Err(syn::Error::new_spanned( + &options.ident, + "enums cannot be used with `list_query` or `search_query`", + )); + } + let mut result = parse::expand(options, variants)?; if options.write { result.extend(write::expand(options, variants)); } + Ok(result) } diff --git a/bomboni_request_derive/src/parse/oneof/parse.rs b/bomboni_request_derive/src/parse/oneof/parse.rs index bb13ac4..c13f400 100644 --- a/bomboni_request_derive/src/parse/oneof/parse.rs +++ b/bomboni_request_derive/src/parse/oneof/parse.rs @@ -16,8 +16,20 @@ pub fn expand(options: &ParseOptions, variants: &[ParseVariant]) -> syn::Result< fn expand_parse(options: &ParseOptions, variants: &[ParseVariant]) -> syn::Result { let source = &options.source; let ident = &options.ident; + let type_params = { + let type_params = options.generics.type_params().map(|param| ¶m.ident); + quote! { + <#(#type_params),*> + } + }; + let where_clause = if let Some(where_clause) = &options.generics.where_clause { + quote! { #where_clause } + } else { + quote!() + }; let mut parse_variants = quote!(); + for variant in variants { if variant.skip { continue; @@ -45,7 +57,7 @@ fn expand_parse(options: &ParseOptions, variants: &[ParseVariant]) -> syn::Resul } }); } else { - let parse_variant = expand_parse_variant(variant)?; + let parse_variant = expand_parse_variant(options, variant)?; parse_variants.extend(quote! { #source::#source_variant_ident(source) => { #ident::#target_variant_ident({ @@ -57,10 +69,8 @@ fn expand_parse(options: &ParseOptions, variants: &[ParseVariant]) -> syn::Resul } Ok(quote! { - impl RequestParse<#source> for #ident { - type Error = RequestError; - - fn parse(source: #source) -> Result { + impl #type_params RequestParse<#source> for #ident #type_params #where_clause { + fn parse(source: #source) -> RequestResult { let variant_name = source.get_variant_name(); Ok(match source { #parse_variants @@ -78,11 +88,8 @@ fn expand_tagged_union( variants: &[ParseVariant], tagged_union: &ParseTaggedUnion, ) -> syn::Result { - let source = &options.source; let ident = &options.ident; let oneof_ident = &tagged_union.oneof; - let field_ident = &tagged_union.field; - let field_literal = Literal::string(&tagged_union.field.to_string()); let mut parse_variants = quote!(); for variant in variants { @@ -112,7 +119,7 @@ fn expand_tagged_union( } }); } else { - let parse_variant = expand_parse_variant(variant)?; + let parse_variant = expand_parse_variant(options, variant)?; parse_variants.extend(quote! { #oneof_ident::#source_variant_ident(source) => { #ident::#target_variant_ident({ @@ -123,12 +130,25 @@ fn expand_tagged_union( } } - Ok(quote! { - impl RequestParse<#source> for #ident { - type Error = RequestError; + let field_ident = &tagged_union.field; + let field_literal = Literal::string(&tagged_union.field.to_string()); + let source = &options.source; + let type_params = { + let type_params = options.generics.type_params().map(|param| ¶m.ident); + quote! { + <#(#type_params),*> + } + }; + let where_clause = if let Some(where_clause) = &options.generics.where_clause { + quote! { #where_clause } + } else { + quote!() + }; + Ok(quote! { + impl #type_params RequestParse<#source> for #ident #type_params #where_clause { #[allow(ignored_unit_patterns)] - fn parse(source: #source) -> Result { + fn parse(source: #source) -> RequestResult { let source = source.#field_ident .ok_or_else(|| RequestError::field(#field_literal, CommonError::RequiredFieldMissing))?; let variant_name = source.get_variant_name(); @@ -143,7 +163,10 @@ fn expand_tagged_union( }) } -fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { +fn expand_parse_variant( + options: &ParseOptions, + variant: &ParseVariant, +) -> syn::Result { if (variant.with.is_some() || variant.parse_with.is_some()) && variant.regex.is_some() { return Err(syn::Error::new_spanned( &variant.ident, @@ -191,8 +214,9 @@ fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { is_nested, is_string, is_box, + is_generic, .. - } = get_proto_type_info(variant_type); + } = get_proto_type_info(options, variant_type); if variant.regex.is_some() && !is_string { return Err(syn::Error::new_spanned( @@ -201,6 +225,7 @@ fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { )); } + let custom_parse = variant.with.is_some() || variant.parse_with.is_some(); let mut parse_source = if variant.keep { if is_box || variant.source_box { quote! { @@ -209,7 +234,7 @@ fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { } else { quote!() } - } else if variant.with.is_some() || variant.parse_with.is_some() { + } else if custom_parse { let parse_with = if let Some(with) = variant.with.as_ref() { quote! { #with::parse @@ -226,7 +251,7 @@ fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { let target = target.try_into() .map_err(|_| RequestError::field(variant_name, CommonError::InvalidEnumValue))?; } - } else if is_nested { + } else if is_nested || is_generic { let parse_source = if is_box || variant.source_box { quote! { let target = *target; @@ -306,36 +331,36 @@ fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { if is_option { if source_option { - if is_string { - parse.extend(quote! { + parse.extend(if is_string && !custom_parse { + quote! { let target = if let Some(target) = target.filter(|target| !target.is_empty()) { #parse_source Some(target) } else { None }; - }); - } else if variant.enumeration { - parse.extend(quote! { + } + } else if variant.enumeration && !custom_parse { + quote! { let target = if let Some(target) = target.filter(|e| *e != 0) { #parse_source Some(target) } else { None }; - }); + } } else { - parse.extend(quote! { + quote! { let target = if let Some(target) = target { #parse_source Some(target) } else { None }; - }); - } + } + }); } else { - parse.extend(if is_string { + parse.extend(if is_string && !custom_parse { quote! { let target = if target.is_empty() { None @@ -344,7 +369,7 @@ fn expand_parse_variant(variant: &ParseVariant) -> syn::Result { Some(target) }; } - } else if variant.enumeration { + } else if variant.enumeration && !custom_parse { quote! { let target = if target == 0 { None diff --git a/bomboni_request_derive/src/parse/oneof/write.rs b/bomboni_request_derive/src/parse/oneof/write.rs index de3262f..b63d0b8 100644 --- a/bomboni_request_derive/src/parse/oneof/write.rs +++ b/bomboni_request_derive/src/parse/oneof/write.rs @@ -16,6 +16,17 @@ pub fn expand(options: &ParseOptions, variants: &[ParseVariant]) -> TokenStream fn expand_write(options: &ParseOptions, variants: &[ParseVariant]) -> TokenStream { let source = &options.source; let ident = &options.ident; + let type_params = { + let type_params = options.generics.type_params().map(|param| ¶m.ident); + quote! { + <#(#type_params),*> + } + }; + let where_clause = if let Some(where_clause) = &options.generics.where_clause { + quote! { #where_clause } + } else { + quote!() + }; let mut write_variants = quote!(); @@ -46,7 +57,7 @@ fn expand_write(options: &ParseOptions, variants: &[ParseVariant]) -> TokenStrea } }); } else { - let write_variant = expand_write_variant(variant); + let write_variant = expand_write_variant(options, variant); write_variants.extend(quote! { #ident::#target_variant_ident(value) => { #source::#source_variant_ident({ @@ -58,8 +69,8 @@ fn expand_write(options: &ParseOptions, variants: &[ParseVariant]) -> TokenStrea } quote! { - impl From<#ident> for #source { - fn from(value: #ident) -> Self { + impl #type_params From<#ident #type_params> for #source #where_clause { + fn from(value: #ident #type_params) -> Self { match value { #write_variants _ => panic!("unknown oneof variant"), @@ -78,6 +89,17 @@ fn expand_write_tagged_union( let ident = &options.ident; let oneof_ident = &tagged_union.oneof; let field_ident = &tagged_union.field; + let type_params = { + let type_params = options.generics.type_params().map(|param| ¶m.ident); + quote! { + <#(#type_params),*> + } + }; + let where_clause = if let Some(where_clause) = &options.generics.where_clause { + quote! { #where_clause } + } else { + quote!() + }; let mut write_variants = quote!(); for variant in variants { @@ -107,7 +129,7 @@ fn expand_write_tagged_union( } }); } else { - let write_variant = expand_write_variant(variant); + let write_variant = expand_write_variant(options, variant); write_variants.extend(quote! { #ident::#target_variant_ident(value) => { #oneof_ident::#source_variant_ident({ @@ -119,8 +141,8 @@ fn expand_write_tagged_union( } quote! { - impl From<#ident> for #source { - fn from(value: #ident) -> Self { + impl #type_params From<#ident #type_params> for #source #where_clause { + fn from(value: #ident #type_params) -> Self { #source { #field_ident: Some(match value { #write_variants @@ -132,14 +154,15 @@ fn expand_write_tagged_union( } } -fn expand_write_variant(variant: &ParseVariant) -> TokenStream { +fn expand_write_variant(options: &ParseOptions, variant: &ParseVariant) -> TokenStream { let variant_type = variant.fields.iter().next().unwrap(); let ProtoTypeInfo { is_option, is_nested, is_box, + is_generic, .. - } = get_proto_type_info(variant_type); + } = get_proto_type_info(options, variant_type); let mut write_target = if variant.keep { if is_box { @@ -164,7 +187,7 @@ fn expand_write_variant(variant: &ParseVariant) -> TokenStream { quote! { let source = source as i32; } - } else if is_nested { + } else if is_nested || is_generic { let write_target = if is_box { quote! { let source = *source; diff --git a/bomboni_request_derive/src/utility.rs b/bomboni_request_derive/src/utility.rs index 7d3c6a4..9fab077 100644 --- a/bomboni_request_derive/src/utility.rs +++ b/bomboni_request_derive/src/utility.rs @@ -1,6 +1,8 @@ use proc_macro2::Ident; use syn::{GenericArgument, Path, PathArguments, PathSegment, Type, TypePath}; +use crate::parse::ParseOptions; + pub fn is_option_type(ty: &Type) -> bool { if let Type::Path(TypePath { path, .. }) = ty { path.segments.len() == 1 && path.segments[0].ident == "Option" @@ -16,19 +18,24 @@ pub struct ProtoTypeInfo { pub is_string: bool, pub is_box: bool, pub is_vec: bool, + pub is_generic: bool, pub map_ident: Option, } -pub fn get_proto_type_info(ty: &Type) -> ProtoTypeInfo { +pub fn get_proto_type_info(options: &ParseOptions, ty: &Type) -> ProtoTypeInfo { let mut info = ProtoTypeInfo::default(); if let Type::Path(type_path) = ty { let segment = type_path.path.segments.first().unwrap(); - update_proto_type_segment(&mut info, segment); + update_proto_type_segment(&mut info, options, segment); } info } -fn update_proto_type_segment(info: &mut ProtoTypeInfo, segment: &PathSegment) { +fn update_proto_type_segment( + info: &mut ProtoTypeInfo, + options: &ParseOptions, + segment: &PathSegment, +) { if segment.ident == "Option" { info.is_option = true; } else if segment.ident == "Box" { @@ -39,6 +46,14 @@ fn update_proto_type_segment(info: &mut ProtoTypeInfo, segment: &PathSegment) { info.map_ident = Some(segment.ident.clone()); } else if segment.ident == "String" { info.is_string = true; + } else if options.generics.params.iter().any(|param| { + if let syn::GenericParam::Type(type_param) = param { + type_param.ident == segment.ident + } else { + false + } + }) { + info.is_generic = true; } else { // Assume nested message begin with a capital letter info.is_nested = !info.is_nested @@ -60,7 +75,22 @@ fn update_proto_type_segment(info: &mut ProtoTypeInfo, segment: &PathSegment) { _ => args.args.first().unwrap(), } { let nested_segment = segments.first().unwrap(); - update_proto_type_segment(info, nested_segment); + update_proto_type_segment(info, options, nested_segment); + } + } +} + +pub fn get_query_field_token_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if path.segments.len() == 1 + && (path.segments[0].ident == "ListQuery" || path.segments[0].ident == "SearchQuery") + { + if let PathArguments::AngleBracketed(args) = &path.segments[0].arguments { + if let GenericArgument::Type(ty) = args.args.first().unwrap() { + return Some(ty); + } + } } } + None } diff --git a/bomboni_template/Cargo.toml b/bomboni_template/Cargo.toml index 17c9a3c..e204f4b 100644 --- a/bomboni_template/Cargo.toml +++ b/bomboni_template/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_template" -version = "0.1.34" +version = "0.1.39" authors = ["Tin Rabzelj "] description = "Utilities for working Handlebars templates. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" @@ -17,8 +17,8 @@ path = "src/lib.rs" testing = [] [dependencies] -bomboni_common = { path = "../bomboni_common", version = "0.1.34" } -bomboni_proto = { version = "0.1.34", path = "../bomboni_proto", features = [ +bomboni_common = { path = "../bomboni_common", version = "0.1.39" } +bomboni_proto = { version = "0.1.39", path = "../bomboni_proto", features = [ "json", ] } thiserror = "1.0.50"