Skip to content

Commit

Permalink
feat, perf: Refactor the PoC to support multiple dtypes (#757)
Browse files Browse the repository at this point in the history
Co-authored-by: Mateusz Sokół <[email protected]>
  • Loading branch information
hameerabbasi and mtsokol authored Aug 27, 2024
1 parent 41159c0 commit 1c56a0b
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 94 deletions.
14 changes: 14 additions & 0 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import abc
import functools

from mlir import ir


class MlirType(abc.ABC):
@classmethod
@abc.abstractmethod
def get_mlir_type(cls) -> ir.Type: ...


def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
91 changes: 54 additions & 37 deletions sparse/mlir_backend/_constructors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ctypes
import ctypes.util
import functools
import weakref

import mlir.execution_engine
import mlir.passmanager
Expand All @@ -9,9 +11,26 @@
import numpy as np
import scipy.sparse as sps

from ._core import DEBUG, MLIR_C_RUNNER_UTILS, SCRIPT_PATH, ctx
from ._dtypes import DType, Float64, Index
from ._memref import MemrefF64_1D, MemrefIdx_1D
from ._common import fn_cache
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx
from ._dtypes import DType, Index, asdtype
from ._memref import make_memref_ctype, ranked_memref_from_np


def _hold_self_ref_in_ret(fn):
@functools.wraps(fn)
def wrapped(self, *a, **kw):
ptr = ctypes.py_object(self)
ctypes.pythonapi.Py_IncRef(ptr)
ret = fn(self, *a, **kw)

def finalizer(ptr):
ctypes.pythonapi.Py_DecRef(ptr)

weakref.finalize(ret, finalizer, ptr)
return ret

return wrapped


class Tensor:
Expand All @@ -26,21 +45,21 @@ def __init__(self, obj, module, tensor_type, disassemble_fn, values_dtype, index
def __del__(self):
self.module.invoke("free_tensor", ctypes.pointer(self.obj))

@_hold_self_ref_in_ret
def to_scipy_sparse(self):
"""
Returns scipy.sparse or ndarray
"""
return self.disassemble_fn(self.module, self.obj)
return self.disassemble_fn(self.module, self.obj, self.values_dtype)


class DenseFormat:
modules = {}

@fn_cache
def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
with ir.Location.unknown(ctx):
module = ir.Module.create()
values_dtype = values_dtype.get()
index_dtype = index_dtype.get()
values_dtype = values_dtype.get_mlir_type()
index_dtype = index_dtype.get_mlir_type()
index_width = getattr(index_dtype, "width", 0)
levels = (sparse_tensor.LevelType.dense, sparse_tensor.LevelType.dense)
ordering = ir.AffineMap.get_permutation([0, 1])
Expand Down Expand Up @@ -78,18 +97,19 @@ def free_tensor(tensor_shaped):
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(SCRIPT_PATH / "dense_module.mlir").write_text(str(module))
(CWD / "dense_module.mlir").write_text(str(module))
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
pm.run(module.operation)
if DEBUG:
(SCRIPT_PATH / "dense_module_opt.mlir").write_text(str(module))
(CWD / "dense_module_opt.mlir").write_text(str(module))

module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
return (module, dense_shaped)

@classmethod
def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
data = MemrefF64_1D.from_numpy(arr.flatten())
assert arr.ndim == 2
data = ranked_memref_from_np(arr.flatten())
out = ctypes.c_void_p()
module.invoke(
"assemble",
Expand All @@ -99,18 +119,18 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
return out

@classmethod
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p) -> np.ndarray:
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> np.ndarray:
class Dense(ctypes.Structure):
_fields_ = [
("data", MemrefF64_1D),
("data", make_memref_ctype(dtype, 1)),
("data_len", np.ctypeslib.c_intp),
("shape_x", np.ctypeslib.c_intp),
("shape_y", np.ctypeslib.c_intp),
]

def to_np(self) -> np.ndarray:
data = self.data.to_numpy()[: self.data_len]
return data.copy().reshape((self.shape_x, self.shape_y))
return data.reshape((self.shape_x, self.shape_y))

arr = Dense()
module.invoke(
Expand All @@ -122,18 +142,17 @@ def to_np(self) -> np.ndarray:


class COOFormat:
modules = {}
# TODO: implement
...


class CSRFormat:
modules = {}

def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
@fn_cache
def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[DType]):
with ir.Location.unknown(ctx):
module = ir.Module.create()
values_dtype = values_dtype.get()
index_dtype = index_dtype.get()
values_dtype = values_dtype.get_mlir_type()
index_dtype = index_dtype.get_mlir_type()
index_width = getattr(index_dtype, "width", 0)
levels = (sparse_tensor.LevelType.dense, sparse_tensor.LevelType.compressed)
ordering = ir.AffineMap.get_permutation([0, 1])
Expand Down Expand Up @@ -175,11 +194,11 @@ def free_tensor(tensor_shaped):
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(SCRIPT_PATH / "scr_module.mlir").write_text(str(module))
(CWD / "csr_module.mlir").write_text(str(module))
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
pm.run(module.operation)
if DEBUG:
(SCRIPT_PATH / "csr_module_opt.mlir").write_text(str(module))
(CWD / "csr_module_opt.mlir").write_text(str(module))

module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
return (module, csr_shaped)
Expand All @@ -189,20 +208,20 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
out = ctypes.c_void_p()
module.invoke(
"assemble",
ctypes.pointer(ctypes.pointer(MemrefIdx_1D.from_numpy(arr.indptr))),
ctypes.pointer(ctypes.pointer(MemrefIdx_1D.from_numpy(arr.indices))),
ctypes.pointer(ctypes.pointer(MemrefF64_1D.from_numpy(arr.data))),
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.indptr))),
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.indices))),
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.data))),
ctypes.pointer(out),
)
return out

