Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add qml.capture.to_catalsyt conversion function #5771

Closed
wants to merge 120 commits into from
Closed

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented May 30, 2024

Context:

Previous PR's have integrated operators, measurements, and qnodes with jax primitives, allowing a full workflow to be captured with the jax.make_jaxpr function.

To start to integrate

Description of the Change:

qml.capture.enable()

@qml.qnode(qml.device('lightning.qubit', wires=2))
def circuit(x):
    qml.RX(x,0)
    return qml.probs(wires=(0,1))

def f(x):
    return circuit(2* x) ** 2

jaxpr = jax.make_jaxpr(circuit)(0.5)
jaxpr
{ lambda ; a:f32[]. let
    b:f32[4] = qnode[
      device=<lightning.qubit device (wires=2) at 0x3008fd490>
      qfunc_jaxpr={ lambda ; c:f32[]. let
          _:AbstractOperator() = RX[n_wires=1] c 0
          d:AbstractMeasurement(n_wires=2) = probs_wires 0 1
        in (d,) }
      qnode_kwargs={'diff_method': 'best', 'grad_on_execution': 'best', 'cache': False, 'cachesize': 10000, 'max_diff': 1, 'max_expansion': 10, 'device_vjp': False, 'postselect_mode': None, 'mcm_method': None}
      shots=Shots(total=None)
    ] a
  in (b,) }
qml.capture.to_catalyst(jaxpr)
{ lambda ; a:f32[]. let
    b:f32[4] = func[
      call_jaxpr={ lambda ; c:f32[]. let
          qdevice[
            rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
            rtd_lib=/Users/christina/Prog/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.dylib
            rtd_name=LightningSimulator
          ] 
          d:AbstractQreg() = qalloc 2
          e:AbstractQbit() = qextract d 0
          f:AbstractQbit() = qinst[ctrl_len=0 op=RX params_len=1 qubits_len=1] e
            c
          g:AbstractQbit() = qextract d 1
          h:AbstractObs(num_qubits=None,primitive=None) = compbasis f g
          i:f64[4] = probs[shots=0] h
          j:AbstractQreg() = qinsert d 0 f
          qdealloc j
        in (i,) }
      fn=TODO
    ] a
  in (b,) }

Benefits:

Improved integration of pennylane and catalyst.

Possible Drawbacks:

This route is the naive route of first converting to plxpr and then converting catalxpr. Another route would involve natively capturing the catalxpr itself from the original user code. But that would involve deeper coupling with jax internals and extending it with custom CatalystTracer and CatalystTrace objects.

This current version also only supports a limited gateset that is already natively supported by catalyst, and will not be able to handle generic templates/ more complicated operations.

We are also introducing a complicated dependency tree in this PR. So far we have the core pl objects (operations, measurements, qnode) depending on the capture module. Then catalyst depends on our pennylane infrastructure.

With this new PR, we are introducing a catalyst dependency to the capture module. This makes imports difficult if we want to avoid circular dependencies. I would like to propose either:

  1. Moving to_catalyst to the catalyst package
  2. Moving to_catalyst to the qml.compiler module

Related GitHub Issues:

[sc-61537]

albi3ro and others added 30 commits April 12, 2024 14:33
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented May 30, 2024

Codecov Report

Attention: Patch coverage is 0.67114% with 148 lines in your changes missing coverage. Please review.

Project coverage is 99.29%. Comparing base (3f0dec4) to head (6f51ed6).
Report is 1 commits behind head on master.

Files Patch % Lines
pennylane/capture/to_catalyst.py 0.00% 146 Missing ⚠️
pennylane/capture/__init__.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5771      +/-   ##
==========================================
- Coverage   99.67%   99.29%   -0.39%     
==========================================
  Files         414      415       +1     
  Lines       39463    39316     -147     
==========================================
- Hits        39336    39039     -297     
- Misses        127      277     +150     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Base automatically changed from capture-qnode to master June 3, 2024 14:02
@dwierichs dwierichs self-requested a review June 13, 2024 20:01
Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @albi3ro ! 🎉
Had some minor initial comments, will go through tests in a second round.

pennylane/capture/to_catalyst.py Outdated Show resolved Hide resolved


def to_catalyst(jaxpr: jax.core.Jaxpr) -> jax.core.Jaxpr:
"""Convert pennylane variant jaxpr to catalyst variant jaxpr.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're using "plxpr" in so many places, we could go for that name and change it once a public name has been decided.

pennylane/capture/to_catalyst.py Outdated Show resolved Hide resolved
call_jaxpr={ lambda ; c:f64[]. let
qdevice[
rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
rtd_lib=/Users/christina/Prog/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.dylib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care to redact this kind of information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... good point.

pennylane/capture/to_catalyst.py Outdated Show resolved Hide resolved
pennylane/capture/to_catalyst.py Outdated Show resolved Hide resolved
pennylane/capture/to_catalyst.py Outdated Show resolved Hide resolved
)
n_wires = eqn.params["n_wires"]

orig_wires = eqn.invars[-n_wires:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support zero-wires operators already? 🤔

pennylane/capture/to_catalyst.py Outdated Show resolved Hide resolved
Should be the first method called when populating the catalyst xpr.
"""
self._num_device_wires = len(device.wires)
self._shots = device.shots
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this capture the scenario where a QNode received the shots kwarg? I suppose so, because the device temporarily has changed shots?

@albi3ro
Copy link
Contributor Author

albi3ro commented Jul 10, 2024

Closed in favor of PennyLaneAI/catalyst#837

@albi3ro albi3ro closed this Jul 10, 2024
@albi3ro albi3ro deleted the to-catalyst branch July 10, 2024 20:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants