Skip to content

Commit

Permalink
Updated parser to remove custom serde functions
Browse files Browse the repository at this point in the history
  • Loading branch information
FloppyDisck committed Oct 14, 2024
1 parent 5238713 commit abbfcce
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/proto/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "archway-proto"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
description = "Rust build of Archway's ProtoBuf definitions"
authors.workspace = true
Expand Down
85 changes: 85 additions & 0 deletions packages/proto/src/any.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,58 @@
use prost::bytes::{Buf, BufMut};
use prost::{Message, Name};
use serde::{ser, Deserialize, Deserializer, Serialize, Serializer};

/// Works as a wrapper for Vec<u8> when working with structures that have an undefined/unknown type.
#[derive(:: serde :: Serialize, :: serde :: Deserialize, Clone, PartialEq, Debug, Default)]
pub struct GenericData(pub Vec<u8>);

impl Message for GenericData {
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
{
buf.put_slice(self.0.as_slice());
}

fn merge_field<B>(
&mut self,
tag: u32,
_wire_type: prost::encoding::WireType,
buf: &mut B,
_ctx: prost::encoding::DecodeContext,
) -> Result<(), prost::DecodeError>
where
B: Buf,
{
if tag == 1 {
self.0.push(10u8);
while buf.has_remaining() {
self.0.push(buf.get_u8());
}
Ok(())
} else {
Err(prost::DecodeError::new("invalid tag"))
}
}

fn encoded_len(&self) -> usize {
self.0.len()
}

fn clear(&mut self) {
self.0.clear();
}
}

impl Name for GenericData {
const NAME: &'static str = "";
const PACKAGE: &'static str = "";

fn full_name() -> String {
format!("{}{}", Self::PACKAGE, Self::NAME)
}
}

// An improved any type that allows you to implement typing directly into it
#[derive(Clone, PartialEq, Serialize, Deserialize, Message)]
pub struct Any<T: Message + PartialEq + Default> {
Expand All @@ -17,6 +69,13 @@ impl<T: Message + Name + PartialEq + Default> Any<T> {
value,
}
}

pub fn generic(value: T) -> Any<GenericData> {
Any {
type_url: T::full_name(),
value: GenericData(value.encode_to_vec()),
}
}
}

