Skip to content

Commit

Permalink
check insufficient predict data
Browse files Browse the repository at this point in the history
  • Loading branch information
juaduan committed Jan 26, 2024
1 parent 2b2f88e commit 42566bf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -402,5 +402,6 @@ mlartifacts/
model/
venv/
dist/
temp/

.idea/
11 changes: 7 additions & 4 deletions anomaly-detector/anomaly_detector/multivariate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self):
self.variables: Optional[List[str]] = None

def fit(self, data: pd.DataFrame, params: Dict[str, Any] = None) -> None:
variables, values, config = self._verify_data_and_params(data, params)
variables, values, config = self._process_data_and_params(data, params)
self.config = config
self.variables = variables
_window = self.config.threshold_window + self.config.input_size
Expand Down Expand Up @@ -211,14 +211,17 @@ def evaluate_epoch(self, model, criterion, dataloader):
def predict(
self, context, data: pd.DataFrame, params: Optional[Dict[str, Any]] = None
):
variables, values, _ = self._verify_data_and_params(data, params)
variables, values, _ = self._process_data_and_params(data, params)

if self.model is None:
try:
self.load_checkpoint(self.model_path)
except Exception as ex:
raise ValueError(f"Cannot load model. Please train model. {repr(ex)}")

if len(values) < self.config.threshold_window + self.config.input_size:
raise ValueError(
f"Not enough data. Minimum size is {self.config.threshold_window + self.config.input_size}"
)
hard_th_upper = max(MultiADConstants.ANOMALY_UPPER_THRESHOLD, self.threshold)
hard_th_lower = min(MultiADConstants.ANOMALY_LOWER_THRESHOLD, self.threshold)
torch.manual_seed(self.config.seed)
Expand Down Expand Up @@ -382,7 +385,7 @@ def compute_thresholds(self, data) -> (float, float, float):
train_score_max = np.max(train_scores)
return threshold, train_score_max, train_score_min

def _verify_data_and_params(self, data: pd.DataFrame, params: Dict[str, Any] = None):
def _process_data_and_params(self, data: pd.DataFrame, params: Dict[str, Any] = None):
if params is None:
params = {}

Expand Down

0 comments on commit 42566bf

Please sign in to comment.