diff --git a/doc/code/sf_cli.rst b/doc/code/sf_cli.rst index 2a561b740..cb569f44d 100644 --- a/doc/code/sf_cli.rst +++ b/doc/code/sf_cli.rst @@ -136,4 +136,4 @@ Code details .. automodapi:: strawberryfields.cli :no-heading: :no-inheritance-diagram: - :skip: store_account, create_config, load, RemoteEngine, Connection, ConfigurationError, ping + :skip: store_account, DEFAULT_CONFIG, load, RemoteEngine, Connection, ConfigurationError, ping diff --git a/doc/code/sf_configuration.rst b/doc/code/sf_configuration.rst index fae532314..986d94b82 100644 --- a/doc/code/sf_configuration.rst +++ b/doc/code/sf_configuration.rst @@ -33,6 +33,11 @@ and has the following format: use_ssl = true port = 443 + [logging] + # Options for the logger + level = "warning" + logfile = "sf.log" + Configuration options --------------------- @@ -41,7 +46,6 @@ Configuration options Settings for the Xanadu cloud platform. - **authentication_token (str)** (*required*) API token for authentication to the Xanadu cloud platform. This is required for submitting remote jobs using :class:`~.RemoteEngine`. @@ -63,6 +67,29 @@ Settings for the Xanadu cloud platform. *Corresponding environment variable:* ``SF_API_PORT`` +``[logging]`` +^^^^^^^^^^^^^ + +Settings for the Strawberry Fields logger. + +**level (str)** (*optional*) + Specifies the level of information that should be printed to the standard + output. Defaults to ``"info"``, which indicates that all logged details + are displayed as output. + + Other options include ``"error"``, ``"warning"``, ``"info"``, ``"debug"``, + in decreasing levels of verbosity. + + *Corresponding environment variable:* ``SF_LOGGING_LEVEL`` + +**logfile (str)** (*optional*) + The filepath of an output logfile. This may be a relative or an + absolute path. If specified, all logging data is appended to this + file during Strawberry Fields execution. + + *Corresponding environment variable:* ``SF_LOGGING_LOGFILE`` + + Functions --------- @@ -73,5 +100,5 @@ Functions .. automodapi:: strawberryfields.configuration :no-heading: - :skip: user_config_dir, store_account, active_configs, reset_config, create_logger + :skip: user_config_dir, store_account, active_configs, reset_config, create_logger, delete_config :no-inheritance-diagram: diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index e2a3ded07..d2beaf686 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -154,6 +154,8 @@ def create_job(self, target: str, program: Program, run_options: dict = None) -> circuit = bb.serialize() + self.log.debug("Submitting job\n%s", circuit) + path = "/jobs" response = requests.post(self._url(path), headers=self._headers, json={"circuit": circuit}) if response.status_code == 201: diff --git a/strawberryfields/api/job.py b/strawberryfields/api/job.py index 672130697..69ae72e07 100644 --- a/strawberryfields/api/job.py +++ b/strawberryfields/api/job.py @@ -144,6 +144,7 @@ def refresh(self): self.log.debug("Job %s metadata: %s", self.id, job_info.meta) if self._status == JobStatus.COMPLETED: self._result = self._connection.get_job_result(self.id) + self.log.info("Job %s is complete", self.id) def cancel(self): """Cancels an open or queued job. diff --git a/strawberryfields/cli/__init__.py b/strawberryfields/cli/__init__.py index f0d02c524..3603b4ec9 100755 --- a/strawberryfields/cli/__init__.py +++ b/strawberryfields/cli/__init__.py @@ -20,7 +20,7 @@ import sys from strawberryfields.api import Connection -from strawberryfields.configuration import ConfigurationError, create_config, store_account +from strawberryfields.configuration import ConfigurationError, DEFAULT_CONFIG, store_account from strawberryfields.engine import RemoteEngine from strawberryfields.io import load @@ -162,7 +162,7 @@ def configuration_wizard(): Returns: dict[str, Union[str, bool, int]]: the configuration options """ - default_config = create_config()["api"] + default_config = DEFAULT_CONFIG["api"] # Getting default values that can be used for as messages when getting inputs hostname_default = default_config["hostname"] diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index 892d07dcf..371519c40 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -15,12 +15,12 @@ This module contains functions used to load, store, save, and modify configuration options for Strawberry Fields. """ +import collections import os import toml from appdirs import user_config_dir -from strawberryfields.logger import create_logger DEFAULT_CONFIG_SPEC = { "api": { @@ -28,15 +28,124 @@ "hostname": (str, "platform.strawberryfields.ai"), "use_ssl": (bool, True), "port": (int, 443), - } + }, + "logging": {"level": (str, "info"), "logfile": ((str, type(None)), None)}, } +"""dict: Nested dictionary representing the allowed configuration +sections, options, default values, and allowed types for Strawberry +Fields configurations. For each configuration option key, the +corresponding value is a length-2 tuple, containing: + +* A type or tuple of types, representing the allowed type + for that configuration option. + +* The default value for that configuration option. + +.. note:: + + By TOML convention, keys with a default value of ``None`` + will **not** be present in the generated/loaded configuration + file. This is because TOML has no concept of ``NoneType`` or ``Null``, + instead, the non-presence of a key indicates that the configuration + value is not set. +""" class ConfigurationError(Exception): """Exception used for configuration errors""" -def load_config(filename="config.toml", **kwargs): +def _deep_update(source, overrides): + """Recursively update a nested dictionary. + + This function is a generalization of Python's built in + ``dict.update`` method, modified to recursively update + keys with nested dictionaries. + """ + for key, value in overrides.items(): + if isinstance(value, collections.Mapping) and value: + # Override value is a non-empty dictionary. + # Update the source key with the override dictionary. + returned = _deep_update(source.get(key, {}), value) + source[key] = returned + elif value != {}: + # Override value is not an empty dictionary. + source[key] = overrides[key] + return source + + +def _generate_config(config_spec, **kwargs): + """Generates a configuration, given a Strawberry Fields configuration + specification. + + See :attr:`~.DEFAULT_CONFIG_SPEC` for an example of a valid configuration + specification. + + Optional keyword arguments may be provided to override default values + in the cofiguration specification. If the provided override values + do not match the expected type defined in the configuration spec, + a ``ConfigurationError`` is raised. + + **Example** + + >>> _generate_config(DEFAULT_CONFIG_SPEC, api={"port": 54}) + { + "api": { + "authentication_token": "", + "hostname": "platform.strawberryfields.ai", + "use_ssl": True, + "port": 54, + }, + 'logging': {'level': 'info'} + } + + Args: + config_spec (dict): nested dictionary representing the + configuration specification + + Keyword Args: + Provided keyword arguments may overwrite default values of + matching keys. + + Returns: + dict: the default configuration defined by the input config spec + + Raises: + ConfigurationError: if provided keyword argument overrides do not + match the expected type defined in the configuration spec. + """ + res = {} + for k, v in config_spec.items(): + if isinstance(v, tuple): + # config spec value v represents the allowed type and default value + + if k in kwargs: + # Key also exists as a keyword argument. + # Perform type validation. + if not isinstance(kwargs[k], v[0]): + raise ConfigurationError( + "Expected type {} for option {}, received {}".format( + v[0], k, type(kwargs[k]) + ) + ) + + if kwargs[k] is not None: + # Only add the key to the configuration object + # if the provided override is not None. + res[k] = kwargs[k] + else: + if v[1] is not None: + # Only add the key to the configuration object + # if the default value is not None. + res[k] = v[1] + + elif isinstance(v, dict): + # config spec value is a configuration section + res[k] = _generate_config(v, **kwargs.get(k, {})) + return res + + +def load_config(filename="config.toml", logging=True, **kwargs): """Load configuration from keyword arguments, configuration file or environment variables. @@ -50,66 +159,65 @@ def load_config(filename="config.toml", **kwargs): 2. data contained in environmental variables (if any) 3. data contained in a configuration file (if exists) + Args: + filename (str): the name of the configuration file to look for + logging (bool): whether or not to log details + Keyword Args: - filename (str): the name of the configuration file to look for. - Additional configuration options are detailed in + Additional configuration options are detailed in :doc:`/code/sf_configuration` Returns: dict[str, dict[str, Union[str, bool, int]]]: the configuration """ - config = create_config() - filepath = find_config_file(filename=filename) - if filepath is not None: - loaded_config = load_config_file(filepath) - api_config = get_api_config(loaded_config, filepath) + if logging: + # We import the create_logger function only if logging + # has been requested, to avoid circular imports. + from strawberryfields.logger import create_logger #pylint: disable=import-outside-toplevel - valid_api_options = keep_valid_options(api_config) - config["api"].update(valid_api_options) - else: log = create_logger(__name__) - log.warning("No Strawberry Fields configuration file found.") - update_from_environment_variables(config) + if filepath is not None: + # load the configuration file + with open(filepath, "r") as f: + config = toml.load(f) - valid_kwargs_config = keep_valid_options(kwargs) - config["api"].update(valid_kwargs_config) + if logging: + log.debug("Configuration file %s loaded", filepath) - return config + if "api" not in config and logging: + # Raise a warning if the configuration doesn't contain + # an API section. + log.warning( + 'The configuration from the %s file does not contain an "api" section.', filepath + ) + else: + config = {} -def create_config(authentication_token=None, **kwargs): - """Create a configuration object that stores configuration related data - organized into sections. + if logging: + log.warning("No Strawberry Fields configuration file found.") - The configuration object contains API-related configuration options. This - function takes into consideration only pre-defined options. + # update the configuration from environment variables + update_from_environment_variables(config) - If called without passing any keyword arguments, then a default - configuration object is created. + # update the configuration from keyword arguments + for config_section, section_options in kwargs.items(): + _deep_update(config, {config_section: section_options}) - Keyword Args: - Configuration options as detailed in :doc:`/code/sf_configuration` + # generate the configuration object by using the defined + # configuration specification at the top of the file + config = _generate_config(DEFAULT_CONFIG_SPEC, **config) - Returns: - dict[str, dict[str, Union[str, bool, int]]]: the configuration - object - """ - authentication_token = authentication_token or "" - hostname = kwargs.get("hostname", DEFAULT_CONFIG_SPEC["api"]["hostname"][1]) - use_ssl = kwargs.get("use_ssl", DEFAULT_CONFIG_SPEC["api"]["use_ssl"][1]) - port = kwargs.get("port", DEFAULT_CONFIG_SPEC["api"]["port"][1]) + # Log the loaded configuration details, masking out the API key. + if logging: + config_details = "Loaded configuration: {}".format(config) + auth_token = config.get("api", {}).get("authentication_token", "") + config_details = config_details.replace(auth_token[5:], "*" * len(auth_token[5:])) + log.debug(config_details) - config = { - "api": { - "authentication_token": authentication_token, - "hostname": hostname, - "use_ssl": use_ssl, - "port": port, - } - } return config @@ -164,11 +272,10 @@ def find_config_file(filename="config.toml"): Union[str, None]: the filepath to the configuration file or None, if no file was found """ - directories = directories_to_check() - for directory in directories: - filepath = os.path.join(directory, filename) - if os.path.exists(filepath): - return filepath + directories = get_available_config_paths(filename=filename) + + if directories: + return directories[0] return None @@ -194,65 +301,15 @@ def directories_to_check(): sf_user_config_dir = user_config_dir("strawberryfields", "Xanadu") directories.append(current_dir) - if sf_env_config_dir != "": + + if sf_env_config_dir: directories.append(sf_env_config_dir) + directories.append(sf_user_config_dir) return directories -def load_config_file(filepath): - """Load a configuration object from a TOML formatted file. - - Args: - filepath (str): path to the configuration file - - Returns: - dict[str, dict[str, Union[str, bool, int]]]: the configuration - object that was loaded - """ - with open(filepath, "r") as f: - config_from_file = toml.load(f) - return config_from_file - - -def get_api_config(loaded_config, filepath): - """Gets the API section from the loaded configuration. - - Args: - loaded_config (dict): the configuration that was loaded from the TOML config - file - filepath (str): path to the configuration file - - Returns: - dict[str, Union[str, bool, int]]: the api section of the configuration - - Raises: - ConfigurationError: if the api section was not defined in the - configuration - """ - try: - return loaded_config["api"] - except KeyError: - log = create_logger(__name__) - log.error('The configuration from the %s file does not contain an "api" section.', filepath) - raise ConfigurationError - - -def keep_valid_options(sectionconfig): - """Filters the valid options in a section of a configuration dictionary. - - Args: - sectionconfig (dict[str, Union[str, bool, int]]): the section of the - configuration to check - - Returns: - dict[str, Union[str, bool, int]]: the keep section of the - configuration - """ - return {k: v for k, v in sectionconfig.items() if k in VALID_KEYS} - - def update_from_environment_variables(config): """Updates the current configuration object from data stored in environment variables. @@ -271,13 +328,14 @@ def update_from_environment_variables(config): for key in sectionconfig: env = env_prefix + key.upper() if env in os.environ: - config[section][key] = parse_environment_variable(key, os.environ[env]) + config[section][key] = _parse_environment_variable(section, key, os.environ[env]) -def parse_environment_variable(key, value): +def _parse_environment_variable(section, key, value): """Parse a value stored in an environment variable. Args: + section (str): configuration section name key (str): the name of the environment variable value (Union[str, bool, int]): the value obtained from the environment variable @@ -288,7 +346,7 @@ def parse_environment_variable(key, value): trues = (True, "true", "True", "TRUE", "1", 1) falses = (False, "false", "False", "FALSE", "0", 0) - if DEFAULT_CONFIG_SPEC["api"][key][0] is bool: + if DEFAULT_CONFIG_SPEC[section][key][0] is bool: if value in trues: return True @@ -297,7 +355,7 @@ def parse_environment_variable(key, value): raise ValueError("Boolean could not be parsed") - if DEFAULT_CONFIG_SPEC["api"][key][0] is int: + if DEFAULT_CONFIG_SPEC[section][key][0] is int: return int(value) return value @@ -450,21 +508,28 @@ def store_account(authentication_token, filename="config.toml", location="user_c filepath = os.path.join(directory, filename) - config = create_config(authentication_token=authentication_token, **kwargs) - save_config_to_file(config, filepath) + config = {} + # load the existing config if it already exists + if os.path.isfile(filepath): + with open(filepath, "r") as f: + config = toml.load(f) -def save_config_to_file(config, filepath): - """Saves a configuration to a TOML file. + # update the loaded configuration file with the specified + # authentication token + kwargs.update({"authentication_token": authentication_token}) + + # update the loaded configuration with any + # provided API options passed as keyword arguments + _deep_update(config, {"api": kwargs}) + + # generate the configuration object by using the defined + # configuration specification at the top of the file + config = _generate_config(DEFAULT_CONFIG_SPEC, **config) - Args: - config (dict[str, dict[str, Union[str, bool, int]]]): the - configuration to be saved - filepath (str): path to the configuration file - """ with open(filepath, "w") as f: toml.dump(config, f) -VALID_KEYS = set(create_config()["api"].keys()) -DEFAULT_CONFIG = create_config() +DEFAULT_CONFIG = _generate_config(DEFAULT_CONFIG_SPEC) +SESSION_CONFIG = load_config(logging=False) diff --git a/strawberryfields/engine.py b/strawberryfields/engine.py index ed155cbbf..b12683129 100644 --- a/strawberryfields/engine.py +++ b/strawberryfields/engine.py @@ -611,6 +611,11 @@ def run_async(self, program: Program, *, compile_options=None, **kwargs) -> Job: # * compiled to a different chip family to the engine target # # In both cases, recompile the program to match the intended target. + self.log.debug( + "Compiling program for target %s with compile options %s", + self.target, + compile_options, + ) program = program.compile(self.target, **compile_options) # update the run options if provided diff --git a/strawberryfields/logger.py b/strawberryfields/logger.py index 86381d5ca..7bc421537 100644 --- a/strawberryfields/logger.py +++ b/strawberryfields/logger.py @@ -49,6 +49,8 @@ import logging import sys +from strawberryfields.configuration import SESSION_CONFIG + def logging_handler_defined(logger): """Checks if the logger or any of its ancestors has a handler defined. @@ -75,12 +77,12 @@ def logging_handler_defined(logger): return False -default_handler = logging.StreamHandler(sys.stderr) +output_handler = logging.StreamHandler(sys.stderr) formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -default_handler.setFormatter(formatter) +output_handler.setFormatter(formatter) -def create_logger(name, level=logging.INFO): +def create_logger(name, level=None): """Get the Strawberry Fields module specific logger and configure it if needed. Configuration only takes place if no user configuration was applied to the @@ -108,7 +110,23 @@ def create_logger(name, level=logging.INFO): no_handlers = not logging_handler_defined(logger) if effective_level_inherited and level_not_set and no_handlers: - logger.setLevel(level) - logger.addHandler(default_handler) + # The root logger should pass all log message levels + # to the handlers. + logger.setLevel(logging.DEBUG) + level = level or getattr(logging, SESSION_CONFIG["logging"]["level"].upper()) + + # Attach the standard output logger, + # with the user defined logging level (defaults to INFO) + output_handler.setLevel(level) + logger.addHandler(output_handler) + + if "logfile" in SESSION_CONFIG["logging"]: + # Create the file logger + file_handler = logging.FileHandler(SESSION_CONFIG["logging"]["logfile"]) + file_handler.setFormatter(formatter) + + # file logger should display all log message levels + file_handler.setLevel(logging.DEBUG) + logger.addHandler(file_handler) return logger diff --git a/tests/frontend/test_configuration.py b/tests/frontend/test_configuration.py index 07993cc20..ccc34555d 100644 --- a/tests/frontend/test_configuration.py +++ b/tests/frontend/test_configuration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the configuration module""" +import copy import os import logging import pytest @@ -46,7 +47,8 @@ "hostname": "platform.strawberryfields.ai", "use_ssl": True, "port": 443, - } + }, + 'logging': {'level': 'info'} } OTHER_EXPECTED_CONFIG = { @@ -55,7 +57,8 @@ "hostname": "SomeHost", "use_ssl": False, "port": 56, - } + }, + "logging": {"level": "info"} } environment_variables = [ @@ -91,9 +94,12 @@ def test_keywords_take_precedence_over_everything(self, monkeypatch, tmpdir): m.setenv("SF_API_PORT", "42") m.setattr(os, "getcwd", lambda: tmpdir) - configuration = conf.load_config( - authentication_token="SomeAuth", hostname="SomeHost", use_ssl=False, port=56 - ) + configuration = conf.load_config(api={ + "authentication_token": "SomeAuth", + "hostname": "SomeHost", + "use_ssl": False, + "port": 56 + }) assert configuration == OTHER_EXPECTED_CONFIG @@ -142,9 +148,8 @@ def test_get_api_section_safely_error(self, monkeypatch, tmpdir, caplog): f.write(empty_file) with monkeypatch.context() as m: - with pytest.raises(conf.ConfigurationError, match=""): - m.setattr(os, "getcwd", lambda: tmpdir) - configuration = conf.load_config() + m.setattr(os, "getcwd", lambda: tmpdir) + configuration = conf.load_config() assert "does not contain an \"api\" section" in caplog.text @@ -290,32 +295,31 @@ def test_print_active_configs_no_configs(self, capsys, monkeypatch): general_message_2 + first_dir_msg + second_dir_msg + third_dir_msg -class TestCreateConfigObject: +class TestGenerateConfigObject: """Test the creation of a configuration object""" - def test_empty_config_object(self): - """Test that an empty configuration object can be created.""" - config = conf.create_config(authentication_token="", hostname="", use_ssl="", port="") - - assert all(value == "" for value in config["api"].values()) + def test_type_validation(self): + """Test that passing an incorrect type raises an exception""" + with pytest.raises(conf.ConfigurationError, match="Expected type"): + config = conf._generate_config( + conf.DEFAULT_CONFIG_SPEC, api={"use_ssl":""} + ) def test_config_object_with_authentication_token(self): """Test that passing only the authentication token creates the expected configuration object.""" - assert ( - conf.create_config(authentication_token="071cdcce-9241-4965-93af-4a4dbc739135") - == EXPECTED_CONFIG + config = conf._generate_config( + conf.DEFAULT_CONFIG_SPEC, api={"authentication_token":"071cdcce-9241-4965-93af-4a4dbc739135",} ) + assert config == EXPECTED_CONFIG def test_config_object_every_keyword_argument(self): """Test that passing every keyword argument creates the expected configuration object.""" - assert ( - conf.create_config( - authentication_token="SomeAuth", hostname="SomeHost", use_ssl=False, port=56 - ) - == OTHER_EXPECTED_CONFIG + config = conf._generate_config( + conf.DEFAULT_CONFIG_SPEC, api={"authentication_token":"SomeAuth", "hostname":"SomeHost", "use_ssl":False, "port":56} ) + assert config == OTHER_EXPECTED_CONFIG class TestRemoveConfigFile: """Test the removal of configuration files""" @@ -475,61 +479,6 @@ def raise_wrapper(ex): assert config_filepath is None -class TestLoadConfigFile: - """Tests the load_config_file function.""" - - def test_load_config_file(self, monkeypatch, tmpdir): - """Tests that configuration is loaded correctly from a TOML file.""" - filename = tmpdir.join("test_config.toml") - - with open(filename, "w") as f: - f.write(TEST_FILE) - - loaded_config = conf.load_config_file(filepath=filename) - - assert loaded_config == EXPECTED_CONFIG - - def test_loading_absolute_path(self, monkeypatch, tmpdir): - """Test that the default configuration file can be loaded - via an absolute path.""" - filename = tmpdir.join("test_config.toml") - - with open(filename, "w") as f: - f.write(TEST_FILE) - - with monkeypatch.context() as m: - m.setenv("SF_CONF", "") - loaded_config = conf.load_config_file(filepath=filename) - - assert loaded_config == EXPECTED_CONFIG - - -class TestKeepValidOptions: - def test_only_invalid_options(self): - section_config_with_invalid_options = {"NotValid1": 1, "NotValid2": 2, "NotValid3": 3} - assert conf.keep_valid_options(section_config_with_invalid_options) == {} - - def test_valid_and_invalid_options(self): - section_config_with_invalid_options = { - "authentication_token": "MyToken", - "NotValid1": 1, - "NotValid2": 2, - "NotValid3": 3, - } - assert conf.keep_valid_options(section_config_with_invalid_options) == { - "authentication_token": "MyToken" - } - - def test_only_valid_options(self): - section_config_only_valid = { - "authentication_token": "071cdcce-9241-4965-93af-4a4dbc739135", - "hostname": "platform.strawberryfields.ai", - "use_ssl": True, - "port": 443, - } - assert conf.keep_valid_options(section_config_only_valid) == EXPECTED_CONFIG["api"] - - value_mapping = [ ("SF_API_AUTHENTICATION_TOKEN", "SomeAuth"), ("SF_API_HOSTNAME", "SomeHost"), @@ -558,7 +507,7 @@ def test_all_environment_variables_defined(self, monkeypatch): for env_var, value in value_mapping: m.setenv(env_var, value) - config = conf.create_config() + config = copy.deepcopy(conf.DEFAULT_CONFIG) for v, parsed_value in zip(config["api"].values(), parsed_values_mapping.values()): assert v != parsed_value @@ -581,7 +530,7 @@ def test_one_environment_variable_defined(self, env_var, key, value, monkeypatch with monkeypatch.context() as m: m.setenv(env_var, value) - config = conf.create_config() + config = copy.deepcopy(conf.DEFAULT_CONFIG) for v, parsed_value in zip(config["api"].values(), parsed_values_mapping.values()): assert v != parsed_value @@ -598,24 +547,24 @@ def test_parse_environment_variable_boolean(self, monkeypatch): """Tests that boolean values can be parsed correctly from environment variables.""" monkeypatch.setattr(conf, "DEFAULT_CONFIG_SPEC", {"api": {"some_boolean": (bool, True)}}) - assert conf.parse_environment_variable("some_boolean", "true") is True - assert conf.parse_environment_variable("some_boolean", "True") is True - assert conf.parse_environment_variable("some_boolean", "TRUE") is True - assert conf.parse_environment_variable("some_boolean", "1") is True - assert conf.parse_environment_variable("some_boolean", 1) is True - - assert conf.parse_environment_variable("some_boolean", "false") is False - assert conf.parse_environment_variable("some_boolean", "False") is False - assert conf.parse_environment_variable("some_boolean", "FALSE") is False - assert conf.parse_environment_variable("some_boolean", "0") is False - assert conf.parse_environment_variable("some_boolean", 0) is False + assert conf._parse_environment_variable("api", "some_boolean", "true") is True + assert conf._parse_environment_variable("api", "some_boolean", "True") is True + assert conf._parse_environment_variable("api", "some_boolean", "TRUE") is True + assert conf._parse_environment_variable("api", "some_boolean", "1") is True + assert conf._parse_environment_variable("api", "some_boolean", 1) is True + + assert conf._parse_environment_variable("api", "some_boolean", "false") is False + assert conf._parse_environment_variable("api", "some_boolean", "False") is False + assert conf._parse_environment_variable("api", "some_boolean", "FALSE") is False + assert conf._parse_environment_variable("api", "some_boolean", "0") is False + assert conf._parse_environment_variable("api", "some_boolean", 0) is False def test_parse_environment_variable_integer(self, monkeypatch): """Tests that integer values can be parsed correctly from environment variables.""" monkeypatch.setattr(conf, "DEFAULT_CONFIG_SPEC", {"api": {"some_integer": (int, 123)}}) - assert conf.parse_environment_variable("some_integer", "123") == 123 + assert conf._parse_environment_variable("api", "some_integer", "123") == 123 DEFAULT_KWARGS = {"hostname": "platform.strawberryfields.ai", "use_ssl": True, "port": 443} @@ -656,8 +605,7 @@ def test_config_created_locally(self, monkeypatch, tmpdir): with monkeypatch.context() as m: m.setattr(os, "getcwd", lambda: tmpdir) m.setattr(conf, "user_config_dir", lambda *args: "NotTheCorrectDir") - m.setattr(conf, "create_config", mock_create_config) - m.setattr(conf, "save_config_to_file", lambda a, b: mock_save_config_file.update(a, b)) + m.setattr("toml.dump", lambda a, b: mock_save_config_file.update(a, b.name)) conf.store_account( authentication_token, filename="config.toml", location="local", **DEFAULT_KWARGS ) @@ -676,8 +624,7 @@ def test_global_config_created(self, monkeypatch, tmpdir): with monkeypatch.context() as m: m.setattr(os, "getcwd", lambda: "NotTheCorrectDir") m.setattr(conf, "user_config_dir", lambda *args: tmpdir) - m.setattr(conf, "create_config", mock_create_config) - m.setattr(conf, "save_config_to_file", lambda a, b: mock_save_config_file.update(a, b)) + m.setattr("toml.dump", lambda a, b: mock_save_config_file.update(a, b.name)) conf.store_account( authentication_token, filename="config.toml", @@ -800,32 +747,3 @@ def test_nested_directory_is_created(self, monkeypatch, tmpdir): filepath = os.path.join(recursive_dir, "config.toml") result = toml.load(filepath) assert result == EXPECTED_CONFIG - - -class TestSaveConfigToFile: - """Tests for the store_account function.""" - - def test_correct(self, tmpdir): - """Test saving a configuration file.""" - filepath = str(tmpdir.join("config.toml")) - - conf.save_config_to_file(OTHER_EXPECTED_CONFIG, filepath) - - result = toml.load(filepath) - assert result == OTHER_EXPECTED_CONFIG - - def test_file_already_existed(self, tmpdir): - """Test saving a configuration file even if the file already - existed.""" - filepath = str(tmpdir.join("config.toml")) - - with open(filepath, "w") as f: - f.write(TEST_FILE) - - result_for_existing_file = toml.load(filepath) - assert result_for_existing_file == EXPECTED_CONFIG - - conf.save_config_to_file(OTHER_EXPECTED_CONFIG, filepath) - - result_for_new_file = toml.load(filepath) - assert result_for_new_file == OTHER_EXPECTED_CONFIG diff --git a/tests/frontend/test_logger.py b/tests/frontend/test_logger.py index 3d2df772c..d0d118c55 100644 --- a/tests/frontend/test_logger.py +++ b/tests/frontend/test_logger.py @@ -52,7 +52,7 @@ import strawberryfields.api.connection as connection import strawberryfields.engine as engine -from strawberryfields.logger import logging_handler_defined, default_handler, create_logger +from strawberryfields.logger import logging_handler_defined, output_handler, create_logger modules_contain_logging = [job, connection, engine] @@ -114,7 +114,7 @@ def test_create_logger(self, module): logger = create_logger(module.__name__) assert logger.level == logging.INFO assert logging_handler_defined(logger) - assert logger.handlers[0] == default_handler + assert logger.handlers[0] == output_handler class TestLoggerIntegration: """Tests that the SF logger integrates well with user defined logging diff --git a/tests/frontend/test_sf_cli.py b/tests/frontend/test_sf_cli.py index 43076ffd3..f8263c929 100644 --- a/tests/frontend/test_sf_cli.py +++ b/tests/frontend/test_sf_cli.py @@ -15,6 +15,7 @@ Unit tests for the Strawberry Fields command line interface. """ # pylint: disable=no-self-use,unused-argument +import copy import os import functools import argparse @@ -160,7 +161,7 @@ def test_configuration_wizard(self, monkeypatch): configuration takes place using the configuration_wizard function.""" with monkeypatch.context() as m: mock_store_account = MockStoreAccount() - m.setattr(cli, "configuration_wizard", lambda: cli.create_config()["api"]) + m.setattr(cli, "configuration_wizard", lambda: cli.DEFAULT_CONFIG["api"]) m.setattr(cli, "store_account", mock_store_account.store_account) args = MockArgs() @@ -190,7 +191,7 @@ def test_configuration_wizard_local(self, monkeypatch): the configuration_wizard function.""" with monkeypatch.context() as m: mock_store_account = MockStoreAccount() - m.setattr(cli, "configuration_wizard", lambda: cli.create_config()["api"]) + m.setattr(cli, "configuration_wizard", lambda: cli.DEFAULT_CONFIG["api"]) m.setattr(cli, "store_account", mock_store_account.store_account) args = MockArgs() @@ -278,7 +279,7 @@ def test_auth_correct(self, monkeypatch): correctly, once the authentication token is passed.""" with monkeypatch.context() as m: auth_prompt = "Please enter the authentication token" - default_config = cli.create_config()["api"] + default_config = copy.deepcopy(cli.DEFAULT_CONFIG["api"]) default_auth = "SomeAuth" default_config['authentication_token'] = default_auth @@ -291,7 +292,7 @@ def test_correct_inputs(self, monkeypatch): with monkeypatch.context() as m: auth_prompt = "Please enter the authentication token" - default_config = cli.create_config()["api"] + default_config = copy.deepcopy(cli.DEFAULT_CONFIG["api"]) default_auth = "SomeAuth" default_config['authentication_token'] = default_auth