Skip to content

Commit

Permalink
fix model load
Browse files Browse the repository at this point in the history
  • Loading branch information
Juanyong Duan committed Jan 30, 2025
1 parent fa1c09b commit fc8438a
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tests/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import pickle

import torch
import pytest
import pandas as pd
import yaml
Expand Down Expand Up @@ -42,17 +41,15 @@ def test_normal_predict(self):
eval_data[-1, :] += 100
columns = [f"variable_{i}" for i in range(20)]
eval_data = pd.DataFrame(eval_data, columns=columns)
with open(os.path.join(WORKING_DIR, TEST_FILE_ROOT, 'model.pkl'), 'rb') as f:
loaded_model = pickle.load(f)
loaded_model.predict(eval_data)
loaded_model = torch.load(os.path.join(WORKING_DIR, TEST_FILE_ROOT, 'model.pkl'), weights_only=True)
loaded_model.predict(eval_data)

def test_inference_data_smaller_than_window(self):
with pytest.raises(ValueError):
eval_data = pd.read_csv(os.path.join(WORKING_DIR, TEST_FILE_ROOT, "inference_data_smaller_than_window.csv"))
eval_data = eval_data.set_index("timestamp", drop=True)
with open(os.path.join(WORKING_DIR, TEST_FILE_ROOT,'model.pkl'), 'rb') as f:
loaded_model = pickle.load(f)
loaded_model.predict(eval_data)
loaded_model = torch.load(os.path.join(WORKING_DIR, TEST_FILE_ROOT, 'model.pkl'), weights_only=True)
loaded_model.predict(eval_data)

def test_invalid_fillna_config(self):
with pytest.raises(InvalidParameterError):
Expand Down

0 comments on commit fc8438a

Please sign in to comment.