Skip to content

Commit

Permalink
Implement
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Dec 12, 2024
1 parent 356c96a commit 120963b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
6 changes: 6 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub enum Attr {
GetAll,
Module(String),
Signature(Signature),
Extends(String),

// Attributes appears in components within `#[pymethods]`
// <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
Expand Down Expand Up @@ -120,6 +121,11 @@ pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
}
}
[Ident(ident), Punct(_), Ident(ident2)] => {
if ident == "extends" {
pyo3_attrs.push(Attr::Extends(ident2.to_string()));
}
}
_ => {}
}
}
Expand Down
11 changes: 11 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub struct PyClassInfo {
module: Option<String>,
members: Vec<MemberInfo>,
doc: String,
bases: Vec<(Option<String>, String)>,
}

impl From<&PyClassInfo> for StubType {
Expand Down Expand Up @@ -41,13 +42,16 @@ impl TryFrom<ItemStruct> for PyClassInfo {
let mut pyclass_name = None;
let mut module = None;
let mut is_get_all = false;
let mut bases = Vec::new();
for attr in parse_pyo3_attrs(&attrs)? {
match attr {
Attr::Name(name) => pyclass_name = Some(name),
Attr::Module(name) => {
module = Some(name);
}
Attr::GetAll => is_get_all = true,
// TODO: allow other modules
Attr::Extends(name) => bases.push((module.clone(), name)),
_ => {}
}
}
Expand All @@ -65,6 +69,7 @@ impl TryFrom<ItemStruct> for PyClassInfo {
members,
module,
doc,
bases,
})
}
}
Expand All @@ -77,15 +82,21 @@ impl ToTokens for PyClassInfo {
members,
doc,
module,
bases,
} = self;
let module = quote_option(module);
let bases: Vec<_> = bases.into_iter().map(|(mod_, name)| {
let mod_ = quote_option(mod_);
quote! { (#mod_, #name) }
}).collect();
tokens.append_all(quote! {
::pyo3_stub_gen::type_info::PyClassInfo {
pyclass_name: #pyclass_name,
struct_id: std::any::TypeId::of::<#struct_type>,
members: &[ #( #members),* ],
module: #module,
doc: #doc,
bases: &[ #(#bases),* ],
}
})
}
Expand Down
10 changes: 9 additions & 1 deletion pyo3-stub-gen/src/generate/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct ClassDef {
pub new: Option<NewDef>,
pub members: Vec<MemberDef>,
pub methods: Vec<MethodDef>,
pub bases: &'static [(Option<&'static str>, &'static str)],
}

impl Import for ClassDef {
Expand Down Expand Up @@ -37,13 +38,20 @@ impl From<&PyClassInfo> for ClassDef {
doc: info.doc,
members: info.members.iter().map(MemberDef::from).collect(),
methods: Vec::new(),
bases: info.bases,
}
}
}

impl fmt::Display for ClassDef {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "class {}:", self.name)?;
let bases = self.bases
.into_iter()
.map(|(m, n)| m.map(|m| format!("{m}.{n}")).unwrap_or_else(|| n.to_string()))
.reduce(|acc, path| format!("{acc}, {path}"))
.map(|bases| format!("({bases})"))
.unwrap_or_default();
writeln!(f, "class {}{}:", self.name, bases)?;
let indent = indent();
let doc = self.doc.trim();
if !doc.is_empty() {
Expand Down
2 changes: 2 additions & 0 deletions pyo3-stub-gen/src/type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pub struct PyClassInfo {
pub doc: &'static str,
/// static members by `#[pyo3(get, set)]`
pub members: &'static [MemberInfo],
/// Base classes specified by `#[pyclass(extends = Type)]`
pub bases: &'static [(Option<&'static str>, &'static str)],
}

inventory::collect!(PyClassInfo);
Expand Down

0 comments on commit 120963b

Please sign in to comment.