diff --git a/examples/pure/pure.pyi b/examples/pure/pure.pyi index 0e22628..dfd37af 100644 --- a/examples/pure/pure.pyi +++ b/examples/pure/pure.pyi @@ -17,6 +17,9 @@ class A: ... +class B(A): + ... + class Number(Enum): FLOAT = auto() INTEGER = auto() diff --git a/examples/pure/src/lib.rs b/examples/pure/src/lib.rs index 743b5d2..6cfe73f 100644 --- a/examples/pure/src/lib.rs +++ b/examples/pure/src/lib.rs @@ -34,7 +34,7 @@ fn create_dict(n: usize) -> HashMap> { } #[gen_stub_pyclass] -#[pyclass] +#[pyclass(subclass)] #[derive(Debug)] struct A { #[pyo3(get, set)] @@ -64,6 +64,11 @@ fn create_a(x: usize) -> A { A { x } } +#[gen_stub_pyclass] +#[pyclass(extends=A)] +#[derive(Debug)] +struct B; + create_exception!(pure, MyError, PyRuntimeError); /// Returns the length of the string. @@ -107,6 +112,7 @@ fn pure(m: &Bound) -> PyResult<()> { m.add("MyError", m.py().get_type::())?; m.add("MY_CONSTANT", 19937)?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(sum, m)?)?; m.add_function(wrap_pyfunction!(create_dict, m)?)?; diff --git a/pyo3-stub-gen-derive/src/gen_stub.rs b/pyo3-stub-gen-derive/src/gen_stub.rs index 4e6436f..a6d6495 100644 --- a/pyo3-stub-gen-derive/src/gen_stub.rs +++ b/pyo3-stub-gen-derive/src/gen_stub.rs @@ -26,6 +26,7 @@ //! }, //! ], //! doc: "", +//! bases: &[], //! } //! } //! ``` diff --git a/pyo3-stub-gen-derive/src/gen_stub/attr.rs b/pyo3-stub-gen-derive/src/gen_stub/attr.rs index 046a6cf..5ccbd1b 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/attr.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/attr.rs @@ -54,6 +54,7 @@ pub enum Attr { GetAll, Module(String), Signature(Signature), + Extends(String), // Attributes appears in components within `#[pymethods]` // @@ -120,6 +121,11 @@ pub fn parse_pyo3_attr(attr: &Attribute) -> Result> { pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?)); } } + [Ident(ident), Punct(_), Ident(ident2)] => { + if ident == "extends" { + pyo3_attrs.push(Attr::Extends(ident2.to_string())); + } + } _ => {} } } diff --git a/pyo3-stub-gen-derive/src/gen_stub/pyclass.rs b/pyo3-stub-gen-derive/src/gen_stub/pyclass.rs index d2945c0..6722f07 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/pyclass.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/pyclass.rs @@ -10,6 +10,7 @@ pub struct PyClassInfo { module: Option, members: Vec, doc: String, + bases: Vec<(Option, String)>, } impl From<&PyClassInfo> for StubType { @@ -41,6 +42,7 @@ impl TryFrom for PyClassInfo { let mut pyclass_name = None; let mut module = None; let mut is_get_all = false; + let mut bases = Vec::new(); for attr in parse_pyo3_attrs(&attrs)? { match attr { Attr::Name(name) => pyclass_name = Some(name), @@ -48,6 +50,8 @@ impl TryFrom for PyClassInfo { module = Some(name); } Attr::GetAll => is_get_all = true, + // TODO: allow other modules + Attr::Extends(name) => bases.push((module.clone(), name)), _ => {} } } @@ -65,6 +69,7 @@ impl TryFrom for PyClassInfo { members, module, doc, + bases, }) } } @@ -77,8 +82,16 @@ impl ToTokens for PyClassInfo { members, doc, module, + bases, } = self; let module = quote_option(module); + let bases: Vec<_> = bases + .iter() + .map(|(mod_, name)| { + let mod_ = quote_option(mod_); + quote! { (#mod_, #name) } + }) + .collect(); tokens.append_all(quote! { ::pyo3_stub_gen::type_info::PyClassInfo { pyclass_name: #pyclass_name, @@ -86,6 +99,7 @@ impl ToTokens for PyClassInfo { members: &[ #( #members),* ], module: #module, doc: #doc, + bases: &[ #(#bases),* ], } }) } @@ -136,6 +150,7 @@ mod test { ], module: Some("my_module"), doc: "", + bases: &[], } "###); Ok(()) diff --git a/pyo3-stub-gen/src/generate/class.rs b/pyo3-stub-gen/src/generate/class.rs index e115ef8..cbfacf5 100644 --- a/pyo3-stub-gen/src/generate/class.rs +++ b/pyo3-stub-gen/src/generate/class.rs @@ -9,6 +9,7 @@ pub struct ClassDef { pub new: Option, pub members: Vec, pub methods: Vec, + pub bases: &'static [(Option<&'static str>, &'static str)], } impl Import for ClassDef { @@ -37,13 +38,24 @@ impl From<&PyClassInfo> for ClassDef { doc: info.doc, members: info.members.iter().map(MemberDef::from).collect(), methods: Vec::new(), + bases: info.bases, } } } impl fmt::Display for ClassDef { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "class {}:", self.name)?; + let bases = self + .bases + .iter() + .map(|(m, n)| { + m.map(|m| format!("{m}.{n}")) + .unwrap_or_else(|| n.to_string()) + }) + .reduce(|acc, path| format!("{acc}, {path}")) + .map(|bases| format!("({bases})")) + .unwrap_or_default(); + writeln!(f, "class {}{}:", self.name, bases)?; let indent = indent(); let doc = self.doc.trim(); if !doc.is_empty() { diff --git a/pyo3-stub-gen/src/lib.rs b/pyo3-stub-gen/src/lib.rs index 176e619..ac89d68 100644 --- a/pyo3-stub-gen/src/lib.rs +++ b/pyo3-stub-gen/src/lib.rs @@ -46,7 +46,11 @@ //! r#type: as ::pyo3_stub_gen::PyStubType>::type_output, //! }, //! ], +//! //! doc: "Docstring used in Python", +//! +//! // Base classes +//! bases: &[], //! } //! } //! ``` diff --git a/pyo3-stub-gen/src/type_info.rs b/pyo3-stub-gen/src/type_info.rs index 1ae983a..f140057 100644 --- a/pyo3-stub-gen/src/type_info.rs +++ b/pyo3-stub-gen/src/type_info.rs @@ -95,6 +95,8 @@ pub struct PyClassInfo { pub doc: &'static str, /// static members by `#[pyo3(get, set)]` pub members: &'static [MemberInfo], + /// Base classes specified by `#[pyclass(extends = Type)]` + pub bases: &'static [(Option<&'static str>, &'static str)], } inventory::collect!(PyClassInfo);