-
-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat, perf: Refactor the PoC to support multiple dtypes (#757)
Co-authored-by: Mateusz Sokół <[email protected]>
- Loading branch information
1 parent
41159c0
commit 1c56a0b
Showing
8 changed files
with
242 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import abc | ||
import functools | ||
|
||
from mlir import ir | ||
|
||
|
||
class MlirType(abc.ABC): | ||
@classmethod | ||
@abc.abstractmethod | ||
def get_mlir_type(cls) -> ir.Type: ... | ||
|
||
|
||
def fn_cache(f, maxsize: int | None = None): | ||
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,119 @@ | ||
import inspect | ||
import math | ||
import sys | ||
import typing | ||
|
||
from mlir import ir | ||
|
||
import numpy as np | ||
|
||
from ._common import MlirType | ||
|
||
|
||
def _get_pointer_width() -> int: | ||
return round(math.log2(sys.maxsize + 1.0)) + 1 | ||
|
||
|
||
_PTR_WIDTH = _get_pointer_width() | ||
|
||
|
||
def _make_int_classes(namespace: dict[str, object], bit_widths: typing.Iterable[int]) -> None: | ||
for bw in bit_widths: | ||
|
||
class SignedBW(SignedIntegerDType): | ||
np_dtype = getattr(np, f"int{bw}") | ||
bit_width = bw | ||
|
||
@classmethod | ||
def get_mlir_type(cls): | ||
return ir.IntegerType.get_signless(cls.bit_width) | ||
|
||
SignedBW.__name__ = f"Int{bw}" | ||
SignedBW.__module__ = __name__ | ||
|
||
class UnsignedBW(UnsignedIntegerDType): | ||
np_dtype = getattr(np, f"uint{bw}") | ||
bit_width = bw | ||
|
||
@classmethod | ||
def get_mlir_type(cls): | ||
return ir.IntegerType.get_signless(cls.bit_width) | ||
|
||
UnsignedBW.__name__ = f"UInt{bw}" | ||
UnsignedBW.__module__ = __name__ | ||
|
||
namespace[SignedBW.__name__] = SignedBW | ||
namespace[UnsignedBW.__name__] = UnsignedBW | ||
|
||
class DType: | ||
pass | ||
|
||
class DType(MlirType): | ||
np_dtype: np.dtype | ||
bit_width: int | ||
|
||
class Float64(DType): | ||
|
||
class FloatingDType(DType): ... | ||
|
||
|
||
class Float64(FloatingDType): | ||
np_dtype = np.float64 | ||
bit_width = 64 | ||
|
||
@classmethod | ||
def get(cls): | ||
def get_mlir_type(cls): | ||
return ir.F64Type.get() | ||
|
||
|
||
class Float32(DType): | ||
class Float32(FloatingDType): | ||
np_dtype = np.float32 | ||
bit_width = 32 | ||
|
||
@classmethod | ||
def get(cls): | ||
def get_mlir_type(cls): | ||
return ir.F32Type.get() | ||
|
||
|
||
class Int64(DType): | ||
np_dtype = np.int64 | ||
class Float16(FloatingDType): | ||
np_dtype = np.float16 | ||
bit_width = 16 | ||
|
||
@classmethod | ||
def get(cls): | ||
return ir.IntegerType.get_signed(64) | ||
def get_mlir_type(cls): | ||
return ir.F16Type.get() | ||
|
||
|
||
class UInt64(DType): | ||
np_dtype = np.uint64 | ||
class IntegerDType(DType): ... | ||
|
||
@classmethod | ||
def get(cls): | ||
return ir.IntegerType.get_unsigned(64) | ||
|
||
class UnsignedIntegerDType(IntegerDType): ... | ||
|
||
class Int32(DType): | ||
np_dtype = np.int32 | ||
|
||
@classmethod | ||
def get(cls): | ||
return ir.IntegerType.get_signed(32) | ||
|
||
class SignedIntegerDType(IntegerDType): ... | ||
|
||
class UInt32(DType): | ||
np_dtype = np.uint32 | ||
|
||
@classmethod | ||
def get(cls): | ||
return ir.IntegerType.get_unsigned(32) | ||
_make_int_classes(locals(), [8, 16, 32, 64]) | ||
|
||
|
||
class Index(DType): | ||
np_dtype = np.intp | ||
|
||
@classmethod | ||
def get(cls): | ||
def get_mlir_type(cls): | ||
return ir.IndexType.get() | ||
|
||
|
||
class SignlessInt64(DType): | ||
np_dtype = np.int64 | ||
IntP: type[SignedIntegerDType] = locals()[f"Int{_PTR_WIDTH}"] | ||
UIntP: type[UnsignedIntegerDType] = locals()[f"UInt{_PTR_WIDTH}"] | ||
|
||
@classmethod | ||
def get(cls): | ||
return ir.IntegerType.get_signless(64) | ||
|
||
def isdtype(dt, /) -> bool: | ||
return isinstance(dt, type) and issubclass(dt, DType) and not inspect.isabstract(dt) | ||
|
||
|
||
NUMPY_DTYPE_MAP = {np.dtype(dt.np_dtype): dt for dt in locals().values() if isdtype(dt)} | ||
|
||
|
||
def asdtype(dt, /) -> type[DType]: | ||
if isdtype(dt): | ||
return dt | ||
|
||
return NUMPY_DTYPE_MAP[np.dtype(dt)] |
Oops, something went wrong.