Skip to content

Commit

Permalink
Merge branch 'Deprecate_hamiltonian_in_PauliSentence_and_PauliWord' of
Browse files Browse the repository at this point in the history
…https://github.com/PennyLaneAI/pennylane into Deprecate_hamiltonian_in_PauliSentence_and_PauliWord
  • Loading branch information
PietropaoloFrisoni committed May 14, 2024
2 parents a5082fc + eea3f30 commit 73d0b18
Show file tree
Hide file tree
Showing 24 changed files with 1,230 additions and 55 deletions.
15 changes: 13 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
* Sets up the framework for the development of an `assert_equal` function for testing operator comparison.
[(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634)

* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental
`capture` module for more information.
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)
Expand All @@ -95,12 +99,19 @@

<h3>Bug fixes 🐛</h3>

* Use vanilla NumPy arrays in `test_projector_expectation` to avoid differentiating `qml.Projector` with respect to the state attribute.
[(#5683)](https://github.com/PennyLaneAI/pennylane/pull/5683)

* `qml.Projector` is now compatible with jax-jit.
[(#5595)](https://github.com/PennyLaneAI/pennylane/pull/5595)

* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`.
[(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619)

* `param_shift`, `finite_diff`, `compile`, `merge_rotations`, and `transpile` now all work
with circuits with non-commuting measurements.
* `param_shift`, `finite_diff`, `compile`, `insert`, `merge_rotations`, and `transpile` now
all work with circuits with non-commuting measurements.
[(#5424)](https://github.com/PennyLaneAI/pennylane/pull/5424)
[(#5681)](https://github.com/PennyLaneAI/pennylane/pull/5681)

* A correction is added to `bravyi_kitaev` to call the correct function for a FermiSentence input.
[(#5671)](https://github.com/PennyLaneAI/pennylane/pull/5671)
Expand Down
73 changes: 73 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,78 @@
>>> qml.capture.enabled()
False
**Custom Operator Behaviour**
Any operator that inherits from :class:`~.Operator` gains a default ability to be captured
in a Jaxpr. Any positional argument is bound as a tracer, wires are processed out into individual tracers,
and any keyword arguments are passed as keyword metadata.
.. code-block:: python
class MyOp1(qml.operation.Operator):
def __init__(self, arg1, wires, key=None):
super().__init__(arg1, wires=wires)
def qfunc(a):
MyOp1(a, wires=(0,1), key="a")
qml.capture.enable()
print(jax.make_jaxpr(qfunc)(0.1))
.. code-block::
{ lambda ; a:f32[]. let
_:AbstractOperator() = MyOp1[key=a n_wires=2] a 0 1
in () }
But an operator developer may need to override custom behavior for calling ``cls._primitive.bind``
(where ``cls`` indicates the class) if:
* The operator does not accept wires, like :class:`~.SymbolicOp` or :class:`~.CompositeOp`.
* The operator needs to enforce a data/ metadata distinction, like :class:`~.PauliRot`.
In such cases, the operator developer can override ``cls._primitive_bind_call``, which
will be called when constructing a new class instance instead of ``type.__call__``. For example,
.. code-block:: python
class JustMetadataOp(qml.operation.Operator):
def __init__(self, metadata):
super().__init__(wires=[])
self._metadata = metadata
@classmethod
def _primitive_bind_call(cls, metadata):
return cls._primitive.bind(metadata=metadata)
def qfunc():
JustMetadataOp("Y")
qml.capture.enable()
print(jax.make_jaxpr(qfunc)())
.. code-block::
{ lambda ; . let _:AbstractOperator() = JustMetadataOp[metadata=Y] in () }
As you can see, the input ``"Y"``, while being passed as a positional argument, is converted to
metadata within the custom ``_primitive_bind_call`` method.
If needed, developers can also override the implementation method of the primitive like was done with ``Controlled``.
``Controlled`` needs to do so to handle packing and unpacking the control wires.
.. code-block:: python
class MyCustomOp(qml.operation.Operator):
pass
@MyCustomOp._primitive.def_impl
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta
from .primitives import create_operator_primitive
46 changes: 46 additions & 0 deletions pennylane/capture/capture_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines a metaclass for automatic integration of any ``Operator`` with plxpr program capture.
See ``explanations.md`` for technical explanations of how this works.
"""

from .switches import enabled


# pylint: disable=no-self-argument, too-few-public-methods
class CaptureMeta(type):
"""A metatype that dispatches class creation to ``cls._primitve_bind_call`` instead
of normal class creation.
See ``pennylane/capture/explanations.md`` for more detailed information on how this technically
works.
"""

def _primitive_bind_call(cls, *args, **kwargs):
raise NotImplementedError(
"Types using CaptureMeta must implement cls._primitive_bind_call to"
" gain integration with plxpr program capture."
)

def __call__(cls, *args, **kwargs):
# this method is called everytime we want to create an instance of the class.
# default behavior uses __new__ then __init__

if enabled():
# when tracing is enabled, we want to
# use bind to construct the class if we want class construction to add it to the jaxpr
return cls._primitive_bind_call(*args, **kwargs)
return type.__call__(cls, *args, **kwargs)
237 changes: 237 additions & 0 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
This documentation explains the principles behind `qml.capture.CaptureMeta`.


```python
import jax
```

# Primitive basics


```python
my_func_prim = jax.core.Primitive("my_func")

@my_func_prim.def_impl
def _(x):
return x**2

@my_func_prim.def_abstract_eval
def _(x):
return jax.core.ShapedArray((1,), x.dtype)

def my_func(x):
return my_func_prim.bind(x)
```


```python
>>> jaxpr = jax.make_jaxpr(my_func)(0.1)
>>> jaxpr
{ lambda ; a:f32[]. let b:f32[1] = my_func a in (b,) }
>>> jaxpr.jaxpr.eqns
[a:f32[1] = my_func b]
```

## Metaprogramming


```python
class MyMetaClass(type):

def __init__(cls, *args, **kwargs):
print(f"Creating a new type {cls} with {args}, {kwargs}. ")

# giving every class a property
cls.a = "a"

def __call__(cls, *args, **kwargs):
print(f"creating an instance of type {cls} with {args}, {kwargs}. ")
inst = cls.__new__(cls, *args, **kwargs)
inst.__init__(*args, **kwargs)
return inst
```

Now let's define a class with this meta class.

You can see that when we *define* the class, we have called `MyMetaClass.__init__` to create the new type


```python
class MyClass(metaclass=MyMetaClass):

def __init__(self, *args, **kwargs):
print("now creating an instance in __init__")
self.args = args
self.kwargs = kwargs
```

Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.


And that we have set a class property `a`


```python
>>> MyClass.a
'a'
```

But can we actually create instances of these classes?


```python
>> obj = MyClass(0.1, a=2)
>>> obj
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```


So far, we've just added print statements around default behavior. Let's try something more radical


```python
class MetaClass2(type):

def __call__(cls, *args, **kwargs):
return 2.0

class MyClass2(metaclass=MetaClass2):

def __init__(self, *args, **kwargs):
print("Am I here?")
self.args = args
```

You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.

Using a metaclass, we can hijack what happens when a type is called.


```python
>>> out = MyClass2(1.0)
>>> out, out == 2.0
(2.0, True)
```

## Putting Primitives and Metaprogramming together

We have two goals that we need to accomplish with our meta class.

1. Create an associated primitive every time we define a new class type
2. Hijack creating a new instance to use `primitive.bind` instead


```python
class PrimitiveMeta(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return jax.core.ShapedArray((1,), jax.numpy.float32)

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
```


```python
class PrimitiveClass(metaclass=PrimitiveMeta):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

What happens if we just create a class normally as is?


```python
>>> PrimitiveClass(1.0)
PrimitiveClass(1.0)
```

But now it can also be used in tracing as well


```python
>>> jax.make_jaxpr(PrimitiveClass)(1.0)
{ lambda ; a:f32[]. let b:f32[1] = PrimitiveClass a in (b,) }
```

Great!👍

Now you can see that the problem is that we lied in our definition of abstract evaluation. Jax thinks that `PrimitiveClass` returns something of shape `(1,)` and type `float32`.

But jax doesn't have an abstract type that really describes "PrimitiveClass". So we need to define an register our own.


```python
class AbstractPrimitiveClass(jax.core.AbstractValue):

def __eq__(self, other):
return isinstance(other, AbstractPrimitiveClass)

def __hash__(self):
return hash("AbstractPrimitiveClass")

jax.core.raise_to_shaped_mappings[AbstractPrimitiveClass] = lambda aval, _: aval
```

Now we can redefine our class to use this abstract class


```python
class PrimitiveMeta2(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return AbstractPrimitiveClass()

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

class PrimitiveClass2(metaclass=PrimitiveMeta2):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `AbstractPrimitiveClass`.


```python
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```
Loading

0 comments on commit 73d0b18

Please sign in to comment.