Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AnomalyDetection] Better type hinting in specifiable. #34310

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions sdks/python/apache_beam/ml/anomaly/specifiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,22 @@

from __future__ import annotations

import abc
import collections
import dataclasses
import inspect
import logging
import os
from typing import Any
from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Type
from typing import TypeVar
from typing import Union
from typing import runtime_checkable
from typing import overload

from typing_extensions import Self

Expand All @@ -59,7 +60,7 @@
#: `spec_type` when applying the `specifiable` decorator to an existing class.
_KNOWN_SPECIFIABLE = collections.defaultdict(dict)

SpecT = TypeVar('SpecT', bound='Specifiable')
T = TypeVar('T', bound=type)


def _class_to_subspace(cls: Type) -> str:
Expand Down Expand Up @@ -104,8 +105,7 @@ class Spec():
config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)


@runtime_checkable
class Specifiable(Protocol):
class Specifiable(abc.ABC):
"""Protocol that a specifiable class needs to implement."""
#: The value of the `type` field in the object's spec for this class.
spec_type: ClassVar[str]
Expand All @@ -130,7 +130,9 @@ def _from_spec_helper(v, _run_init):
return v

@classmethod
def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
def from_spec(cls,
spec: Spec,
_run_init: bool = True) -> Union[Self, type[Self]]:
"""Generate a `Specifiable` subclass object based on a spec.

Args:
Expand Down Expand Up @@ -250,13 +252,35 @@ def _get_init_kwargs(inst, init_method, *args, **kwargs):
return params


@overload
def specifiable(
my_cls=None,
my_cls: None = None,
/,
*,
spec_type=None,
on_demand_init=True,
just_in_time_init=True):
spec_type: Optional[str] = None,
on_demand_init: bool = True,
just_in_time_init: bool = True) -> Callable[[T], T]:
...


@overload
def specifiable(
my_cls: T,
/,
*,
spec_type: Optional[str] = None,
on_demand_init: bool = True,
just_in_time_init: bool = True) -> T:
...


def specifiable(
my_cls: Optional[T] = None,
/,
*,
spec_type: Optional[str] = None,
on_demand_init: bool = True,
just_in_time_init: bool = True) -> Union[T, Callable[[T], T]]:
"""A decorator that turns a class into a `Specifiable` subclass by
implementing the `Specifiable` protocol.

Expand Down Expand Up @@ -285,8 +309,8 @@ class Bar():
original `__init__` method will be called when the first time an attribute
is accessed.
"""
def _wrapper(cls):
def new_init(self: Specifiable, *args, **kwargs):
def _wrapper(cls: T) -> T:
def new_init(self, *args, **kwargs):
self._initialized = False
self._in_init = False

Expand Down Expand Up @@ -361,9 +385,14 @@ def new_getattr(self, name):
# start of the function body of _wrapper
_register(cls, spec_type)

# register the original class as a virtual subclass of Specifiable
# so issubclass(cls, Specifiable) and isinstance(cls(), Specifiable) are
# true
Specifiable.register(cls)

class_name = cls.__name__
original_init = cls.__init__
cls.__init__ = new_init
original_init = cls.__init__ # type: ignore[misc]
cls.__init__ = new_init # type: ignore[misc]
if just_in_time_init:
cls.__getattr__ = new_getattr

Expand Down
Loading