Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: apache/beam
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 7d224c0deb1f23e129653d53d2434f5bac0b6587
Choose a base ref
..
head repository: apache/beam
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 4cb5c1fb7bfd5029d4eaf9e8e8ba02146489e620
Choose a head ref
25 changes: 23 additions & 2 deletions sdks/python/apache_beam/ml/anomaly/specifiable.py
Original file line number Diff line number Diff line change
@@ -391,8 +391,17 @@ def new_getattr(self, name):
Specifiable.register(cls)

class_name = cls.__name__
original_init = cls.__init__
cls._original_init = cls.__init__ # type: ignore[misc]

# if a class has `<class_name>__original_init` function, it means it has
# already been registered as specifiable. In this case, we skip the
# following modification.
if hasattr(cls, cls.__name__ + '__original_init'):
return cls

original_init = cls.__init__ # type: ignore[misc]
# saved the class init function in a name that won't conflict with its
# parents.
setattr(cls, cls.__name__ + '__original_init', cls.__init__) # type: ignore[misc]
cls.__init__ = new_init # type: ignore[misc]
if just_in_time_init:
cls.__getattr__ = new_getattr
@@ -416,3 +425,15 @@ def new_getattr(self, name):
# When this decorator is called without an argument, i.e. "@specifiable",
# we return the augmented class.
return _wrapper(my_cls)


def unspecifiable(cls) -> None:
cls.__init__ = getattr(cls, cls.__name__ + '__original_init')
cls.__getattr__ = None
delattr(cls, cls.__name__ + '__original_init')
delattr(cls, 'spec_type')
delattr(cls, 'run_original_init')
delattr(cls, 'to_spec')
delattr(cls, '_to_spec_helper')
delattr(cls, 'from_spec')
delattr(cls, '_from_spec_helper')
13 changes: 13 additions & 0 deletions sdks/python/apache_beam/ml/anomaly/specifiable_test.py
Original file line number Diff line number Diff line change
@@ -585,6 +585,19 @@ def apply(self, x, y):
self.assertEqual(w_2.run_func_in_class(5, 3), 150)


class TestUncommonUsages(unittest.TestCase):
def test_double_specifiable(self):
@specifiable
@specifiable
class ZZ():
def __init__(self, a):
self.a = a

c = ZZ("b")
c.run_original_init()
self.assertEqual(c.a, "b")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
13 changes: 6 additions & 7 deletions sdks/python/apache_beam/ml/anomaly/transforms_test.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@
# limitations under the License.
#

import copy
import logging
import math
import os
@@ -46,6 +45,7 @@
from apache_beam.ml.anomaly.specifiable import Specifiable
from apache_beam.ml.anomaly.specifiable import _spec_type_to_subspace
from apache_beam.ml.anomaly.specifiable import specifiable
from apache_beam.ml.anomaly.specifiable import unspecifiable
from apache_beam.ml.anomaly.thresholds import FixedThreshold
from apache_beam.ml.anomaly.thresholds import QuantileThreshold
from apache_beam.ml.anomaly.transforms import AnomalyDetection
@@ -318,12 +318,11 @@ def setUp(self):
self.tmpdir = tempfile.mkdtemp()

def tearDown(self):
global SklearnModelHandlerNumpy, KeyedModelHandler
global _PreProcessingModelHandler, _PostProcessingModelHandler
SklearnModelHandlerNumpy.__init__ = SklearnModelHandlerNumpy._original_init
KeyedModelHandler.__init__ = KeyedModelHandler._original_init
_PreProcessingModelHandler.__init__ = _PreProcessingModelHandler._original_init
_PostProcessingModelHandler.__init__ = _PostProcessingModelHandler._original_init
# Make the model handlers back to normal
unspecifiable(SklearnModelHandlerNumpy)
unspecifiable(KeyedModelHandler)
unspecifiable(_PreProcessingModelHandler)
unspecifiable(_PostProcessingModelHandler)
shutil.rmtree(self.tmpdir)

def test_default_inference_fn(self):