From dfcceabf33b94e283189987d65c4965f152c0857 Mon Sep 17 00:00:00 2001 From: Anthony Hayes Date: Thu, 19 Sep 2024 14:02:23 -0400 Subject: [PATCH] fix queries --- pennylane/data/__init__.py | 4 +- pennylane/data/data_manager/__init__.py | 69 ++------- tests/data/data_manager/support.py | 106 +++++++++++--- .../data/data_manager/test_dataset_access.py | 131 +++++------------- 4 files changed, 129 insertions(+), 181 deletions(-) diff --git a/pennylane/data/__init__.py b/pennylane/data/__init__.py index 099f6cbfd16..b2e7b215475 100644 --- a/pennylane/data/__init__.py +++ b/pennylane/data/__init__.py @@ -216,7 +216,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi from .base import DatasetNotWriteableError from .base.attribute import AttributeInfo, DatasetAttribute, attribute from .base.dataset import Dataset, field -from .data_manager import DEFAULT, FULL, list_attributes, list_datasets, load, load_interactive +from .data_manager import DEFAULT, FULL, list_attributes, list_data_names, load, load_interactive __all__ = ( "AttributeInfo", @@ -240,7 +240,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi "load", "load_interactive", "list_attributes", - "list_datasets", + "list_data_names", "DEFAULT", "FULL", ) diff --git a/pennylane/data/data_manager/__init__.py b/pennylane/data/data_manager/__init__.py index e885cc24327..198da52c1fb 100644 --- a/pennylane/data/data_manager/__init__.py +++ b/pennylane/data/data_manager/__init__.py @@ -24,7 +24,7 @@ from time import sleep from typing import Any, Iterable, Mapping, Optional, Union -from requests import get, head +from requests import get, head, post from pennylane.data.base import Dataset from pennylane.data.base.hdf5 import open_hdf5_s3 @@ -122,12 +122,9 @@ def _get_graphql(url: str, query: str, variables: dict[str, Any] = None): if variables: json["variables"] = variables - response = get(url=url, json=json, timeout=10) + response = post(url=url, json=json, timeout=10) response.raise_for_status() - if response.json() is None: - raise GraphQLError("No Response") - if "errors" in response.json(): all_errors = ",".join(error["message"] for error in response.json()["errors"]) raise GraphQLError(f"Errors in request: {all_errors}") @@ -418,55 +415,22 @@ def load( # pylint: disable=too-many-arguments return [Dataset.open(path, "a") for path in download_paths] -def list_datasets() -> dict: - r"""Returns a dictionary of the available datasets. - - Return: - dict: Nested dictionary representing the directory structure of the hosted datasets. - - .. seealso:: :func:`~.load_interactive`, :func:`~.list_attributes`, :func:`~.load`. - - **Example:** - - Note that the results of calling this function may differ from this example as more datasets - are added. For updates on available data see the `datasets website `_. - - >>> available_data = qml.data.list_datasets() - >>> available_data.keys() - dict_keys(["qspin", "qchem"]) - >>> available_data["qchem"].keys() - dict_keys(["H2", "LiH", ...]) - >>> available_data['qchem']['H2'].keys() - dict_keys(["CC-PVDZ", "6-31G", "STO-3G"]) - >>> print(available_data['qchem']['H2']['STO-3G']) - ["0.5", "0.54", "0.62", "0.66", ...] - - Note that this example limits the results of the function calls for - clarity and that as more data becomes available, the results of these - function calls will change. - """ - +def list_data_names() -> list[str]: + """Get list of dataclass IDs.""" response = _get_graphql( GRAPHQL_URL, """ - query ListDatasets($datasetClassId: String!) { + query GetDatasetClasses { datasetClasses { id - datasets { - parameterValues{ - name - value - } - } } } """, ) - - return response["data"]["datasetClasses"] + return [dsc["id"] for dsc in response["data"]["datasetClasses"]] -def list_attributes(data_name): +def list_attributes(data_name) -> list[str]: r"""List the attributes that exist for a specific ``data_name``. Args: @@ -502,7 +466,7 @@ def list_attributes(data_name): {"input": {"datasetClassId": data_name}}, ) - return response["data"]["datasetClass"]["attributes"] + return [attribute["name"] for attribute in response["data"]["datasetClass"]["attributes"]] def _interactive_request_data_name(data_names): @@ -581,21 +545,6 @@ def _get_parameter_tree(class_id) -> tuple[list[str], list[str], dict]: return (parameters, attributes, response["data"]["datasetClass"]["parameterTree"]) -def _get_data_names() -> list[str]: - """Get dataclass IDs.""" - response = _get_graphql( - GRAPHQL_URL, - """ - query GetDatasetClasses { - datasetClasses { - id - } - } - """, - ) - return [dsc["id"] for dsc in response["data"]["datasetClasses"]] - - def load_interactive(): r"""Download a dataset using an interactive load prompt. @@ -635,7 +584,7 @@ def load_interactive(): Would you like to continue? (Default is yes) [Y/n]: """ - data_names = _get_data_names() + data_names = list_data_names() data_name = _interactive_request_data_name(data_names) parameters, attribute_options, parameter_tree = _get_parameter_tree(data_name) diff --git a/tests/data/data_manager/support.py b/tests/data/data_manager/support.py index d6d6453c598..16261858395 100644 --- a/tests/data/data_manager/support.py +++ b/tests/data/data_manager/support.py @@ -1,28 +1,90 @@ """Test support for mocking GraphQL queries""" _list_attrs_resp = { - "data": { - "datasetClass": { - "attributes": ["molecule", "hamiltonian", "sparse_hamiltonian", "hf_state", "full"] - } - } -} - -_list_datasets_resp = { - "data": { - "datasetClasses": { - "id": "qchem", - "datasets": [ - { - "parameterValues": [ - {"name": "molname", "value": "H2"}, - {"name": "bondlength", "value": "1.16"}, - {"name": "basis", "value": "STO-3G"}, - ] - } - ], + "data": { + "datasetClass": { + "attributes": [ + { + "name": "basis_rot_groupings" + }, + { + "name": "basis_rot_samples" + }, + { + "name": "dipole_op" + }, + { + "name": "fci_energy" + }, + { + "name": "fci_spectrum" + }, + { + "name": "hamiltonian" + }, + { + "name": "hf_state" + }, + { + "name": "molecule" + }, + { + "name": "number_op" + }, + { + "name": "optimal_sector" + }, + { + "name": "paulix_ops" + }, + { + "name": "qwc_groupings" + }, + { + "name": "qwc_samples" + }, + { + "name": "sparse_hamiltonian" + }, + { + "name": "spin2_op" + }, + { + "name": "spinz_op" + }, + { + "name": "symmetries" + }, + { + "name": "tapered_dipole_op" + }, + { + "name": "tapered_hamiltonian" + }, + { + "name": "tapered_hf_state" + }, + { + "name": "tapered_num_op" + }, + { + "name": "tapered_spin2_op" + }, + { + "name": "tapered_spinz_op" + }, + { + "name": "vqe_energy" + }, + { + "name": "vqe_gates" + }, + { + "name": "vqe_params" } + ] } + } } _get_urls_resp = { @@ -48,6 +110,8 @@ _dataclass_ids = {"data": {"datasetClasses": [{"id": "other"}, {"id": "qchem"}, {"id": "qspin"}]}} +_error_response = {"data": None, "errors": [{"message": "Mock error message."}]} + _parameter_tree = { "data": { "datasetClass": { @@ -504,5 +568,3 @@ } } } - -_error_response = {"data": None, "errors": [{"message": "Mock error message."}]} diff --git a/tests/data/data_manager/test_dataset_access.py b/tests/data/data_manager/test_dataset_access.py index f1ffa26939f..512e0f479c5 100644 --- a/tests/data/data_manager/test_dataset_access.py +++ b/tests/data/data_manager/test_dataset_access.py @@ -32,7 +32,6 @@ _error_response, _get_urls_resp, _list_attrs_resp, - _list_datasets_resp, _parameter_tree, ) @@ -50,48 +49,13 @@ pytestmark = pytest.mark.data -_folder_map = { - "__params": { - "qchem": ["molname", "basis", "bondlength"], - "qspin": ["sysname", "periodicity", "lattice", "layout"], - }, - "qchem": { - "H2": { - "6-31G": { - "0.46": PosixPath("qchem/H2/6-31G/0.46.h5"), - "1.16": PosixPath("qchem/H2/6-31G/1.16.h5"), - "1.0": PosixPath("qchem/H2/6-31G/1.0.h5"), - } - } - }, - "qspin": { - "Heisenberg": { - "closed": {"chain": {"1x4": PosixPath("qspin/Heisenberg/closed/chain/1x4/1.4.h5")}} - } - }, -} - -_data_struct = { - "qchem": { - "docstr": "Quantum chemistry dataset.", - "params": ["molname", "basis", "bondlength"], - "attributes": ["molecule", "hamiltonian", "sparse_hamiltonian", "hf_state", "full"], - }, - "qspin": { - "docstr": "Quantum many-body spin system dataset.", - "params": ["sysname", "periodicity", "lattice", "layout"], - "attributes": ["parameters", "hamiltonians", "ground_states", "full"], - }, -} - - @pytest.fixture(scope="session") def httpserver_listen_address(): return ("localhost", 8888) # pylint:disable=unused-argument -def get_mock(url, json, timeout=1.0): +def post_mock(url, json, timeout=1.0): """Return mocked get response depending on json content.""" resp = MagicMock(ok=True) if "ErrorQuery" in json["query"]: @@ -268,27 +232,37 @@ class TestMiscHelpers: def test_list_datasets(self): """Test list_datasets.""" - assert qml.data.list_datasets() == { - "id": "qchem", - "datasets": [ - { - "parameterValues": [ - {"name": "molname", "value": "H2"}, - {"name": "bondlength", "value": "1.16"}, - {"name": "basis", "value": "STO-3G"}, - ] - } - ], - } + assert qml.data.list_data_names() == ["other", "qchem", "qspin"] def test_list_attributes(self): """Test list_attributes.""" assert qml.data.list_attributes("qchem") == [ - "molecule", + "basis_rot_groupings", + "basis_rot_samples", + "dipole_op", + "fci_energy", + "fci_spectrum", "hamiltonian", - "sparse_hamiltonian", "hf_state", - "full", + "molecule", + "number_op", + "optimal_sector", + "paulix_ops", + "qwc_groupings", + "qwc_samples", + "sparse_hamiltonian", + "spin2_op", + "spinz_op", + "symmetries", + "tapered_dipole_op", + "tapered_hamiltonian", + "tapered_hf_state", + "tapered_num_op", + "tapered_spin2_op", + "tapered_spinz_op", + "vqe_energy", + "vqe_gates", + "vqe_params", ] @@ -660,12 +634,12 @@ def test_download_datasets_escapes_url_partial( "attributes,msg", [ ( - ["molecule", "hamiltonian", "sparse_hamiltonian", "hf_state", "full", "foo"], - r"'foo' is an invalid attribute for 'my_dataset'. Valid attributes are: \['molecule', 'hamiltonian', 'sparse_hamiltonian', 'hf_state', 'full'\]", + ["basis_rot_groupings", "basis_rot_samples", "dipole_op", "fci_energy", "foo"], + r"'foo' is an invalid attribute for 'my_dataset'. Valid attributes are: \['basis_rot_groupings', 'basis_rot_samples', 'dipole_op', 'fci_energy', 'fci_spectrum', 'hamiltonian', 'hf_state', 'molecule', 'number_op', 'optimal_sector', 'paulix_ops', 'qwc_groupings', 'qwc_samples', 'sparse_hamiltonian', 'spin2_op', 'spinz_op', 'symmetries', 'tapered_dipole_op', 'tapered_hamiltonian', 'tapered_hf_state', 'tapered_num_op', 'tapered_spin2_op', 'tapered_spinz_op', 'vqe_energy', 'vqe_gates', 'vqe_params'\]", ), ( - ["molecule", "hamiltonian", "sparse_hamiltonian", "hf_state", "full", "foo", "bar"], - r"\['foo', 'bar'\] are invalid attributes for 'my_dataset'. Valid attributes are: \['molecule', 'hamiltonian', 'sparse_hamiltonian', 'hf_state', 'full'\]", + ["basis_rot_groupings", "basis_rot_samples", "dipole_op", "fci_energy", "foo", "bar"], + r"\['foo', 'bar'\] are invalid attributes for 'my_dataset'. Valid attributes are: \['basis_rot_groupings', 'basis_rot_samples', 'dipole_op', 'fci_energy', 'fci_spectrum', 'hamiltonian', 'hf_state', 'molecule', 'number_op', 'optimal_sector', 'paulix_ops', 'qwc_groupings', 'qwc_samples', 'sparse_hamiltonian', 'spin2_op', 'spinz_op', 'symmetries', 'tapered_dipole_op', 'tapered_hamiltonian', 'tapered_hf_state', 'tapered_num_op', 'tapered_spin2_op', 'tapered_spinz_op', 'vqe_energy', 'vqe_gates', 'vqe_params'\]", ), ], ) @@ -691,9 +665,9 @@ class TestGetGraphql: } """, ) - inputs = {"input": {"datasetClassId": "qchem"}} + inputs = {"input": {"datasetClassId": "qspin"}} - @patch.object(pennylane.data.data_manager, "get", get_mock) + @patch.object(pennylane.data.data_manager, "post", post_mock) def test_return_json(self): """Tests that an expected json response is returned for a valid query and url.""" response = pennylane.data.data_manager._get_graphql( @@ -701,46 +675,9 @@ def test_return_json(self): self.query, self.inputs, ) - assert response == { - "data": { - "datasetClass": { - "attributes": [ - "molecule", - "hamiltonian", - "sparse_hamiltonian", - "hf_state", - "full", - ] - } - } - } - - def test_bad_url(self): - """Tests that a GraphQLError is raised when given a bad url""" - with pytest.raises(pennylane.data.data_manager.GraphQLError, match="No Response"): - pennylane.data.data_manager._get_graphql( - "https://bad/dataset/url", - self.query, - self.inputs, - ) - - def test_bad_query(self): - """Tests that GraphQLError is raised when given a bad query""" - bad_query = """ - query BadQuery { - badQuery { - badField - } - } - """ - - with pytest.raises(pennylane.data.data_manager.GraphQLError, match="No Response"): - pennylane.data.data_manager._get_graphql( - GRAPHQL_URL, - bad_query, - ) + assert response == _list_attrs_resp - @patch.object(pennylane.data.data_manager, "get", get_mock) + @patch.object(pennylane.data.data_manager, "post", post_mock) def test_error_response(self): """Tests that GraphQLError is raised with error messages when the returned json contains an error message.