Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Implement non-jax versions of IQP models #6

Open
wants to merge 100 commits into
base: main
Choose a base branch
from

Conversation

mariaschuld
Copy link
Collaborator

@mariaschuld mariaschuld commented Feb 28, 2024

This is a quick-and-dirty implementation to add hyperparameters use_jax and vmap to IQPVariationalCircuit and IQPKernelClassifier. Together with the existing jit hyperparameter, we can toggle between these options:

    use_jax (bool): Whether to use jax. If False, no jitting and vmapping is performed either.
    jit (bool): Whether to use just-in-time compilation.
    vmap (bool): Whether to use jax.vmap.

This is useful to test PennyLane's lightning backends as well as Catalyst with two examples of the QML benchmarking code base.

In IQPKernelClassifier, jax, jitting and vmapping is only meaningfully used in the computation of the entries of a kernel matrix, not in optimisation (which is done by scikit-learn and very quick). In IQPVariationalClassifier, jax, jitting and vmapping defines the optimisation procedure, and an option for optimisation based on PennyLane's autograd interface had to be added.

I tested the new logic in the following two non-exhaustive ways:

  1. Make sure that functions imported from jax or jnp are only ever used within if self.use_jax blocks.
  2. Check that a simple example works and that runtimes make sense using the different options.

This was the example:

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from qml_benchmarks.models import IQPKernelClassifier, IQPVariationalClassifier
from itertools import product
from time import time

# load data and use labels -1, 1
X, y = make_classification(n_samples=100, n_features=2,
                           n_informative=2, n_redundant=0, random_state=42)
y = np.array([-1 if y_ == 0 else 1 for y_ in y])

# split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

for cls in [IQPKernelClassifier, IQPVariationalClassifier]:
    print(f"\n timing {cls.__name__}")

    for use_jax, jit, vmap in product([True, False], repeat=3):

        print(f"use_jax {use_jax}, jit {jit}, vmap {vmap}")
        model = cls(use_jax=use_jax, jit=jit, vmap=vmap, random_state=42)
        start = time()
        model.fit(X_train, y_train)
        print(f"Score: {model.score(X_test, y_test)}")
        print(f"Time: {time()-start}\n")

Results:

timing IQPKernelClassifier

use_jax True, jit True, vmap True
Score: 0.95
Time: 0.838606595993042

use_jax True, jit True, vmap False
Score: 0.95
Time: 0.5824296474456787

use_jax True, jit False, vmap True
Score: 0.95
Time: 3.0215656757354736

use_jax True, jit False, vmap False
Score: 0.95
Time: 12.549571990966797

use_jax False, jit True, vmap True
Score: 0.95
Time: 12.060103416442871

use_jax False, jit True, vmap False
Score: 0.95
Time: 11.994547843933105


use_jax False, jit False, vmap True
Score: 0.95
Time: 11.92346978187561

use_jax False, jit False, vmap False
Score: 0.95
Time: 12.152809381484985


timing IQPVariationalClassifier

use_jax True, jit True, vmap True
Score: 0.95
Time: 3.5195658206939697

use_jax True, jit True, vmap False
Score: 0.95
Time: 56.759628772735596

use_jax True, jit False, vmap True
Score: 0.95
Time: 221.65359616279602

use_jax True, jit False, vmap False
Score: 0.95
Time: 6350.0739550590515

use_jax False, jit True, vmap True
Score: 0.95
Time: 923.234719991684

use_jax False, jit True, vmap False
Score: 0.95
Time: 702.6788325309753

use_jax False, jit False, vmap True
Score: 0.95
Time: 673.9196376800537

use_jax False, jit False, vmap False
Score: 0.95
Time: 923.7728781700134

@@ -132,15 +139,19 @@ def precompute_kernel(self, X1, X2):
dim2 = len(X2)

# concatenate all pairs of vectors
Z = jnp.array(
Z = np.array(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using pure numpy because we're not differentiating through the construction of the kernel matrix...

@mariaschuld mariaschuld changed the title [WIP] Implement non-jax versions of IQP models [Do not merge] Implement non-jax versions of IQP models Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants