Skip to content

Commit

Permalink
fix queries
Browse files Browse the repository at this point in the history
  • Loading branch information
anthayes92 committed Sep 19, 2024
1 parent 2026898 commit dfcceab
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 181 deletions.
4 changes: 2 additions & 2 deletions pennylane/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi
from .base import DatasetNotWriteableError
from .base.attribute import AttributeInfo, DatasetAttribute, attribute
from .base.dataset import Dataset, field
from .data_manager import DEFAULT, FULL, list_attributes, list_datasets, load, load_interactive
from .data_manager import DEFAULT, FULL, list_attributes, list_data_names, load, load_interactive

__all__ = (
"AttributeInfo",
Expand All @@ -240,7 +240,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi
"load",
"load_interactive",
"list_attributes",
"list_datasets",
"list_data_names",
"DEFAULT",
"FULL",
)
69 changes: 9 additions & 60 deletions pennylane/data/data_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from time import sleep
from typing import Any, Iterable, Mapping, Optional, Union

from requests import get, head
from requests import get, head, post

from pennylane.data.base import Dataset
from pennylane.data.base.hdf5 import open_hdf5_s3
Expand Down Expand Up @@ -122,12 +122,9 @@ def _get_graphql(url: str, query: str, variables: dict[str, Any] = None):
if variables:
json["variables"] = variables

response = get(url=url, json=json, timeout=10)
response = post(url=url, json=json, timeout=10)
response.raise_for_status()

if response.json() is None:
raise GraphQLError("No Response")

if "errors" in response.json():
all_errors = ",".join(error["message"] for error in response.json()["errors"])
raise GraphQLError(f"Errors in request: {all_errors}")
Expand Down Expand Up @@ -418,55 +415,22 @@ def load( # pylint: disable=too-many-arguments
return [Dataset.open(path, "a") for path in download_paths]


def list_datasets() -> dict:
r"""Returns a dictionary of the available datasets.
Return:
dict: Nested dictionary representing the directory structure of the hosted datasets.
.. seealso:: :func:`~.load_interactive`, :func:`~.list_attributes`, :func:`~.load`.
**Example:**
Note that the results of calling this function may differ from this example as more datasets
are added. For updates on available data see the `datasets website <https://pennylane.ai/datasets>`_.
>>> available_data = qml.data.list_datasets()
>>> available_data.keys()
dict_keys(["qspin", "qchem"])
>>> available_data["qchem"].keys()
dict_keys(["H2", "LiH", ...])
>>> available_data['qchem']['H2'].keys()
dict_keys(["CC-PVDZ", "6-31G", "STO-3G"])
>>> print(available_data['qchem']['H2']['STO-3G'])
["0.5", "0.54", "0.62", "0.66", ...]
Note that this example limits the results of the function calls for
clarity and that as more data becomes available, the results of these
function calls will change.
"""

def list_data_names() -> list[str]:
"""Get list of dataclass IDs."""
response = _get_graphql(
GRAPHQL_URL,
"""
query ListDatasets($datasetClassId: String!) {
query GetDatasetClasses {
datasetClasses {
id
datasets {
parameterValues{
name
value
}
}
}
}
""",
)

return response["data"]["datasetClasses"]
return [dsc["id"] for dsc in response["data"]["datasetClasses"]]


def list_attributes(data_name):
def list_attributes(data_name) -> list[str]:
r"""List the attributes that exist for a specific ``data_name``.
Args:
Expand Down Expand Up @@ -502,7 +466,7 @@ def list_attributes(data_name):
{"input": {"datasetClassId": data_name}},
)

return response["data"]["datasetClass"]["attributes"]
return [attribute["name"] for attribute in response["data"]["datasetClass"]["attributes"]]


def _interactive_request_data_name(data_names):
Expand Down Expand Up @@ -581,21 +545,6 @@ def _get_parameter_tree(class_id) -> tuple[list[str], list[str], dict]:
return (parameters, attributes, response["data"]["datasetClass"]["parameterTree"])


def _get_data_names() -> list[str]:
"""Get dataclass IDs."""
response = _get_graphql(
GRAPHQL_URL,
"""
query GetDatasetClasses {
datasetClasses {
id
}
}
""",
)
return [dsc["id"] for dsc in response["data"]["datasetClasses"]]


def load_interactive():
r"""Download a dataset using an interactive load prompt.
Expand Down Expand Up @@ -635,7 +584,7 @@ def load_interactive():
Would you like to continue? (Default is yes) [Y/n]:
"""

data_names = _get_data_names()
data_names = list_data_names()
data_name = _interactive_request_data_name(data_names)

parameters, attribute_options, parameter_tree = _get_parameter_tree(data_name)
Expand Down
106 changes: 84 additions & 22 deletions tests/data/data_manager/support.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,90 @@
"""Test support for mocking GraphQL queries"""

_list_attrs_resp = {
"data": {
"datasetClass": {
"attributes": ["molecule", "hamiltonian", "sparse_hamiltonian", "hf_state", "full"]
}
}
}

_list_datasets_resp = {
"data": {
"datasetClasses": {
"id": "qchem",
"datasets": [
{
"parameterValues": [
{"name": "molname", "value": "H2"},
{"name": "bondlength", "value": "1.16"},
{"name": "basis", "value": "STO-3G"},
]
}
],
"data": {
"datasetClass": {
"attributes": [
{
"name": "basis_rot_groupings"
},
{
"name": "basis_rot_samples"
},
{
"name": "dipole_op"
},
{
"name": "fci_energy"
},
{
"name": "fci_spectrum"
},
{
"name": "hamiltonian"
},
{
"name": "hf_state"
},
{
"name": "molecule"
},
{
"name": "number_op"
},
{
"name": "optimal_sector"
},
{
"name": "paulix_ops"
},
{
"name": "qwc_groupings"
},
{
"name": "qwc_samples"
},
{
"name": "sparse_hamiltonian"
},
{
"name": "spin2_op"
},
{
"name": "spinz_op"
},
{
"name": "symmetries"
},
{
"name": "tapered_dipole_op"
},
{
"name": "tapered_hamiltonian"
},
{
"name": "tapered_hf_state"
},
{
"name": "tapered_num_op"
},
{
"name": "tapered_spin2_op"
},
{
"name": "tapered_spinz_op"
},
{
"name": "vqe_energy"
},
{
"name": "vqe_gates"
},
{
"name": "vqe_params"
}
]
}
}
}

_get_urls_resp = {
Expand All @@ -48,6 +110,8 @@

_dataclass_ids = {"data": {"datasetClasses": [{"id": "other"}, {"id": "qchem"}, {"id": "qspin"}]}}

_error_response = {"data": None, "errors": [{"message": "Mock error message."}]}

_parameter_tree = {
"data": {
"datasetClass": {
Expand Down Expand Up @@ -504,5 +568,3 @@
}
}
}

_error_response = {"data": None, "errors": [{"message": "Mock error message."}]}
Loading

0 comments on commit dfcceab

Please sign in to comment.