Skip to content

Commit

Permalink
Add derive macros
Browse files Browse the repository at this point in the history
  • Loading branch information
Sufflope committed Sep 25, 2024
1 parent 13fcbc7 commit bedadec
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ name = "approx"

[features]
default = ["std"]
derive = ["dep:approx-derive"]
std = []

[dependencies]
approx-derive = { version = "0.5.1", path = "approx-derive", optional = true }
num-traits = { version = "0.2.0", default_features = false }
num-complex = { version = "0.4.0", optional = true }

[workspace]
members = ["approx-derive"]
14 changes: 14 additions & 0 deletions approx-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "approx-derive"
version = "0.5.1"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
darling = "0.20"
itertools = "0.13"
proc-macro2 = "1"
quote = "1"
syn = "2"
150 changes: 150 additions & 0 deletions approx-derive/src/abs_diff_eq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use darling::{
ast::{Data, Fields, Style},
FromDeriveInput,
};
use itertools::Itertools;
use proc_macro2::{Literal, TokenStream};
use quote::{format_ident, quote, ToTokens};
use syn::{DeriveInput, Expr, Path};

use crate::{Field, Variant};

#[derive(FromDeriveInput)]
#[darling(attributes(approx))]
struct Opts {
epsilon: Path,
default: Expr,
data: Data<Variant, Field>,
}

pub(crate) fn derive(item: DeriveInput) -> syn::Result<TokenStream> {
let Opts {
epsilon,
default,
data,
} = Opts::from_derive_input(&item)?;

let ident = item.ident;

let comparisons = data
.map_enum_variants(|variant| {
let ident = variant.ident;
let fields = variant.fields;
match fields.style {
Style::Tuple => {
let (comps, self_extractors, other_extractors): (Vec<_>, Vec<_>, Vec<_>) =
fields
.fields
.into_iter()
.enumerate()
.map(|(i, field)| {
if field.skip() {
(None, format_ident!("_"), format_ident!("_"))
} else {
let one = format_ident!("_{}", i);
let other = format_ident!("other_{}", i);
(Some(compare(&one, &other, field.exact)), one, other)
}
})
.multiunzip();
let comps = comps.iter().flatten();
quote! {
Self::#ident(#(#self_extractors),*) => match other {
Self::#ident(#(#other_extractors),*) => #(#comps)&&*,
_ => false
}
}
}
Style::Struct => {
let (comps, self_extractors, other_extractors): (Vec<_>, Vec<_>, Vec<_>) =
fields
.fields
.into_iter()
.filter_map(|field| {
if field.skip() {
None
} else {
let one = field.ident.clone().unwrap();
let other = format_ident!("other_{}", one);
Some((
compare(&one, &other, field.exact),
one.clone(),
quote! { #one: #other },
))
}
})
.multiunzip();
quote! {
Self::#ident { #(#self_extractors),*, .. } => match other {
Self::#ident {#(#other_extractors),*, ..} => #(#comps)&&*,
_ => false
}
}
}
Style::Unit => quote! { Self::#ident => self == other },
}
})
.map_struct(|fields| {
Fields::<TokenStream>::from((
fields.style,
fields
.into_iter()
.enumerate()
.filter_map(|(i, field)| {
if field.skip() {
None
} else {
let ident = match field.ident {
None => Literal::usize_unsuffixed(i).to_token_stream(),
Some(ident) => quote! { #ident },
};
Some(compare(
quote! { self.#ident },
quote! { other.#ident },
field.exact,
))
}
})
.collect::<Vec<_>>(),
))
});

let comparisons = if comparisons.is_enum() {
let comparisons = comparisons.take_enum().unwrap();
quote! {
match self {
#(#comparisons),*
}
}
} else {
let comparisons = comparisons.take_struct().unwrap().fields.into_iter();
quote!(#(#comparisons)&&*)
};

Ok(quote! {
#[automatically_derived]
impl AbsDiffEq for #ident {
type Epsilon = #epsilon;

fn default_epsilon() -> Self::Epsilon {
#default
}

fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
#comparisons
}
}
})
}

fn compare<One, Other>(one: One, other: Other, exact: Option<bool>) -> TokenStream
where
One: ToTokens,
Other: ToTokens,
{
if exact.unwrap_or(true) {
quote! { #one == #other }
} else {
quote! { #one.abs_diff_eq(&#other, epsilon) }
}
}
36 changes: 36 additions & 0 deletions approx-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
mod abs_diff_eq;

use darling::{ast::Fields, FromField, FromVariant};
use proc_macro2::TokenStream;
use syn::{parse_macro_input, DeriveInput, Ident};

#[derive(FromVariant)]
#[darling(attributes(approx))]
struct Variant {
ident: Ident,
fields: Fields<Field>,
}

#[derive(FromField)]
#[darling(attributes(approx))]
struct Field {
ident: Option<Ident>,
skip: Option<bool>,
exact: Option<bool>,
}

impl Field {
fn skip(&self) -> bool {
self.skip.unwrap_or(false)
}
}

#[proc_macro_derive(AbsDiffEq, attributes(approx))]
pub fn abs_diff_eq(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let item = parse_macro_input!(item as DeriveInput);
convert(abs_diff_eq::derive(item))
}

fn convert(tokens: syn::Result<TokenStream>) -> proc_macro::TokenStream {
tokens.unwrap_or_else(syn::Error::into_compile_error).into()
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
#![no_std]
#![allow(clippy::transmute_float_to_int)]

#[cfg(feature = "derive")]
extern crate approx_derive;
#[cfg(feature = "num-complex")]
extern crate num_complex;
extern crate num_traits;
Expand All @@ -171,6 +173,9 @@ pub use abs_diff_eq::AbsDiffEq;
pub use relative_eq::RelativeEq;
pub use ulps_eq::UlpsEq;

#[cfg(feature = "derive")]
pub use approx_derive::AbsDiffEq;

/// The requisite parameters for testing for approximate equality using a
/// absolute difference based comparison.
///
Expand Down
Loading

0 comments on commit bedadec

Please sign in to comment.