From fc8438aa3b47005e5cd1b26abb0be17dea132f57 Mon Sep 17 00:00:00 2001 From: Juanyong Duan Date: Thu, 30 Jan 2025 21:56:29 +0800 Subject: [PATCH] fix model load --- tests/test_demo.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/test_demo.py b/tests/test_demo.py index 2972789..f6fe626 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -2,8 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -import pickle - +import torch import pytest import pandas as pd import yaml @@ -42,17 +41,15 @@ def test_normal_predict(self): eval_data[-1, :] += 100 columns = [f"variable_{i}" for i in range(20)] eval_data = pd.DataFrame(eval_data, columns=columns) - with open(os.path.join(WORKING_DIR, TEST_FILE_ROOT, 'model.pkl'), 'rb') as f: - loaded_model = pickle.load(f) - loaded_model.predict(eval_data) + loaded_model = torch.load(os.path.join(WORKING_DIR, TEST_FILE_ROOT, 'model.pkl'), weights_only=True) + loaded_model.predict(eval_data) def test_inference_data_smaller_than_window(self): with pytest.raises(ValueError): eval_data = pd.read_csv(os.path.join(WORKING_DIR, TEST_FILE_ROOT, "inference_data_smaller_than_window.csv")) eval_data = eval_data.set_index("timestamp", drop=True) - with open(os.path.join(WORKING_DIR, TEST_FILE_ROOT,'model.pkl'), 'rb') as f: - loaded_model = pickle.load(f) - loaded_model.predict(eval_data) + loaded_model = torch.load(os.path.join(WORKING_DIR, TEST_FILE_ROOT, 'model.pkl'), weights_only=True) + loaded_model.predict(eval_data) def test_invalid_fillna_config(self): with pytest.raises(InvalidParameterError):