Skip to content

Commit 2449f53

Browse files
clean up types a bit
1 parent 7b0d862 commit 2449f53

File tree

4 files changed

+14
-11
lines changed

4 files changed

+14
-11
lines changed

loopy/kernel/function_interface.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,7 @@ def __init__(self,
348348
try:
349349
hash(arg_id_to_dtype)
350350
except TypeError:
351-
if arg_id_to_dtype is None:
352-
arg_id_to_dtype = {}
351+
assert arg_id_to_dtype is not None
353352
arg_id_to_dtype = constantdict(arg_id_to_dtype)
354353
warn("arg_id_to_dtype passed to InKernelCallable was not hashable. "
355354
"This usage is deprecated and will stop working in 2026.",
@@ -358,8 +357,7 @@ def __init__(self,
358357
try:
359358
hash(arg_id_to_descr)
360359
except TypeError:
361-
if arg_id_to_descr is None:
362-
arg_id_to_descr = {}
360+
assert arg_id_to_descr is not None
363361
arg_id_to_descr = constantdict(arg_id_to_descr)
364362
warn("arg_id_to_descr passed to InKernelCallable was not hashable. "
365363
"This usage is deprecated and will stop working in 2026.",

loopy/target/c/c_execution.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Sequence
3232

3333
import numpy as np
34+
35+
from constantdict import constantdict
36+
3437
from codepy.jit import compile_from_string
3538
from codepy.toolchain import GCCToolchain, ToolchainGuessError, guess_toolchain
3639

@@ -500,7 +503,7 @@ def get_wrapper_generator(self):
500503

501504
@memoize_method
502505
def translation_unit_info(self,
503-
arg_to_dtype: Mapping[str, LoopyType] | None = None) -> _KernelInfo:
506+
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
504507
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)
505508

506509
from loopy.codegen import generate_code_v2

loopy/target/execution.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def check_for_required_array_arguments(self, input_args):
817817
"your argument.")
818818

819819
def get_typed_and_scheduled_translation_unit_uncached(
820-
self, arg_to_dtype: Mapping[str, LoopyType] | None
820+
self, arg_to_dtype: constantdict[str, LoopyType] | None
821821
) -> TranslationUnit:
822822
t_unit = self.t_unit
823823

@@ -827,15 +827,15 @@ def get_typed_and_scheduled_translation_unit_uncached(
827827
# FIXME: This is not so nice. This transfers types from the
828828
# subarrays of sep-tagged arrays to the 'main' array, because
829829
# type inference fails otherwise.
830-
mm = dict(arg_to_dtype)
830+
mm = arg_to_dtype.mutate()
831831
for name, sep_info in self.sep_info.items():
832832
if entry_knl.arg_dict[name].dtype is None:
833833
for sep_name in sep_info.subarray_names.values():
834834
if sep_name in arg_to_dtype:
835835
mm[name] = arg_to_dtype[sep_name]
836836
del mm[sep_name]
837837

838-
arg_to_dtype = constantdict(mm)
838+
arg_to_dtype = mm.finish()
839839

840840
from loopy.kernel.tools import add_dtypes
841841
t_unit = t_unit.with_kernel(add_dtypes(entry_knl, arg_to_dtype))
@@ -854,7 +854,7 @@ def get_typed_and_scheduled_translation_unit_uncached(
854854
return t_unit
855855

856856
def get_typed_and_scheduled_translation_unit(
857-
self, arg_to_dtype: Mapping[str, LoopyType] | None
857+
self, arg_to_dtype: constantdict[str, LoopyType] | None
858858
) -> TranslationUnit:
859859
from loopy import CACHING_ENABLED
860860

@@ -904,7 +904,7 @@ def get_highlighted_code(self, entrypoint, arg_to_dtype=None, code=None):
904904

905905
def get_code(
906906
self, entrypoint: str,
907-
arg_to_dtype: Mapping[str, LoopyType] | None = None) -> str:
907+
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> str:
908908
kernel = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)
909909

910910
from loopy.codegen import generate_code_v2

loopy/target/pyopencl_execution.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import numpy as np
3232

33+
from constantdict import constantdict
34+
3335
from pytools import memoize_method
3436
from pytools.codegen import CodeGenerator, Indentation
3537

@@ -311,7 +313,7 @@ def get_wrapper_generator(self):
311313
@memoize_method
312314
def translation_unit_info(
313315
self,
314-
arg_to_dtype: Mapping[str, LoopyType] | None = None) -> _KernelInfo:
316+
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
315317
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)
316318

317319
# FIXME: now just need to add the types to the arguments

0 commit comments

Comments
 (0)