From 70e1cc19b7295c1f7574e9d217314bf2301550a7 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 11 Apr 2020 14:12:20 +0930 Subject: [PATCH 1/7] Small cleaning up of the configuration module, to remove hardcoded API option dependence --- strawberryfields/__init__.py | 14 +- strawberryfields/cli/__init__.py | 4 +- strawberryfields/configuration.py | 214 ++++++++++++--------------- tests/frontend/test_configuration.py | 151 ++++--------------- tests/frontend/test_sf_cli.py | 9 +- 5 files changed, 148 insertions(+), 244 deletions(-) diff --git a/strawberryfields/__init__.py b/strawberryfields/__init__.py index 791047cfe..1ef4e40b3 100644 --- a/strawberryfields/__init__.py +++ b/strawberryfields/__init__.py @@ -30,7 +30,19 @@ from .parameters import par_funcs as math from .program import Program -__all__ = ["Engine", "RemoteEngine", "Program", "version", "save", "load", "about", "cite"] +__all__ = [ + "Engine", + "RemoteEngine", + "Program", + "version", + "save", + "load", + "about", + "cite", + "math", + "ping", + "store_account", +] #: float: numerical value of hbar for the frontend (in the implicit units of position * momentum) diff --git a/strawberryfields/cli/__init__.py b/strawberryfields/cli/__init__.py index f0d02c524..3603b4ec9 100755 --- a/strawberryfields/cli/__init__.py +++ b/strawberryfields/cli/__init__.py @@ -20,7 +20,7 @@ import sys from strawberryfields.api import Connection -from strawberryfields.configuration import ConfigurationError, create_config, store_account +from strawberryfields.configuration import ConfigurationError, DEFAULT_CONFIG, store_account from strawberryfields.engine import RemoteEngine from strawberryfields.io import load @@ -162,7 +162,7 @@ def configuration_wizard(): Returns: dict[str, Union[str, bool, int]]: the configuration options """ - default_config = create_config()["api"] + default_config = DEFAULT_CONFIG["api"] # Getting default values that can be used for as messages when getting inputs hostname_default = default_config["hostname"] diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index 892d07dcf..7e3d38fdb 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -15,6 +15,7 @@ This module contains functions used to load, store, save, and modify configuration options for Strawberry Fields. """ +import collections import os import toml @@ -22,6 +23,7 @@ from strawberryfields.logger import create_logger + DEFAULT_CONFIG_SPEC = { "api": { "authentication_token": (str, ""), @@ -36,6 +38,60 @@ class ConfigurationError(Exception): """Exception used for configuration errors""" +def _deep_update(source, overrides): + """Update a nested dictionary.""" + for key, value in overrides.items(): + if isinstance(value, collections.Mapping) and value: + returned = _deep_update(source.get(key, {}), value) + source[key] = returned + elif value != {}: + source[key] = overrides[key] + return source + + +def _generate_config(config_spec, **kwargs): + """Generates a configuration, given a configuration specification + and optional keyword arguments. + + Args: + config_spec (dict): Nested dictionary representing the + configuration specification. Keys in the dictionary + represent allowed configuration keys. Corresponding + values must be a tuple, with the first element representing + the type, and the second representing the default value. + + Keyword Args: + Provided keyword arguments may overwrite default values of + matching (nested) keys. + + Returns: + dict: the default configuration defined by the input config spec + """ + res = {} + for k, v in config_spec.items(): + if isinstance(v, tuple) and isinstance(v[0], type): + # config spec value v represents the allowed type and default value + + if k in kwargs: + # Key also exists as a keyword argument. + # Perform type validation. + if not isinstance(kwargs[k], v[0]): + raise ConfigurationError( + "Expected type {} for option {}, received {}".format( + v[0], k, type(kwargs[k]) + ) + ) + + res[k] = kwargs[k] + else: + res[k] = v[1] + + elif isinstance(v, dict): + # config spec value is a dictionary of more options + res[k] = _generate_config(v, **kwargs.get(k, {})) + return res + + def load_config(filename="config.toml", **kwargs): """Load configuration from keyword arguments, configuration file or environment variables. @@ -50,66 +106,46 @@ def load_config(filename="config.toml", **kwargs): 2. data contained in environmental variables (if any) 3. data contained in a configuration file (if exists) - Keyword Args: + Args: filename (str): the name of the configuration file to look for. - Additional configuration options are detailed in + + Keyword Args: + Additional configuration options are detailed in :doc:`/code/sf_configuration` Returns: dict[str, dict[str, Union[str, bool, int]]]: the configuration """ - config = create_config() - filepath = find_config_file(filename=filename) if filepath is not None: - loaded_config = load_config_file(filepath) - api_config = get_api_config(loaded_config, filepath) + # load the configuration file + with open(filepath, "r") as f: + config = toml.load(f) + + if "api" not in config: + # Raise a warning if the configuration doesn't contain + # an API section. + log = create_logger(__name__) + log.warning('The configuration from the %s file does not contain an "api" section.', filepath) - valid_api_options = keep_valid_options(api_config) - config["api"].update(valid_api_options) else: + config = {} log = create_logger(__name__) log.warning("No Strawberry Fields configuration file found.") + # update the configuration from environment variables update_from_environment_variables(config) - valid_kwargs_config = keep_valid_options(kwargs) - config["api"].update(valid_kwargs_config) - - return config - - -def create_config(authentication_token=None, **kwargs): - """Create a configuration object that stores configuration related data - organized into sections. - - The configuration object contains API-related configuration options. This - function takes into consideration only pre-defined options. - - If called without passing any keyword arguments, then a default - configuration object is created. - - Keyword Args: - Configuration options as detailed in :doc:`/code/sf_configuration` + # update the configuration from keyword arguments + # NOTE: currently the configuration keyword arguments are specific + # only to the API section. Once we have more configuration sections, + # they will likely need to be passed via separate keyword arguments. + _deep_update(config, {"api": kwargs}) - Returns: - dict[str, dict[str, Union[str, bool, int]]]: the configuration - object - """ - authentication_token = authentication_token or "" - hostname = kwargs.get("hostname", DEFAULT_CONFIG_SPEC["api"]["hostname"][1]) - use_ssl = kwargs.get("use_ssl", DEFAULT_CONFIG_SPEC["api"]["use_ssl"][1]) - port = kwargs.get("port", DEFAULT_CONFIG_SPEC["api"]["port"][1]) - - config = { - "api": { - "authentication_token": authentication_token, - "hostname": hostname, - "use_ssl": use_ssl, - "port": port, - } - } + # generate the configuration object by using the defined + # configuration specification at the top of the file + config = _generate_config(DEFAULT_CONFIG_SPEC, **config) return config @@ -164,11 +200,10 @@ def find_config_file(filename="config.toml"): Union[str, None]: the filepath to the configuration file or None, if no file was found """ - directories = directories_to_check() - for directory in directories: - filepath = os.path.join(directory, filename) - if os.path.exists(filepath): - return filepath + directories = get_available_config_paths(filename=filename) + + if directories: + return directories[0] return None @@ -194,65 +229,15 @@ def directories_to_check(): sf_user_config_dir = user_config_dir("strawberryfields", "Xanadu") directories.append(current_dir) - if sf_env_config_dir != "": + + if sf_env_config_dir: directories.append(sf_env_config_dir) + directories.append(sf_user_config_dir) return directories -def load_config_file(filepath): - """Load a configuration object from a TOML formatted file. - - Args: - filepath (str): path to the configuration file - - Returns: - dict[str, dict[str, Union[str, bool, int]]]: the configuration - object that was loaded - """ - with open(filepath, "r") as f: - config_from_file = toml.load(f) - return config_from_file - - -def get_api_config(loaded_config, filepath): - """Gets the API section from the loaded configuration. - - Args: - loaded_config (dict): the configuration that was loaded from the TOML config - file - filepath (str): path to the configuration file - - Returns: - dict[str, Union[str, bool, int]]: the api section of the configuration - - Raises: - ConfigurationError: if the api section was not defined in the - configuration - """ - try: - return loaded_config["api"] - except KeyError: - log = create_logger(__name__) - log.error('The configuration from the %s file does not contain an "api" section.', filepath) - raise ConfigurationError - - -def keep_valid_options(sectionconfig): - """Filters the valid options in a section of a configuration dictionary. - - Args: - sectionconfig (dict[str, Union[str, bool, int]]): the section of the - configuration to check - - Returns: - dict[str, Union[str, bool, int]]: the keep section of the - configuration - """ - return {k: v for k, v in sectionconfig.items() if k in VALID_KEYS} - - def update_from_environment_variables(config): """Updates the current configuration object from data stored in environment variables. @@ -271,13 +256,14 @@ def update_from_environment_variables(config): for key in sectionconfig: env = env_prefix + key.upper() if env in os.environ: - config[section][key] = parse_environment_variable(key, os.environ[env]) + config[section][key] = _parse_environment_variable(section, key, os.environ[env]) -def parse_environment_variable(key, value): +def _parse_environment_variable(section, key, value): """Parse a value stored in an environment variable. Args: + section (str): configuration section name key (str): the name of the environment variable value (Union[str, bool, int]): the value obtained from the environment variable @@ -288,7 +274,7 @@ def parse_environment_variable(key, value): trues = (True, "true", "True", "TRUE", "1", 1) falses = (False, "false", "False", "FALSE", "0", 0) - if DEFAULT_CONFIG_SPEC["api"][key][0] is bool: + if DEFAULT_CONFIG_SPEC[section][key][0] is bool: if value in trues: return True @@ -297,7 +283,7 @@ def parse_environment_variable(key, value): raise ValueError("Boolean could not be parsed") - if DEFAULT_CONFIG_SPEC["api"][key][0] is int: + if DEFAULT_CONFIG_SPEC[section][key][0] is int: return int(value) return value @@ -450,21 +436,13 @@ def store_account(authentication_token, filename="config.toml", location="user_c filepath = os.path.join(directory, filename) - config = create_config(authentication_token=authentication_token, **kwargs) - save_config_to_file(config, filepath) - + # generate the configuration object by using the defined + # configuration specification at the top of the file + kwargs.update({"authentication_token": authentication_token}) + config = _generate_config(DEFAULT_CONFIG_SPEC, api=kwargs) -def save_config_to_file(config, filepath): - """Saves a configuration to a TOML file. - - Args: - config (dict[str, dict[str, Union[str, bool, int]]]): the - configuration to be saved - filepath (str): path to the configuration file - """ with open(filepath, "w") as f: toml.dump(config, f) -VALID_KEYS = set(create_config()["api"].keys()) -DEFAULT_CONFIG = create_config() +DEFAULT_CONFIG = _generate_config(DEFAULT_CONFIG_SPEC) diff --git a/tests/frontend/test_configuration.py b/tests/frontend/test_configuration.py index 07993cc20..64c82a832 100644 --- a/tests/frontend/test_configuration.py +++ b/tests/frontend/test_configuration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the configuration module""" +import copy import os import logging import pytest @@ -142,9 +143,8 @@ def test_get_api_section_safely_error(self, monkeypatch, tmpdir, caplog): f.write(empty_file) with monkeypatch.context() as m: - with pytest.raises(conf.ConfigurationError, match=""): - m.setattr(os, "getcwd", lambda: tmpdir) - configuration = conf.load_config() + m.setattr(os, "getcwd", lambda: tmpdir) + configuration = conf.load_config() assert "does not contain an \"api\" section" in caplog.text @@ -290,32 +290,31 @@ def test_print_active_configs_no_configs(self, capsys, monkeypatch): general_message_2 + first_dir_msg + second_dir_msg + third_dir_msg -class TestCreateConfigObject: +class TestGenerateConfigObject: """Test the creation of a configuration object""" - def test_empty_config_object(self): - """Test that an empty configuration object can be created.""" - config = conf.create_config(authentication_token="", hostname="", use_ssl="", port="") - - assert all(value == "" for value in config["api"].values()) + def test_type_validation(self): + """Test that passing an incorrect type raises an exception""" + with pytest.raises(conf.ConfigurationError, match="Expected type"): + config = conf._generate_config( + conf.DEFAULT_CONFIG_SPEC, api={"use_ssl":""} + ) def test_config_object_with_authentication_token(self): """Test that passing only the authentication token creates the expected configuration object.""" - assert ( - conf.create_config(authentication_token="071cdcce-9241-4965-93af-4a4dbc739135") - == EXPECTED_CONFIG + config = conf._generate_config( + conf.DEFAULT_CONFIG_SPEC, api={"authentication_token":"071cdcce-9241-4965-93af-4a4dbc739135",} ) + assert config == EXPECTED_CONFIG def test_config_object_every_keyword_argument(self): """Test that passing every keyword argument creates the expected configuration object.""" - assert ( - conf.create_config( - authentication_token="SomeAuth", hostname="SomeHost", use_ssl=False, port=56 - ) - == OTHER_EXPECTED_CONFIG + config = conf._generate_config( + conf.DEFAULT_CONFIG_SPEC, api={"authentication_token":"SomeAuth", "hostname":"SomeHost", "use_ssl":False, "port":56} ) + assert config == OTHER_EXPECTED_CONFIG class TestRemoveConfigFile: """Test the removal of configuration files""" @@ -475,61 +474,6 @@ def raise_wrapper(ex): assert config_filepath is None -class TestLoadConfigFile: - """Tests the load_config_file function.""" - - def test_load_config_file(self, monkeypatch, tmpdir): - """Tests that configuration is loaded correctly from a TOML file.""" - filename = tmpdir.join("test_config.toml") - - with open(filename, "w") as f: - f.write(TEST_FILE) - - loaded_config = conf.load_config_file(filepath=filename) - - assert loaded_config == EXPECTED_CONFIG - - def test_loading_absolute_path(self, monkeypatch, tmpdir): - """Test that the default configuration file can be loaded - via an absolute path.""" - filename = tmpdir.join("test_config.toml") - - with open(filename, "w") as f: - f.write(TEST_FILE) - - with monkeypatch.context() as m: - m.setenv("SF_CONF", "") - loaded_config = conf.load_config_file(filepath=filename) - - assert loaded_config == EXPECTED_CONFIG - - -class TestKeepValidOptions: - def test_only_invalid_options(self): - section_config_with_invalid_options = {"NotValid1": 1, "NotValid2": 2, "NotValid3": 3} - assert conf.keep_valid_options(section_config_with_invalid_options) == {} - - def test_valid_and_invalid_options(self): - section_config_with_invalid_options = { - "authentication_token": "MyToken", - "NotValid1": 1, - "NotValid2": 2, - "NotValid3": 3, - } - assert conf.keep_valid_options(section_config_with_invalid_options) == { - "authentication_token": "MyToken" - } - - def test_only_valid_options(self): - section_config_only_valid = { - "authentication_token": "071cdcce-9241-4965-93af-4a4dbc739135", - "hostname": "platform.strawberryfields.ai", - "use_ssl": True, - "port": 443, - } - assert conf.keep_valid_options(section_config_only_valid) == EXPECTED_CONFIG["api"] - - value_mapping = [ ("SF_API_AUTHENTICATION_TOKEN", "SomeAuth"), ("SF_API_HOSTNAME", "SomeHost"), @@ -558,7 +502,7 @@ def test_all_environment_variables_defined(self, monkeypatch): for env_var, value in value_mapping: m.setenv(env_var, value) - config = conf.create_config() + config = copy.deepcopy(conf.DEFAULT_CONFIG) for v, parsed_value in zip(config["api"].values(), parsed_values_mapping.values()): assert v != parsed_value @@ -581,7 +525,7 @@ def test_one_environment_variable_defined(self, env_var, key, value, monkeypatch with monkeypatch.context() as m: m.setenv(env_var, value) - config = conf.create_config() + config = copy.deepcopy(conf.DEFAULT_CONFIG) for v, parsed_value in zip(config["api"].values(), parsed_values_mapping.values()): assert v != parsed_value @@ -598,24 +542,24 @@ def test_parse_environment_variable_boolean(self, monkeypatch): """Tests that boolean values can be parsed correctly from environment variables.""" monkeypatch.setattr(conf, "DEFAULT_CONFIG_SPEC", {"api": {"some_boolean": (bool, True)}}) - assert conf.parse_environment_variable("some_boolean", "true") is True - assert conf.parse_environment_variable("some_boolean", "True") is True - assert conf.parse_environment_variable("some_boolean", "TRUE") is True - assert conf.parse_environment_variable("some_boolean", "1") is True - assert conf.parse_environment_variable("some_boolean", 1) is True - - assert conf.parse_environment_variable("some_boolean", "false") is False - assert conf.parse_environment_variable("some_boolean", "False") is False - assert conf.parse_environment_variable("some_boolean", "FALSE") is False - assert conf.parse_environment_variable("some_boolean", "0") is False - assert conf.parse_environment_variable("some_boolean", 0) is False + assert conf._parse_environment_variable("api", "some_boolean", "true") is True + assert conf._parse_environment_variable("api", "some_boolean", "True") is True + assert conf._parse_environment_variable("api", "some_boolean", "TRUE") is True + assert conf._parse_environment_variable("api", "some_boolean", "1") is True + assert conf._parse_environment_variable("api", "some_boolean", 1) is True + + assert conf._parse_environment_variable("api", "some_boolean", "false") is False + assert conf._parse_environment_variable("api", "some_boolean", "False") is False + assert conf._parse_environment_variable("api", "some_boolean", "FALSE") is False + assert conf._parse_environment_variable("api", "some_boolean", "0") is False + assert conf._parse_environment_variable("api", "some_boolean", 0) is False def test_parse_environment_variable_integer(self, monkeypatch): """Tests that integer values can be parsed correctly from environment variables.""" monkeypatch.setattr(conf, "DEFAULT_CONFIG_SPEC", {"api": {"some_integer": (int, 123)}}) - assert conf.parse_environment_variable("some_integer", "123") == 123 + assert conf._parse_environment_variable("api", "some_integer", "123") == 123 DEFAULT_KWARGS = {"hostname": "platform.strawberryfields.ai", "use_ssl": True, "port": 443} @@ -656,8 +600,7 @@ def test_config_created_locally(self, monkeypatch, tmpdir): with monkeypatch.context() as m: m.setattr(os, "getcwd", lambda: tmpdir) m.setattr(conf, "user_config_dir", lambda *args: "NotTheCorrectDir") - m.setattr(conf, "create_config", mock_create_config) - m.setattr(conf, "save_config_to_file", lambda a, b: mock_save_config_file.update(a, b)) + m.setattr("toml.dump", lambda a, b: mock_save_config_file.update(a, b.name)) conf.store_account( authentication_token, filename="config.toml", location="local", **DEFAULT_KWARGS ) @@ -676,8 +619,7 @@ def test_global_config_created(self, monkeypatch, tmpdir): with monkeypatch.context() as m: m.setattr(os, "getcwd", lambda: "NotTheCorrectDir") m.setattr(conf, "user_config_dir", lambda *args: tmpdir) - m.setattr(conf, "create_config", mock_create_config) - m.setattr(conf, "save_config_to_file", lambda a, b: mock_save_config_file.update(a, b)) + m.setattr("toml.dump", lambda a, b: mock_save_config_file.update(a, b.name)) conf.store_account( authentication_token, filename="config.toml", @@ -800,32 +742,3 @@ def test_nested_directory_is_created(self, monkeypatch, tmpdir): filepath = os.path.join(recursive_dir, "config.toml") result = toml.load(filepath) assert result == EXPECTED_CONFIG - - -class TestSaveConfigToFile: - """Tests for the store_account function.""" - - def test_correct(self, tmpdir): - """Test saving a configuration file.""" - filepath = str(tmpdir.join("config.toml")) - - conf.save_config_to_file(OTHER_EXPECTED_CONFIG, filepath) - - result = toml.load(filepath) - assert result == OTHER_EXPECTED_CONFIG - - def test_file_already_existed(self, tmpdir): - """Test saving a configuration file even if the file already - existed.""" - filepath = str(tmpdir.join("config.toml")) - - with open(filepath, "w") as f: - f.write(TEST_FILE) - - result_for_existing_file = toml.load(filepath) - assert result_for_existing_file == EXPECTED_CONFIG - - conf.save_config_to_file(OTHER_EXPECTED_CONFIG, filepath) - - result_for_new_file = toml.load(filepath) - assert result_for_new_file == OTHER_EXPECTED_CONFIG diff --git a/tests/frontend/test_sf_cli.py b/tests/frontend/test_sf_cli.py index 43076ffd3..f8263c929 100644 --- a/tests/frontend/test_sf_cli.py +++ b/tests/frontend/test_sf_cli.py @@ -15,6 +15,7 @@ Unit tests for the Strawberry Fields command line interface. """ # pylint: disable=no-self-use,unused-argument +import copy import os import functools import argparse @@ -160,7 +161,7 @@ def test_configuration_wizard(self, monkeypatch): configuration takes place using the configuration_wizard function.""" with monkeypatch.context() as m: mock_store_account = MockStoreAccount() - m.setattr(cli, "configuration_wizard", lambda: cli.create_config()["api"]) + m.setattr(cli, "configuration_wizard", lambda: cli.DEFAULT_CONFIG["api"]) m.setattr(cli, "store_account", mock_store_account.store_account) args = MockArgs() @@ -190,7 +191,7 @@ def test_configuration_wizard_local(self, monkeypatch): the configuration_wizard function.""" with monkeypatch.context() as m: mock_store_account = MockStoreAccount() - m.setattr(cli, "configuration_wizard", lambda: cli.create_config()["api"]) + m.setattr(cli, "configuration_wizard", lambda: cli.DEFAULT_CONFIG["api"]) m.setattr(cli, "store_account", mock_store_account.store_account) args = MockArgs() @@ -278,7 +279,7 @@ def test_auth_correct(self, monkeypatch): correctly, once the authentication token is passed.""" with monkeypatch.context() as m: auth_prompt = "Please enter the authentication token" - default_config = cli.create_config()["api"] + default_config = copy.deepcopy(cli.DEFAULT_CONFIG["api"]) default_auth = "SomeAuth" default_config['authentication_token'] = default_auth @@ -291,7 +292,7 @@ def test_correct_inputs(self, monkeypatch): with monkeypatch.context() as m: auth_prompt = "Please enter the authentication token" - default_config = cli.create_config()["api"] + default_config = copy.deepcopy(cli.DEFAULT_CONFIG["api"]) default_auth = "SomeAuth" default_config['authentication_token'] = default_auth From 2b1586b3a269273742648f51241b5097a3228692 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 11 Apr 2020 14:37:46 +0930 Subject: [PATCH 2/7] fix docs building --- doc/code/sf_cli.rst | 2 +- doc/code/sf_configuration.rst | 2 +- strawberryfields/__init__.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/doc/code/sf_cli.rst b/doc/code/sf_cli.rst index e0fccb25c..cb569f44d 100644 --- a/doc/code/sf_cli.rst +++ b/doc/code/sf_cli.rst @@ -136,4 +136,4 @@ Code details .. automodapi:: strawberryfields.cli :no-heading: :no-inheritance-diagram: - :skip: store_account, create_config, load, RemoteEngine, Connection, ConfigurationError + :skip: store_account, DEFAULT_CONFIG, load, RemoteEngine, Connection, ConfigurationError, ping diff --git a/doc/code/sf_configuration.rst b/doc/code/sf_configuration.rst index 5d92dcad1..7507df645 100644 --- a/doc/code/sf_configuration.rst +++ b/doc/code/sf_configuration.rst @@ -67,5 +67,5 @@ Functions .. automodapi:: strawberryfields.configuration :no-heading: - :skip: user_config_dir + :skip: user_config_dir, store_account, active_configs, reset_config, create_logger :no-inheritance-diagram: diff --git a/strawberryfields/__init__.py b/strawberryfields/__init__.py index 1ef4e40b3..8e1bba861 100644 --- a/strawberryfields/__init__.py +++ b/strawberryfields/__init__.py @@ -24,7 +24,7 @@ from . import apps from ._version import __version__ from .cli import ping -from .configuration import store_account +from .configuration import store_account, active_configs, reset_config from .engine import Engine, LocalEngine, RemoteEngine from .io import load, save from .parameters import par_funcs as math @@ -39,9 +39,10 @@ "load", "about", "cite", - "math", "ping", "store_account", + "active_configs", + "reset_config" ] From 11fdb1cb5dabfc21831492fbd8fa612f641f3a9c Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 12 Apr 2020 15:56:51 +0930 Subject: [PATCH 3/7] added logger --- doc/code/sf_configuration.rst | 33 ++++++++ strawberryfields/__init__.py | 2 +- strawberryfields/configuration.py | 117 +++++++++++++++++++++------ strawberryfields/logger.py | 37 +++++++-- tests/frontend/test_configuration.py | 15 ++-- tests/frontend/test_logger.py | 4 +- 6 files changed, 171 insertions(+), 37 deletions(-) diff --git a/doc/code/sf_configuration.rst b/doc/code/sf_configuration.rst index 7507df645..8e2dea094 100644 --- a/doc/code/sf_configuration.rst +++ b/doc/code/sf_configuration.rst @@ -33,9 +33,19 @@ and has the following format: use_ssl = true port = 443 + [logging] + # Options for the logger + level = "warning" + logfile = "sf.log" + Configuration options --------------------- +``[api]`` +^^^^^^^^^ + +Settings for the Xanadu cloud platform. + **authentication_token (str)** (*required*) API token for authentication to the Xanadu cloud platform. This is required for submitting remote jobs using :class:`~.RemoteEngine`. @@ -57,6 +67,29 @@ Configuration options *Corresponding environment variable:* ``SF_API_PORT`` +``[logging]`` +^^^^^^^^^^^^^ + +Settings for the Strawberry Fields logger. + +**level (str)** (*optional*) + Specifies the level of information that should be printed to the standard + output. Defaults to ``"info"``, which indicates that all logged details + are displayed as output. + + Other options include ``"error"``, ``"warning"``, ``"info"``, ``"debug"``, + in decreasing levels of verbosity. + + *Corresponding environment variable:* ``SF_LOGGING_LEVEL`` + +**logfile (str)** (*optional*) + The filepath of an output logfile. This may be a relative or an + absolute path. If specified, all logging data is appended to this + file during Strawberry Fields execution. + + *Corresponding environment variable:* ``SF_LOGGING_LOGFILE`` + + Functions --------- diff --git a/strawberryfields/__init__.py b/strawberryfields/__init__.py index 8e1bba861..057ec9581 100644 --- a/strawberryfields/__init__.py +++ b/strawberryfields/__init__.py @@ -42,7 +42,7 @@ "ping", "store_account", "active_configs", - "reset_config" + "reset_config", ] diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index 7e3d38fdb..aafbaa393 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -30,8 +30,27 @@ "hostname": (str, "platform.strawberryfields.ai"), "use_ssl": (bool, True), "port": (int, 443), - } + }, + "logging": {"level": (str, "info"), "logfile": ((str, type(None)), None)}, } +"""dict: Nested dictionary representing the allowed configuration +sections, options, default values, and allowed types for Strawberry +Fields configurations. For each configuration option key, the +corresponding value is a length-2 tuple, containing: + +* A type or tuple of types, representing the allowed type + for that configuration option. + +* The default value for that configuration option. + +.. note:: + + By TOML convention, keys with a default value of ``None`` + will **not** be present in the generated/loaded configuration + file. This is because TOML has no concept of ``NoneType`` or ``Null``, + instead, the non-presence of a key indicates that the configuration + value is not set. +""" class ConfigurationError(Exception): @@ -39,37 +58,66 @@ class ConfigurationError(Exception): def _deep_update(source, overrides): - """Update a nested dictionary.""" + """Recursively update a nested dictionary. + + This function is a generalization of Python's built in + ``dict.update`` method, modified to recursively update + keys with nested dictionaries. + """ for key, value in overrides.items(): if isinstance(value, collections.Mapping) and value: + # Override value is a non-empty dictionary. + # Update the source key with the override dictionary. returned = _deep_update(source.get(key, {}), value) source[key] = returned elif value != {}: + # Override value is not an empty dictionary. source[key] = overrides[key] return source def _generate_config(config_spec, **kwargs): - """Generates a configuration, given a configuration specification - and optional keyword arguments. + """Generates a configuration, given a Strawberry Fields configuration + specification. + + See :attr:`~.DEFAULT_CONFIG_SPEC` for an example of a valid configuration + specification. + + Optional keyword arguments may be provided to override default values + in the cofiguration specification. If the provided override values + do not match the expected type defined in the configuration spec, + a ``ConfigurationError`` is raised. + + **Example** + + >>> _generate_config(DEFAULT_CONFIG_SPEC, api={"port": 54}) + { + "api": { + "authentication_token": "", + "hostname": "platform.strawberryfields.ai", + "use_ssl": True, + "port": 54, + } + } Args: - config_spec (dict): Nested dictionary representing the - configuration specification. Keys in the dictionary - represent allowed configuration keys. Corresponding - values must be a tuple, with the first element representing - the type, and the second representing the default value. + config_spec (dict): nested dictionary representing the + configuration specification Keyword Args: Provided keyword arguments may overwrite default values of - matching (nested) keys. + matching keys. Returns: dict: the default configuration defined by the input config spec + + Raises: + ConfigurationError: if provided keyword argument overrides do not + match the expected type defined in the configuration spec. """ res = {} for k, v in config_spec.items(): - if isinstance(v, tuple) and isinstance(v[0], type): + if isinstance(v, tuple): # config spec value v represents the allowed type and default value if k in kwargs: @@ -82,9 +130,15 @@ def _generate_config(config_spec, **kwargs): ) ) - res[k] = kwargs[k] + if kwargs[k] is not None: + # Only add the key to the configuration object + # if the provided override is not None. + res[k] = kwargs[k] else: - res[k] = v[1] + if v[1] is not None: + # Only add the key to the configuration object + # if the default value is not None. + res[k] = v[1] elif isinstance(v, dict): # config spec value is a dictionary of more options @@ -92,7 +146,7 @@ def _generate_config(config_spec, **kwargs): return res -def load_config(filename="config.toml", **kwargs): +def load_config(filename="config.toml", verbose=True, **kwargs): """Load configuration from keyword arguments, configuration file or environment variables. @@ -107,7 +161,8 @@ def load_config(filename="config.toml", **kwargs): 3. data contained in a configuration file (if exists) Args: - filename (str): the name of the configuration file to look for. + filename (str): the name of the configuration file to look for + verbose (bool): whether or not to log warnings and errors Keyword Args: Additional configuration options are detailed in @@ -123,13 +178,15 @@ def load_config(filename="config.toml", **kwargs): with open(filepath, "r") as f: config = toml.load(f) - if "api" not in config: + if "api" not in config and verbose: # Raise a warning if the configuration doesn't contain # an API section. log = create_logger(__name__) - log.warning('The configuration from the %s file does not contain an "api" section.', filepath) + log.warning( + 'The configuration from the %s file does not contain an "api" section.', filepath + ) - else: + elif verbose: config = {} log = create_logger(__name__) log.warning("No Strawberry Fields configuration file found.") @@ -138,10 +195,8 @@ def load_config(filename="config.toml", **kwargs): update_from_environment_variables(config) # update the configuration from keyword arguments - # NOTE: currently the configuration keyword arguments are specific - # only to the API section. Once we have more configuration sections, - # they will likely need to be passed via separate keyword arguments. - _deep_update(config, {"api": kwargs}) + for config_section, section_options in kwargs.items(): + _deep_update(config, {config_section: section_options}) # generate the configuration object by using the defined # configuration specification at the top of the file @@ -436,10 +491,24 @@ def store_account(authentication_token, filename="config.toml", location="user_c filepath = os.path.join(directory, filename) + config = {} + + # load the existing config if it already exists + if os.path.isfile(filepath): + with open(filepath, "r") as f: + config = toml.load(f) + + # update the loaded configuration file with the specified + # authentication token + kwargs.update({"authentication_token": authentication_token}) + + # update the loaded configuration with any + # provided API options passed as keyword arguments + _deep_update(config, {"api": kwargs}) + # generate the configuration object by using the defined # configuration specification at the top of the file - kwargs.update({"authentication_token": authentication_token}) - config = _generate_config(DEFAULT_CONFIG_SPEC, api=kwargs) + config = _generate_config(DEFAULT_CONFIG_SPEC, **config) with open(filepath, "w") as f: toml.dump(config, f) diff --git a/strawberryfields/logger.py b/strawberryfields/logger.py index 86381d5ca..80762c900 100644 --- a/strawberryfields/logger.py +++ b/strawberryfields/logger.py @@ -75,12 +75,12 @@ def logging_handler_defined(logger): return False -default_handler = logging.StreamHandler(sys.stderr) +output_handler = logging.StreamHandler(sys.stderr) formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -default_handler.setFormatter(formatter) +output_handler.setFormatter(formatter) -def create_logger(name, level=logging.INFO): +def create_logger(name, level=None): """Get the Strawberry Fields module specific logger and configure it if needed. Configuration only takes place if no user configuration was applied to the @@ -108,7 +108,34 @@ def create_logger(name, level=logging.INFO): no_handlers = not logging_handler_defined(logger) if effective_level_inherited and level_not_set and no_handlers: - logger.setLevel(level) - logger.addHandler(default_handler) + # The root logger should pass all log message levels + # to the handlers. + logger.setLevel(logging.DEBUG) + + # Import load_config here to avoid a cyclic import + # (since the configuration module imports the logger). + from strawberryfields.configuration import load_config + + # The load_config function by default logs information. + # We need to turn this off here, since the logger has not + # been created yet. + default_config = load_config(verbose=False) + level = level or getattr(logging, default_config["logging"]["level"].upper()) + + # Attach the standard output logger, + # with the user defined logging level (defaults to INFO) + output_handler.setLevel(level) + logger.addHandler(output_handler) + + if "logfile" in default_config["logging"]: + # Create the file logger + file_handler = logging.FileHandler( + default_config["logging"]["logfile"], disable_existing_loggers=False + ) + file_handler.setFormatter(formatter) + + # file logger should display all log message levels + file_handler.setLevel(logging.DEBUG) + logger.addHandler(file_handler) return logger diff --git a/tests/frontend/test_configuration.py b/tests/frontend/test_configuration.py index 64c82a832..fcd8c3d7a 100644 --- a/tests/frontend/test_configuration.py +++ b/tests/frontend/test_configuration.py @@ -47,7 +47,8 @@ "hostname": "platform.strawberryfields.ai", "use_ssl": True, "port": 443, - } + }, + 'logging': {'level': 'info'} } OTHER_EXPECTED_CONFIG = { @@ -56,7 +57,8 @@ "hostname": "SomeHost", "use_ssl": False, "port": 56, - } + }, + 'logging': {'level': 'info'} } environment_variables = [ @@ -92,9 +94,12 @@ def test_keywords_take_precedence_over_everything(self, monkeypatch, tmpdir): m.setenv("SF_API_PORT", "42") m.setattr(os, "getcwd", lambda: tmpdir) - configuration = conf.load_config( - authentication_token="SomeAuth", hostname="SomeHost", use_ssl=False, port=56 - ) + configuration = conf.load_config(api={ + "authentication_token": "SomeAuth", + "hostname": "SomeHost", + "use_ssl": False, + "port": 56 + }) assert configuration == OTHER_EXPECTED_CONFIG diff --git a/tests/frontend/test_logger.py b/tests/frontend/test_logger.py index 3d2df772c..d0d118c55 100644 --- a/tests/frontend/test_logger.py +++ b/tests/frontend/test_logger.py @@ -52,7 +52,7 @@ import strawberryfields.api.connection as connection import strawberryfields.engine as engine -from strawberryfields.logger import logging_handler_defined, default_handler, create_logger +from strawberryfields.logger import logging_handler_defined, output_handler, create_logger modules_contain_logging = [job, connection, engine] @@ -114,7 +114,7 @@ def test_create_logger(self, module): logger = create_logger(module.__name__) assert logger.level == logging.INFO assert logging_handler_defined(logger) - assert logger.handlers[0] == default_handler + assert logger.handlers[0] == output_handler class TestLoggerIntegration: """Tests that the SF logger integrates well with user defined logging From bee52ad939609e0700c35f7ea949775ba2114d6e Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 12 Apr 2020 21:19:06 +0930 Subject: [PATCH 4/7] logging --- strawberryfields/api/connection.py | 2 ++ strawberryfields/api/job.py | 1 + strawberryfields/configuration.py | 32 +++++++++++++++++++++--------- strawberryfields/engine.py | 5 +++++ strawberryfields/logger.py | 19 +++++------------- 5 files changed, 36 insertions(+), 23 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 1a0e2db90..b066acfb9 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -154,6 +154,8 @@ def create_job(self, target: str, program: Program, run_options: dict = None) -> circuit = bb.serialize() + self.log.debug("Submitting job\n%s", circuit) + path = "/jobs" response = requests.post(self._url(path), headers=self._headers, json={"circuit": circuit}) if response.status_code == 201: diff --git a/strawberryfields/api/job.py b/strawberryfields/api/job.py index ef1a063e9..4e3750b79 100644 --- a/strawberryfields/api/job.py +++ b/strawberryfields/api/job.py @@ -129,6 +129,7 @@ def refresh(self): self._status = JobStatus(self._connection.get_job_status(self.id)) if self._status == JobStatus.COMPLETED: self._result = self._connection.get_job_result(self.id) + self.log.info("Job %s is complete", self.id) def cancel(self): """Cancels an open or queued job. diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index aafbaa393..d4de50ff4 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -21,8 +21,6 @@ import toml from appdirs import user_config_dir -from strawberryfields.logger import create_logger - DEFAULT_CONFIG_SPEC = { "api": { @@ -146,7 +144,7 @@ def _generate_config(config_spec, **kwargs): return res -def load_config(filename="config.toml", verbose=True, **kwargs): +def load_config(filename="config.toml", logging=True, **kwargs): """Load configuration from keyword arguments, configuration file or environment variables. @@ -162,7 +160,7 @@ def load_config(filename="config.toml", verbose=True, **kwargs): Args: filename (str): the name of the configuration file to look for - verbose (bool): whether or not to log warnings and errors + logging (bool): whether or not to log details Keyword Args: Additional configuration options are detailed in @@ -173,23 +171,30 @@ def load_config(filename="config.toml", verbose=True, **kwargs): """ filepath = find_config_file(filename=filename) + if logging: + from strawberryfields.logger import create_logger + log = create_logger(__name__) + if filepath is not None: # load the configuration file with open(filepath, "r") as f: config = toml.load(f) - if "api" not in config and verbose: + if logging: + log.debug("Configuration file %s loaded", filepath) + + if "api" not in config and logging: # Raise a warning if the configuration doesn't contain # an API section. - log = create_logger(__name__) log.warning( 'The configuration from the %s file does not contain an "api" section.', filepath ) - elif verbose: + else: config = {} - log = create_logger(__name__) - log.warning("No Strawberry Fields configuration file found.") + + if logging: + log.warning("No Strawberry Fields configuration file found.") # update the configuration from environment variables update_from_environment_variables(config) @@ -201,6 +206,14 @@ def load_config(filename="config.toml", verbose=True, **kwargs): # generate the configuration object by using the defined # configuration specification at the top of the file config = _generate_config(DEFAULT_CONFIG_SPEC, **config) + + # Log the loaded configuration details, masking out the API key. + if logging: + config_details = "Loaded configuration: {}".format(config) + auth_token = config.get("api", {}).get("authentication_token", "") + config_details = config_details.replace(auth_token[5:], "*"*len(auth_token[5:])) + log.debug(config_details) + return config @@ -515,3 +528,4 @@ def store_account(authentication_token, filename="config.toml", location="user_c DEFAULT_CONFIG = _generate_config(DEFAULT_CONFIG_SPEC) +SESSION_CONFIG = load_config(logging=False) diff --git a/strawberryfields/engine.py b/strawberryfields/engine.py index 70f95d32f..e132142b3 100644 --- a/strawberryfields/engine.py +++ b/strawberryfields/engine.py @@ -589,6 +589,11 @@ def run_async(self, program: Program, *, compile_options=None, **kwargs) -> Job: # * compiled to a different chip family to the engine target # # In both cases, recompile the program to match the intended target. + self.log.debug( + "Compiling program for target %s with compile options %s", + self.target, + compile_options + ) program = program.compile(self.target, **compile_options) # update the run options if provided diff --git a/strawberryfields/logger.py b/strawberryfields/logger.py index 80762c900..7bc421537 100644 --- a/strawberryfields/logger.py +++ b/strawberryfields/logger.py @@ -49,6 +49,8 @@ import logging import sys +from strawberryfields.configuration import SESSION_CONFIG + def logging_handler_defined(logger): """Checks if the logger or any of its ancestors has a handler defined. @@ -111,27 +113,16 @@ def create_logger(name, level=None): # The root logger should pass all log message levels # to the handlers. logger.setLevel(logging.DEBUG) - - # Import load_config here to avoid a cyclic import - # (since the configuration module imports the logger). - from strawberryfields.configuration import load_config - - # The load_config function by default logs information. - # We need to turn this off here, since the logger has not - # been created yet. - default_config = load_config(verbose=False) - level = level or getattr(logging, default_config["logging"]["level"].upper()) + level = level or getattr(logging, SESSION_CONFIG["logging"]["level"].upper()) # Attach the standard output logger, # with the user defined logging level (defaults to INFO) output_handler.setLevel(level) logger.addHandler(output_handler) - if "logfile" in default_config["logging"]: + if "logfile" in SESSION_CONFIG["logging"]: # Create the file logger - file_handler = logging.FileHandler( - default_config["logging"]["logfile"], disable_existing_loggers=False - ) + file_handler = logging.FileHandler(SESSION_CONFIG["logging"]["logfile"]) file_handler.setFormatter(formatter) # file logger should display all log message levels From 86a35c520abdc99ca0f412a7992aadf52a0d1769 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 12 Apr 2020 23:24:41 +0930 Subject: [PATCH 5/7] blacking --- strawberryfields/configuration.py | 3 ++- strawberryfields/engine.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index d4de50ff4..14f91c29f 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -173,6 +173,7 @@ def load_config(filename="config.toml", logging=True, **kwargs): if logging: from strawberryfields.logger import create_logger + log = create_logger(__name__) if filepath is not None: @@ -211,7 +212,7 @@ def load_config(filename="config.toml", logging=True, **kwargs): if logging: config_details = "Loaded configuration: {}".format(config) auth_token = config.get("api", {}).get("authentication_token", "") - config_details = config_details.replace(auth_token[5:], "*"*len(auth_token[5:])) + config_details = config_details.replace(auth_token[5:], "*" * len(auth_token[5:])) log.debug(config_details) return config diff --git a/strawberryfields/engine.py b/strawberryfields/engine.py index e132142b3..67c59a17f 100644 --- a/strawberryfields/engine.py +++ b/strawberryfields/engine.py @@ -592,7 +592,7 @@ def run_async(self, program: Program, *, compile_options=None, **kwargs) -> Job: self.log.debug( "Compiling program for target %s with compile options %s", self.target, - compile_options + compile_options, ) program = program.compile(self.target, **compile_options) From 98c7a73f71adca4fdf93320c94c262da71f708ac Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 12 Apr 2020 23:27:06 +0930 Subject: [PATCH 6/7] add logging function --- strawberryfields/configuration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index 14f91c29f..44e10a673 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -172,7 +172,9 @@ def load_config(filename="config.toml", logging=True, **kwargs): filepath = find_config_file(filename=filename) if logging: - from strawberryfields.logger import create_logger + # We import the create_logger function only if logging + # has been requested, to avoid circular imports. + from strawberryfields.logger import create_logger #pylint: disable=import-outside-toplevel log = create_logger(__name__) From 84bc9d9d67f4641c6801c28ca4ee36331b8d48ff Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 14 Apr 2020 14:37:29 +0930 Subject: [PATCH 7/7] Apply suggestions from code review Co-Authored-By: antalszava Co-Authored-By: Theodor --- strawberryfields/configuration.py | 5 +++-- tests/frontend/test_configuration.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/strawberryfields/configuration.py b/strawberryfields/configuration.py index 44e10a673..371519c40 100644 --- a/strawberryfields/configuration.py +++ b/strawberryfields/configuration.py @@ -95,7 +95,8 @@ def _generate_config(config_spec, **kwargs): "hostname": "platform.strawberryfields.ai", "use_ssl": True, "port": 54, - } + }, + 'logging': {'level': 'info'} } Args: @@ -139,7 +140,7 @@ def _generate_config(config_spec, **kwargs): res[k] = v[1] elif isinstance(v, dict): - # config spec value is a dictionary of more options + # config spec value is a configuration section res[k] = _generate_config(v, **kwargs.get(k, {})) return res diff --git a/tests/frontend/test_configuration.py b/tests/frontend/test_configuration.py index fcd8c3d7a..ccc34555d 100644 --- a/tests/frontend/test_configuration.py +++ b/tests/frontend/test_configuration.py @@ -58,7 +58,7 @@ "use_ssl": False, "port": 56, }, - 'logging': {'level': 'info'} + "logging": {"level": "info"} } environment_variables = [