@classmethod
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p) -> sps.csr_array:
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> sps.csr_array:
class Csr(ctypes.Structure):
_fields_ = [
("data", MemrefF64_1D),
("pos", MemrefIdx_1D),
("crd", MemrefIdx_1D),
("data", make_memref_ctype(dtype, 1)),
("pos", make_memref_ctype(Index, 1)),
("crd", make_memref_ctype(Index, 1)),
("data_len", np.ctypeslib.c_intp),
("pos_len", np.ctypeslib.c_intp),
("crd_len", np.ctypeslib.c_intp),
Expand All @@ -214,7 +233,7 @@ def to_sps(self) -> sps.csr_array:
pos = self.pos.to_numpy()[: self.pos_len]
crd = self.crd.to_numpy()[: self.crd_len]
data = self.data.to_numpy()[: self.data_len]
return sps.csr_array((data.copy(), crd.copy(), pos.copy()), shape=(self.shape_x, self.shape_y))
return sps.csr_array((data, crd, pos), shape=(self.shape_x, self.shape_y))

arr = Csr()
module.invoke(
Expand All @@ -235,23 +254,21 @@ def _is_numpy_obj(x) -> bool:

def asarray(obj) -> Tensor:
# TODO: discover obj's dtype
values_dtype = Float64
index_dtype = Index
values_dtype = asdtype(obj.dtype)

# TODO: support other scipy formats
if _is_scipy_sparse_obj(obj):
format_class = CSRFormat
# This can be int32 or int64
index_dtype = asdtype(obj.indptr.dtype)
elif _is_numpy_obj(obj):
format_class = DenseFormat
index_dtype = Index
else:
raise Exception(f"{type(obj)} not supported.")

# TODO: support proper caching
if hash(obj.shape) in format_class.modules:
module, tensor_type = format_class.modules[hash(obj.shape)]
else:
module, tensor_type = format_class.get_module(obj.shape, values_dtype, index_dtype)
format_class.modules[hash(obj.shape)] = module, tensor_type
module, tensor_type = format_class.get_module(obj.shape, values_dtype, index_dtype)

assembled_obj = format_class.assemble(module, obj)
return Tensor(assembled_obj, module, tensor_type, format_class.disassemble, values_dtype, index_dtype)
2 changes: 1 addition & 1 deletion sparse/mlir_backend/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mlir.ir import Context

DEBUG = bool(int(os.environ.get("DEBUG", "0")))
SCRIPT_PATH = pathlib.Path(__file__).parent
CWD = pathlib.Path(".")

MLIR_C_RUNNER_UTILS = ctypes.util.find_library("mlir_c_runner_utils")
libc = ctypes.CDLL(ctypes.util.find_library("c")) if os.name != "nt" else ctypes.cdll.msvcrt
Expand Down
112 changes: 80 additions & 32 deletions sparse/mlir_backend/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,119 @@
import inspect
import math
import sys
import typing

from mlir import ir

import numpy as np

from ._common import MlirType


def _get_pointer_width() -> int:
return round(math.log2(sys.maxsize + 1.0)) + 1


_PTR_WIDTH = _get_pointer_width()


def _make_int_classes(namespace: dict[str, object], bit_widths: typing.Iterable[int]) -> None:
for bw in bit_widths:

class SignedBW(SignedIntegerDType):
np_dtype = getattr(np, f"int{bw}")
bit_width = bw

@classmethod
def get_mlir_type(cls):
return ir.IntegerType.get_signless(cls.bit_width)

SignedBW.__name__ = f"Int{bw}"
SignedBW.__module__ = __name__

class UnsignedBW(UnsignedIntegerDType):
np_dtype = getattr(np, f"uint{bw}")
bit_width = bw

@classmethod
def get_mlir_type(cls):
return ir.IntegerType.get_signless(cls.bit_width)

UnsignedBW.__name__ = f"UInt{bw}"
UnsignedBW.__module__ = __name__

namespace[SignedBW.__name__] = SignedBW
namespace[UnsignedBW.__name__] = UnsignedBW

class DType:
pass

class DType(MlirType):
np_dtype: np.dtype
bit_width: int

class Float64(DType):

class FloatingDType(DType): ...


class Float64(FloatingDType):
np_dtype = np.float64
bit_width = 64

@classmethod
def get(cls):
def get_mlir_type(cls):
return ir.F64Type.get()


class Float32(DType):
class Float32(FloatingDType):
np_dtype = np.float32
bit_width = 32

@classmethod
def get(cls):
def get_mlir_type(cls):
return ir.F32Type.get()


class Int64(DType):
np_dtype = np.int64
class Float16(FloatingDType):
np_dtype = np.float16
bit_width = 16

@classmethod
def get(cls):
return ir.IntegerType.get_signed(64)
def get_mlir_type(cls):
return ir.F16Type.get()


class UInt64(DType):
np_dtype = np.uint64
class IntegerDType(DType): ...

@classmethod
def get(cls):
return ir.IntegerType.get_unsigned(64)

class UnsignedIntegerDType(IntegerDType): ...

class Int32(DType):
np_dtype = np.int32

@classmethod
def get(cls):
return ir.IntegerType.get_signed(32)

class SignedIntegerDType(IntegerDType): ...

class UInt32(DType):
np_dtype = np.uint32

@classmethod
def get(cls):
return ir.IntegerType.get_unsigned(32)
_make_int_classes(locals(), [8, 16, 32, 64])


class Index(DType):
np_dtype = np.intp

@classmethod
def get(cls):
def get_mlir_type(cls):
return ir.IndexType.get()


class SignlessInt64(DType):
np_dtype = np.int64
IntP: type[SignedIntegerDType] = locals()[f"Int{_PTR_WIDTH}"]
UIntP: type[UnsignedIntegerDType] = locals()[f"UInt{_PTR_WIDTH}"]

@classmethod
def get(cls):
return ir.IntegerType.get_signless(64)

def isdtype(dt, /) -> bool:
return isinstance(dt, type) and issubclass(dt, DType) and not inspect.isabstract(dt)


NUMPY_DTYPE_MAP = {np.dtype(dt.np_dtype): dt for dt in locals().values() if isdtype(dt)}


def asdtype(dt, /) -> type[DType]:
if isdtype(dt):
return dt

return NUMPY_DTYPE_MAP[np.dtype(dt)]
Loading

0 comments on commit 1c56a0b

Please sign in to comment.