From 5b66569fd5722098335543c468ac15533469ede6 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 22 Jan 2024 19:24:03 +0100 Subject: [PATCH] update mojo import --- .../pyunit_mojo_import.py | 114 +++++++++--------- 1 file changed, 59 insertions(+), 55 deletions(-) diff --git a/h2o-py/tests/testdir_generic_model/pyunit_mojo_import.py b/h2o-py/tests/testdir_generic_model/pyunit_mojo_import.py index 4e95f0d6b2cf..a59aa8630740 100644 --- a/h2o-py/tests/testdir_generic_model/pyunit_mojo_import.py +++ b/h2o-py/tests/testdir_generic_model/pyunit_mojo_import.py @@ -1,4 +1,5 @@ import h2o +import unittest import tempfile from h2o.estimators import H2OGradientBoostingEstimator, H2OGenericEstimator from tests import pyunit_utils @@ -6,60 +7,63 @@ from pandas.testing import assert_frame_equal -# Test of MOJO convenience methods -def mojo_convenience(): - # Train a model - airlines = h2o.import_file(path=pyunit_utils.locate("smalldata/testng/airlines_train.csv")) - model = H2OGradientBoostingEstimator(ntrees = 1) - model.train(x = ["Origin", "Dest"], y = "IsDepDelayed", training_frame=airlines) - - #Save the previously created model into a temporary file - original_model_filename = tempfile.mkdtemp() - original_model_filename = model.save_mojo(original_model_filename) - - # Load the model from the temporary file - mojo_model = h2o.import_mojo(original_model_filename) - assert isinstance(mojo_model, H2OGenericEstimator) - - # Test scoring is available on the model - predictions = mojo_model.predict(airlines) - assert predictions is not None - assert predictions.nrows == 24421 +class TestMojoImport(unittest.TestCase): + # Test of MOJO convenience methods + def mojo_convenience(self): + try: + h2o.init(strict_version_check=False, jvm_custom_args=["-Dsys.ai.h2o.pojo.import.enabled=true", ]) + # Train a model + airlines = h2o.import_file(path=pyunit_utils.locate("smalldata/testng/airlines_train.csv")) + model = H2OGradientBoostingEstimator(ntrees = 1) + model.train(x = ["Origin", "Dest"], y = "IsDepDelayed", training_frame=airlines) + + #Save the previously created model into a temporary file + original_model_filename = tempfile.mkdtemp() + original_model_filename = model.save_mojo(original_model_filename) + + # Load the model from the temporary file + mojo_model = h2o.import_mojo(original_model_filename) + assert isinstance(mojo_model, H2OGenericEstimator) + + # Test scoring is available on the model + predictions = mojo_model.predict(airlines) + assert predictions is not None + assert predictions.nrows == 24421 + + ##### + # MOJO UPLOAD TEST + ##### + + try: + pyunit_utils.set_forbidden_paths([original_model_filename]) + # Download the MOJO + original_model_filename = model.download_mojo(original_model_filename) + # Load the model from the temporary file + mojo_model = h2o.upload_mojo(original_model_filename) + assert isinstance(mojo_model, H2OGenericEstimator) + + # Test scoring is available on the model + predictions = mojo_model.predict(airlines) + assert predictions is not None + assert predictions.nrows == 24421 + finally: + pyunit_utils.clear_forbidden_paths() + + ##### + # MOJO to POJO Conversion test with POJO re-import + ##### + + pojo_directory = os.path.join(pyunit_utils.locate("results"), model.model_id + ".java") + pojo_path = model.download_pojo(path = pojo_directory) + mojo2_model = h2o.import_mojo(pojo_path) + + predictions2 = mojo2_model.predict(airlines) + assert predictions2 is not None + assert predictions2.nrows == 24421 + assert_frame_equal(predictions.as_data_frame(), predictions2.as_data_frame()) + finally: + h2o.cluster().shutdown() - ##### - # MOJO UPLOAD TEST - ##### - try: - pyunit_utils.set_forbidden_paths([original_model_filename]) - # Download the MOJO - original_model_filename = model.download_mojo(original_model_filename) - # Load the model from the temporary file - mojo_model = h2o.upload_mojo(original_model_filename) - assert isinstance(mojo_model, H2OGenericEstimator) - - # Test scoring is available on the model - predictions = mojo_model.predict(airlines) - assert predictions is not None - assert predictions.nrows == 24421 - finally: - pyunit_utils.clear_forbidden_paths() - - ##### - # MOJO to POJO Conversion test with POJO re-import - ##### - - pojo_directory = os.path.join(pyunit_utils.locate("results"), model.model_id + ".java") - pojo_path = model.download_pojo(path = pojo_directory) - mojo2_model = h2o.import_mojo(pojo_path) - - predictions2 = mojo2_model.predict(airlines) - assert predictions2 is not None - assert predictions2.nrows == 24421 - assert_frame_equal(predictions.as_data_frame(), predictions2.as_data_frame()) - - -if __name__ == "__main__": - pyunit_utils.standalone_test(mojo_convenience) -else: - mojo_convenience() +suite = unittest.TestLoader().loadTestsFromTestCase(TestMojoImport) +unittest.TextTestRunner().run(suite)