diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index ff6a024a745..b4e808eb0cb 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -129,6 +129,9 @@
Bug fixes 🐛
+* Fixed a bug in `qml.SPSAOptimizer` that ignored keyword arguments in the objective function.
+ [(#6027)](https://github.com/PennyLaneAI/pennylane/pull/6027)
+
* `dynamic_one_shot` was broken for old-API devices since `override_shots` was deprecated.
[(#6024)](https://github.com/PennyLaneAI/pennylane/pull/6024)
diff --git a/pennylane/optimize/spsa.py b/pennylane/optimize/spsa.py
index e44784f7245..535f6f64675 100644
--- a/pennylane/optimize/spsa.py
+++ b/pennylane/optimize/spsa.py
@@ -194,6 +194,7 @@ def step_and_cost(self, objective_fn, *args, **kwargs):
objective function output prior to the step.
"""
g = self.compute_grad(objective_fn, args, kwargs)
+
new_args = self.apply_grad(g, args)
self.k += 1
@@ -270,7 +271,8 @@ def compute_grad(self, objective_fn, args, kwargs):
shots = Shots(objective_fn.device._raw_shot_sequence) # pragma: no cover
else:
shots = Shots(None)
- if np.prod(objective_fn.func(*args).shape(objective_fn.device, shots)) > 1:
+
+ if np.prod(objective_fn.func(*args, **kwargs).shape(objective_fn.device, shots)) > 1:
raise ValueError(
"The objective function must be a scalar function for the gradient "
"to be computed."
diff --git a/tests/optimize/test_spsa.py b/tests/optimize/test_spsa.py
index 01726f843f1..f0422602ffd 100644
--- a/tests/optimize/test_spsa.py
+++ b/tests/optimize/test_spsa.py
@@ -443,7 +443,7 @@ def cost(params):
@pytest.mark.usefixtures("use_legacy_opmath")
@pytest.mark.slow
- def test_lighting_device_legacy_opmath(self):
+ def test_lightning_device_legacy_opmath(self):
"""Test SPSAOptimizer implementation with lightning.qubit device."""
coeffs = [0.2, -0.543, 0.4514]
obs = [
@@ -479,7 +479,7 @@ def cost_fun(params, num_qubits=1):
assert energy < init_energy
@pytest.mark.slow
- def test_lighting_device(self):
+ def test_lightning_device(self):
"""Test SPSAOptimizer implementation with lightning.qubit device."""
coeffs = [0.2, -0.543, 0.4514]
obs = [
@@ -494,6 +494,9 @@ def test_lighting_device(self):
@qml.qnode(dev)
def cost_fun(params, num_qubits=1):
qml.BasisState([1, 1, 0, 0], wires=range(num_qubits))
+
+ assert num_qubits == 4
+
for i in range(num_qubits):
qml.Rot(*params[i], wires=0)
qml.CNOT(wires=[2, 3])