Skip to content

Commit

Permalink
Plxpr can capture operations (#5511)
Browse files Browse the repository at this point in the history
**HIGHLY EXPERIMENTAL FEATURE**

**Context:**

To improve integration with catalyst, we want a way to be able to
natively capture the creation of quantum operations into an intermediate
representation. This is one of the early PR's for this experimental
push.

**Description of the Change:**

This PR adds a `PLXPR` metaclass that `Operator` uses. This allows for
the capture of all `Operator` classes into jaxpr.


```python
qml.capture.enable_plxpr()

def qfunc(a):
    qml.X(0)
    qml.IsingXX(a, wires=range(2))
    qml.ctrl(qml.adjoint(qml.X(0)), 1)

    0.5 * qml.X(0) @ qml.Y(1) + qml.Z(2)
    
jaxpr = jax.make_jaxpr(qfunc)(0)
jaxpr
```

```
{ lambda ; a:i32[]. let
    _:AbstractOperator() = PauliX[n_wires=1] 0
    _:AbstractOperator() = IsingXX[n_wires=2] a 0 1
    b:AbstractOperator() = PauliX[n_wires=1] 0
    c:AbstractOperator() = Adjoint b
    _:AbstractOperator() = Controlled[control_values=None work_wires=None] c 1
    d:AbstractOperator() = PauliX[n_wires=1] 0
    e:AbstractOperator() = SProd[id=None] 0.5 d
    f:AbstractOperator() = PauliY[n_wires=1] 1
    g:AbstractOperator() = Prod[id=None] e f
    h:AbstractOperator() = PauliZ[n_wires=1] 2
    _:AbstractOperator() = Sum[grouping_type=None id=None method=rlf] g h
  in () }
```

We can also return the jaxpr to normal pennylane qfunc behaviour via
`jax.core.eval_jaxpr`:

```python
with qml.queuing.AnnotatedQueue() as q:
    jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.1)
q.queue
```
```
[X(0),
 IsingXX(0.1, wires=[0, 1]),
 Controlled(Adjoint(X(0)), control_wires=[1]),
 (0.5 * X(0)) @ Y(1) + Z(2)]
```

**Benefits:**

**Possible Drawbacks:**

* Metaprogramming in python is an edge skill, and often not the best way
to solve a problem. Messing around with things like this can often have
unintended consequences down the line.

* With PLXPR, wires will be restricted to be jax-tracable friendly
labels.

**Related GitHub Issues:**

[sc-61199]

---------

Co-authored-by: dwierichs <david.wierichs@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
  • Loading branch information
4 people authored May 10, 2024
1 parent 762b337 commit 64de143
Show file tree
Hide file tree
Showing 14 changed files with 1,061 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
* Sets up the framework for the development of an `assert_equal` function for testing operator comparison.
[(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634)

* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental
`capture` module for more information.
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)
Expand Down
73 changes: 73 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,78 @@
>>> qml.capture.enabled()
False
**Custom Operator Behaviour**
Any operator that inherits from :class:`~.Operator` gains a default ability to be captured
in a Jaxpr. Any positional argument is bound as a tracer, wires are processed out into individual tracers,
and any keyword arguments are passed as keyword metadata.
.. code-block:: python
class MyOp1(qml.operation.Operator):
def __init__(self, arg1, wires, key=None):
super().__init__(arg1, wires=wires)
def qfunc(a):
MyOp1(a, wires=(0,1), key="a")
qml.capture.enable()
print(jax.make_jaxpr(qfunc)(0.1))
.. code-block::
{ lambda ; a:f32[]. let
_:AbstractOperator() = MyOp1[key=a n_wires=2] a 0 1
in () }
But an operator developer may need to override custom behavior for calling ``cls._primitive.bind``
(where ``cls`` indicates the class) if:
* The operator does not accept wires, like :class:`~.SymbolicOp` or :class:`~.CompositeOp`.
* The operator needs to enforce a data/ metadata distinction, like :class:`~.PauliRot`.
In such cases, the operator developer can override ``cls._primitive_bind_call``, which
will be called when constructing a new class instance instead of ``type.__call__``. For example,
.. code-block:: python
class JustMetadataOp(qml.operation.Operator):
def __init__(self, metadata):
super().__init__(wires=[])
self._metadata = metadata
@classmethod
def _primitive_bind_call(cls, metadata):
return cls._primitive.bind(metadata=metadata)
def qfunc():
JustMetadataOp("Y")
qml.capture.enable()
print(jax.make_jaxpr(qfunc)())
.. code-block::
{ lambda ; . let _:AbstractOperator() = JustMetadataOp[metadata=Y] in () }
As you can see, the input ``"Y"``, while being passed as a positional argument, is converted to
metadata within the custom ``_primitive_bind_call`` method.
If needed, developers can also override the implementation method of the primitive like was done with ``Controlled``.
``Controlled`` needs to do so to handle packing and unpacking the control wires.
.. code-block:: python
class MyCustomOp(qml.operation.Operator):
pass
@MyCustomOp._primitive.def_impl
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta
from .primitives import create_operator_primitive
46 changes: 46 additions & 0 deletions pennylane/capture/capture_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines a metaclass for automatic integration of any ``Operator`` with plxpr program capture.
See ``explanations.md`` for technical explanations of how this works.
"""

from .switches import enabled


# pylint: disable=no-self-argument, too-few-public-methods
class CaptureMeta(type):
"""A metatype that dispatches class creation to ``cls._primitve_bind_call`` instead
of normal class creation.
See ``pennylane/capture/explanations.md`` for more detailed information on how this technically
works.
"""

def _primitive_bind_call(cls, *args, **kwargs):
raise NotImplementedError(
"Types using CaptureMeta must implement cls._primitive_bind_call to"
" gain integration with plxpr program capture."
)

def __call__(cls, *args, **kwargs):
# this method is called everytime we want to create an instance of the class.
# default behavior uses __new__ then __init__

if enabled():
# when tracing is enabled, we want to
# use bind to construct the class if we want class construction to add it to the jaxpr
return cls._primitive_bind_call(*args, **kwargs)
return type.__call__(cls, *args, **kwargs)
237 changes: 237 additions & 0 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
This documentation explains the principles behind `qml.capture.CaptureMeta`.


```python
import jax
```

# Primitive basics


```python
my_func_prim = jax.core.Primitive("my_func")

@my_func_prim.def_impl
def _(x):
return x**2

@my_func_prim.def_abstract_eval
def _(x):
return jax.core.ShapedArray((1,), x.dtype)

def my_func(x):
return my_func_prim.bind(x)
```


```python
>>> jaxpr = jax.make_jaxpr(my_func)(0.1)
>>> jaxpr
{ lambda ; a:f32[]. let b:f32[1] = my_func a in (b,) }
>>> jaxpr.jaxpr.eqns
[a:f32[1] = my_func b]
```

## Metaprogramming


```python
class MyMetaClass(type):

def __init__(cls, *args, **kwargs):
print(f"Creating a new type {cls} with {args}, {kwargs}. ")

# giving every class a property
cls.a = "a"

def __call__(cls, *args, **kwargs):
print(f"creating an instance of type {cls} with {args}, {kwargs}. ")
inst = cls.__new__(cls, *args, **kwargs)
inst.__init__(*args, **kwargs)
return inst
```

Now let's define a class with this meta class.

You can see that when we *define* the class, we have called `MyMetaClass.__init__` to create the new type


```python
class MyClass(metaclass=MyMetaClass):

def __init__(self, *args, **kwargs):
print("now creating an instance in __init__")
self.args = args
self.kwargs = kwargs
```

Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.


And that we have set a class property `a`


```python
>>> MyClass.a
'a'
```

But can we actually create instances of these classes?


```python
>> obj = MyClass(0.1, a=2)
>>> obj
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```


So far, we've just added print statements around default behavior. Let's try something more radical


```python
class MetaClass2(type):

def __call__(cls, *args, **kwargs):
return 2.0

class MyClass2(metaclass=MetaClass2):

def __init__(self, *args, **kwargs):
print("Am I here?")
self.args = args
```

You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.

Using a metaclass, we can hijack what happens when a type is called.


```python
>>> out = MyClass2(1.0)
>>> out, out == 2.0
(2.0, True)
```

## Putting Primitives and Metaprogramming together

We have two goals that we need to accomplish with our meta class.

1. Create an associated primitive every time we define a new class type
2. Hijack creating a new instance to use `primitive.bind` instead


```python
class PrimitiveMeta(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return jax.core.ShapedArray((1,), jax.numpy.float32)

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
```


```python
class PrimitiveClass(metaclass=PrimitiveMeta):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

What happens if we just create a class normally as is?


```python
>>> PrimitiveClass(1.0)
PrimitiveClass(1.0)
```

But now it can also be used in tracing as well


```python
>>> jax.make_jaxpr(PrimitiveClass)(1.0)
{ lambda ; a:f32[]. let b:f32[1] = PrimitiveClass a in (b,) }
```

Great!👍

Now you can see that the problem is that we lied in our definition of abstract evaluation. Jax thinks that `PrimitiveClass` returns something of shape `(1,)` and type `float32`.

But jax doesn't have an abstract type that really describes "PrimitiveClass". So we need to define an register our own.


```python
class AbstractPrimitiveClass(jax.core.AbstractValue):

def __eq__(self, other):
return isinstance(other, AbstractPrimitiveClass)

def __hash__(self):
return hash("AbstractPrimitiveClass")

jax.core.raise_to_shaped_mappings[AbstractPrimitiveClass] = lambda aval, _: aval
```

Now we can redefine our class to use this abstract class


```python
class PrimitiveMeta2(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return AbstractPrimitiveClass()

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

class PrimitiveClass2(metaclass=PrimitiveMeta2):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `AbstractPrimitiveClass`.


```python
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```
Loading

0 comments on commit 64de143

Please sign in to comment.