Skip to content

Commit

Permalink
Fix signature of CaptureMeta objects (#5727)
Browse files Browse the repository at this point in the history
**Context:**

In PR #5511 , we introduced the meta type `qml.capture.MetaOperator`
that hijacked the class creation processes. The `__call__` signature for
these classes got altered to the completely general `(cls, *args,
**kwargs)` instead of the more specific version used by `__init__`.

**Description of the Change:**

Adds a `__signature__` property to `MetaOperator` that ensures that the
signature will always be the `__init__` signature.

**Benefits:**

Signatures continue to match the `__init__` definition.

**Possible Drawbacks:**

The metaprogramming black magic may still have other consequences we
still aren't aware of. This fixes the particular signature issue, but it
still shows the consequences of metaprogramming.

**Related GitHub Issues:**

Fixes #5724 [sc-63734]

<img width="487" alt="Screenshot 2024-05-22 at 9 42 10 AM"
src="https://github.com/PennyLaneAI/pennylane/assets/6364575/1abf41ce-b72a-4734-a069-3dd37e7f8c28">
  • Loading branch information
albi3ro authored May 22, 2024
1 parent 48f21d8 commit b7b4b75
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@
returned samples of observables containing the `qml.Identity` operator.
[(#5607)](https://github.com/PennyLaneAI/pennylane/pull/5607)

* The signature of `CaptureMeta` objects (like `Operator`) now match the signature of the `__init__` call.
[(#5727)](https://github.com/PennyLaneAI/pennylane/pull/5727)

* Use vanilla NumPy arrays in `test_projector_expectation` to avoid differentiating `qml.Projector` with respect to the state attribute.
[(#5683)](https://github.com/PennyLaneAI/pennylane/pull/5683)

Expand Down
7 changes: 7 additions & 0 deletions pennylane/capture/capture_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
See ``explanations.md`` for technical explanations of how this works.
"""
from inspect import Signature, signature

from .switches import enabled

Expand All @@ -29,6 +30,12 @@ class CaptureMeta(type):
works.
"""

@property
def __signature__(cls):
sig = signature(cls.__init__)
without_self = tuple(sig.parameters.values())[1:]
return Signature(without_self)

def _primitive_bind_call(cls, *args, **kwargs):
raise NotImplementedError(
"Types using CaptureMeta must implement cls._primitive_bind_call to"
Expand Down
12 changes: 6 additions & 6 deletions pennylane/ops/qubit/non_parametric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ class PauliX(Observable, Operation):

_queue_category = "_ops"

def __init__(self, *params, wires=None, id=None):
super().__init__(*params, wires=wires, id=id)
def __init__(self, wires=None, id=None):
super().__init__(wires=wires, id=id)
self._pauli_rep = qml.pauli.PauliSentence({qml.pauli.PauliWord({self.wires[0]: "X"}): 1.0})

def label(self, decimals=None, base_label=None, cache=None):
Expand Down Expand Up @@ -391,8 +391,8 @@ class PauliY(Observable, Operation):

_queue_category = "_ops"

def __init__(self, *params, wires=None, id=None):
super().__init__(*params, wires=wires, id=id)
def __init__(self, wires=None, id=None):
super().__init__(wires=wires, id=id)
self._pauli_rep = qml.pauli.PauliSentence({qml.pauli.PauliWord({self.wires[0]: "Y"}): 1.0})

def __repr__(self):
Expand Down Expand Up @@ -575,8 +575,8 @@ class PauliZ(Observable, Operation):

_queue_category = "_ops"

def __init__(self, *params, wires=None, id=None):
super().__init__(*params, wires=wires, id=id)
def __init__(self, wires=None, id=None):
super().__init__(wires=wires, id=id)
self._pauli_rep = qml.pauli.PauliSentence({qml.pauli.PauliWord({self.wires[0]: "Z"}): 1.0})

def __repr__(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/capture/test_meta_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""
Unit tests for the CaptureMeta metaclass.
"""
from inspect import signature

# pylint: disable=protected-access, undefined-variable
import pytest

Expand Down Expand Up @@ -51,11 +53,21 @@ class MyObj(metaclass=CaptureMeta):
def _primitive_bind_call(cls, *args, **kwargs):
return p.bind(*args, **kwargs)

def __init__(self, a: int, b: bool):
self.a = a
self.b = b

def f(a: int, b: bool):
# similar signature to MyObj but without init
return a + b

jaxpr = jax.make_jaxpr(MyObj)(0.5)

assert len(jaxpr.eqns) == 1
assert jaxpr.eqns[0].primitive == p

assert signature(MyObj) == signature(f)


def test_custom_capture_meta_no_bind_primitive_call():
"""Test that an NotImplementedError is raised if the type does not define _primitive_bind_call."""
Expand Down

0 comments on commit b7b4b75

Please sign in to comment.