// Deserialize data shown inside Any<T> and return T
Expand Down Expand Up @@ -146,6 +205,7 @@ pub mod vec {

#[cfg(test)]
mod test {
use crate::any::{Any, GenericData};
use prost::{Message, Name};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -239,6 +299,31 @@ mod test {
}
}

#[test]
fn generic_equal_any() {
let val = AnyValue {
value: "testinnnng".to_string(),
number: 33,
};

dbg!(val.encode_to_vec());
dbg!(GenericData(val.encode_to_vec()).encode_to_vec());

let real_any = Any::new(val.clone());
let bytes_any = Any::generic(val);

let mut real_encoded = real_any.encode_to_vec();
let mut bytes_encoded = bytes_any.encode_to_vec();

assert_eq!(real_encoded, bytes_encoded);

let real_decoded = Any::<AnyValue>::decode(bytes_encoded.as_slice()).unwrap();
let bytes_decoded = Any::<GenericData>::decode(real_encoded.as_slice()).unwrap();

assert_eq!(real_decoded.type_url, bytes_decoded.type_url);
assert_eq!(real_decoded.value.encode_to_vec(), bytes_decoded.value.0);
}

#[test]
fn test_ser_de() {
let test = Test {
Expand Down
175 changes: 68 additions & 107 deletions proto-build/src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use crate::patch_file;
use glob::glob;
use proc_macro2::{Literal, Punct, Spacing, Span, TokenStream};
use quote::{quote, TokenStreamExt};
use proc_macro2::Span;
use quote::quote;
use regex::Regex;
use std::cmp::Ordering;
use std::collections::BTreeMap;
use std::fs;
use std::path::PathBuf;
use syn::punctuated::Punctuated;
use syn::token::{Paren, PathSep};
use syn::token::PathSep;
use syn::{
AngleBracketedGenericArguments, AttrStyle, Attribute, Field, Fields, FieldsNamed, File,
GenericArgument, GenericParam, Ident, Item, ItemStruct, MacroDelimiter, Meta, MetaList, Path,
PathArguments, PathSegment, TraitBound, TraitBoundModifier, Type, TypeParam, TypeParamBound,
TypePath,
AngleBracketedGenericArguments, Field, Fields, FieldsNamed, File, GenericArgument,
GenericParam, Ident, Item, ItemStruct, Path, PathArguments, PathSegment, TraitBound,
TraitBoundModifier, Type, TypeParam, TypeParamBound, TypePath,
};

fn as_struct(item: &mut Item) -> Option<&mut ItemStruct> {
Expand Down Expand Up @@ -78,6 +77,63 @@ fn is_important(field: &mut Field) -> Option<(FoundEnclosure, &mut TypePath)> {
// prost Message implements a generic B which was conflicting with this script
const GENERICS: [&str; 10] = ["A", "BB", "C", "D", "E", "F", "G", "H", "I", "J"];

fn gen_any(name: &str) -> Path {
let mut paths = create_punctuated(vec!["crate", "any"]);

let mut punctuation = Punctuated::new();
punctuation.push(GenericArgument::Type(Type::Path(TypePath {
qself: None,
path: gen_generic(name),
})));

paths.push(PathSegment {
ident: Ident::new("Any", Span::call_site()),
arguments: PathArguments::AngleBracketed(AngleBracketedGenericArguments {
colon2_token: None,
lt_token: Default::default(),
args: punctuation,
gt_token: Default::default(),
}),
});

Path {
leading_colon: None,
segments: paths,
}
}

/// Find any type of Any<GENERIC>
fn contains_any_generic(segment: &PathSegment) -> bool {
if segment.ident == "Any" {
if let PathArguments::AngleBracketed(arguments) = &segment.arguments {
if let Some(GenericArgument::Type(Type::Path(type_path))) = arguments.args.last() {
return GENERICS.contains(
&type_path
.path
.segments
.last()
.unwrap()
.ident
.to_string()
.as_str(),
);
}
}
}

false
}

/// Function assumes we have a type Any already and want to update the generic
fn replace_generic(segment: &mut PathSegment, name: &str) {
if let PathArguments::AngleBracketed(arguments) = &mut segment.arguments {
if let Some(GenericArgument::Type(Type::Path(type_path))) = arguments.args.last_mut() {
let last = type_path.path.segments.last_mut().unwrap();
last.ident = Ident::new(name, Span::call_site());
}
}
}

fn gen_generic(name: &str) -> Path {
Path {
leading_colon: None,
Expand Down Expand Up @@ -118,12 +174,6 @@ fn gen_unnamed_param(name: &str) -> TypeParam {
type_param
.bounds
.push(trait_param_bound(vec!["prost", "Message"]));
type_param
.bounds
.push(trait_param_bound(vec!["serde", "Serialize"]));
type_param
.bounds
.push(trait_param_bound(vec!["serde", "de", "DeserializeOwned"]));

type_param
}
Expand Down Expand Up @@ -160,11 +210,6 @@ fn load_and_patch_any(out_dir: &str) -> BTreeMap<String, (File, BTreeMap<String,
for src in src_files {
let current_file = fs::read_to_string(&src).unwrap();

// // Filter files that dont have `Any`
// if !any_filter.is_match(&current_file) {
// continue;
// }

let mut ast = syn::parse_file(&current_file).unwrap();

// Get all struct we might work with here
Expand All @@ -177,49 +222,14 @@ fn load_and_patch_any(out_dir: &str) -> BTreeMap<String, (File, BTreeMap<String,
// Find any fields and replace with generics
let fields = as_named_fields(&mut item.fields).unwrap();
for field in fields.named.iter_mut() {
if let Some((ty, path)) = is_important(field) {
if let Some((_, path)) = is_important(field) {
if path.path.segments.last().unwrap().ident == "Any" {
// Set struct generics
let generic = GENERICS[item.generics.params.len()];
path.path = gen_generic(generic);
path.path = gen_any(generic);
item.generics
.params
.push(GenericParam::Type(gen_type_param(generic)));

// Set serialization function
let serde_path = match ty {
FoundEnclosure::Option => "option",
FoundEnclosure::Vec => "vec",
};

let mut token_stream = TokenStream::new();
token_stream.append(Ident::new("serialize_with", Span::call_site()));
token_stream.append(Punct::new('=', Spacing::Alone));
token_stream.append(Literal::string(&format!(
"crate::any::{}::serialize",
serde_path
)));
token_stream.append(Punct::new(',', Spacing::Alone));
token_stream.append(Ident::new("deserialize_with", Span::call_site()));
token_stream.append(Punct::new('=', Spacing::Alone));
token_stream.append(Literal::string(&format!(
"crate::any::{}::deserialize",
serde_path
)));

field.attrs.push(Attribute {
pound_token: Default::default(),
style: AttrStyle::Outer,
bracket_token: Default::default(),
meta: Meta::List(MetaList {
path: Path {
leading_colon: None,
segments: create_punctuated(vec!["serde"]),
},
delimiter: MacroDelimiter::Paren(Paren::default()),
tokens: token_stream,
}),
});
}
}
}
Expand Down Expand Up @@ -277,8 +287,7 @@ fn patch_generics(files: &mut BTreeMap<String, (File, BTreeMap<String, usize>)>)
.iter_mut()
{
let name = field.ident.clone().unwrap().to_string();
let mut found_ty = None;
if let Some((field_ty, path)) = is_important(field) {
if let Some((_, path)) = is_important(field) {
let ty = path.path.segments.last_mut().unwrap();

let ident_name = ty.ident.to_string();
Expand Down Expand Up @@ -319,7 +328,6 @@ fn patch_generics(files: &mut BTreeMap<String, (File, BTreeMap<String, usize>)>)
key.items.get_mut(*s.get("GenesisState").unwrap()).unwrap(),
)
.unwrap();
found_ty = Some(field_ty);
new_total_generics = push_generics(ty_struct, ty, new_total_generics);
} else if let Some(i) = structs.get(&ident_name) {
let ty_item = match i.cmp(&idx) {
Expand All @@ -329,10 +337,9 @@ fn patch_generics(files: &mut BTreeMap<String, (File, BTreeMap<String, usize>)>)
};

let ty_struct = as_struct(ty_item).unwrap();
found_ty = Some(field_ty);
new_total_generics = push_generics(ty_struct, ty, new_total_generics);
} else if GENERICS.contains(&&*ident_name) {
ty.ident = Ident::new(GENERICS[new_total_generics], Span::call_site());
} else if contains_any_generic(ty) {
replace_generic(ty, GENERICS[new_total_generics]);
new_total_generics += 1;
} else if let Some((_, (other_ast, other_structs))) = files
.iter_mut()
Expand All @@ -345,55 +352,9 @@ fn patch_generics(files: &mut BTreeMap<String, (File, BTreeMap<String, usize>)>)
.unwrap(),
)
.unwrap();
found_ty = Some(field_ty);
new_total_generics = push_generics(ty_struct, ty, new_total_generics);
}
}

// Try to add field attrs
if let Some(found_ty) = found_ty {
let last = field.attrs.last().unwrap().clone();
if let Meta::List(meta_list) = &last.meta {
if meta_list.path.segments.last().unwrap().ident != "serde" {
// Set serialization function
let serde_path = match found_ty {
FoundEnclosure::Option => "option",
FoundEnclosure::Vec => "vec",
};

let mut token_stream = TokenStream::new();
token_stream
.append(Ident::new("serialize_with", Span::call_site()));
token_stream.append(Punct::new('=', Spacing::Alone));
token_stream.append(Literal::string(&format!(
"crate::any::{}::generic_serialize",
serde_path
)));
token_stream.append(Punct::new(',', Spacing::Alone));
token_stream
.append(Ident::new("deserialize_with", Span::call_site()));
token_stream.append(Punct::new('=', Spacing::Alone));
token_stream.append(Literal::string(&format!(
"crate::any::{}::generic_deserialize",
serde_path
)));

field.attrs.push(Attribute {
pound_token: Default::default(),
style: AttrStyle::Outer,
bracket_token: Default::default(),
meta: Meta::List(MetaList {
path: Path {
leading_colon: None,
segments: create_punctuated(vec!["serde"]),
},
delimiter: MacroDelimiter::Paren(Paren::default()),
tokens: token_stream,
}),
});
}
}
}
}

if new_total_generics > 0 {
Expand Down

0 comments on commit abbfcce

Please sign in to comment.