diff --git a/app/Dockerfile b/app/Dockerfile index 9ec44fbb..19b1e1d9 100644 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -28,7 +28,7 @@ COPY . /srv # Set the host to 0.0.0.0 to make the server available external # to the Docker container that it's running in. -ENV HOST=0.0.0.0 +ENV APP_HOST=0.0.0.0 # Install application dependencies. # https://python-poetry.org/docs/basic-usage/#installing-dependencies diff --git a/app/Makefile b/app/Makefile index 40e328d8..bb016627 100644 --- a/app/Makefile +++ b/app/Makefile @@ -16,9 +16,9 @@ DECODE_LOG := 2>&1 | python3 -u src/logging/util/decodelog.py # TODO - when CI gets hooked up, actually test this. ifdef CI DOCKER_EXEC_ARGS := -T -e CI -e PYTEST_ADDOPTS="--color=yes" - FLAKE8_FORMAT := '::warning file=src/%(path)s,line=%(row)d,col=%(col)d::%(path)s:%(row)d:%(col)d: %(code)s %(text)s' + FLAKE8_FORMAT := '::warning file=app/%(path)s,line=%(row)d,col=%(col)d::%(path)s:%(row)d:%(col)d: %(code)s %(text)s' MYPY_FLAGS := --no-pretty - MYPY_POSTPROC := | perl -pe "s/^(.+):(\d+):(\d+): error: (.*)/::warning file=src\/\1,line=\2,col=\3::\4/" + MYPY_POSTPROC := | perl -pe "s/^(.+):(\d+):(\d+): error: (.*)/::warning file=app\/\1,line=\2,col=\3::\4/" else FLAKE8_FORMAT := default endif @@ -33,7 +33,7 @@ else PY_RUN_CMD := docker-compose run $(DOCKER_EXEC_ARGS) --rm $(APP_NAME) poetry run endif -FLASK_CMD := $(PY_RUN_CMD) flask --env-file local.env +FLASK_CMD := $(PY_RUN_CMD) flask --app=src.__main__:main ################################################## # Local Development Environment Setup diff --git a/app/local.env b/app/local.env deleted file mode 100644 index 04c30a3a..00000000 --- a/app/local.env +++ /dev/null @@ -1,76 +0,0 @@ -# Local environment variables -# Used by docker-compose and it can be loaded -# by calling load_local_env_vars() from app/src/util/local.py - -ENVIRONMENT=local -PORT=8080 - -# Python path needs to be specified -# for pytest to find the implementation code -PYTHONPATH=/app/ - -# PY_RUN_APPROACH=python OR docker -# Set this in your environment -# to modify how the Makefile runs -# commands that can run in or out -# of the Docker container - defaults to outside - -FLASK_APP=src.app:create_app - -############################ -# Logging -############################ - -# Can be "human-readable" OR "json" -LOG_FORMAT=human-readable - -# Set log level. Valid values are DEBUG, INFO, WARNING, CRITICAL -# LOG_LEVEL=INFO - -# Enable/disable audit logging. Valid values are TRUE, FALSE -LOG_ENABLE_AUDIT=FALSE - -# Change the message length for the human readable formatter -# LOG_HUMAN_READABLE_FORMATTER__MESSAGE_WIDTH=50 - -############################ -# Authentication -############################ -# The auth token used by the local endpoints -API_AUTH_TOKEN=LOCAL_AUTH_12345678 - -############################ -# DB Environment Variables -############################ -POSTGRES_DB=main-db -POSTGRES_USER=local_db_user -POSTGRES_PASSWORD=secret123 - -# Note that this is only used when running -# commands outside of the Docker container -# and is overriden when running inside by the -# value specified in the docker-compose file -DB_HOST=localhost - -# When an error occurs with a SQL query, -# whether or not to hide the parameters which -# could contain sensitive information. -HIDE_SQL_PARAMETER_LOGS=TRUE - -############################ -# AWS Defaults -############################ -# For these secret access keys, don't -# add them to this file to avoid mistakenly -# committing them. Set these in your shell -# by doing `export AWS_ACCESS_KEY_ID=whatever` -AWS_ACCESS_KEY_ID=DO_NOT_SET_HERE -AWS_SECRET_ACCESS_KEY=DO_NOT_SET_HERE -# These next two are commented out as we -# don't have configuration for individuals -# to use these at the moment and boto3 -# tries to use them first before the keys above. -#AWS_SECURITY_TOKEN=DO_NOT_SET_HERE -#AWS_SESSION_TOKEN=DO_NOT_SET_HERE - -AWS_DEFAULT_REGION=us-east-1 diff --git a/app/src/__main__.py b/app/src/__main__.py index 725a6f2d..8e21e13d 100644 --- a/app/src/__main__.py +++ b/app/src/__main__.py @@ -6,39 +6,59 @@ # https://docs.python.org/3/library/__main__.html import logging +import os + +from flask import Flask import src.app +import src.config +import src.config.load import src.logging -from src.app_config import AppConfig -from src.util.local import load_local_env_vars logger = logging.getLogger(__package__) -def main() -> None: - load_local_env_vars() - app_config = AppConfig() +def load_config() -> src.config.RootConfig: + return src.config.load.load( + environment_name=os.getenv("ENVIRONMENT", "local"), environ=os.environ + ) + - app = src.app.create_app() +def main() -> Flask: + config = load_config() + app = src.app.create_app(config) + logger.info("loaded configuration", extra={"config": config}) - environment = app_config.environment + environment = config.app.environment # When running in a container, the host needs to be set to 0.0.0.0 so that the app can be # accessed from outside the container. See Dockerfile - host = app_config.host - port = app_config.port + host = config.app.host + port = config.app.port + + if __name__ != "__main__": + return app logger.info( "Running API Application", extra={"environment": environment, "host": host, "port": port} ) - if app_config.environment == "local": + if config.app.environment == "local": # If python files are changed, the app will auto-reload # Note this doesn't have the OpenAPI yaml file configured at the moment - app.run(host=host, port=port, use_reloader=True, reloader_type="stat") + app.run( + host=host, + port=port, + debug=True, # nosec B201 + load_dotenv=False, + use_reloader=True, + reloader_type="stat", + ) else: # Don't enable the reloader if non-local - app.run(host=host, port=port) + app.run(host=host, port=port, load_dotenv=False) + + return app main() diff --git a/app/src/adapters/db/clients/postgres_client.py b/app/src/adapters/db/clients/postgres_client.py index d31cdea3..fd5152f5 100644 --- a/app/src/adapters/db/clients/postgres_client.py +++ b/app/src/adapters/db/clients/postgres_client.py @@ -1,5 +1,4 @@ import logging -import os import urllib.parse from typing import Any @@ -8,7 +7,7 @@ import sqlalchemy.pool as pool from src.adapters.db.client import DBClient -from src.adapters.db.clients.postgres_config import PostgresDBConfig, get_db_config +from src.adapters.db.clients.postgres_config import PostgresDBConfig logger = logging.getLogger(__name__) @@ -19,9 +18,7 @@ class PostgresDBClient(DBClient): as configured by parameters passed in from the db_config """ - def __init__(self, db_config: PostgresDBConfig | None = None) -> None: - if not db_config: - db_config = get_db_config() + def __init__(self, db_config: PostgresDBConfig) -> None: self._engine = self._configure_engine(db_config) if db_config.check_connection_on_init: @@ -80,23 +77,15 @@ def check_db_connection(self) -> None: def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]: - connect_args = {} - environment = os.getenv("ENVIRONMENT") - if not environment: - raise Exception("ENVIRONMENT is not set") - - if environment != "local": - connect_args["sslmode"] = "require" - return dict( host=db_config.host, dbname=db_config.name, user=db_config.username, - password=db_config.password, + password=db_config.password.get_secret_value() if db_config.password else None, port=db_config.port, options=f"-c search_path={db_config.db_schema}", connect_timeout=3, - **connect_args, + sslmode=db_config.sslmode, ) @@ -109,7 +98,7 @@ def make_connection_uri(config: PostgresDBConfig) -> str: host = config.host db_name = config.name username = config.username - password = urllib.parse.quote(config.password) if config.password else None + password = urllib.parse.quote(config.password.get_secret_value()) if config.password else None schema = config.db_schema port = config.port @@ -122,10 +111,7 @@ def make_connection_uri(config: PostgresDBConfig) -> str: elif password: netloc_parts.append(f":{password}@") - netloc_parts.append(host) - - if port: - netloc_parts.append(f":{port}") + netloc_parts.append(f"{host}:{port}") netloc = "".join(netloc_parts) diff --git a/app/src/adapters/db/clients/postgres_config.py b/app/src/adapters/db/clients/postgres_config.py index d78b2a10..8fca8816 100644 --- a/app/src/adapters/db/clients/postgres_config.py +++ b/app/src/adapters/db/clients/postgres_config.py @@ -1,6 +1,7 @@ import logging from typing import Optional +import pydantic from pydantic import Field from src.util.env_config import PydanticBaseEnvConfig @@ -13,26 +14,8 @@ class PostgresDBConfig(PydanticBaseEnvConfig): host: str = Field("localhost", env="DB_HOST") name: str = Field("main-db", env="POSTGRES_DB") username: str = Field("local_db_user", env="POSTGRES_USER") - password: Optional[str] = Field(..., env="POSTGRES_PASSWORD") + password: Optional[pydantic.types.SecretStr] = Field(None, env="POSTGRES_PASSWORD") db_schema: str = Field("public", env="DB_SCHEMA") - port: str = Field("5432", env="DB_PORT") + port: int = Field(5432, env="DB_PORT") hide_sql_parameter_logs: bool = Field(True, env="HIDE_SQL_PARAMETER_LOGS") - - -def get_db_config() -> PostgresDBConfig: - db_config = PostgresDBConfig() - - logger.info( - "Constructed database configuration", - extra={ - "host": db_config.host, - "dbname": db_config.name, - "username": db_config.username, - "password": "***" if db_config.password is not None else None, - "db_schema": db_config.db_schema, - "port": db_config.port, - "hide_sql_parameter_logs": db_config.hide_sql_parameter_logs, - }, - ) - - return db_config + sslmode: str = "require" diff --git a/app/src/app.py b/app/src/app.py index e8857183..6aab7e1c 100644 --- a/app/src/app.py +++ b/app/src/app.py @@ -14,17 +14,18 @@ from src.api.schemas import response_schema from src.api.users import user_blueprint from src.auth.api_key_auth import User, get_app_security_scheme +from src.config import RootConfig logger = logging.getLogger(__name__) -def create_app() -> APIFlask: +def create_app(config: RootConfig) -> APIFlask: app = APIFlask(__name__) - src.logging.init(__package__) + src.logging.init(__package__, config.logging) flask_logger.init_app(logging.root, app) - db_client = db.PostgresDBClient() + db_client = db.PostgresDBClient(config.database) flask_db.register_db_client(db_client, app) configure_app(app) diff --git a/app/src/app_config.py b/app/src/app_config.py index 44159955..f9e1ec8a 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -2,7 +2,7 @@ class AppConfig(PydanticBaseEnvConfig): - environment: str + environment: str = "unknown" # Set HOST to 127.0.0.1 by default to avoid other machines on the network # from accessing the application. This is especially important if you are diff --git a/app/src/config/__init__.py b/app/src/config/__init__.py new file mode 100644 index 00000000..c4138c07 --- /dev/null +++ b/app/src/config/__init__.py @@ -0,0 +1,14 @@ +# +# Multi-environment configuration expressed in Python. +# + +from src.adapters.db.clients.postgres_config import PostgresDBConfig +from src.app_config import AppConfig +from src.logging.config import LoggingConfig +from src.util.env_config import PydanticBaseEnvConfig + + +class RootConfig(PydanticBaseEnvConfig): + app: AppConfig + database: PostgresDBConfig + logging: LoggingConfig diff --git a/app/src/config/default.py b/app/src/config/default.py new file mode 100644 index 00000000..3878b8d2 --- /dev/null +++ b/app/src/config/default.py @@ -0,0 +1,26 @@ +# +# Default configuration. +# +# This is the base layer of configuration. It is used if an environment does not override a value. +# Each environment may override individual values (see local.py, dev.py, prod.py, etc.). +# +# This configuration is also used when running tests (individual test cases may have code to use +# different configuration). +# + +from src.adapters.db.clients.postgres_config import PostgresDBConfig +from src.app_config import AppConfig +from src.config import RootConfig +from src.logging.config import LoggingConfig, LoggingFormat + + +def default_config() -> RootConfig: + return RootConfig( + app=AppConfig(), + database=PostgresDBConfig(), + logging=LoggingConfig( + format=LoggingFormat.json, + level="INFO", + enable_audit=True, + ), + ) diff --git a/app/src/config/env/dev.py b/app/src/config/env/dev.py new file mode 100644 index 00000000..1bece4be --- /dev/null +++ b/app/src/config/env/dev.py @@ -0,0 +1,9 @@ +# +# Configuration for dev environments. +# +# This file only contains overrides (differences) from the defaults in default.py. +# + +from .. import default + +config = default.default_config() diff --git a/app/src/config/env/local.py b/app/src/config/env/local.py new file mode 100644 index 00000000..ce3c82fb --- /dev/null +++ b/app/src/config/env/local.py @@ -0,0 +1,19 @@ +# +# Configuration for local development environments. +# +# This file only contains overrides (differences) from the defaults in default.py. +# + +import pydantic.types + +from src.logging.config import LoggingFormat + +from .. import default + +config = default.default_config() + +config.database.password = pydantic.types.SecretStr("secret123") +config.database.hide_sql_parameter_logs = False +config.database.sslmode = "prefer" +config.logging.format = LoggingFormat.human_readable +config.logging.enable_audit = False diff --git a/app/src/config/env/local_override_example.py b/app/src/config/env/local_override_example.py new file mode 100644 index 00000000..256b2ba1 --- /dev/null +++ b/app/src/config/env/local_override_example.py @@ -0,0 +1,12 @@ +# +# Local overrides for local development environments. +# +# This file allows overrides to be set that you never want to be committed. +# +# To use this, copy to `local_override.py` and edit below. +# + +from .local import config + +# Example override: +config.logging.enable_audit = True diff --git a/app/src/config/env/prod.py b/app/src/config/env/prod.py new file mode 100644 index 00000000..6dd8e827 --- /dev/null +++ b/app/src/config/env/prod.py @@ -0,0 +1,9 @@ +# +# Configuration for prod environment. +# +# This file only contains overrides (differences) from the defaults in default.py. +# + +from .. import default + +config = default.default_config() diff --git a/app/src/config/env/staging.py b/app/src/config/env/staging.py new file mode 100644 index 00000000..44db8eb4 --- /dev/null +++ b/app/src/config/env/staging.py @@ -0,0 +1,9 @@ +# +# Configuration for staging environment. +# +# This file only contains overrides (differences) from the defaults in default.py. +# + +from .. import default + +config = default.default_config() diff --git a/app/src/config/load.py b/app/src/config/load.py new file mode 100644 index 00000000..e0b1e2af --- /dev/null +++ b/app/src/config/load.py @@ -0,0 +1,43 @@ +# +# Multi-environment configuration expressed in Python. +# + +import importlib +import logging +import pathlib +from typing import Mapping, Optional + +from src.config import RootConfig + +logger = logging.getLogger(__name__) + + +def load(environment_name: str, environ: Optional[Mapping[str, str]] = None) -> RootConfig: + """Load the configuration for the given environment name.""" + logger.debug("loading configuration", extra={"environment": environment_name}) + module = importlib.import_module(name=".env." + environment_name, package=__package__) + config = module.config.copy(deep=True) + + if environment_name == "local": + # Load overrides from local_override.py in the same directory, if it exists. + try: + module = importlib.import_module(name=".env.local_override", package=__package__) + config = module.config.copy(deep=True) + except ImportError: + pass + + if environ: + config.override_from_environment(environ) + config.app.environment = environment_name + + return config + + +def load_all() -> dict[str, RootConfig]: + """Load all environment configurations, to ensure they are valid. Used in tests.""" + directory = pathlib.Path(__file__).parent / "env" + return { + item.stem: load(str(item.stem)) + for item in directory.glob("*.py") + if "override" not in item.stem + } diff --git a/app/src/db/migrations/env.py b/app/src/db/migrations/env.py index a45a584f..0e53e37f 100644 --- a/app/src/db/migrations/env.py +++ b/app/src/db/migrations/env.py @@ -1,4 +1,5 @@ import logging +import os import sys from typing import Any @@ -10,14 +11,9 @@ # See database migrations section in `./database/database-migrations.md` for details about running migrations. sys.path.insert(0, ".") # noqa: E402 -# Load env vars before anything further -from src.util.local import load_local_env_vars # noqa: E402 isort:skip - -load_local_env_vars() - from src.adapters.db.clients.postgres_client import make_connection_uri # noqa: E402 isort:skip -from src.adapters.db.clients.postgres_config import get_db_config # noqa: E402 isort:skip from src.db.models import metadata # noqa: E402 isort:skip +import src.config.load # noqa: E402 isort:skip import src.logging # noqa: E402 isort:skip # this is the Alembic Config object, which provides @@ -26,11 +22,15 @@ logger = logging.getLogger("migrations") +root_config = src.config.load.load( + environment_name=os.getenv("ENVIRONMENT", "local"), environ=os.environ +) + # Initialize logging -with src.logging.init("migrations"): +with src.logging.init("migrations", root_config.logging): if not config.get_main_option("sqlalchemy.url"): - uri = make_connection_uri(get_db_config()) + uri = make_connection_uri(root_config.database) # Escape percentage signs in the URI. # https://alembic.sqlalchemy.org/en/latest/api/config.html#alembic.config.Config.set_main_option diff --git a/app/src/logging/__init__.py b/app/src/logging/__init__.py index efe17b42..5f8353a8 100644 --- a/app/src/logging/__init__.py +++ b/app/src/logging/__init__.py @@ -28,5 +28,5 @@ import src.logging.config as config -def init(program_name: str) -> config.LoggingContext: - return config.LoggingContext(program_name) +def init(program_name: str, logging_config: config.LoggingConfig) -> config.LoggingContext: + return config.LoggingContext(program_name, logging_config) diff --git a/app/src/logging/config.py b/app/src/logging/config.py index 7108dc93..36dc502c 100644 --- a/app/src/logging/config.py +++ b/app/src/logging/config.py @@ -1,3 +1,4 @@ +import enum import logging import os import platform @@ -5,6 +6,8 @@ import sys from typing import Any, ContextManager, cast +import pydantic + import src.logging.audit import src.logging.formatters as formatters import src.logging.pii as pii @@ -15,15 +18,27 @@ _original_argv = tuple(sys.argv) +class LoggingFormat(str, enum.Enum): + json = "json" + human_readable = "human_readable" + + class HumanReadableFormatterConfig(PydanticBaseEnvConfig): message_width: int = formatters.HUMAN_READABLE_FORMATTER_DEFAULT_MESSAGE_WIDTH class LoggingConfig(PydanticBaseEnvConfig): - format = "json" - level = "INFO" - enable_audit = True - human_readable_formatter = HumanReadableFormatterConfig() + format: LoggingFormat = LoggingFormat.json + level: str = "INFO" + enable_audit: bool = True + human_readable_formatter: HumanReadableFormatterConfig = HumanReadableFormatterConfig() + + @pydantic.validator("level") + def valid_level(cls, v: str) -> str: # noqa: B902 + value = logging.getLevelName(v) + if not isinstance(value, int): + raise ValueError("invalid logging level %s" % v) + return v class Config: env_prefix = "log_" @@ -58,8 +73,8 @@ class LoggingContext(ContextManager[None]): and calling this multiple times before exit would result in duplicate logs. """ - def __init__(self, program_name: str) -> None: - self._configure_logging() + def __init__(self, program_name: str, config: LoggingConfig) -> None: + self._configure_logging(config) log_program_info(program_name) def __enter__(self) -> None: @@ -72,14 +87,13 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # of those tests. logging.root.removeHandler(self.console_handler) - def _configure_logging(self) -> None: + def _configure_logging(self, config: LoggingConfig) -> None: """Configure logging for the application. Configures the root module logger to log to stdout. Adds a PII mask filter to the root logger. Also configures log levels third party packages. """ - config = LoggingConfig() # Loggers can be configured using config functions defined # in logging.config or by directly making calls to the main API @@ -112,7 +126,7 @@ def get_formatter(config: LoggingConfig) -> logging.Formatter: The formatter is determined by the environment variable LOG_FORMAT. If the environment variable is not set, the JSON formatter is used by default. """ - if config.format == "human-readable": + if config.format == LoggingFormat.human_readable: return get_human_readable_formatter(config.human_readable_formatter) return formatters.JsonFormatter() diff --git a/app/src/util/env_config.py b/app/src/util/env_config.py index cd9c4dc3..f225effc 100644 --- a/app/src/util/env_config.py +++ b/app/src/util/env_config.py @@ -1,16 +1,37 @@ -import os +import logging +from typing import Mapping -from pydantic import BaseSettings +from pydantic import BaseModel -import src +logger = logging.getLogger(__name__) -env_file = os.path.join( - os.path.dirname(os.path.dirname(src.__file__)), - "config", - "%s.env" % os.getenv("ENVIRONMENT", "local"), -) +class PydanticBaseEnvConfig(BaseModel): + """Base class for application configuration. + + Similar to Pydantic's BaseSettings class, but we implement our own method to override from the + environment so that it can be run later, after an instance was constructed.""" + + meta_overridden: list = [] -class PydanticBaseEnvConfig(BaseSettings): class Config: - env_file = env_file + validate_assignment = True + + def override_from_environment(self, environ: Mapping[str, str], prefix: str = "") -> None: + """Recursively override field values from the given environment variable mapping.""" + for name, field in self.__fields__.items(): + value = getattr(self, name) + if isinstance(value, BaseModel): + # Nested models must be instances of this class too. + if not isinstance(value, PydanticBaseEnvConfig): + raise TypeError("nested models must be instances of PydanticBaseEnvConfig") + value.override_from_environment(environ, prefix=name + "_") + continue + + env_var_name = field.field_info.extra.get("env", prefix + name) + for key in (env_var_name, env_var_name.lower(), env_var_name.upper()): + if key in environ: + # logging.debug("override from environment", extra={"key": key}) + setattr(self, field.name, environ[key]) + self.meta_overridden.append((field.name, key)) + break diff --git a/app/src/util/local.py b/app/src/util/local.py deleted file mode 100644 index 0f1e0a29..00000000 --- a/app/src/util/local.py +++ /dev/null @@ -1,22 +0,0 @@ -import os - -from dotenv import load_dotenv - - -def load_local_env_vars(env_file: str = "local.env") -> None: - """ - Load environment variables from the local.env so - that they can be fetched with `os.getenv()` or with - other utils that pull env vars. - - https://pypi.org/project/python-dotenv/ - - NOTE: any existing env vars will not be overriden by this - """ - environment = os.getenv("ENVIRONMENT", None) - - # If the environment is explicitly local or undefined - # we'll use the dotenv file, otherwise we'll skip - # Should never run if not local development - if environment is None or environment == "local": - load_dotenv(env_file) diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 9a9ed6dc..2078353d 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,47 +1,25 @@ import logging +import os +import uuid import _pytest.monkeypatch import boto3 -import flask import flask.testing import moto import pytest import src.adapters.db as db import src.app as app_entry +import src.config +import src.config.load import tests.src.db.models.factories as factories +from src.adapters.db.clients.postgres_config import PostgresDBConfig from src.db import models -from src.util.local import load_local_env_vars from tests.lib import db_testing logger = logging.getLogger(__name__) -@pytest.fixture(scope="session", autouse=True) -def env_vars(): - """ - Default environment variables for tests to be - based on the local.env file. These get set once - before all tests run. As "session" is the highest - scope, this will run before any other explicit fixtures - in a test. - - See: https://docs.pytest.org/en/6.2.x/fixture.html#autouse-order - - To set a different environment variable for a test, - use the monkeypatch fixture, for example: - - ```py - def test_example(monkeypatch): - monkeypatch.setenv("LOG_LEVEL", "debug") - ``` - - Several monkeypatch fixtures exists below for different - scope levels. - """ - load_local_env_vars() - - #################### # Test DB session #################### @@ -69,7 +47,26 @@ def monkeypatch_module(): @pytest.fixture(scope="session") -def db_client(monkeypatch_session) -> db.DBClient: +def config() -> src.config.RootConfig: + schema_name = f"test_schema_{uuid.uuid4().int}" + + config = src.config.load.load("local") + config.database.db_schema = schema_name + + # Allow host to be overridden when running in docker-compose. + if "DB_HOST" in os.environ: + config.database.host = os.environ["DB_HOST"] + + return config + + +@pytest.fixture(scope="session") +def db_config(config) -> PostgresDBConfig: + return config.database + + +@pytest.fixture(scope="session") +def db_client(monkeypatch_session, db_config) -> db.DBClient: """ Creates an isolated database for the test session. @@ -79,7 +76,7 @@ def db_client(monkeypatch_session) -> db.DBClient: after the test suite session completes. """ - with db_testing.create_isolated_db(monkeypatch_session) as db_client: + with db_testing.create_isolated_db(db_config) as db_client: models.metadata.create_all(bind=db_client.get_connection()) yield db_client @@ -114,8 +111,8 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session: # Make app session scoped so the database connection pool is only created once # for the test session. This speeds up the tests. @pytest.fixture(scope="session") -def app(db_client) -> flask.Flask: - return app_entry.create_app() +def app(db_client, config) -> flask.Flask: + return app_entry.create_app(config) @pytest.fixture diff --git a/app/tests/lib/db_testing.py b/app/tests/lib/db_testing.py index 13bd25c6..b102623e 100644 --- a/app/tests/lib/db_testing.py +++ b/app/tests/lib/db_testing.py @@ -1,44 +1,34 @@ """Helper functions for testing database code.""" import contextlib import logging -import uuid import src.adapters.db as db -from src.adapters.db.clients.postgres_config import get_db_config +from src.adapters.db.clients.postgres_config import PostgresDBConfig logger = logging.getLogger(__name__) @contextlib.contextmanager -def create_isolated_db(monkeypatch) -> db.DBClient: +def create_isolated_db(db_config: PostgresDBConfig) -> db.DBClient: """ Creates a temporary PostgreSQL schema and creates a database engine that connects to that schema. Drops the schema after the context manager exits. """ - schema_name = f"test_schema_{uuid.uuid4().int}" - monkeypatch.setenv("DB_SCHEMA", schema_name) - monkeypatch.setenv("POSTGRES_DB", "main-db") - monkeypatch.setenv("POSTGRES_USER", "local_db_user") - monkeypatch.setenv("POSTGRES_PASSWORD", "secret123") - monkeypatch.setenv("ENVIRONMENT", "local") - monkeypatch.setenv("DB_CHECK_CONNECTION_ON_INIT", "False") # To improve test performance, don't check the database connection # when initializing the DB client. - db_client = db.PostgresDBClient() + db_client = db.PostgresDBClient(db_config) with db_client.get_connection() as conn: - _create_schema(conn, schema_name) + _create_schema(conn, db_config.db_schema, db_config.username) try: yield db_client finally: - _drop_schema(conn, schema_name) + _drop_schema(conn, db_config.db_schema) -def _create_schema(conn: db.Connection, schema_name: str): +def _create_schema(conn: db.Connection, schema_name: str, db_test_user: str): """Create a database schema.""" - db_test_user = get_db_config().username - conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name} AUTHORIZATION {db_test_user};") logger.info("create schema %s", schema_name) diff --git a/app/tests/src/adapters/db/clients/test_postgres_client.py b/app/tests/src/adapters/db/clients/test_postgres_client.py index f04c0e97..5735aa2a 100644 --- a/app/tests/src/adapters/db/clients/test_postgres_client.py +++ b/app/tests/src/adapters/db/clients/test_postgres_client.py @@ -2,13 +2,14 @@ from itertools import product import pytest +from pydantic.types import SecretStr from src.adapters.db.clients.postgres_client import ( get_connection_parameters, make_connection_uri, verify_ssl, ) -from src.adapters.db.clients.postgres_config import PostgresDBConfig, get_db_config +from src.adapters.db.clients.postgres_config import PostgresDBConfig class DummyConnectionInfo: @@ -47,16 +48,16 @@ def test_verify_ssl_not_in_use(caplog): "username_password_port,expected", zip( # Test all combinations of username, password, and port - product(["testuser", ""], ["testpass", None], ["5432", ""]), + product(["testuser", ""], ["testpass", None], [5432, 5433]), [ "postgresql://testuser:testpass@localhost:5432/dbname?options=-csearch_path=public", - "postgresql://testuser:testpass@localhost/dbname?options=-csearch_path=public", + "postgresql://testuser:testpass@localhost:5433/dbname?options=-csearch_path=public", "postgresql://testuser@localhost:5432/dbname?options=-csearch_path=public", - "postgresql://testuser@localhost/dbname?options=-csearch_path=public", + "postgresql://testuser@localhost:5433/dbname?options=-csearch_path=public", "postgresql://:testpass@localhost:5432/dbname?options=-csearch_path=public", - "postgresql://:testpass@localhost/dbname?options=-csearch_path=public", + "postgresql://:testpass@localhost:5433/dbname?options=-csearch_path=public", "postgresql://localhost:5432/dbname?options=-csearch_path=public", - "postgresql://localhost/dbname?options=-csearch_path=public", + "postgresql://localhost:5433/dbname?options=-csearch_path=public", ], ), ) @@ -77,23 +78,15 @@ def test_make_connection_uri(username_password_port, expected): ) -def test_get_connection_parameters_require_environment(monkeypatch: pytest.MonkeyPatch): - monkeypatch.delenv("ENVIRONMENT") - db_config = get_db_config() - with pytest.raises(Exception, match="ENVIRONMENT is not set"): - get_connection_parameters(db_config) - - -def test_get_connection_parameters(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("ENVIRONMENT", "production") - db_config = get_db_config() +def test_get_connection_parameters(): + db_config = PostgresDBConfig(host="test1", password=SecretStr("test_password_123")) conn_params = get_connection_parameters(db_config) assert conn_params == dict( - host=db_config.host, + host="test1", dbname=db_config.name, user=db_config.username, - password=db_config.password, + password="test_password_123", port=db_config.port, options=f"-c search_path={db_config.db_schema}", connect_timeout=3, diff --git a/app/tests/src/adapters/db/test_db.py b/app/tests/src/adapters/db/test_db.py index f3ddc99a..9b602727 100644 --- a/app/tests/src/adapters/db/test_db.py +++ b/app/tests/src/adapters/db/test_db.py @@ -5,19 +5,20 @@ def test_db_connection(db_client): - db_client = db.PostgresDBClient() + # db_client = db.PostgresDBClient() with db_client.get_connection() as conn: assert conn.scalar(text("SELECT 1")) == 1 -def test_check_db_connection(caplog, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("DB_CHECK_CONNECTION_ON_INIT", "True") - db.PostgresDBClient() +def test_check_db_connection(caplog, monkeypatch: pytest.MonkeyPatch, db_config): + db_config = db_config.copy() + db_config.check_connection_on_init = True + db.PostgresDBClient(db_config) assert "database connection is not using SSL" in caplog.messages -def test_get_session(): - db_client = db.PostgresDBClient() +def test_get_session(db_client): + # db_client = db.PostgresDBClient() with db_client.get_session() as session: with session.begin(): assert session.scalar(text("SELECT 1")) == 1 diff --git a/app/tests/src/adapters/db/test_flask_db.py b/app/tests/src/adapters/db/test_flask_db.py index c95ee12d..5810ada7 100644 --- a/app/tests/src/adapters/db/test_flask_db.py +++ b/app/tests/src/adapters/db/test_flask_db.py @@ -9,9 +9,9 @@ # Define an isolated example Flask app fixture specific to this test module # to avoid dependencies on any project-specific fixtures in conftest.py @pytest.fixture -def example_app() -> Flask: +def example_app(config) -> Flask: app = Flask(__name__) - db_client = db.PostgresDBClient() + db_client = db.PostgresDBClient(config.database) flask_db.register_db_client(db_client, app) return app @@ -37,8 +37,8 @@ def hello(db_session: db.Session): assert response.get_json() == {"data": "hello, world"} -def test_with_db_session_not_default_name(example_app: Flask): - db_client = db.PostgresDBClient() +def test_with_db_session_not_default_name(example_app: Flask, db_config): + db_client = db.PostgresDBClient(db_config) flask_db.register_db_client(db_client, example_app, client_name="something_else") @example_app.route("/hello") diff --git a/app/tests/src/config/__init__.py b/app/tests/src/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/tests/src/config/test_config.py b/app/tests/src/config/test_config.py new file mode 100644 index 00000000..a77d973f --- /dev/null +++ b/app/tests/src/config/test_config.py @@ -0,0 +1,54 @@ +# +# Tests for src.config. +# + +import pydantic.error_wrappers +import pytest + +from src import config +from src.config import load + + +def test_load_with_override(): + conf = load.load( + environment_name="local", environ={"app_host": "test.host", "app_port": 999, "port": 888} + ) + + assert isinstance(conf, config.RootConfig) + assert conf.app.host == "test.host" + assert conf.app.port == 999 + + +def test_load_with_invalid_override(): + with pytest.raises(pydantic.error_wrappers.ValidationError): + load.load(environment_name="local", environ={"app_port": "not_a_number"}) + + +def test_load(): + conf = load.load(environment_name="local") + + assert isinstance(conf, config.RootConfig) + assert conf.app.host == "127.0.0.1" + + +def test_load_invalid_environment_name(): + with pytest.raises(ModuleNotFoundError): + load.load(environment_name="does_not_exist") + + +def test_load_all(): + """This test is important to confirm that all configurations are valid - otherwise we would + not know until runtime in the appropriate environment.""" + + all_configs = load.load_all() + + # We expect at least these configs to exist - there may be others too. + assert all_configs.keys() >= {"local", "dev", "prod"} + + for value in all_configs.values(): + assert isinstance(value, config.RootConfig) + + # Make sure they are all different objects (prevent bugs where they overwrite by accidental + # sharing). + ids = {id(value) for value in all_configs.values()} + assert len(ids) == len(all_configs) diff --git a/app/tests/src/db/test_migrations.py b/app/tests/src/db/test_migrations.py index 160a8766..6bb56df1 100644 --- a/app/tests/src/db/test_migrations.py +++ b/app/tests/src/db/test_migrations.py @@ -1,4 +1,5 @@ import logging # noqa: B1 +import uuid import pytest from alembic import command @@ -7,12 +8,20 @@ from alembic.util.exc import CommandError import src.adapters.db as db +from src.adapters.db.clients.postgres_client import make_connection_uri from src.db.migrations.run import alembic_cfg from tests.lib import db_testing @pytest.fixture -def empty_schema(monkeypatch) -> db.DBClient: +def empty_db_config(db_config): + empty_db_config = db_config.copy() + empty_db_config.db_schema = f"test_schema_{uuid.uuid4().int}" + return empty_db_config + + +@pytest.fixture +def empty_schema(empty_db_config) -> db.DBClient: """ Create a test schema, if it doesn't already exist, and drop it after the test completes. @@ -20,7 +29,7 @@ def empty_schema(monkeypatch) -> db.DBClient: This is similar to what the db_client fixture does but does not create any tables in the schema. """ - with db_testing.create_isolated_db(monkeypatch) as db_client: + with db_testing.create_isolated_db(empty_db_config) as db_client: yield db_client @@ -47,8 +56,14 @@ def test_only_single_head_revision_in_migrations(): ) -def test_db_setup_via_alembic_migration(empty_schema, caplog: pytest.LogCaptureFixture): +def test_db_setup_via_alembic_migration( + empty_db_config, empty_schema, caplog: pytest.LogCaptureFixture +): caplog.set_level(logging.INFO) # noqa: B1 + + uri = make_connection_uri(empty_db_config) + alembic_cfg.set_main_option("sqlalchemy.url", uri.replace("%", "%%")) + command.upgrade(alembic_cfg, "head") # Verify the migration ran by checking the logs assert "Running upgrade" in caplog.text diff --git a/app/tests/src/logging/test_logging.py b/app/tests/src/logging/test_logging.py index 6991ce37..acf89b50 100644 --- a/app/tests/src/logging/test_logging.py +++ b/app/tests/src/logging/test_logging.py @@ -5,29 +5,30 @@ import src.logging import src.logging.formatters as formatters +from src.logging.config import LoggingConfig, LoggingFormat from tests.lib.assertions import assert_dict_contains @pytest.fixture def init_test_logger(caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch): caplog.set_level(logging.DEBUG) - monkeypatch.setenv("LOG_FORMAT", "human-readable") - with src.logging.init("test_logging"): + logging_config = LoggingConfig(format=LoggingFormat.human_readable) + with src.logging.init("test_logging", logging_config): yield @pytest.mark.parametrize( "log_format,expected_formatter", [ - ("human-readable", formatters.HumanReadableFormatter), + ("human_readable", formatters.HumanReadableFormatter), ("json", formatters.JsonFormatter), ], ) def test_init(caplog: pytest.LogCaptureFixture, monkeypatch, log_format, expected_formatter): caplog.set_level(logging.DEBUG) - monkeypatch.setenv("LOG_FORMAT", log_format) + logging_config = LoggingConfig(format=log_format, enable_audit=False) - with src.logging.init("test_logging"): + with src.logging.init("test_logging", logging_config): records = caplog.records assert len(records) == 2 diff --git a/docker-compose.yml b/docker-compose.yml index 13bcc57e..be115831 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,8 +11,10 @@ services: container_name: main-db command: postgres -c "log_lock_waits=on" -N 1000 -c "fsync=off" - # Load environment variables for local development. - env_file: ./app/local.env + environment: + POSTGRES_DB: main-db + POSTGRES_USER: local_db_user + POSTGRES_PASSWORD: secret123 ports: - "5432:5432" volumes: @@ -27,11 +29,8 @@ services: container_name: main-app - # Load environment variables for local development - env_file: ./app/local.env - # NOTE: These values take precedence if the same value is specified in the env_file. environment: - # The env_file defines DB_HOST=localhost for accessing a non-dockerized database. + # The code defines DB_HOST=localhost for accessing a non-dockerized database. # In the docker-compose, we tell the app to use the dockerized database service - DB_HOST=main-db ports: