diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 17d88b8a3..a1db24037 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,7 +34,17 @@ jobs: - name: Test with pytest id: test run: | - coverage run --source=merlion/ -L -m pytest -v + # A BLAS bug causes high-dim multivar Bayesian LR test to segfault in 3.6. Run the test first to avoid. + if [[ $PYTHON_VERSION == 3.6 ]]; then + python -m pytest -v tests/change_point/test_conj_prior.py + coverage run --source=merlion/ -L -m pytest -v --ignore tests/change_point/test_conj_prior.py + # MoE test seems to hang in 3.7. Run the test first to avoid. + elif [[ $PYTHON_VERSION == 3.7 ]]; then + python -m pytest -v tests/forecast/test_MoE_forecast_ensemble.py + coverage run --source=merlion/ -L -m pytest -v --ignore tests/forecast/test_MoE_forecast_ensemble.py + else + coverage run --source=merlion/ -L -m pytest -v + fi # Obtain code coverage from coverage report coverage report @@ -56,6 +66,8 @@ jobs: COLOR=red fi echo "##[set-output name=color;]${COLOR}" + env: + PYTHON_VERSION: ${{ matrix.python-version }} - name: Create coverage badge if: ${{ github.ref == 'refs/heads/main' && matrix.python-version == '3.8' }} diff --git a/README.md b/README.md index ccd0f2794..af52e8e25 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,10 @@ ## Introduction Merlion is a Python library for time series intelligence. It provides an end-to-end machine learning framework that includes loading and transforming data, building and training models, post-processing model outputs, and evaluating -model performance. It supports various time series learning tasks, including forecasting and anomaly detection for both -univariate and multivariate time series. This library aims to provide engineers and researchers a one-stop solution to -rapidly develop models for their specific time series needs, and benchmark them across multiple time series datasets. +model performance. It supports various time series learning tasks, including forecasting, anomaly detection, +and change point detection for both univariate and multivariate time series. This library aims to provide engineers and +researchers a one-stop solution to rapidly develop models for their specific time series needs, and benchmark them +across multiple time series datasets. Merlion's key features are - Standardized and easily extensible data loading & benchmarking for a wide range of forecasting and anomaly diff --git a/docs/source/index.rst b/docs/source/index.rst index 58239b18f..c87ddabc9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,8 +7,8 @@ Welcome to Merlion's documentation! =================================== Merlion is a Python library for time series intelligence. It features a unified interface for many commonly used -:doc:`models ` and :doc:`datasets ` for anomaly detection and forecasting -on both univariate and multivariate time series, along with standard +:doc:`models ` and :doc:`datasets ` for forecasting, anomaly detection, and change +point detection on both univariate and multivariate time series, along with standard :doc:`pre-processing ` and :doc:`post-processing ` layers. It has several modules to improve ease-of-use, including :ref:`visualization `, diff --git a/docs/source/merlion.models.anomaly.change_point.rst b/docs/source/merlion.models.anomaly.change_point.rst new file mode 100644 index 000000000..8b38178e7 --- /dev/null +++ b/docs/source/merlion.models.anomaly.change_point.rst @@ -0,0 +1,21 @@ +merlion.models.anomaly.change\_point package +============================================ + +.. automodule:: merlion.models.anomaly.change_point + :members: + :undoc-members: + :show-inheritance: + +.. autosummary:: + bocpd + +Submodules +---------- + +merlion.models.anomaly.change\_point.bocpd module +------------------------------------------------- + +.. automodule:: merlion.models.anomaly.change_point.bocpd + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/merlion.models.anomaly.rst b/docs/source/merlion.models.anomaly.rst index a88dc1feb..1392ae210 100644 --- a/docs/source/merlion.models.anomaly.rst +++ b/docs/source/merlion.models.anomaly.rst @@ -28,6 +28,7 @@ Subpackages :maxdepth: 4 merlion.models.anomaly.forecast_based + merlion.models.anomaly.change_point Submodules ---------- diff --git a/docs/source/merlion.models.rst b/docs/source/merlion.models.rst index 8e42e6788..64570c229 100644 --- a/docs/source/merlion.models.rst +++ b/docs/source/merlion.models.rst @@ -62,6 +62,7 @@ Finally, we support ensembles of models in :py:mod:`merlion.models.ensemble`. factory defaults anomaly + anomaly.change_point anomaly.forecast_based forecast ensemble @@ -75,6 +76,7 @@ Subpackages :maxdepth: 2 merlion.models.anomaly + merlion.models.anomaly.change_point merlion.models.anomaly.forecast_based merlion.models.forecast merlion.models.ensemble diff --git a/docs/source/merlion.rst b/docs/source/merlion.rst index 0d5a2f289..70ef5aa1e 100644 --- a/docs/source/merlion.rst +++ b/docs/source/merlion.rst @@ -7,6 +7,7 @@ each associated with its own sub-package: for anomaly detection and forecasting. More specifically, we have - :py:mod:`merlion.models.anomaly`: Anomaly detection models + - :py:mod:`merlion.models.anomaly.change_point`: Change point detection models - :py:mod:`merlion.models.forecast`: Forecasting models - :py:mod:`merlion.models.anomaly.forecast_based`: Forecasting models adapted for anomaly detection. Anomaly scores are based on the residual between the predicted and true value at each timestamp. diff --git a/docs/source/merlion.utils.rst b/docs/source/merlion.utils.rst index e9e85b61c..d574e411c 100644 --- a/docs/source/merlion.utils.rst +++ b/docs/source/merlion.utils.rst @@ -11,6 +11,13 @@ utilities for resampling time series. Submodules ---------- +merlion.utils.conj_priors module +-------------------------------- +.. automodule:: merlion.utils.conj_priors + :members: + :undoc-members: + :show-inheritance: + merlion.utils.istat module -------------------------- diff --git a/merlion/models/anomaly/__init__.py b/merlion/models/anomaly/__init__.py index 8f48cb9fa..1ed1758a3 100644 --- a/merlion/models/anomaly/__init__.py +++ b/merlion/models/anomaly/__init__.py @@ -6,7 +6,8 @@ # """ Contains all anomaly detection models. Forecaster-based anomaly detection models -may be found in :py:mod:`merlion.models.anomaly.forecast_based`. +may be found in :py:mod:`merlion.models.anomaly.forecast_based`. Change-point detection models may be +found in :py:mod:`merlion.models.anomaly.change_point`. For anomaly detection, we define an abstract `DetectorBase` class which inherits from `ModelBase` and supports the following interface, in addition to ``model.save`` and ``DetectorClass.load`` defined for `ModelBase`: diff --git a/merlion/models/anomaly/base.py b/merlion/models/anomaly/base.py index bbbbd5a82..e7ec83e90 100644 --- a/merlion/models/anomaly/base.py +++ b/merlion/models/anomaly/base.py @@ -98,6 +98,18 @@ class NoCalibrationDetectorConfig(DetectorConfig): def __init__(self, enable_calibrator=False, **kwargs): super().__init__(enable_calibrator=enable_calibrator, **kwargs) + @property + def calibrator(self): + """ + :return: ``None`` + """ + return None + + @calibrator.setter + def calibrator(self, calibrator): + # no-op + pass + @property def enable_calibrator(self): """ @@ -132,7 +144,14 @@ def _default_post_rule_train_config(self): from merlion.evaluate.anomaly import TSADMetric t = self.config._default_threshold.alm_threshold - q = None if self.config.enable_calibrator or t == 0 else 2 * norm.cdf(t) - 1 + # self.calibrator is only None if calibration has been manually disabled + # and the anomaly scores are expected to be calibrated by get_anomaly_score(). If + # self.config.enable_calibrator, the model will return a calibrated score. + if self.calibrator is None or self.config.enable_calibrator or t == 0: + q = None + # otherwise, choose the quantile corresponding to the given threshold + else: + q = 2 * norm.cdf(t) - 1 return dict(metric=TSADMetric.F1, unsup_quantile=q) @property diff --git a/merlion/models/anomaly/change_point/__init__.py b/merlion/models/anomaly/change_point/__init__.py new file mode 100644 index 000000000..6e28f3dcd --- /dev/null +++ b/merlion/models/anomaly/change_point/__init__.py @@ -0,0 +1,10 @@ +# +# Copyright (c) 2021 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Contains all change point detection algorithms. These models implement the anomaly detector interface, but +they are specialized for detecting change points in time series. +""" diff --git a/merlion/models/anomaly/change_point/bocpd.py b/merlion/models/anomaly/change_point/bocpd.py new file mode 100644 index 000000000..ba92f084b --- /dev/null +++ b/merlion/models/anomaly/change_point/bocpd.py @@ -0,0 +1,486 @@ +# +# Copyright (c) 2021 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Bayesian online change point detection algorithm. +""" +import bisect +import copy +from enum import Enum +import logging +from typing import List, Tuple, Union +import warnings + +import numpy as np +import pandas as pd +import scipy.sparse +from scipy.special import logsumexp +from scipy.stats import norm +from tqdm import tqdm + +from merlion.models.anomaly.base import NoCalibrationDetectorConfig +from merlion.models.anomaly.forecast_based.base import ForecastingDetectorBase +from merlion.models.forecast.base import ForecasterConfig +from merlion.plot import Figure +from merlion.post_process.threshold import AggregateAlarms +from merlion.utils.conj_priors import ConjPrior, MVNormInvWishart, BayesianMVLinReg +from merlion.utils.time_series import TimeSeries, UnivariateTimeSeries, to_pd_datetime + +logger = logging.getLogger(__name__) + + +class ChangeKind(Enum): + """ + Enum representing the kinds of changes points we would like to detect. + Enum values correspond to the Bayesian `ConjPrior` class used to detect each sort of change point. + """ + + Auto = None + """ + Automatically choose the Bayesian conjugate prior we would like to use. + """ + + LevelShift = MVNormInvWishart + """ + Model data points with a normal distribution, to detect level shifts. + """ + + TrendChange = BayesianMVLinReg + """ + Model data points as a linear function of time, to detect trend changes. + """ + + +class _PosteriorBeam: + """ + Utility class to track the posterior beam in the dynamic programming for BOCPD. + """ + + def __init__(self, run_length: int, posterior: ConjPrior, cp_prior: float, logp: float): + self.run_length: int = run_length + self.posterior: ConjPrior = posterior + self.cp_prior = cp_prior + self.logp = logp # joint probability P(r_t = self.run_length, x_{1:t}) + + def update(self, x): + # self.logp starts as log P(r_{t-1} = self.run_length, x_{1:t-1}) + n = 1 if isinstance(x, tuple) and len(x) == 2 else len(x) + # logp_x is log P(x_t) + if n == 1: + method = getattr(self.posterior, "posterior_explicit", self.posterior.posterior) + else: + method = self.posterior.posterior + logp_x, updated = method(x, log=True, return_updated=True) + self.posterior = updated + self.run_length += n + # P(r_t = self.run_length + 1, x_{1:t}) = P(r_{t-1} = self.run_length, x_{1:t-1}) * P(x_t) * (1 - self.cp_prior) + self.logp += sum(logp_x) + n * np.log1p(-self.cp_prior) + + +class BOCPDConfig(ForecasterConfig, NoCalibrationDetectorConfig): + """ + Config class for `BOCPD` (Bayesian Online Change Point Detection). + """ + + _default_threshold = AggregateAlarms(alm_threshold=norm.ppf((1 + 0.5) / 2), min_alm_in_window=1) + """ + Default threshold is for a >=50% probability that a point is a change point. + """ + + def __init__( + self, + change_kind: Union[str, ChangeKind] = ChangeKind.Auto, + cp_prior=1e-2, + lag=None, + min_likelihood=1e-12, + max_forecast_steps=None, + **kwargs, + ): + """ + :param change_kind: the kind of change points we would like to detect + :param cp_prior: prior belief probability of how frequently changepoints occur + :param lag: the maximum amount of delay/lookback (in number of steps) allowed for detecting change points. + If ``lag`` is ``None``, we will consider the entire history. Note: we do not recommend ``lag = 0``. + :param min_likelihood: we will discard any hypotheses whose probability of being a change point is + lower than this threshold. Lower values improve accuracy at the cost of time and space complexity. + :param max_forecast_steps: the maximum number of steps the model is allowed to forecast. Ignored. + """ + self.change_kind = change_kind + self.min_likelihood = min_likelihood + self.cp_prior = cp_prior # Kats checks [0.001, 0.002, 0.005, 0.01, 0.02] + self.lag = lag + super().__init__(max_forecast_steps=max_forecast_steps, **kwargs) + + def to_dict(self, _skipped_keys=None): + _skipped_keys = _skipped_keys if _skipped_keys is not None else set() + config_dict = super().to_dict(_skipped_keys.union({"change_kind"})) + config_dict["change_kind"] = self.change_kind.name + return config_dict + + @property + def change_kind(self) -> ChangeKind: + return self._change_kind + + @change_kind.setter + def change_kind(self, change_kind: Union[str, ChangeKind]): + if isinstance(change_kind, str): + valid = set(ChangeKind.__members__.keys()) + if change_kind not in valid: + raise KeyError(f"{change_kind} is not a valid change kind. Valid change kinds are: {valid}") + change_kind = ChangeKind[change_kind] + self._change_kind = change_kind + + +class BOCPD(ForecastingDetectorBase): + """ + Bayesian online change point detection algorithm described by + `Adams & MacKay (2007) `__. + At a high level, this algorithm models the observed data using Bayesian conjugate priors. If an observed value + deviates too much from the current posterior distribution, it is likely a change point, and we should start + modeling the time series from that point forwards with a freshly initialized Bayesian conjugate prior. + + The ``get_anomaly_score()`` method returns a z-score corresponding to the probability of each point being + a change point. The ``forecast()`` method returns the predicted values (and standard error) of the underlying + piecewise model on the relevant data. + """ + + config_class = BOCPDConfig + + def __init__(self, config: BOCPDConfig = None): + config = BOCPDConfig() if config is None else config + super().__init__(config) + self.posterior_beam: List[_PosteriorBeam] = [] + self.train_timestamps: List[float] = [] + self.full_run_length_posterior = scipy.sparse.dok_matrix((0, 0), dtype=float) + self.pw_model: List[Tuple[pd.Timestamp, ConjPrior]] = [] + + @property + def last_train_time(self): + return None if len(self.train_timestamps) == 0 else to_pd_datetime(self.train_timestamps[-1]) + + @last_train_time.setter + def last_train_time(self, t): + pass + + @property + def n_seen(self): + """ + :return: the number of data points seen so far + """ + return self.full_run_length_posterior.get_shape()[0] + + @property + def change_kind(self) -> ChangeKind: + """ + :return: the kind of change points we would like to detect + """ + return self.config.change_kind + + @property + def cp_prior(self) -> float: + """ + :return: prior belief probability of how frequently changepoints occur + """ + return self.config.cp_prior + + @property + def lag(self) -> int: + """ + :return: the maximum amount of delay allowed for detecting change points. A higher lag can increase + recall, but it may decrease precision. + """ + return self.config.lag + + @property + def min_likelihood(self) -> float: + """ + :return: we will not consider any hypotheses (about whether a particular point is a change point) + with likelihood lower than this threshold + """ + return self.config.min_likelihood + + def _create_posterior(self, logp: float) -> _PosteriorBeam: + posterior = self.change_kind.value() + return _PosteriorBeam(run_length=0, posterior=posterior, cp_prior=self.cp_prior, logp=logp) + + def _get_anom_scores(self, time_stamps: List[Union[int, float]]) -> TimeSeries: + # Convert sparse posterior matrix to a form where it's fast to access its diagonals + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + posterior = scipy.sparse.dia_matrix(self.full_run_length_posterior) + + # Compute the MAP probability that each point is a change point. + # full_run_length_posterior[i, r] = P[run length = r at time t_i] + i_0 = bisect.bisect_left(self.train_timestamps, time_stamps[0]) + i_f = bisect.bisect_right(self.train_timestamps, time_stamps[-1]) + probs = np.zeros(i_f - i_0) + n_lag = None if self.lag is None else self.lag + 1 + for i_prob, i_posterior in enumerate(range(max(i_0, 1), i_f)): + probs[i_prob] = posterior.diagonal(-i_posterior)[:n_lag].max() + + # Convert P[changepoint] to z-score units, and align it to the right time stamps + scores = norm.ppf((1 + probs) / 2) + ts = UnivariateTimeSeries(time_stamps=self.train_timestamps[i_0:i_f], values=scores, name="anom_score").to_ts() + return ts.align(reference=time_stamps) + + def _update_model(self, timestamps): + # Figure out where the changepoints are in the data + changepoints = self.threshold.to_simple_threshold()(self._get_anom_scores(timestamps)) + changepoints = changepoints.to_pd().iloc[:, 0] + cp_times = changepoints[changepoints != 0].index + + # Remove every sub-model that takes effect after the first timestamp provided. + self.pw_model = [(t0, model) for t0, model in self.pw_model if t0 < changepoints.index[0]] + + # Update the final piece of the existing model (if there is one) + t0 = changepoints.index[0] if len(self.pw_model) == 0 else self.pw_model[-1][0] + tf = changepoints.index[-1] if len(cp_times) == 0 else cp_times[0] + train_data = self.transform(self.train_data) + data = train_data.window(t0, tf, include_tf=len(cp_times) == 0) + if len(data) > 0: + if len(self.pw_model) == 0: + self.pw_model.append((t0, self.change_kind.value(data))) + else: + self.pw_model[-1] = (t0, self.change_kind.value(data)) + + # Build a piecewise model by using the data between each subsequent change point + t0 = tf + for tf in cp_times[1:]: + data = train_data.window(t0, tf) + if len(data) > 0: + self.pw_model.append((t0, self.change_kind.value(data))) + t0 = tf + if t0 < changepoints.index[-1]: + _, data = train_data.bisect(t0, t_in_left=False) + self.pw_model.append((t0, self.change_kind.value(data))) + + def train_pre_process( + self, train_data: TimeSeries, require_even_sampling: bool, require_univariate: bool + ) -> TimeSeries: + # BOCPD doesn't _require_ target_seq_index to be specified, but train_pre_process() does. + if self.target_seq_index is None and train_data.dim > 1: + self.config.target_seq_index = 0 + logger.warning( + f"Received a {train_data.dim}-variate time series, but `target_seq_index` was not " + f"specified. Setting `target_seq_index = 0` so the `forecast()` method will work." + ) + train_data = super().train_pre_process(train_data, require_even_sampling, require_univariate) + # We manually update self.train_data in update(), so do nothing here + self.train_data = None + return train_data + + def forecast( + self, + time_stamps: Union[int, List[int]], + time_series_prev: TimeSeries = None, + return_iqr: bool = False, + return_prev: bool = False, + ) -> Union[Tuple[TimeSeries, TimeSeries], Tuple[TimeSeries, TimeSeries, TimeSeries]]: + if time_series_prev is not None: + self.update(time_series_prev) + if isinstance(time_stamps, (int, float)): + time_stamps = pd.date_range(start=self.last_train_time, freq=self.timedelta, periods=int(time_stamps))[1:] + else: + time_stamps = to_pd_datetime(time_stamps) + if return_prev and time_series_prev is not None: + time_stamps = to_pd_datetime(time_series_prev.time_stamps).union(time_stamps) + + # Initialize output accumulators + pred_full, err_full = None, None + + # Split the time stamps based on which model piece should be used + j = 0 + i = bisect.bisect_left([t0 for t0, model in self.pw_model], time_stamps[j], hi=len(self.pw_model) - 1) + for i, (t0, posterior) in enumerate(self.pw_model[i:], i): + # Stop forecasting if we've finished with all the input timestamps + if j >= len(time_stamps): + break + + # If this is the last piece, use it to forecast the rest of the timestamps + if i == len(self.pw_model) - 1: + pred, err = posterior.forecast(time_stamps[j:]) + + # Otherwise, predict until the next piece takes over + else: + t_next = self.pw_model[i + 1][0] + j_next = bisect.bisect_left(time_stamps, t_next) + pred, err = posterior.forecast(time_stamps[j:j_next]) + j = j_next + + # Accumulate results + pred_full = pred if pred_full is None else pred_full + pred + err_full = err if err_full is None else err_full + err + + pred = pred_full.univariates[pred_full.names[self.target_seq_index]].to_pd() + err = err_full.univariates[err_full.names[self.target_seq_index]].to_pd() + pred[pred.isna() | np.isinf(pred)] = 0 + err[err.isna() | np.isinf(err)] = 0 + + if return_iqr: + name = pred.name + lb = UnivariateTimeSeries.from_pd(pred + norm.ppf(0.25) * err, name=f"{name}_lower") + ub = UnivariateTimeSeries.from_pd(pred + norm.ppf(0.75) * err, name=f"{name}_upper") + return TimeSeries.from_pd(pred), lb.to_ts(), ub.to_ts() + return TimeSeries.from_pd(pred), TimeSeries.from_pd(err) + + def get_figure( + self, + *, + time_series: TimeSeries = None, + time_stamps: List[int] = None, + time_series_prev: TimeSeries = None, + plot_anomaly=True, + filter_scores=True, + plot_forecast=False, + plot_forecast_uncertainty=False, + plot_time_series_prev=False, + ) -> Figure: + if time_series is not None: + self.update(time_series) + return super().get_figure( + time_series=time_series, + time_stamps=time_stamps, + time_series_prev=time_series_prev, + plot_anomaly=plot_anomaly, + filter_scores=filter_scores, + plot_forecast=plot_forecast, + plot_forecast_uncertainty=plot_forecast_uncertainty, + plot_time_series_prev=plot_time_series_prev, + ) + + def update(self, time_series: TimeSeries): + """ + Updates the BOCPD model's internal state using the time series values provided. + + :param time_series: time series whose values we are using to update the internal state of the model + :return: anomaly score associated with each point (based on the probability of it being a change point) + """ + # Only update on the portion of the time series after the last training timestamp + time_stamps = time_series.time_stamps + if self.last_train_time is not None: + time_series_prev, time_series = time_series.bisect(self.last_train_time, t_in_left=True) + else: + time_series_prev = None + + # Update the training data accumulated so far + if self.train_data is None: + self.train_data = time_series + else: + self.train_data = self.train_data + time_series + + # Apply any pre-processing transforms to the time series + time_series, time_series_prev = self.transform_time_series(time_series, time_series_prev) + + # Align the time series & expand the array storing the full posterior distribution of run lengths + time_series = time_series.align() + n_seen, T = self.n_seen, len(time_series) + self.full_run_length_posterior = scipy.sparse.block_diag( + (self.full_run_length_posterior, scipy.sparse.dok_matrix((T, T), dtype=float)), format="dok" + ) + + # Compute the minimum log likelihood threshold that we consider. + min_ll = -np.inf if self.min_likelihood is None or self.min_likelihood <= 0 else np.log(self.min_likelihood) + if self.change_kind is ChangeKind.TrendChange: + min_ll = min_ll * time_series.dim + min_ll = min_ll + np.log(self.cp_prior) + + # Iterate over the time series + for i, (t, x) in enumerate(tqdm(time_series, desc="BOCPD Update", disable=(T == 0))): + # Update posterior beams + for post in self.posterior_beam: + post.update((t, x)) + + # Calculate posterior probability that this is change point with + # P_changepoint = \sum_{r_{t-1}} P(r_{t-1}, x_{1:t-1}) * P(x_t) * cp_prior + # After the updates, post.logp = log P(r_t, x_{1:t}) + # = log P(r_{t-1}, x_1{1:t-1}) + log P(x_t) - log(1 - cp_prior) + # So we can just add log(cp_prior) - log(1 - cp_prior) to each of the logp's + if len(self.posterior_beam) == 0: + cp_logp = 0 + else: + cp_delta = np.log(self.cp_prior) - np.log1p(-self.cp_prior) + cp_logp = logsumexp([post.logp + cp_delta for post in self.posterior_beam]) + self.posterior_beam.append(self._create_posterior(logp=cp_logp)) + + # P(x_{1:t}) = \sum_{r_t} P(r_t, x_{1:t}) + evidence = logsumexp([post.logp for post in self.posterior_beam]) + + # P(r_t) = P(r_t, x_{1:t}) / P(x_{1:t}) + run_length_dist_0 = {post.run_length: post.logp - evidence for post in self.posterior_beam} + + # Remove posterior beam candidates whose run length probability is too low + run_length_dist, to_remove = {}, {} + for r, logp in run_length_dist_0.items(): + if logp < min_ll and r > 2: # allow at least 2 updates for each change point hypothesis + to_remove[r] = logp + else: + run_length_dist[r] = logp + + # Re-normalize all remaining probabilities to sum to 1 + self.posterior_beam = [post for post in self.posterior_beam if post.run_length not in to_remove] + if len(to_remove) > 0: + excess_p = np.exp(logsumexp(list(to_remove.values()))) # log P[to_remove] + for post in self.posterior_beam: + post.logp -= np.log1p(-excess_p) + run_length_dist[post.run_length] -= np.log1p(-excess_p) + + # Update the full posterior distribution of run-length at each time, up to the desired lag + run_length_dist = [(r, logp) for r, logp in run_length_dist.items()] + if len(run_length_dist) > 0: + all_r, all_logp_r = zip(*run_length_dist) + self.full_run_length_posterior[n_seen + i, all_r] = np.exp(all_logp_r) + + # Add this timestamp to the list of timestamps we've trained on + self.train_timestamps.append(t) + + # Update the predictive model if there is any new data + if len(time_series) > 0: + if self.lag is None: + n = len(self.train_timestamps) + else: + n = T + self.lag + self._update_model(self.train_timestamps[-n:]) + + # Return the anomaly scores + return self._get_anom_scores(time_stamps) + + def train( + self, train_data: TimeSeries, anomaly_labels: TimeSeries = None, train_config=None, post_rule_train_config=None + ) -> TimeSeries: + + # If automatically detecting the change kind, train candidate models with each change kind + # TODO: abstract this logic into a merlion.models.automl.GridSearch object? + if self.change_kind is ChangeKind.Auto: + candidates = [] + for change_kind in ChangeKind: + if change_kind is ChangeKind.Auto: + continue + candidate = copy.deepcopy(self) + candidate.config.change_kind = change_kind + train_scores = candidate.train(train_data, anomaly_labels, train_config, post_rule_train_config) + log_likelihood = logsumexp([p.logp for p in candidate.posterior_beam]) + candidates.append((candidate, train_scores, log_likelihood)) + logger.info(f"Change kind {change_kind.name} has log likelihood {log_likelihood:.3f}.") + + # Choose the model with the best log likelihood + i_best = np.argmax([candidate[2] for candidate in candidates]) + best, train_scores, _ = candidates[i_best] + self.__setstate__(best.__getstate__()) + logger.info(f"Using change kind {self.change_kind.name} because it has the best log likelihood.") + + # Otherwise, just train as normal + else: + self.train_pre_process(train_data, require_even_sampling=False, require_univariate=False) + train_scores = self.update(time_series=train_data) + self.train_post_rule(train_scores, anomaly_labels, post_rule_train_config) + + # Return the anomaly scores on the training data + return train_scores + + def get_anomaly_score(self, time_series: TimeSeries, time_series_prev: TimeSeries = None) -> TimeSeries: + if time_series_prev is not None: + self.update(time_series_prev) + return self.update(time_series) diff --git a/merlion/models/anomaly/isolation_forest.py b/merlion/models/anomaly/isolation_forest.py index eee9fb511..41c6bf03f 100644 --- a/merlion/models/anomaly/isolation_forest.py +++ b/merlion/models/anomaly/isolation_forest.py @@ -23,12 +23,14 @@ class IsolationForestConfig(DetectorConfig): + """ + Configuration class for `IsolationForest`. + """ + _default_transform = TransformSequence([DifferenceTransform(), Shingle(size=2, stride=1)]) def __init__(self, max_n_samples: int = None, n_estimators: int = 100, **kwargs): """ - Configuration class for isolation forest. - :param max_n_samples: Maximum number of samples to allow the isolation forest to train on. Specify ``None`` to use all samples in the training data. diff --git a/merlion/models/anomaly/random_cut_forest.py b/merlion/models/anomaly/random_cut_forest.py index 74790120a..f8b33ed02 100644 --- a/merlion/models/anomaly/random_cut_forest.py +++ b/merlion/models/anomaly/random_cut_forest.py @@ -40,6 +40,12 @@ def __init__(self): class RandomCutForestConfig(DetectorConfig): + """ + Configuration class for `RandomCutForest`. Refer to + https://github.com/aws/random-cut-forest-by-aws/tree/main/Java for + further documentation and defaults of the Java class. + """ + _default_transform = TransformSequence([DifferenceTransform(), Shingle(size=5, stride=1)]) def __init__( @@ -53,10 +59,6 @@ def __init__( **kwargs ): """ - Configuration class for random cut forest. Refer to - https://github.com/aws/random-cut-forest-by-aws/tree/main/Java for - further documentation and defaults of the Java class. - :param n_estimators: The number of trees in this forest. :param parallel: If true, then the forest will create an internal thread pool. Forest updates and traversals will be submitted to this thread @@ -102,8 +104,7 @@ def java_params(self): class RandomCutForest(DetectorBase): """ The random cut forest is a refinement of the classic isolation forest - algorithm. It was proposed in - `Guha et al. 2016 `_. + algorithm. It was proposed in `Guha et al. 2016 `__. """ config_class = RandomCutForestConfig diff --git a/merlion/models/anomaly/spectral_residual.py b/merlion/models/anomaly/spectral_residual.py index 1e031ac01..0bac97872 100644 --- a/merlion/models/anomaly/spectral_residual.py +++ b/merlion/models/anomaly/spectral_residual.py @@ -18,6 +18,10 @@ class SpectralResidualConfig(DetectorConfig): + """ + Config class for `SpectralResidual` anomaly detector. + """ + _default_transform = TemporalResample(granularity=None) def __init__(self, local_wind_sz=21, q=3, estimated_points=5, predicting_points=5, target_seq_index=None, **kwargs): @@ -61,8 +65,8 @@ class SpectralResidual(DetectorBase): """ Spectral Residual Algorithm for Anomaly Detection. - Spectral Residual Anomaly Detection algorithm based on the algorithm described in this - `paper `_. After taking the frequency spectrum, compute the + Spectral Residual Anomaly Detection algorithm based on the algorithm described by + `Ren et al. (2019) `__. After taking the frequency spectrum, compute the log deviation from the mean. Use inverse fourier transform to obtain the saliency map. Anomaly scores for a point in the time series are obtained by comparing the saliency score of the point to the average of the previous points. @@ -140,7 +144,7 @@ def get_anomaly_score(self, time_series: TimeSeries, time_series_prev: TimeSerie def train( self, train_data: TimeSeries, anomaly_labels: TimeSeries = None, train_config=None, post_rule_train_config=None ) -> TimeSeries: - self.train_pre_process(train_data, require_even_sampling=True, require_univariate=False) + train_data = self.train_pre_process(train_data, require_even_sampling=True, require_univariate=False) if train_data.dim == 1: self.config.target_seq_index = 0 @@ -159,5 +163,4 @@ def train( self.train_post_rule( anomaly_scores=train_scores, anomaly_labels=anomaly_labels, post_rule_train_config=post_rule_train_config ) - self.train_data = train_data return train_scores diff --git a/merlion/models/anomaly/stat_threshold.py b/merlion/models/anomaly/stat_threshold.py index f02e83b84..7162ac95a 100644 --- a/merlion/models/anomaly/stat_threshold.py +++ b/merlion/models/anomaly/stat_threshold.py @@ -14,10 +14,18 @@ class StatThresholdConfig(DetectorConfig, NormalizingConfig): + """ + Config class for `StatThreshold`. + """ + _default_transform = DifferenceTransform() class StatThreshold(DetectorBase): + """ + Anomaly detection based on a static threshold. + """ + config_class = StatThresholdConfig def train( diff --git a/merlion/models/anomaly/windstats.py b/merlion/models/anomaly/windstats.py index 8cfc069cd..e75bd0fc9 100644 --- a/merlion/models/anomaly/windstats.py +++ b/merlion/models/anomaly/windstats.py @@ -21,6 +21,10 @@ class WindStatsConfig(DetectorConfig): + """ + Config class for `WindStats`. + """ + _default_transform = DifferenceTransform() @property diff --git a/merlion/models/anomaly/zms.py b/merlion/models/anomaly/zms.py index fe669a65c..69efaa030 100644 --- a/merlion/models/anomaly/zms.py +++ b/merlion/models/anomaly/zms.py @@ -24,7 +24,7 @@ class ZMSConfig(DetectorConfig, NormalizingConfig): """ - Configuration class for an ZMS anomaly detection model. + Configuration class for `ZMS` anomaly detection model. """ _default_transform = TemporalResample(trainable_granularity=True) diff --git a/merlion/models/automl/autosarima.py b/merlion/models/automl/autosarima.py index 8b2487092..a73be8d0f 100644 --- a/merlion/models/automl/autosarima.py +++ b/merlion/models/automl/autosarima.py @@ -22,6 +22,10 @@ class AutoSarimaConfig(SarimaConfig): + """ + Configuration class for `AutoSarima`. + """ + _default_transform = TemporalResample() def __init__( @@ -39,7 +43,6 @@ def __init__( **kwargs, ): """ - Configuration class for AutoSarima. For order and seasonal_order, 'auto' indicates automatically select the parameter. Now autosarima support automatically select differencing order, length of the seasonality cycle, seasonal differencing order, and the rest of AR, MA, seasonal AR @@ -140,8 +143,9 @@ def _generate_sarima_parameters(self, train_data: TimeSeries) -> dict: m = seasonal_order[-1] if not isinstance(m, (int, float)): m = 1 - warnings.warn("Set periodicity to 1, use the SeasonalityLayer()" - "wrapper to automatically detect seasonality.") + warnings.warn( + "Set periodicity to 1, use the SeasonalityLayer()" "wrapper to automatically detect seasonality." + ) # adjust max p,q,P,Q start p,q,P,Q max_p = int(min(max_p, np.floor(n_samples / 3))) diff --git a/merlion/models/base.py b/merlion/models/base.py index 707e034d7..a365dbd03 100644 --- a/merlion/models/base.py +++ b/merlion/models/base.py @@ -166,6 +166,7 @@ def __init__(self, config: Config): self.config = deepcopy(config) self.last_train_time = None self.timedelta = None + self.train_data = None def reset(self): """ @@ -237,6 +238,7 @@ def train_pre_process( :return: the training data, after any necessary pre-processing has been applied """ + self.train_data = train_data self.transform.train(train_data) train_data = self.transform(train_data) @@ -272,10 +274,12 @@ def transform_time_series( :return: The transformed ``time_series``. """ - t0 = time_series.t0 - if time_series_prev is not None: + if time_series_prev is not None and not time_series.is_empty(): + t0 = time_series.t0 time_series = time_series_prev + time_series time_series_prev, time_series = self.transform(time_series).bisect(t0, t_in_left=False) + elif time_series_prev is not None: + time_series_prev = self.transform(time_series_prev) else: time_series = self.transform(time_series) return time_series, time_series_prev diff --git a/merlion/models/ensemble/anomaly.py b/merlion/models/ensemble/anomaly.py index b255a3928..1743c39ec 100644 --- a/merlion/models/ensemble/anomaly.py +++ b/merlion/models/ensemble/anomaly.py @@ -55,6 +55,10 @@ def __init__(self, enable_calibrator=False, **kwargs): class DetectorEnsemble(EnsembleBase, DetectorBase): + """ + Class representing an ensemble of multiple anomaly detection models. + """ + models: List[DetectorBase] config_class = DetectorEnsembleConfig _default_train_config = EnsembleTrainConfig(valid_frac=0.0) diff --git a/merlion/models/ensemble/base.py b/merlion/models/ensemble/base.py index a273fadcb..61327983d 100644 --- a/merlion/models/ensemble/base.py +++ b/merlion/models/ensemble/base.py @@ -94,6 +94,10 @@ def __init__(self, valid_frac, per_model_train_configs=None): class EnsembleBase(ModelBase, ABC): + """ + An abstract class representing an ensemble of multiple models. + """ + models: List[ModelBase] config_class = EnsembleConfig diff --git a/merlion/models/ensemble/forecast.py b/merlion/models/ensemble/forecast.py index 0f4881597..8a98bd1f0 100644 --- a/merlion/models/ensemble/forecast.py +++ b/merlion/models/ensemble/forecast.py @@ -32,6 +32,10 @@ def __init__(self, max_forecast_steps=None, **kwargs): class ForecasterEnsemble(EnsembleBase, ForecasterBase): + """ + Class representing an ensemble of multiple forecasting models. + """ + models: List[ForecasterBase] config_class = ForecasterEnsembleConfig diff --git a/merlion/models/factory.py b/merlion/models/factory.py index 3158cf3a0..7218dfc30 100644 --- a/merlion/models/factory.py +++ b/merlion/models/factory.py @@ -22,6 +22,7 @@ ArimaDetector="merlion.models.anomaly.forecast_based.arima:ArimaDetector", DynamicBaseline="merlion.models.anomaly.dbl:DynamicBaseline", IsolationForest="merlion.models.anomaly.isolation_forest:IsolationForest", + # Forecast-based anomaly detection models ETSDetector="merlion.models.anomaly.forecast_based.ets:ETSDetector", LSTMDetector="merlion.models.anomaly.forecast_based.lstm:LSTMDetector", MSESDetector="merlion.models.anomaly.forecast_based.mses:MSESDetector", @@ -37,6 +38,8 @@ VAE="merlion.models.anomaly.vae:VAE", DAGMM="merlion.models.anomaly.dagmm:DAGMM", LSTMED="merlion.models.anomaly.lstm_ed:LSTMED", + # Change point detection models + BOCPD="merlion.models.anomaly.change_point.bocpd", # Forecasting models Arima="merlion.models.forecast.arima:Arima", ETS="merlion.models.forecast.ets:ETS", diff --git a/merlion/models/forecast/arima.py b/merlion/models/forecast/arima.py index a13678920..0dc5574d8 100644 --- a/merlion/models/forecast/arima.py +++ b/merlion/models/forecast/arima.py @@ -18,13 +18,13 @@ class ArimaConfig(SarimaConfig): + """ + Configuration class for `Arima`. Just a `Sarima` model with seasonal order ``(0, 0, 0, 0)``. + """ + _default_transform = TemporalResample(granularity=None, trainable_granularity=True) def __init__(self, max_forecast_steps=None, target_seq_index=None, order=(4, 1, 2), **kwargs): - """ - Configuration class for Arima. Just a Sarima model with seasonal order - (0, 0, 0, 0). - """ if "seasonal_order" in kwargs: raise ValueError("cannot specify seasonal_order for ARIMA") super().__init__( diff --git a/merlion/models/forecast/baggingtrees.py b/merlion/models/forecast/baggingtrees.py index 026676522..229b1b87c 100644 --- a/merlion/models/forecast/baggingtrees.py +++ b/merlion/models/forecast/baggingtrees.py @@ -237,6 +237,10 @@ def reset_data_already_transformed(self): class RandomForestForecasterConfig(BaggingTreeForecasterConfig): + """ + Config class for `RandomForestForecaster`. + """ + pass @@ -262,6 +266,10 @@ def __init__(self, config: RandomForestForecasterConfig): class ExtraTreesForecasterConfig(BaggingTreeForecasterConfig): + """ + Config cass for `ExtraTreesForecaster`. + """ + pass diff --git a/merlion/models/forecast/base.py b/merlion/models/forecast/base.py index 283d38762..bad028852 100644 --- a/merlion/models/forecast/base.py +++ b/merlion/models/forecast/base.py @@ -79,10 +79,7 @@ class ForecasterBase(ModelBase): def __init__(self, config: ForecasterConfig): super().__init__(config) - self.timedelta = None - self.last_train_time = None self.target_name = None - self.train_data = None @property def max_forecast_steps(self): @@ -147,7 +144,6 @@ def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_p def train_pre_process( self, train_data: TimeSeries, require_even_sampling: bool, require_univariate: bool ) -> TimeSeries: - self.train_data = train_data self.config.dim = train_data.dim train_data = super().train_pre_process(train_data, require_even_sampling, require_univariate) if self.dim == 1: diff --git a/merlion/models/forecast/boostingtrees.py b/merlion/models/forecast/boostingtrees.py index 7b9d3a266..cfc896cd9 100644 --- a/merlion/models/forecast/boostingtrees.py +++ b/merlion/models/forecast/boostingtrees.py @@ -240,6 +240,10 @@ def reset_data_already_transformed(self): class LGBMForecasterConfig(BoostingTreeForecasterConfig): + """ + Config class for `LGBMForecaster`. + """ + pass diff --git a/merlion/models/forecast/ets.py b/merlion/models/forecast/ets.py index f1429a0d7..d58f4bcce 100644 --- a/merlion/models/forecast/ets.py +++ b/merlion/models/forecast/ets.py @@ -25,6 +25,15 @@ class ETSConfig(ForecasterConfig): + """ + Configuration class for :py:class:`ETS` model. ETS model is an underlying state space + model consisting of an error term (E), a trend component (T), a seasonal + component (S), and a level component. Each component is flexible with + different traits with additive ('add') or multiplicative ('mul') formulation. + Refer to https://otexts.com/fpp2/taxonomy.html for more information + about ETS model. + """ + _default_transform = TemporalResample(granularity=None) def __init__( @@ -39,13 +48,6 @@ def __init__( **kwargs, ): """ - Configuration class for ETS model. ETS model is an underlying state space - model consisting of an error term (E), a trend component (T), a seasonal - component (S), and a level component. Each component is flexible with - different traits with additive ('add') or multiplicative ('mul') formulation. - Refer to https://otexts.com/fpp2/taxonomy.html for more information - about ETS model. - :param max_forecast_steps: Number of steps we would like to forecast for. :param target_seq_index: The index of the univariate (amongst all univariates in a general multivariate time series) whose value we diff --git a/merlion/models/forecast/lstm.py b/merlion/models/forecast/lstm.py index d85817be4..acace3f5c 100644 --- a/merlion/models/forecast/lstm.py +++ b/merlion/models/forecast/lstm.py @@ -32,6 +32,10 @@ class LSTMConfig(ForecasterConfig): + """ + Configuration class for `LSTM`. + """ + _default_transform = TransformSequence( [ TemporalResample(granularity=None, trainable_granularity=True), @@ -42,8 +46,6 @@ class LSTMConfig(ForecasterConfig): def __init__(self, max_forecast_steps: int, target_seq_index: int = None, nhid=1024, model_strides=(1,), **kwargs): """ - Configuration class for `LSTM`. - :param max_forecast_steps: Max # of steps we would like to forecast for. :param target_seq_index: The index of the univariate (amongst all univariates in a general multivariate time series) whose value we diff --git a/merlion/models/forecast/prophet.py b/merlion/models/forecast/prophet.py index 30a48cce3..5ed84ef9e 100644 --- a/merlion/models/forecast/prophet.py +++ b/merlion/models/forecast/prophet.py @@ -21,6 +21,11 @@ class ProphetConfig(ForecasterConfig): + """ + Configuration class for Facebook's `Prophet` model, as described by + `Taylor & Letham, 2017 `__. + """ + def __init__( self, max_forecast_steps: int = None, @@ -35,9 +40,6 @@ def __init__( **kwargs, ): """ - Configuration class for Facebook's Prophet model, as described in this - `paper `_. - :param max_forecast_steps: Max # of steps we would like to forecast for. :param target_seq_index: The index of the univariate (amongst all univariates in a general multivariate time series) whose value we @@ -79,7 +81,7 @@ def __init__( class Prophet(ForecasterBase): """ Facebook's model for time series forecasting. See docs for `ProphetConfig` - and the `paper `_ for more details. + and `Taylor & Letham, 2017 `__ for more details. """ config_class = ProphetConfig diff --git a/merlion/models/forecast/sarima.py b/merlion/models/forecast/sarima.py index ff14c11ec..b36785cb5 100644 --- a/merlion/models/forecast/sarima.py +++ b/merlion/models/forecast/sarima.py @@ -26,14 +26,16 @@ class SarimaConfig(ForecasterConfig): + """ + Config class for `Sarima` (Seasonal AutoRegressive Integrated Moving Average). + """ + _default_transform = TemporalResample(granularity=None) def __init__( self, max_forecast_steps=None, target_seq_index=None, order=(4, 1, 2), seasonal_order=(2, 0, 1, 24), **kwargs ): """ - Configuration class for Sarima. - :param max_forecast_steps: Number of steps we would like to forecast for. :param target_seq_index: The index of the univariate (amongst all univariates in a general multivariate time series) whose value we diff --git a/merlion/models/forecast/vector_ar.py b/merlion/models/forecast/vector_ar.py index 9b88c1cc8..3da1638ee 100644 --- a/merlion/models/forecast/vector_ar.py +++ b/merlion/models/forecast/vector_ar.py @@ -23,6 +23,10 @@ class VectorARConfig(ForecasterConfig): + """ + Config object for `VectorAR` forecaster. + """ + _default_transform = TemporalResample() """ diff --git a/merlion/plot.py b/merlion/plot.py index 07466ed99..c4ca63251 100644 --- a/merlion/plot.py +++ b/merlion/plot.py @@ -208,7 +208,7 @@ def plot(self, title=None, metric_name=None, figsize=(1000, 600), ax=None, label y = self.get_y() if y is not None: metric_name = y.name if metric_name is None else metric_name - ln = ax.plot(y.index, y.np_values, c="k", lw=1, zorder=1, label=metric_name) + ln = ax.plot(y.index, y.np_values, c="k", alpha=0.8, lw=1, zorder=1, label=metric_name) lines.extend(ln) # Dotted line to cordon off previous times from current ones diff --git a/merlion/utils/__init__.py b/merlion/utils/__init__.py index 4fe2d8884..6cff27e38 100644 --- a/merlion/utils/__init__.py +++ b/merlion/utils/__init__.py @@ -5,5 +5,5 @@ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # from .misc import dynamic_import -from .resample import to_pd_datetime +from .resample import to_pd_datetime, to_timestamp from .time_series import UnivariateTimeSeries, TimeSeries diff --git a/merlion/utils/conj_priors.py b/merlion/utils/conj_priors.py new file mode 100644 index 000000000..53ab3b7a2 --- /dev/null +++ b/merlion/utils/conj_priors.py @@ -0,0 +1,870 @@ +# +# Copyright (c) 2021 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Implementations of Bayesian conjugate priors & their online update rules. + +.. autosummary:: + ConjPrior + BetaBernoulli + NormInvGamma + MVNormInvWishart + BayesianLinReg + BayesianMVLinReg +""" +from abc import ABC, abstractmethod +import copy +import logging +from typing import Tuple + +import numpy as np +import pandas as pd +import scipy +from scipy.special import gammaln, multigammaln +from scipy.linalg import pinv, pinvh +from scipy.stats import bernoulli, beta, invgamma, invwishart, norm, multivariate_normal as mvnorm, t as student_t + +from merlion.utils import TimeSeries, UnivariateTimeSeries, to_timestamp, to_pd_datetime + +logger = logging.getLogger(__name__) + +try: + from scipy.stats import multivariate_t as mvt +except ImportError: + logger.warning("Scipy version <1.6.0 installed. No support for multivariate t density.") + mvt = None + sp_pinv = pinv + + # Redefine pinv to implement an optimization from more recent scipy + # Specifically, if the matrix is tall enough, it's easier to compute pinv with the transpose + def pinv(a): + return sp_pinv(a.T).T if a.shape[0] / a.shape[1] >= 1.1 else sp_pinv(a) + + +_epsilon = 1e-8 + + +def _log_pdet(a): + """ + Log pseudo-determinant of a (possibly singular) matrix A. + """ + eigval, eigvec = np.linalg.eigh(a) + return np.sum(np.log(eigval[eigval > 0])) + + +def _mvt_pdf(x, mu, Sigma, nu, log=True): + """ + (log) PDF of multivariate t distribution. Use as a fallback when scipy >= 1.6.0 isn't available. + """ + # Compute the spectrum of Sigma + eigval, eigvec = np.linalg.eigh(Sigma) + + # Determine a lower bound for eigenvalues s.t. lmbda < eps implies that Sigma is singular + t = eigval.dtype.char.lower() + factor = {"f": 1e3, "d": 1e6} + eps = factor[t] * np.finfo(t).eps * np.max(eigval) + + # Compute the log pseudo-determinant of Sigma + positive_eigval = eigval[eigval > eps] + log_pdet = np.sum(np.log(positive_eigval)) + dim, rank = len(eigval), len(positive_eigval) + + # Compute the square root of the pseudo-inverse of Sigma + inv_eigval = np.array([0 if lmbda < eps else 1 / lmbda for lmbda in eigval]) + pinv_sqrt = np.multiply(eigvec, np.sqrt(inv_eigval)) + + # compute (x - \mu)^T \Sigma^{-1} (x - \mu) + # To do this in batch with D = (x - mu) having shape [n, d], + # we just need the diagonal of D @ Sigma @ D.T, which can be computed as + # below, using the fact that + delta = x - mu # [n, d] + quad_form = np.square(delta @ pinv_sqrt).sum(axis=-1) + + # Multivariate-t log PDF + a = gammaln(0.5 * (nu + dim)) - gammaln(0.5 * nu) + b = -0.5 * (dim * np.log(nu * np.pi) + log_pdet) + c = -0.5 * (nu + dim) * np.log1p(quad_form / nu) + return a + b + c if log else np.exp(a + b + c) + + +class ConjPrior(ABC): + """ + Abstract base class for a Bayesian conjugate prior. + Can be used with either `TimeSeries` or ``numpy`` arrays directly. + """ + + def __init__(self, sample=None): + """ + :param sample: a sample used to initialize the prior. + """ + self.n = 0 + self.dim = None + self.t0 = None + self.dt = None + self.names = None + if sample is not None: + self.update(sample) + + def to_dict(self): + return {k: v.tolist() if hasattr(v, "tolist") else copy.deepcopy(v) for k, v in self.__dict__.items()} + + @classmethod + def from_dict(cls, state_dict): + ret = cls() + for k, v in state_dict.items(): + setattr(ret, k, np.asarray(v)) + return ret + + def __copy__(self): + ret = self.__class__() + for k, v in self.__dict__.items(): + setattr(ret, k, copy.deepcopy(v)) + return ret + + def __deepcopy__(self, memodict={}): + return self.__copy__() + + @staticmethod + def get_time_series_values(x) -> np.ndarray: + """ + :return: numpy array representing the input ``x`` + """ + if x is None: + return None + if isinstance(x, TimeSeries): + x = x.align().to_pd().values + elif isinstance(x, tuple) and len(x) == 2: + t, x = x + x = np.asarray(x).reshape(1, -1) + else: + x = np.asarray(x) + x = x.reshape((1, 1) if x.ndim < 1 else (len(x), -1)) + return x + + def process_time_series(self, x) -> Tuple[np.ndarray, np.ndarray]: + """ + :return: ``(t, x)``, where ``t`` is a normalized list of timestamps, and ``x`` is a ``numpy`` array + representing the input + """ + if x is None: + return None, None + + # Initialize t0 if needed + if self.t0 is None: + if isinstance(x, TimeSeries): + t0, tf = x.t0, x.tf + self.t0 = t0 + self.dt = tf - t0 if tf != t0 else None + elif isinstance(x, tuple) and len(x) == 2: + self.t0 = x[0] + self.dt = None + else: + x = np.asarray(x) + self.t0 = 0 + self.dt = 1 if x.ndim < 1 else max(1, len(x) - 1) + + # Initialize dt if needed; this only happens for cases 1 and 2 above + if self.dt is None: + if isinstance(x, TimeSeries): + tf = x.tf + else: + tf = x[0] + if tf != self.t0: + self.dt = tf - self.t0 + + # Convert time series to numpy, or convert numpy array to pseudo time series + if isinstance(x, TimeSeries): + self.names = x.names + t = x.np_time_stamps + x = x.align().to_pd().values + elif isinstance(x, tuple) and len(x) == 2: + t, x = x + t = np.asarray(t).reshape(1) + x = np.asarray(x).reshape(1, -1) + self.names = [0] + else: + x = np.asarray(x) + x = x.reshape((1, 1) if x.ndim < 1 else (len(x), -1)) + t = np.arange(self.n, self.n + len(x)) + self.names = list(range(x.shape[-1])) + t = (t - self.t0) / (self.dt or 1) + + if self.dim is None: + self.dim = x.shape[-1] + else: + assert x.shape[-1] == self.dim, f"Expected input with dimension {self.dim} but got {x.shape[-1]}" + + return t, x + + @staticmethod + def _process_return(x, rv, return_rv, log): + if x is None or return_rv: + return rv + try: + ret = rv.logpdf(x) if log else rv.pdf(x) + except AttributeError: + ret = rv.logpmf(x) if log else rv.pmf(x) + return ret.reshape(len(x)) + + @abstractmethod + def posterior(self, x, return_rv=False, log=True, return_updated=False): + """ + Predictive posterior (log) PDF for new observations, or the ``scipy.stats`` random variable where applicable. + + :param x: value(s) to evaluate posterior at (``None`` implies that we want to return the random variable) + :param return_rv: whether to return the random variable directly + :param log: whether to return the log PDF (instead of the PDF) + :param return_updated: whether to return an updated version of the conjugate prior as well + """ + raise NotImplementedError + + @abstractmethod + def update(self, x): + """ + Update the conjugate prior based on new observations x. + """ + raise NotImplementedError + + @abstractmethod + def forecast(self, time_stamps) -> Tuple[TimeSeries, TimeSeries]: + """ + Return a posterior predictive interval for the time stamps given. + + :param time_stamps: a list of time stamps + :return: ``(forecast, stderr)``, where ``forecast`` is the expected posterior value and ``stderr`` is the + standard error of that forecast. + """ + raise NotImplementedError + + +class ScalarConjPrior(ConjPrior, ABC): + """ + Abstract base class for a Bayesian conjugate prior for a scalar random variable. + """ + + def __init__(self, sample=None): + super().__init__(sample=sample) + self.dim = 1 + + def process_time_series(self, x): + t, x = super().process_time_series(x) + x = x.flatten() if x is not None else x + return t, x + + @staticmethod + def get_time_series_values(x) -> np.ndarray: + x = super().get_time_series_values(x) + return x.flatten() if x is not None else x + + +class BetaBernoulli(ScalarConjPrior): + r""" + Beta-Bernoulli conjugate prior for binary data. We assume the model + + .. math:: + + \begin{align*} + X &\sim \mathrm{Bernoulli}(\theta) \\ + \theta &\sim \mathrm{Beta}(\alpha, \beta) + \end{align*} + + The update rule for data :math:`x_1, \ldots, x_n` is + + .. math:: + \begin{align*} + \alpha &= \alpha + \sum_{i=1}^{n} \mathbb{I}[x_i = 1] \\ + \beta &= \beta + \sum_{i=1}^{n} \mathbb{I}[x_i = 0] + \end{align*} + + """ + + def __init__(self, sample=None): + self.alpha = 1 + self.beta = 1 + super().__init__(sample=sample) + + def posterior(self, x, return_rv=False, log=True, return_updated=False): + r""" + The posterior distribution of x is :math:`\mathrm{Bernoulli}(\alpha / (\alpha + \beta))`. + """ + t, x_np = self.process_time_series(x) + rv = bernoulli(self.alpha / (self.alpha + self.beta)) + ret = self._process_return(x=x_np, rv=rv, return_rv=return_rv, log=log) + if return_updated: + updated = copy.deepcopy(self) + updated.update(x) + return ret, updated + return ret + + def theta_posterior(self, theta, return_rv=False, log=True): + r""" + The posterior distribution of :math:`\theta` is :math:`\mathrm{Beta}(\alpha, \beta)`. + """ + rv = beta(self.alpha, self.beta) + return self._process_return(x=theta, rv=rv, return_rv=return_rv, log=log) + + def update(self, x): + t, x = self.process_time_series(x) + self.n += len(x) + self.alpha += x.sum() + self.beta += (1 - x).sum() + + def forecast(self, time_stamps) -> Tuple[TimeSeries, TimeSeries]: + n = len(time_stamps) + name = self.names[0] + rv = self.theta_posterior(None) + mu = UnivariateTimeSeries(time_stamps=time_stamps, values=[rv.mean()] * n, name=name) + sigma = UnivariateTimeSeries(time_stamps=time_stamps, values=[rv.std()] * n, name=f"{name}_stderr") + return mu.to_ts(), sigma.to_ts() + + +class NormInvGamma(ScalarConjPrior): + r""" + Normal-InverseGamma conjugate prior. Following + `Wikipedia `__ and + `Murphy (2007) `__, we assume the model + + .. math:: + + \begin{align*} + X &\sim \mathcal{N}(\mu, \sigma^2) \\ + \mu &\sim \mathcal{N}(\mu_0, \sigma^2 / n) \\ + \sigma^2 &\sim \mathrm{InvGamma}(\alpha, \beta) + \end{align*} + + The update rule for data :math:`x_1, \ldots, x_n` is + + .. math:: + \begin{align*} + \bar{x} &= \frac{1}{n} \sum_{i = 1}^{n} x_i \\ + \alpha &= \alpha + n/2 \\ + \beta &= \beta + \frac{1}{2} \sum_{i = 1}^{n} (x_i - \bar{x})^2 + \frac{1}{2} (\mu_0 - \bar{x})^2 \\ + \mu_0 &= \frac{n_0}{n_0 + n} \mu_0 + \frac{n}{n_0 + n} \bar{x} \\ + n_0 &= n_0 + n + \end{align*} + + """ + + def __init__(self, sample=None): + self.mu_0 = 0 + self.alpha = 1 / 2 + _epsilon + self.beta = _epsilon + super().__init__(sample=sample) + + def update(self, x): + t, x = self.process_time_series(x) + n0, n = self.n, len(x) + self.alpha = self.alpha + n / 2 + self.n = n0 + n + + xbar = np.mean(x) + sample_comp = np.sum((x - xbar) ** 2) + prior_comp = n0 * n / (n0 + n) * (self.mu_0 - xbar) ** 2 + self.beta = self.beta + sample_comp / 2 + prior_comp / 2 + self.mu_0 = self.mu_0 * n0 / (n0 + n) + xbar * n / (n0 + n) + + def mu_posterior(self, mu, return_rv=False, log=True): + r""" + The posterior for :math:`\mu` is :math:`\text{Student-t}_{2\alpha}(\mu_0, \beta / (n \alpha))` + """ + scale = self.beta / (2 * self.alpha ** 2) + rv = student_t(loc=self.mu_0, scale=np.sqrt(scale), df=2 * self.alpha) + return self._process_return(x=mu, rv=rv, return_rv=return_rv, log=log) + + def sigma2_posterior(self, sigma2, return_rv=False, log=True): + r""" + The posterior for :math:`\sigma^2` is :math:`\text{InvGamma}(\alpha, \beta)`. + """ + rv = invgamma(a=self.alpha, scale=self.beta) + return self._process_return(x=sigma2, rv=rv, return_rv=return_rv, log=log) + + def posterior(self, x, log=True, return_rv=False, return_updated=False): + r""" + The posterior for :math:`x` is :math:`\text{Student-t}_{2\alpha}(\mu_0, (n+1) \beta / (n \alpha))` + """ + t, x_np = self.process_time_series(x) + scale = (self.beta * (2 * self.alpha + 1)) / (2 * self.alpha ** 2) + rv = student_t(loc=self.mu_0, scale=np.sqrt(scale), df=2 * self.alpha) + ret = self._process_return(x=x_np, rv=rv, return_rv=return_rv, log=log) + if return_updated: + updated = copy.deepcopy(self) + updated.update(x) + return ret, updated + return ret + + def forecast(self, time_stamps) -> Tuple[TimeSeries, TimeSeries]: + n = len(time_stamps) + name = self.names[0] + rv = self.posterior(None) + mu = UnivariateTimeSeries(time_stamps=time_stamps, values=[rv.mean()] * n, name=name) + sigma = UnivariateTimeSeries(time_stamps=time_stamps, values=[rv.std()] * n, name=f"{name}_stderr") + return mu.to_ts(), sigma.to_ts() + + +class MVNormInvWishart(ConjPrior): + r""" + Multivariate Normal-InverseWishart conjugate prior. Multivariate equivalent of Normal-InverseGamma. + Following `Murphy (2007) `__, we assume the model + + .. math:: + + \begin{align*} + X &\sim \mathcal{N}_d(\mu, \Sigma) \\ + \mu &\sim \mathcal{N}_d(\mu_0, \Sigma / n) \\ + \Sigma &\sim \mathrm{InvWishart}_{\nu}(\Lambda) + \end{align*} + + The update rule for data :math:`x_1, \ldots, x_n` is + + .. math:: + \begin{align*} + \bar{x} &= \frac{1}{n} \sum_{i = 1}^{n} x_i \\ + \nu &= \nu + n/2 \\ + \Lambda &= \Lambda + \frac{n_0 n}{n_0 + n} (\mu_0 - \bar{x}) (\mu_0 - \bar{x})^T + + \sum_{i = 1}^{n} (x_i - \bar{x}) (x_i - \bar{x})^T \\ + \mu_0 &= \frac{n_0}{n_0 + n} \mu_0 + \frac{n}{n_0 + n} \bar{x} \\ + n_0 &= n_0 + n + \end{align*} + """ + + def __init__(self, sample=None): + self.nu = 0 + self.mu_0 = None + self.Lambda = None + super().__init__(sample=sample) + + def process_time_series(self, x): + if x is None: + return None, None + t, x = super().process_time_series(x) + n, d = x.shape + if self.nu == 0: + self.nu = d + 2 * _epsilon + if self.Lambda is None: + self.Lambda = 2 * np.eye(d) * _epsilon + if self.mu_0 is None: + self.mu_0 = np.zeros(d) + return t, x + + def update(self, x): + t, x = self.process_time_series(x) + + n0 = self.n + n, d = x.shape + self.nu = self.nu + n + + sample_mean = np.mean(x, axis=0) + sample_cov = (x - sample_mean).T @ (x - sample_mean) + delta = sample_mean - self.mu_0 + self.Lambda = self.Lambda + sample_cov + n * n0 / (n + n0) * (delta.T @ delta) + self.mu_0 = self.mu_0 * n0 / (n0 + n) + sample_mean * n / (n0 + n) + self.n = n0 + n + + def mu_posterior(self, mu, return_rv=False, log=True): + r""" + The posterior for :math:`\mu` is :math:`\text{Student-t}_{\nu-d+1}(\mu_0, \Lambda / (n (\nu - d + 1)))` + """ + dof = self.nu - self.dim + 1 + shape = self.Lambda / (self.nu * dof) + if mvt is not None: + rv = mvt(shape=shape, loc=self.mu_0, df=dof, allow_singular=True) + return self._process_return(x=mu, rv=rv, return_rv=return_rv, log=log) + else: + if mu is None or return_rv: + raise ValueError( + f"The scipy version you have installed ({scipy.__version__}) does not support a multivariate-t " + f"random variable Please specify a non-``None`` value of ``mu`` and set ``return_rv = False``." + ) + return _mvt_pdf(x=mu, mu=self.mu_0, Sigma=shape, nu=dof, log=log) + + def Sigma_posterior(self, sigma2, return_rv=False, log=True): + r""" + The posterior for :math:`\Sigma` is :math:`\text{InvWishart}_{\nu}(\Lambda^{-1})` + """ + rv = invwishart(df=self.nu, scale=self.Lambda) + return self._process_return(x=sigma2, rv=rv, return_rv=return_rv, log=log) + + def posterior(self, x, return_rv=False, log=True, return_updated=False): + r""" + The posterior for :math:`x` is :math:`\text{Student-t}_{\nu-d+1}(\mu_0, (n + 1) \Lambda / (n (\nu - d + 1)))` + """ + t, x_np = self.process_time_series(x) + dof = self.nu - self.dim + 1 + shape = self.Lambda * (self.nu + 1) / (self.nu * dof) + if mvt is not None: + rv = mvt(shape=shape, loc=self.mu_0, df=dof, allow_singular=True) + ret = self._process_return(x=x_np, rv=rv, return_rv=return_rv, log=log) + else: + if x is None or return_rv: + raise ValueError( + f"The scipy version you have installed ({scipy.__version__}) does not support a multivariate-t " + f"random variable Please specify a non-``None`` value of ``x`` and set ``return_rv = False``." + ) + ret = _mvt_pdf(x=x_np, mu=self.mu_0, Sigma=shape, nu=dof, log=log) + + if return_updated: + updated = copy.deepcopy(self) + updated.update(x) + return ret, updated + return ret + + def forecast(self, time_stamps, name="forecast") -> Tuple[TimeSeries, TimeSeries]: + t = to_pd_datetime(time_stamps) + n = len(t) + mu = pd.DataFrame(np.ones((n, self.dim)) * self.mu_0, index=t, columns=self.names) + + dof = self.nu - self.dim + 1 + Sigma = self.Lambda * (self.nu + 1) / (self.nu * dof) + if dof > 2: + cov = dof / (dof - 2) * Sigma + std = np.sqrt(cov.diagonal()) + else: + std = np.zeros(self.dim) + sigma = pd.DataFrame(np.ones((n, self.dim)) * std, index=t, columns=[f"{n}_stderr" for n in self.names]) + + return TimeSeries.from_pd(mu), TimeSeries.from_pd(sigma) + + +class BayesianLinReg(ConjPrior): + r""" + Bayesian Ordinary Linear Regression conjugate prior, which models a univariate input as a function of time. + Following `Wikipedia `__, we assume the model + + .. math:: + + \begin{align*} + x(t) &\sim \mathcal{N}(m t + b, \sigma^2) \\ + w &\sim \mathcal{N}((m_0, b_0), \sigma^2 \Lambda_0^{-1}) \\ + \sigma^2 &\sim \mathrm{InvGamma}(\alpha, \beta) + \end{align*} + + Consider new data :math:`(t_1, x_1), \ldots, (t_n, x_n)`. Let :math:`T \in \mathbb{R}^{n \times 2}` be + the matrix obtained by stacking the row vector of times with an all-ones row vector. Let + :math:`w = (m, b) \in \mathbb{R}^{2}` be the full weight vector. Let :math:`x \in \mathbb{R}^{n}` denote + all observed values. Then we have the update rule + + .. math:: + + \begin{align*} + w_{OLS} &= (T^T T)^{-1} T^T x \\ + \Lambda_n &= \Lambda_0 + T^T T \\ + w_n &= (\Lambda_0 + T^T T)^{-1} (\Lambda_0 w_0 + T^T T w_{OLS}) \\ + \alpha_n &= \alpha_0 + n / 2 \\ + \beta_n &= \beta_0 + \frac{1}{2}(x^T x + w_0^T \Lambda_0 w_0 - w_n^T \Lambda_n w_n) + \end{align*} + """ + + def __init__(self, sample=None): + self.w_0 = np.zeros(2) + self.Lambda_0 = np.array([[0, 0], [0, 1]]) + _epsilon + self.alpha = 1 + _epsilon + self.beta = _epsilon + super().__init__(sample=sample) + + def update(self, x): + t, x = self.process_time_series(x) + t_full = np.stack((t, np.ones_like(t)), axis=-1) # [t, 2] + + # Initial prediction + self.w_0 = self.w_0.reshape((2, 1)) + pred0 = self.w_0.T @ self.Lambda_0 @ self.w_0 + + # Update predictive coefficients & uncertainty + design = t_full.T @ t_full + ols = pinv(t_full) @ x + self.w_0 = pinvh(self.Lambda_0 + design) @ (self.Lambda_0 @ self.w_0 + design @ ols) + self.Lambda_0 = self.Lambda_0 + design + + # Updated prediction + pred = self.w_0.T @ self.Lambda_0 @ self.w_0 + self.w_0 = self.w_0.flatten() + + # Update accumulators + self.n = self.n + len(x) + self.alpha = self.alpha + len(x) / 2 + self.beta = self.beta + (x.T @ x + pred0 - pred).item() / 2 + + def posterior_explicit(self, x, return_rv=False, log=True, return_updated=False): + r""" + Let :math:`\Lambda_n, \alpha_n, \beta_n` be the posterior values obtained by updating + the model on data :math:`(t_1, x_1), \ldots, (t_n, x_n)`. The predictive posterior has PDF + + .. math:: + + \begin{align*} + P((t, x)) &= \frac{1}{(2 \pi)^{-n/2}} \sqrt{\frac{\det \Lambda_0}{\det \Lambda_n}} + \frac{\beta_0^{\alpha_0}}{\beta_n^{\alpha_n}}\frac{\Gamma(\alpha_n)}{\Gamma(\alpha_0)} + \end{align*} + """ + if x is None or return_rv: + raise ValueError( + "Bayesian linear regression doesn't have a scipy.stats random variable posterior. " + "Please specify a non-``None`` value of ``x`` and set ``return_rv = False``." + ) + updated = copy.deepcopy(self) + updated.update(x) + t, x_np = self.process_time_series(x) + a = -len(x_np) / 2 * np.log(2 * np.pi) + b = (np.linalg.slogdet(self.Lambda_0)[1] - np.linalg.slogdet(updated.Lambda_0)[1]) / 2 + c = self.alpha * np.log(self.beta) - updated.alpha * np.log(updated.beta) + d = gammaln(updated.alpha) - gammaln(self.alpha) + ret = (a + b + c + d if log else np.exp(a + b + c + d)).reshape(1) + return (ret, updated) if return_updated else ret + + def posterior(self, x, return_rv=False, log=True, return_updated=False): + r""" + Naive computation of the posterior using Bayes Rule, i.e. + + .. math:: + + \hat{\sigma}^2 &= \mathbb{E}[\sigma^2] \\ + \hat{w} &= \mathbb{E}[w \mid \sigma^2 = \hat{\sigma}^2] \\ + p(x \mid t) &= \frac{ + p(w = \hat{w}, \sigma^2 = \hat{\sigma}^2) + p(x \mid t, w = \hat{w}, \sigma^2 = \hat{\sigma}^2)}{ + p(w = \hat{w}, \sigma^2 = \hat{\sigma}^2 \mid x, t)} + + """ + if x is None or return_rv: + raise ValueError( + "Bayesian linear regression doesn't have a scipy.stats random variable posterior. " + "Please specify a non-``None`` value of ``x`` and set ``return_rv = False``." + ) + t, x_np = self.process_time_series(x) + + # Get priors & MAP estimates for sigma^2 and w; get the MAP estimate for x(t) + prior_sigma2 = invgamma(a=self.alpha, scale=self.beta) + sigma2_hat = prior_sigma2.mean() + prior_w = mvnorm(self.w_0, sigma2_hat * pinvh(self.Lambda_0), allow_singular=True) + w_hat = self.w_0 + xhat = np.stack((t, np.ones_like(t)), axis=-1) @ w_hat + + # Get posteriors + updated = copy.deepcopy(self) + updated.update(x) + post_sigma2 = invgamma(a=updated.alpha, scale=updated.beta) + post_w = mvnorm(updated.w_0, sigma2_hat * pinvh(updated.Lambda_0), allow_singular=True) + + # Apply Bayes' rule + evidence = norm(xhat, np.sqrt(sigma2_hat)).logpdf(x_np.flatten()).reshape(len(x_np)) + prior = prior_sigma2.logpdf(sigma2_hat) + prior_w.logpdf(w_hat) + post = post_sigma2.logpdf(sigma2_hat) + post_w.logpdf(w_hat) + logp = evidence + prior.item() - post.item() + ret = logp if log else np.exp(logp) + return (ret, updated) if return_updated else ret + + def forecast(self, time_stamps) -> Tuple[TimeSeries, TimeSeries]: + name = self.names[0] + t = to_timestamp(time_stamps) + if self.t0 is None: + self.t0 = t[0] + if self.dt is None: + self.dt = t[-1] - t[0] if len(t) > 1 else 1 + t = (t - self.t0) / self.dt + t_full = np.stack((t, np.ones_like(t)), axis=-1) # [t, 2] + sigma2_hat = invgamma(a=self.alpha, scale=self.beta).mean() + w_cov = sigma2_hat * pinvh(self.Lambda_0) # cov of [m, b] + + # x = m t + b = [t, 1] @ [m, b] + xhat = t_full @ self.w_0 + xhat = UnivariateTimeSeries(time_stamps=time_stamps, values=xhat, name=name) + + # var(x) = [[t, 1]] @ cov([m, b]) @ [[t], [1]] + # diagonal of t_full @ w_cov @ t_full.T, since (A @ B)_ii = sum_j A_ij B_ji + sigma2 = np.sum((t_full @ w_cov) * t_full, axis=-1) + + # Add sigma2_hat from the error model of the observations x, and square-root to get sigma + sigma = np.sqrt(sigma2 + sigma2_hat) + sigma = UnivariateTimeSeries(time_stamps=time_stamps, values=sigma, name=f"{name}_stderr") + + return xhat.to_ts(), sigma.to_ts() + + +class BayesianMVLinReg(ConjPrior): + r""" + Bayesian multivariate linear regression conjugate prior, which models a multivariate input as a function of time. + Following `Wikipedia `__ and + `Geisser (1965) `__, we assume the model + + .. math:: + + \begin{align*} + X(t) &\sim \mathcal{N}_{d}(m t + b, \Sigma) \\ + (m, b) &\sim \mathcal{N}_{2d}((m_0, b_0), \Sigma \otimes \Lambda_0^{-1}) \\ + \Sigma &\sim \mathrm{InvWishart}_{\nu}(V_0) \\ + \end{align*} + + where :math:`(m, b)` is the concatenation of the vectors :math:`m` and :math:`b`, + :math:`\Lambda_0 \in \mathbb{R}^{2 \times 2}`, and :math:`\otimes` is the Kronecker product. + Consider new data :math:`(t_1, x_1), \ldots, (t_n, x_n)`. Let :math:`T \in \mathbb{R}^{n \times 2}` be + the matrix obtained by stacking the row vector of times with an all-ones row vector. Let + :math:`W = [m, b]^T \in \mathbb{R}^{2 \times d}` be the full weight matrix. Let + :math:`X \in \mathbb{R}^{n \times d}` be the matrix of observed :math:`x` values. Then we have the update rule + + .. math:: + \begin{align*} + \nu_n &= \nu_0 + n \\ + W_n &= (\Lambda_0 + T^T T)^{-1}(\Lambda_0 W_0 + T^T X) \\ + V_n &= V_0 + (X - TW_n)^T (X - TW_n) + (W_n - W_0)^T \Lambda_0 (W_n - W_0) \\ + \Lambda_n &= \Lambda_0 + T^T T \\ + \end{align*} + + """ + + def __init__(self, sample=None): + self.nu = 0 + self.w_0 = None + self.Lambda_0 = np.array([[0, 0], [0, 1]]) + _epsilon + self.V_0 = None + super().__init__(sample=sample) + + def process_time_series(self, x): + t, x = super().process_time_series(x) + n, d = x.shape + if self.nu == 0: + self.nu = 2 * (d + _epsilon) + if self.V_0 is None: + self.V_0 = 2 * np.eye(d) * _epsilon + if self.w_0 is None: + self.w_0 = np.zeros((2, d)) + return t, x + + def update(self, x): + t, x = self.process_time_series(x) + n, d = x.shape + + t_full = np.stack((t, np.ones_like(t)), axis=-1) # [n, 2] + design = t_full.T @ t_full + new_Lambda = design + self.Lambda_0 + new_w = pinvh(new_Lambda) @ (t_full.T @ x + self.Lambda_0 @ self.w_0) + + self.n = self.n + len(x) + self.nu = self.nu + len(x) + residual = x - t_full @ new_w # [n, d] + delta_w = new_w - self.w_0 # [2, d] + residual_squared = residual.T @ residual + delta_w_quad_form = (delta_w.T @ self.Lambda_0) @ delta_w + self.V_0 = self.V_0 + residual_squared + delta_w_quad_form + self.w_0 = new_w + self.Lambda_0 = new_Lambda + + def posterior_explicit(self, x, return_rv=False, log=True, return_updated=False): + r""" + Let :math:`\Lambda_n, \nu_n, V_n` be the posterior values obtained by updating + the model on data :math:`(t_1, x_1), \ldots, (t_n, x_n)`. The predictive posterior has PDF + + .. math:: + + \begin{align*} + P((t, x)) &= \frac{1}{(2 \pi)^{-nd/2}} \sqrt{\frac{\det \Lambda_0}{\det \Lambda_n}} + \frac{\det(V_0/2)^{\nu_0/2}}{\det(V_n/2)^{\nu_n/2}}\frac{\Gamma_d(\nu_n/2)}{\Gamma_d(\nu_0 / 2)} + \end{align*} + """ + if x is None or return_rv: + raise ValueError( + "Bayesian linear regression doesn't have a scipy.stats random variable posterior. " + "Please specify a non-``None`` value of ``x`` and set ``return_rv = False``." + ) + updated = copy.deepcopy(self) + updated.update(x) + t, x_np = self.process_time_series(x) + + # Compute log pseudo-determinant of V_0 / 2 (for both current and updated values) + logdet_V = np.linalg.slogdet(self.V_0 / 2)[1] + logdet_V = _log_pdet(self.V_0 / 2) if np.isinf(logdet_V) else logdet_V + logdet_V_new = np.linalg.slogdet(updated.V_0 / 2)[1] + logdet_V_new = _log_pdet(updated.V_0 / 2) if np.isinf(logdet_V_new) else logdet_V_new + + a = -len(x_np) / 2 * self.dim * np.log(2 * np.pi) + b = (np.linalg.slogdet(self.Lambda_0)[1] - np.linalg.slogdet(updated.Lambda_0)[1]) / 2 + c = (self.nu * logdet_V - updated.nu * logdet_V_new) / 2 + d = multigammaln(updated.nu / 2, self.dim) - multigammaln(self.nu / 2, self.dim) + ret = (a + b + c + d if log else np.exp(a + b + c + d)).reshape(1) + return (ret, updated) if return_updated else ret + + def posterior(self, x, return_rv=False, log=True, return_updated=False): + r""" + Naive computation of the posterior using Bayes Rule, i.e. + + .. math:: + + \hat{\Sigma} &= \mathbb{E}[\Sigma] \\ + \hat{W} &= \mathbb{E}[W \mid \Sigma = \hat{\Sigma}] \\ + p(X \mid t) &= \frac{ + p(W = \hat{W}, \Sigma = \hat{\Sigma}) + p(X \mid t, W = \hat{W}, \Sigma = \hat{\Sigma})}{ + p(W = \hat{W}, \Sigma = \hat{\Sigma} \mid x, t)} + + """ + if x is None or return_rv: + raise ValueError( + "Bayesian linear regression doesn't have a scipy.stats random variable posterior. " + "Please specify a non-``None`` value of ``x`` and set ``return_rv = False``." + ) + t, x_np = self.process_time_series(x) + + # Get priors & MAP estimates for Sigma and W; get the MAP estimate for x(t) + prior_Sigma = invwishart(df=self.nu, scale=self.V_0) + Sigma_hat = prior_Sigma.mean() + w_hat = self.w_0.flatten() + prior_w = mvnorm(w_hat, np.kron(Sigma_hat, pinvh(self.Lambda_0)), allow_singular=True) + xhat = np.stack((t, np.ones_like(t)), axis=-1) @ w_hat.reshape(2, -1) + + # Get posteriors + updated = copy.deepcopy(self) + updated.update(x) + post_Sigma = invwishart(df=updated.nu, scale=updated.V_0) + post_w = mvnorm(updated.w_0.flatten(), np.kron(Sigma_hat, pinvh(updated.Lambda_0)), allow_singular=True) + + # Apply Bayes' rule + evidence = mvnorm(cov=Sigma_hat, allow_singular=True).logpdf(x_np - xhat).reshape(len(x_np)) + prior = prior_Sigma.logpdf(Sigma_hat) + prior_w.logpdf(w_hat) + post = post_Sigma.logpdf(Sigma_hat) + post_w.logpdf(w_hat) + logp = evidence + prior - post + + ret = logp if log else np.exp(logp) + return (ret, updated) if return_updated else ret + + def forecast(self, time_stamps) -> Tuple[TimeSeries, TimeSeries]: + names = self.names + t = to_timestamp(time_stamps) + if self.t0 is None: + self.t0 = t[0] + if self.dt is None: + self.dt = t[-1] - t[0] if len(t) > 1 else 1 + t = (t - self.t0) / self.dt + t_full = np.stack((t, np.ones_like(t)), axis=-1) # [t, 2] + + Sigma_hat = invwishart(df=self.nu, scale=self.V_0).mean().reshape((self.dim, self.dim)) + + # x = m t + b = [t, 1] @ [m, b] + xhat = t_full @ self.w_0 + + # W ~ MatrixNormal(W_0, \Lambda^{-1}, \Sigma) + # W is 2xd, \Lambda is 2x2, \Sigma is dxd + # Let V be a tx2 matrix representing time. + # Then, X = V @ W --> X is t x d + # V @ W ~ MatrixNormal(V @ W, V @ \Lambda^{-1} @ V^T, \Sigma) + # vec(V @ W) ~ N(vec(V @ W), \Sigma \otimes (V @ \Lambda^{-1} @ V^T)) + # + # Note: (V @ \Lambda^{-1} @ V^T) ha shape t x t, but we only want + # its diagonal. This is because we only care about the diagonal of + # np.kron(Sigma_hat, (V @ \Lambda^{-1} @ V^T)), which is just the outer + # product of the two matrices' diagonals. + # + # Therefore, we first compute the diagonal of (V @ \Lambda^{-1} @ V^T) + # using the trick (A @ B)_ii = sum_j A_ij B_ji: + x_Lambda_diag = np.sum((t_full @ pinvh(self.Lambda_0)) * t_full, axis=-1) + + # Now we can compute the full variances of the prediction + sigma2 = np.outer(Sigma_hat.diagonal(), x_Lambda_diag).reshape(xhat.shape) + sigma = np.sqrt(sigma2 + Sigma_hat.diagonal()) + + # Create data frames & return the appropriate time series + t = to_pd_datetime(time_stamps) + xhat_df = pd.DataFrame(xhat, index=t, columns=names) + sigma_df = pd.DataFrame(sigma, index=t, columns=[f"{n}_stderr" for n in names]) + return TimeSeries.from_pd(xhat_df), TimeSeries.from_pd(sigma_df) diff --git a/merlion/utils/resample.py b/merlion/utils/resample.py index 9bccf1855..f159c2186 100644 --- a/merlion/utils/resample.py +++ b/merlion/utils/resample.py @@ -67,6 +67,8 @@ def to_pd_datetime(timestamp): return pd.to_datetime(int(timestamp * 1000), unit="ms") elif isinstance(timestamp, Iterable) and all(isinstance(t, (int, float)) for t in timestamp): timestamp = pd.to_datetime(np.asarray(timestamp).astype(float) * 1000, unit="ms") + elif isinstance(timestamp, np.ndarray) and timestamp.dtype in [int, np.float32, np.float64]: + timestamp = pd.to_datetime(np.asarray(timestamp).astype(float) * 1000, unit="ms") return pd.to_datetime(timestamp) @@ -76,6 +78,8 @@ def to_timestamp(t): """ if isinstance(t, (int, float)) or isinstance(t, Iterable) and all(isinstance(ti, (int, float)) for ti in t): return t + elif isinstance(t, np.ndarray) and t.dtype in [int, np.float32, np.float64]: + return t return np.asarray(t).astype("datetime64[ms]").astype(float) / 1000 diff --git a/merlion/utils/time_series.py b/merlion/utils/time_series.py index 9be6df244..a780704cf 100644 --- a/merlion/utils/time_series.py +++ b/merlion/utils/time_series.py @@ -715,7 +715,7 @@ def to_pd(self) -> pd.DataFrame: return df @classmethod - def from_pd(cls, df: Union[pd.Series, pd.DataFrame], check_times=True, freq="1h"): + def from_pd(cls, df: Union[pd.Series, pd.DataFrame, np.ndarray], check_times=True, freq="1h"): """ :param df: A pandas DataFrame with a DatetimeIndex. Each column corresponds to a different variable of the time series, and the @@ -725,12 +725,21 @@ def from_pd(cls, df: Union[pd.Series, pd.DataFrame], check_times=True, freq="1h" time series. :param check_times: whether to check that all times in the index are unique (up to the millisecond) and sorted. + :param freq: if ``df`` is not indexed by time, this is the frequency + at which we will assume it is sampled. :rtype: TimeSeries :return: the `TimeSeries` object corresponding to ``df``. """ if isinstance(df, pd.Series): return cls({df.name: UnivariateTimeSeries.from_pd(df[~df.isna()])}) + elif isinstance(df, np.ndarray): + arr = df.reshape(len(df), -1).T + ret = cls([UnivariateTimeSeries(time_stamps=None, values=v, freq=freq) for v in arr], check_aligned=False) + ret._is_aligned = True + return ret + elif not isinstance(df, pd.DataFrame): + df = pd.DataFrame(df) # Time series is not aligned iff there are missing values aligned = df.shape[1] == 1 or not df.isna().any().any() @@ -835,6 +844,14 @@ def align( :rtype: TimeSeries :return: The resampled multivariate time series. """ + if self.is_empty(): + if reference is not None or granularity is not None: + logger.warning( + "Attempting to align an empty time series to a set of reference time stamps or a " + "fixed granularity. Doing nothing." + ) + return self.__class__.from_pd(self.to_pd()) + if reference is not None or alignment_policy is AlignPolicy.FixedReference: if reference is None: raise RuntimeError("`reference` is required when using `alignment_policy` FixedReference.") diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..81d76e520 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +log_format = %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s +log_date_format = %Y-%m-%d %H:%M:%S +log_cli=true +log_cli_level=INFO diff --git a/setup.py b/setup.py index 3c4fa8287..b8736811b 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def read_file(fname): setup( name="salesforce-merlion", - version="1.0.1", + version="1.0.2", author=", ".join(read_file("AUTHORS.md").split("\n")), author_email="abhatnagar@salesforce.com", description="Merlion: A Machine Learning Framework for Time Series Intelligence", @@ -41,10 +41,12 @@ def read_file(fname): "JPype1==1.0.2", "matplotlib", "numpy!=1.18.*", # 1.18 causes a bug with scipy + "packaging", "pandas>=1.1.0", # >=1.1.0 for origin kwarg to df.resample() 'pystan<3.0"', # >=3.0 fails with prophet "scikit-learn>=0.22", # >=0.22 for changes to isolation forest algorithm - "scipy>=1.5.0", + "scipy>=1.6.0; python_version >= '3.7'", # 1.6.0 adds multivariate_t density to scipy.stats + "scipy>=1.5.0; python_version < '3.7'", # however, scipy 1.6.0 requires python 3.7+ "statsmodels>=0.12.2", "torch>=1.1.0", "lightgbm", # if running at MacOS, need OpenMP: "brew install libomp" diff --git a/tests/change_point/__init__.py b/tests/change_point/__init__.py new file mode 100644 index 000000000..f34b13dbf --- /dev/null +++ b/tests/change_point/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2021 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# diff --git a/tests/change_point/test_bocpd.py b/tests/change_point/test_bocpd.py new file mode 100644 index 000000000..7f2f81fdf --- /dev/null +++ b/tests/change_point/test_bocpd.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2021 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import logging +import os +from os.path import abspath, dirname, join +import sys +import unittest + +import numpy as np +import pandas as pd + +from merlion.models.anomaly.change_point.bocpd import BOCPD, BOCPDConfig, ChangeKind +from merlion.utils.time_series import TimeSeries + +rootdir = dirname(dirname(dirname(abspath(__file__)))) +logger = logging.getLogger(__name__) + + +class TestBOCPD(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + np.random.seed(12345) + + def standard_check(self, bocpd: BOCPD, test_data: TimeSeries, n: int, name: str): + # Evaluate trained BOCPD model on the test data & make sure it has perfect precision & recall + scores = bocpd.get_anomaly_score(test_data) + alarms = bocpd.get_anomaly_label(test_data).to_pd().iloc[:, 0].abs() + n_alarms = (alarms != 0).sum() + logger.info(f"# Alarms fired: {n_alarms}") + logger.info(f"Alarms fired at:\n{alarms[alarms != 0]}") + self.assertNotEqual(alarms.iloc[0], 0) + self.assertNotEqual(alarms.iloc[n], 0) + self.assertNotEqual(alarms.iloc[2 * n], 0) + self.assertEqual(n_alarms, 3) + + # Make sure we get the same results after saving & loading + bocpd.save(join(rootdir, "tmp", "bocpd", name)) + loaded = BOCPD.load(join(rootdir, "tmp", "bocpd", name)) + loaded_scores = loaded.get_anomaly_score(test_data) + self.assertSequenceEqual(list(scores), list(loaded_scores)) + + def test_level_shift(self): + print() + logger.info("test_level_shift\n" + "-" * 80 + "\n") + + # Create a multivariate time series with random level shifts & split it into train & test + n, d = 300, 5 + mus = [500, -90, 5, -50, 3] + Sigmas_basis = [np.random.randn(d, d) + np.eye(d) for _ in mus] + vals = [np.ones((n, d)) * mu + np.random.randn(n, d) @ U for i, (mu, U) in enumerate(zip(mus, Sigmas_basis))] + vals = np.concatenate([np.ones((1, d)) * mus[0], *vals], axis=0) + ts = TimeSeries.from_pd(vals, freq="1min") + train, test = ts.bisect(ts.time_stamps[2 * n]) + + # Initialize & train BOCPD with automatic change kind detection. + # Make sure we choose level shift & correctly detect the level shift in the training data + # Also make sure that we can make predictions on training data. + bocpd = BOCPD(BOCPDConfig(change_kind=ChangeKind.Auto, cp_prior=1e-2, lag=1, min_likelihood=1e-12)) + train_scores = bocpd.train(train + test[:n]).to_pd().iloc[:, 0].abs() + self.assertEqual(bocpd.change_kind, ChangeKind.LevelShift) + self.assertGreater(train_scores.iloc[n - 1], 2) + + self.standard_check(bocpd=bocpd, test_data=test, n=n, name="level_shift") + + def test_trend_change(self): + print() + logger.info("test_trend_change\n" + "-" * 80 + "\n") + + # Create a multivariate time series with some trend changes and split it into train & test + n, d = 300, 4 + ms = np.array([[10, -8, 12, 50], [-10, 3, 0, 9], [-3, 2, -10, 0], [-2, -3, 5, -3], [6, -1, 1, 15]]) + bs = np.array([[0, 5, 2, 3], [0, 0, 0, 0], [10, -2, -9, 8], [-3, 66, 2, 0], [85, -9, 21, 3]]) + sigma_basis = [U / np.trace(U) + np.eye(d) for U in np.random.randn(len(ms), d, d)] + t = np.arange(n * len(ms)).reshape(-1, 1) + x = np.concatenate( + [ + m * t[i * n : (i + 1) * n] + b + np.random.randn(n, d) @ U + for i, (m, b, U) in enumerate(zip(ms, bs, sigma_basis)) + ] + ) + x = np.concatenate((bs[0].reshape(1, -1), x)) + ts = TimeSeries.from_pd(x, freq="1min") + train, test = ts.bisect(ts.time_stamps[2 * n]) + + # Initialize & train BOCPD with automatic change kind detection. + # Make sure we choose trend change & correctly detect the level shift in the training data + bocpd = BOCPD(BOCPDConfig(change_kind=ChangeKind.Auto, cp_prior=1e-2, lag=1, min_likelihood=1e-12)) + train_scores = bocpd.train(train).to_pd().iloc[:, 0].abs() + self.assertEqual(bocpd.change_kind, ChangeKind.TrendChange) + self.assertGreater(train_scores.iloc[n - 1], 2) + + # Evaluate trained BOCPD model on the test data & make sure it has perfect precision & recall + self.standard_check(bocpd=bocpd, test_data=test, n=n, name="trend_change") + + def test_vis(self): + print() + logger.info("test_vis\n" + "-" * 80 + "\n") + for fname, change in [("horizontal_level_anomaly", "LevelShift"), ("seasonal_trend_anomaly", "TrendChange")]: + df = pd.read_csv(join(rootdir, "data", "synthetic_anomaly", f"{fname}.csv")) + df.index = pd.to_datetime(df["timestamp"], unit="s") + ts = TimeSeries.from_pd(df.iloc[:, 1]) + model = BOCPD(BOCPDConfig(change_kind=change, cp_prior=1e-2, min_likelihood=1e-10)) + train, test = ts[:500], ts[500:5000] + model.train(train) + fig, ax = model.plot_anomaly( + time_series=test, + time_series_prev=train, + plot_time_series_prev=True, + plot_forecast=True, + plot_forecast_uncertainty=True, + ) + os.makedirs(join(rootdir, "tmp", "bocpd"), exist_ok=True) + fig.savefig(join(rootdir, "tmp", "bocpd", f"{change}.png")) + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.INFO + ) + unittest.main() diff --git a/tests/change_point/test_conj_prior.py b/tests/change_point/test_conj_prior.py new file mode 100644 index 000000000..923ff5f66 --- /dev/null +++ b/tests/change_point/test_conj_prior.py @@ -0,0 +1,186 @@ +# +# Copyright (c) 2021 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import logging +from packaging import version +import sys +import unittest + +import numpy as np +import scipy + +from merlion.utils.conj_priors import BetaBernoulli, NormInvGamma, MVNormInvWishart, BayesianLinReg, BayesianMVLinReg +from merlion.utils.time_series import TimeSeries, UnivariateTimeSeries + +logger = logging.getLogger(__name__) + + +class TestConjugatePriors(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + np.random.seed(12345) + + def test_beta_bernoulli(self): + print() + logger.info("test_beta_bernoulli\n" + "-" * 80 + "\n") + for theta in [0.21, 0.5, 0.93]: + data = np.random.rand(1000) < theta + theta_hat = (1 + sum(data)) / (len(data) + 2) + dist_np = BetaBernoulli() + dist_np.update(data) + self.assertEqual(dist_np.alpha, 1 + sum(data)) + self.assertEqual(dist_np.beta, 1 + sum(1 - data)) + pred = dist_np.posterior([0, 1], log=False) + expected = np.asarray([1 - theta_hat, theta_hat]) + self.assertAlmostEqual(np.max(np.abs(pred - expected)), 0, places=6) + + ts = TimeSeries.from_pd(data, freq="MS") + dist_ts = BetaBernoulli(ts[:30]) + dist_ts.update(ts[30:]) + self.assertEqual(dist_ts.alpha, 1 + sum(data)) + self.assertEqual(dist_ts.beta, 1 + sum(1 - data)) + pred = dist_ts.posterior(TimeSeries.from_pd([0, 1]), log=False) + self.assertAlmostEqual(np.max(np.abs(pred - expected)), 0, places=6) + + def test_normal(self): + print() + logger.info("test_normal\n" + "-" * 80 + "\n") + mu, sigma = 5, 2 + for n in [10, 100, 1000, 100000]: + # Generate data + data = np.random.randn(n) * sigma + mu + + # Univariate model + dist_uni = NormInvGamma() + pred_uni, dist_uni = dist_uni.posterior(data[: n // 2], return_updated=True) + + # Multivariate model + dist_multi = MVNormInvWishart() + pred_multi, dist_multi = dist_multi.posterior(data[: n // 2], return_updated=True) + + # Make sure univariate & multivariate posteriors agree + self.assertAlmostEqual(np.max(np.abs(pred_uni - pred_multi)), 0, places=6) + + # Make sure univariate & multivariate posteriors agree after additional udpate + pred_uni = dist_uni.posterior(data[n // 2 :], log=False) + pred_multi = dist_multi.posterior(data[n // 2 :], log=False) + self.assertAlmostEqual(np.max(np.abs(pred_uni - pred_multi)), 0, places=6) + + # Make sure we converge to the right model after enough data + if n > 5000: + t = [0, 1, 2, 3, 4, 5] + xhat_u, sigma_u = dist_uni.forecast(t) + self.assertAlmostEqual(np.max([np.abs(np.array(x) - mu) for t, x in xhat_u]), 0, delta=0.05) + self.assertAlmostEqual(np.max([np.abs(np.array(s) - sigma) for t, s in sigma_u]), 0, delta=0.05) + xhat_m, sigma_m = dist_multi.forecast(t) + self.assertAlmostEqual(np.max([np.abs(np.array(x) - mu) for t, x in xhat_m]), 0, delta=0.05) + self.assertAlmostEqual(np.max([np.abs(np.array(s) - sigma) for t, s in sigma_m]), 0, delta=0.05) + + def test_mv_normal(self): + print() + logger.info("test_mv_normal\n" + "-" * 80 + "\n") + n, d = 300000, 20 + mu = np.random.randn(d) + u = np.random.randn(d, d) + cov = u.T @ u + data = TimeSeries.from_pd(np.random.randn(n, d) @ u + mu, freq="1h") + dist = MVNormInvWishart(data[:5]) + dist.update(data[5:-5]) + dist.posterior(data[-5:]) # make sure we can compute a posterior + + # require low L1 distance between expected mean/cov and true mean/cov + if version.parse(scipy.__version__) >= version.parse("1.6.0"): + self.assertAlmostEqual(np.abs(mu - dist.mu_posterior(None).loc).mean(), 0, delta=0.05) + self.assertAlmostEqual(np.abs(cov - dist.Sigma_posterior(None).mean()).mean(), 0, delta=0.05) + + # Make sure the forecast is also accurate, i.e. the stderr-normalized MSE is close to 1 + xhat, stderr = dist.forecast(data.time_stamps[-50000:]) + zscores = (xhat.to_pd() - data[-50000:].to_pd()) / stderr.to_pd().values + self.assertAlmostEqual(zscores.pow(2).mean().max(), 1, delta=0.02) + + def test_bayesian_linreg(self): + print() + logger.info("test_bayesian_linreg\n" + "-" * 80 + "\n") + n, sigma = 100000, 1.5 + m, b = np.random.randn(2) + t = np.linspace(0, 2, 2 * n + 1) + x = UnivariateTimeSeries.from_pd(m * t + b + np.random.randn(len(t)) * sigma, name="test").to_ts() + x_train = x[: n + 1] + x_test = x[n + 1 :] + + # Make sure univariate & multivariate agree when initialized from nothing + uni = BayesianLinReg() + uni_posterior, uni = uni.posterior(x_train, return_updated=True) + multi = BayesianMVLinReg() + multi_posterior, multi = multi.posterior(x_train, return_updated=True) + self.assertAlmostEqual(np.abs(uni_posterior - multi_posterior).max(), 0, places=6) + + # Get forecasts for the test split. Make sure that the stderr-normalized MSE is close to 1. + xhat_u, sigma_u = uni.forecast(x_test.time_stamps) + zscore_u = (xhat_u.to_pd() - x_test.to_pd()) / sigma_u.to_pd().values + self.assertAlmostEqual(zscore_u.pow(2).mean().item(), 1, delta=0.01) + + # Validate the multivariate forecasting capability as well. + xhat_m, sigma_m = multi.forecast(x_test.time_stamps) + zscore_m = (xhat_m.to_pd() - x_test.to_pd()) / sigma_m.to_pd().values + self.assertAlmostEqual(zscore_m.pow(2).mean().item(), 1, delta=0.01) + + # Make sure univariate & multivariate agree after an additional update + uni_posterior = uni.posterior(x_test) + multi_posterior = multi.posterior(x_test) + self.assertAlmostEqual(np.abs(uni_posterior - multi_posterior).max(), 0, places=6) + + # Make sure explicit version agrees with naive version (univariate) + naive_uni = np.concatenate([uni.posterior(x_test[i : i + 1]) for i in range(100)]) + explicit_uni = np.concatenate([uni.posterior_explicit(x_test[i : i + 1]) for i in range(100)]) + self.assertAlmostEqual(np.abs(naive_uni - explicit_uni).max(), 0, places=6) + + # Make sure explicit version agrees with naive version (multivariate) + naive_multi = np.concatenate([multi.posterior(x_test[i : i + 1]) for i in range(100)]) + explicit_multi = np.concatenate([multi.posterior_explicit(x_test[i : i + 1]) for i in range(100)]) + self.assertAlmostEqual(np.abs(naive_multi - explicit_multi).max(), 0, places=6) + + # Make sure we're accurately estimating the slope & intercept + mhat, bhat = uni.w_0 + self.assertAlmostEqual(mhat, m, delta=0.02) + self.assertAlmostEqual(bhat, b, delta=0.01) + + def test_mv_bayesian_linreg(self): + print() + logger.info("test_mv_bayesian_linreg\n" + "-" * 80 + "\n") + n, sigma = 200000, 2 + for d in [2, 3, 4, 5, 10, 20]: + m, b = np.random.randn(2, d) + t = np.linspace(0, 2, 2 * n + 1) + x = m.reshape(1, d) * t.reshape(-1, 1) + b.reshape(1, d) + np.random.randn(len(t), d) * sigma + x_train = x[: n + 1] + x_test = x[n + 1 :] + + dist = BayesianMVLinReg() + dist.update(x_train) + post = dist.posterior(x_test) # make sure we can compute a multivariate posterior PDF + self.assertEqual(post.shape, (n,)) + + naive = np.concatenate([dist.posterior(x_test[i : i + 1]) for i in range(100)]) + explicit = np.concatenate([dist.posterior_explicit(x_test[i : i + 1]) for i in range(100)]) + self.assertAlmostEqual(np.abs(naive - explicit).max(), 0, delta=0.01) + + # Make sure we're accurately estimating the slope & intercept after all this data + mhat, bhat = dist.w_0 + self.assertAlmostEqual(np.abs(mhat - m).max(), 0, delta=0.05) + self.assertAlmostEqual(np.abs(bhat - b).max(), 0, delta=0.05) + + # Make sure the forecast is also accurate, i.e. the stderr-normalized MSE is close to 1 + xhat, stderr = dist.forecast(np.arange(n + 1, 2 * n + 1)) + zscores = (xhat.to_pd() - x_test) / stderr.to_pd().values + self.assertAlmostEqual(zscores.pow(2).mean().max(), 1, delta=0.02) + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.INFO + ) + unittest.main() diff --git a/tests/forecast/test_MoE_forecast_ensemble.py b/tests/forecast/test_MoE_forecast_ensemble.py index 53800fb13..4af8867f7 100644 --- a/tests/forecast/test_MoE_forecast_ensemble.py +++ b/tests/forecast/test_MoE_forecast_ensemble.py @@ -9,7 +9,9 @@ import sys import unittest +import numpy as np import pandas as pd +import torch from merlion.models.ensemble.MoE_forecast import * from merlion.models.forecast.arima import Arima @@ -840,6 +842,8 @@ def test_full(self): print("-" * 80) logger.info("test_full\n" + "-" * 80 + "\n") logger.info("Training model...") + np.random.seed(42) + torch.random.manual_seed(42) self.ensemble.train(self.train_data, train_config=self.train_config_ensemble) # extract a chunk of test data for unit tests