diff --git a/h2o-algos/src/main/java/hex/generic/Generic.java b/h2o-algos/src/main/java/hex/generic/Generic.java index 3d819acc1737..58ab25a081be 100644 --- a/h2o-algos/src/main/java/hex/generic/Generic.java +++ b/h2o-algos/src/main/java/hex/generic/Generic.java @@ -107,8 +107,13 @@ public void computeImpl() { if (ZipUtil.isCompressed(modelBytes)) { genericModel = importMojo(modelBytes, dataKey); } else { - warn("_path", "Trying to import a POJO model - this is currently an experimental feature."); - genericModel = importPojo(modelBytes, dataKey, _result.toString()); + if (H2O.getSysBoolProperty("pojo.import.enabled", false)) { + warn("_path", "Trying to import a POJO model - this is currently an experimental feature."); + genericModel = importPojo(modelBytes, dataKey, _result.toString()); + } else { + throw new SecurityException("POJO import is disabled since it brings a security risk. " + + "To enable the feature, set the java property `sys.ai.h2o.pojo.import.enabled` to true."); + } } genericModel.write_lock(_job); genericModel.unlock(_job); diff --git a/h2o-py/tests/testdir_algos/gbm/pyunit_gbm_pojo_import.py b/h2o-py/tests/testdir_algos/gbm/pyunit_gbm_pojo_import.py index 20b257653c3b..9a33263d0641 100644 --- a/h2o-py/tests/testdir_algos/gbm/pyunit_gbm_pojo_import.py +++ b/h2o-py/tests/testdir_algos/gbm/pyunit_gbm_pojo_import.py @@ -35,7 +35,4 @@ def prostate_pojo_import(): assert_frame_equal(pdp_original[0].as_data_frame(), pdp_imported[0].as_data_frame()) -if __name__ == "__main__": - pyunit_utils.standalone_test(prostate_pojo_import) -else: - prostate_pojo_import() +pyunit_utils.standalone_test(prostate_pojo_import, {"jvm_custom_args": ["-Dsys.ai.h2o.pojo.import.enabled=true", ]}) diff --git a/h2o-py/tests/testdir_generic_model/pyunit_pojo_import.py b/h2o-py/tests/testdir_generic_model/pyunit_combined_pojo_import.py similarity index 98% rename from h2o-py/tests/testdir_generic_model/pyunit_pojo_import.py rename to h2o-py/tests/testdir_generic_model/pyunit_combined_pojo_import.py index fe012e630a25..2f6f5dc028c4 100644 --- a/h2o-py/tests/testdir_generic_model/pyunit_pojo_import.py +++ b/h2o-py/tests/testdir_generic_model/pyunit_combined_pojo_import.py @@ -341,7 +341,8 @@ def generate_and_import_combined_pojo(): assert_frame_equal(pojo_weather_cwd_preds.as_data_frame(), expected.as_data_frame()) -if __name__ == "__main__": - pyunit_utils.standalone_test(generate_and_import_combined_pojo) -else: - generate_and_import_combined_pojo() +pyunit_utils.standalone_test( + generate_and_import_combined_pojo, + {"jvm_custom_args": ["-Dsys.ai.h2o.pojo.import.enabled=true", ]} +) + diff --git a/h2o-py/tests/testdir_generic_model/pyunit_pojo_import_disabled.py b/h2o-py/tests/testdir_generic_model/pyunit_pojo_import_disabled.py new file mode 100644 index 000000000000..bf3b2227ef0f --- /dev/null +++ b/h2o-py/tests/testdir_generic_model/pyunit_pojo_import_disabled.py @@ -0,0 +1,33 @@ +import os +import sys +import unittest +import h2o +from h2o.backend import H2OLocalServer +from h2o.estimators import H2OGradientBoostingEstimator +from tests import pyunit_utils + +sys.path.insert(1,"../../") + +class TestJavaImportDisabled(unittest.TestCase): + def test(self): + try: + h2o.init(strict_version_check=False) + airlines = h2o.import_file(path=pyunit_utils.locate("smalldata/testng/airlines_train.csv")) + gbm = H2OGradientBoostingEstimator(ntrees=1, nfolds=3) + gbm.train(x=["Origin", "Dest"], y="IsDepDelayed", training_frame=airlines, validation_frame=airlines) + + pojo_path = gbm.download_pojo(path=os.path.join(pyunit_utils.locate("results"), gbm.model_id + ".java")) + + with self.assertRaises(OSError) as err: + h2o.import_mojo(pojo_path) + assert "POJO import is disabled since it brings a security risk." in str(err.exception) + + with self.assertRaises(OSError) as err: + h2o.upload_mojo(pojo_path) + assert "POJO import is disabled since it brings a security risk." in str(err.exception) + finally: + h2o.cluster().shutdown() + + +suite = unittest.TestLoader().loadTestsFromTestCase(TestJavaImportDisabled) +unittest.TextTestRunner().run(suite)