-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix performance regression for JSON tagged union #1552
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,44 +117,6 @@ fn union_serialize<S>( | |
Ok(None) | ||
} | ||
|
||
fn tagged_union_serialize<S>( | ||
discriminator_value: Option<Py<PyAny>>, | ||
lookup: &HashMap<String, usize>, | ||
// if this returns `Ok(v)`, we picked a union variant to serialize, where | ||
// `S` is intermediate state which can be passed on to the finalizer | ||
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>, | ||
extra: &Extra, | ||
choices: &[CombinedSerializer], | ||
retry_with_lax_check: bool, | ||
) -> PyResult<Option<S>> { | ||
let mut new_extra = extra.clone(); | ||
new_extra.check = SerCheck::Strict; | ||
|
||
if let Some(tag) = discriminator_value { | ||
let tag_str = tag.to_string(); | ||
if let Some(&serializer_index) = lookup.get(&tag_str) { | ||
let selected_serializer = &choices[serializer_index]; | ||
|
||
match selector(selected_serializer, &new_extra) { | ||
Ok(v) => return Ok(Some(v)), | ||
Err(_) => { | ||
if retry_with_lax_check { | ||
new_extra.check = SerCheck::Lax; | ||
if let Ok(v) = selector(selected_serializer, &new_extra) { | ||
return Ok(Some(v)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// if we haven't returned at this point, we should fallback to the union serializer | ||
// which preserves the historical expectation that we do our best with serialization | ||
// even if that means we resort to inference | ||
union_serialize(selector, extra, choices, retry_with_lax_check) | ||
} | ||
|
||
impl TypeSerializer for UnionSerializer { | ||
fn to_python( | ||
&self, | ||
|
@@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer { | |
exclude: Option<&Bound<'_, PyAny>>, | ||
extra: &Extra, | ||
) -> PyResult<PyObject> { | ||
tagged_union_serialize( | ||
self.get_discriminator_value(value, extra), | ||
&self.lookup, | ||
self.tagged_union_serialize( | ||
value, | ||
|comb_serializer: &CombinedSerializer, new_extra: &Extra| { | ||
comb_serializer.to_python(value, include, exclude, new_extra) | ||
}, | ||
extra, | ||
&self.choices, | ||
self.retry_with_lax_check(), | ||
)? | ||
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) | ||
} | ||
|
||
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { | ||
tagged_union_serialize( | ||
self.get_discriminator_value(key, extra), | ||
&self.lookup, | ||
self.tagged_union_serialize( | ||
key, | ||
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra), | ||
extra, | ||
&self.choices, | ||
self.retry_with_lax_check(), | ||
)? | ||
.map_or_else(|| infer_json_key(key, extra), Ok) | ||
} | ||
|
@@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer { | |
exclude: Option<&Bound<'_, PyAny>>, | ||
extra: &Extra, | ||
) -> Result<S::Ok, S::Error> { | ||
match tagged_union_serialize( | ||
None, | ||
&self.lookup, | ||
match self.tagged_union_serialize( | ||
value, | ||
|comb_serializer: &CombinedSerializer, new_extra: &Extra| { | ||
comb_serializer.to_python(value, include, exclude, new_extra) | ||
}, | ||
extra, | ||
&self.choices, | ||
self.retry_with_lax_check(), | ||
) { | ||
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), | ||
Ok(None) => infer_serialize(value, serializer, include, exclude, extra), | ||
|
@@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer { | |
} | ||
|
||
impl TaggedUnionSerializer { | ||
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> { | ||
fn get_discriminator_value<'py>(&self, value: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> { | ||
let py = value.py(); | ||
let discriminator_value = match &self.discriminator { | ||
match &self.discriminator { | ||
Discriminator::LookupKey(lookup_key) => { | ||
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could | ||
// be doing a discriminator lookup on a typed dict, and there's no good way to check that | ||
// at this point. we could be more strict and only do this in lax mode... | ||
let getattr_result = match value.is_instance_of::<PyDict>() { | ||
true => { | ||
let value_dict = value.downcast::<PyDict>().unwrap(); | ||
lookup_key.py_get_dict_item(value_dict).ok() | ||
} | ||
false => lookup_key.simple_py_get_attr(value).ok(), | ||
}; | ||
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))) | ||
if let Ok(value_dict) = value.downcast::<PyDict>() { | ||
lookup_key.py_get_dict_item(value_dict).ok().flatten() | ||
} else { | ||
lookup_key.simple_py_get_attr(value).ok().flatten() | ||
} | ||
.map(|(_, tag)| tag) | ||
} | ||
Discriminator::Function(func) => func.call1(py, (value,)).ok(), | ||
}; | ||
if discriminator_value.is_none() { | ||
let value_str = truncate_safe_repr(value, None); | ||
Discriminator::Function(func) => func.bind(py).call1((value,)).ok(), | ||
} | ||
} | ||
|
||
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning | ||
if extra.check == SerCheck::None { | ||
extra.warnings.custom_warning( | ||
format!( | ||
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization." | ||
) | ||
); | ||
fn tagged_union_serialize<S>( | ||
&self, | ||
value: &Bound<'_, PyAny>, | ||
// if this returns `Ok(v)`, we picked a union variant to serialize, where | ||
// `S` is intermediate state which can be passed on to the finalizer | ||
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>, | ||
extra: &Extra, | ||
) -> PyResult<Option<S>> { | ||
if let Some(tag) = self.get_discriminator_value(value) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sure that we always try to get the tag by having the method call here 😉 |
||
let mut new_extra = extra.clone(); | ||
new_extra.check = SerCheck::Strict; | ||
|
||
let tag_str = tag.to_string(); | ||
if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
let selected_serializer = &self.choices[serializer_index]; | ||
|
||
match selector(selected_serializer, &new_extra) { | ||
Ok(v) => return Ok(Some(v)), | ||
Err(_) => { | ||
if self.retry_with_lax_check() { | ||
new_extra.check = SerCheck::Lax; | ||
if let Ok(v) = selector(selected_serializer, &new_extra) { | ||
return Ok(Some(v)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} else if extra.check == SerCheck::None { | ||
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise | ||
// this warning | ||
let value_str = truncate_safe_repr(value, None); | ||
extra.warnings.custom_warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved this warning here from |
||
format!( | ||
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization." | ||
) | ||
); | ||
} | ||
discriminator_value | ||
|
||
// if we haven't returned at this point, we should fallback to the union serializer | ||
// which preserves the historical expectation that we do our best with serialization | ||
// even if that means we resort to inference | ||
union_serialize(selector, extra, &self.choices, self.retry_with_lax_check()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved down, to be a method on
TaggedUnionSerializer
.