Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subclassing support #120

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/pure/pure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class A:
...


class B(A):
...

class Number(Enum):
FLOAT = auto()
INTEGER = auto()
Expand Down
8 changes: 7 additions & 1 deletion examples/pure/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn create_dict(n: usize) -> HashMap<usize, Vec<usize>> {
}

#[gen_stub_pyclass]
#[pyclass]
#[pyclass(subclass)]
#[derive(Debug)]
struct A {
#[pyo3(get, set)]
Expand Down Expand Up @@ -64,6 +64,11 @@ fn create_a(x: usize) -> A {
A { x }
}

#[gen_stub_pyclass]
#[pyclass(extends=A)]
#[derive(Debug)]
struct B;

create_exception!(pure, MyError, PyRuntimeError);

/// Returns the length of the string.
Expand Down Expand Up @@ -107,6 +112,7 @@ fn pure(m: &Bound<PyModule>) -> PyResult<()> {
m.add("MyError", m.py().get_type::<MyError>())?;
m.add("MY_CONSTANT", 19937)?;
m.add_class::<A>()?;
m.add_class::<B>()?;
m.add_class::<Number>()?;
m.add_function(wrap_pyfunction!(sum, m)?)?;
m.add_function(wrap_pyfunction!(create_dict, m)?)?;
Expand Down
1 change: 1 addition & 0 deletions pyo3-stub-gen-derive/src/gen_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
//! },
//! ],
//! doc: "",
//! bases: &[],
//! }
//! }
//! ```
Expand Down
6 changes: 6 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub enum Attr {
GetAll,
Module(String),
Signature(Signature),
Extends(String),

// Attributes appears in components within `#[pymethods]`
// <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
Expand Down Expand Up @@ -120,6 +121,11 @@ pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
}
}
[Ident(ident), Punct(_), Ident(ident2)] => {
if ident == "extends" {
pyo3_attrs.push(Attr::Extends(ident2.to_string()));
}
}
_ => {}
}
}
Expand Down
15 changes: 15 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub struct PyClassInfo {
module: Option<String>,
members: Vec<MemberInfo>,
doc: String,
bases: Vec<(Option<String>, String)>,
}

impl From<&PyClassInfo> for StubType {
Expand Down Expand Up @@ -41,13 +42,16 @@ impl TryFrom<ItemStruct> for PyClassInfo {
let mut pyclass_name = None;
let mut module = None;
let mut is_get_all = false;
let mut bases = Vec::new();
for attr in parse_pyo3_attrs(&attrs)? {
match attr {
Attr::Name(name) => pyclass_name = Some(name),
Attr::Module(name) => {
module = Some(name);
}
Attr::GetAll => is_get_all = true,
// TODO: allow other modules
Attr::Extends(name) => bases.push((module.clone(), name)),
_ => {}
}
}
Expand All @@ -65,6 +69,7 @@ impl TryFrom<ItemStruct> for PyClassInfo {
members,
module,
doc,
bases,
})
}
}
Expand All @@ -77,15 +82,24 @@ impl ToTokens for PyClassInfo {
members,
doc,
module,
bases,
} = self;
let module = quote_option(module);
let bases: Vec<_> = bases
.iter()
.map(|(mod_, name)| {
let mod_ = quote_option(mod_);
quote! { (#mod_, #name) }
})
.collect();
tokens.append_all(quote! {
::pyo3_stub_gen::type_info::PyClassInfo {
pyclass_name: #pyclass_name,
struct_id: std::any::TypeId::of::<#struct_type>,
members: &[ #( #members),* ],
module: #module,
doc: #doc,
bases: &[ #(#bases),* ],
}
})
}
Expand Down Expand Up @@ -136,6 +150,7 @@ mod test {
],
module: Some("my_module"),
doc: "",
bases: &[],
}
"###);
Ok(())
Expand Down
14 changes: 13 additions & 1 deletion pyo3-stub-gen/src/generate/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct ClassDef {
pub new: Option<NewDef>,
pub members: Vec<MemberDef>,
pub methods: Vec<MethodDef>,
pub bases: &'static [(Option<&'static str>, &'static str)],
}

impl Import for ClassDef {
Expand Down Expand Up @@ -37,13 +38,24 @@ impl From<&PyClassInfo> for ClassDef {
doc: info.doc,
members: info.members.iter().map(MemberDef::from).collect(),
methods: Vec::new(),
bases: info.bases,
}
}
}

impl fmt::Display for ClassDef {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "class {}:", self.name)?;
let bases = self
.bases
.iter()
.map(|(m, n)| {
m.map(|m| format!("{m}.{n}"))
.unwrap_or_else(|| n.to_string())
})
.reduce(|acc, path| format!("{acc}, {path}"))
.map(|bases| format!("({bases})"))
.unwrap_or_default();
writeln!(f, "class {}{}:", self.name, bases)?;
let indent = indent();
let doc = self.doc.trim();
if !doc.is_empty() {
Expand Down
4 changes: 4 additions & 0 deletions pyo3-stub-gen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
//! r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
//! },
//! ],
//!
//! doc: "Docstring used in Python",
//!
//! // Base classes
//! bases: &[],
//! }
//! }
//! ```
Expand Down
2 changes: 2 additions & 0 deletions pyo3-stub-gen/src/type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pub struct PyClassInfo {
pub doc: &'static str,
/// static members by `#[pyo3(get, set)]`
pub members: &'static [MemberInfo],
/// Base classes specified by `#[pyclass(extends = Type)]`
pub bases: &'static [(Option<&'static str>, &'static str)],
}

inventory::collect!(PyClassInfo);
Expand Down
Loading