Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix performance regression for JSON tagged union #1552

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 58 additions & 75 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,44 +117,6 @@ fn union_serialize<S>(
Ok(None)
}

fn tagged_union_serialize<S>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved down, to be a method on TaggedUnionSerializer.

discriminator_value: Option<Py<PyAny>>,
lookup: &HashMap<String, usize>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<Option<S>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = discriminator_value {
let tag_str = tag.to_string();
if let Some(&serializer_index) = lookup.get(&tag_str) {
let selected_serializer = &choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
}

// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, choices, retry_with_lax_check)
}

impl TypeSerializer for UnionSerializer {
fn to_python(
&self,
Expand Down Expand Up @@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
tagged_union_serialize(
self.get_discriminator_value(value, extra),
&self.lookup,
self.tagged_union_serialize(
value,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
tagged_union_serialize(
self.get_discriminator_value(key, extra),
&self.lookup,
self.tagged_union_serialize(
key,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}
Expand All @@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match tagged_union_serialize(
None,
&self.lookup,
match self.tagged_union_serialize(
value,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Expand All @@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer {
}

impl TaggedUnionSerializer {
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
fn get_discriminator_value<'py>(&self, value: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> {
let py = value.py();
let discriminator_value = match &self.discriminator {
match &self.discriminator {
Discriminator::LookupKey(lookup_key) => {
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
// at this point. we could be more strict and only do this in lax mode...
let getattr_result = match value.is_instance_of::<PyDict>() {
true => {
let value_dict = value.downcast::<PyDict>().unwrap();
lookup_key.py_get_dict_item(value_dict).ok()
}
false => lookup_key.simple_py_get_attr(value).ok(),
};
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
if let Ok(value_dict) = value.downcast::<PyDict>() {
lookup_key.py_get_dict_item(value_dict).ok().flatten()
} else {
lookup_key.simple_py_get_attr(value).ok().flatten()
}
.map(|(_, tag)| tag)
}
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
let value_str = truncate_safe_repr(value, None);
Discriminator::Function(func) => func.bind(py).call1((value,)).ok(),
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
if extra.check == SerCheck::None {
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
fn tagged_union_serialize<S>(
&self,
value: &Bound<'_, PyAny>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
) -> PyResult<Option<S>> {
if let Some(tag) = self.get_discriminator_value(value) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sure that we always try to get the tag by having the method call here 😉

let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let selected_serializer = &self.choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
} else if extra.check == SerCheck::None {
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise
// this warning
let value_str = truncate_safe_repr(value, None);
extra.warnings.custom_warning(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this warning here from get_discriminator_value, seemed more correct as the other warnings are raised in union_serialize.

format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
}
discriminator_value

// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, &self.choices, self.retry_with_lax_check())
}
}