Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Introduce complex numbers support (np.complex64/128) #2375

Merged
merged 74 commits into from
Mar 19, 2025
Merged
Changes from 1 commit
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
9b0420e
api: add support for complex dtype
mloubout Aug 2, 2023
5398c68
api: fix printer for complex dtype
mloubout May 22, 2024
2b4f7b0
compiler: fix alias dtype with complex numbers
mloubout May 22, 2024
0cf3f78
api: move complex ctype to dtype lowering
mloubout May 22, 2024
386d5a2
compiler: generate std:complex for cpp compilers
mloubout May 28, 2024
b5bf79c
compiler: add std::complex arithmetic defs for unsupported types
mloubout May 30, 2024
ac28372
compiler: fix alias dtype with complex numbers
mloubout May 30, 2024
89b6fea
compiler: fix internal language specific types and cast
mloubout May 31, 2024
d5a2542
compiler: rework dtype lowering
mloubout Jun 20, 2024
64801cb
compiler: switch to c++14 for complex_literals
mloubout Jun 27, 2024
2f0cacf
compiler: subdtype numpy for dtype lowering
mloubout Jul 8, 2024
7ad523b
compiler: use structs to pass complex arguments
enwask Jul 9, 2024
e8e51f1
compiler: add Dereference scalar case
enwask Jul 11, 2024
d0c55f9
compiler: implement float16 support
enwask Jul 11, 2024
9081b2e
symbolics: fix printer for half precision
enwask Jul 11, 2024
89c8bf5
misc: fix formatting
enwask Jul 11, 2024
43b82db
compiler: refactor float16 and lower_dtypes
enwask Jul 11, 2024
91ab1e9
compiler: add dtype_alloc_ctype helper for allocation size
enwask Jul 11, 2024
d907e8d
misc: more float16 refactoring/formatting fixes
enwask Jul 15, 2024
425d568
Remove dtypes lowering from IET layer
enwask Jul 16, 2024
8db3f96
compiler: reimplement float16/complex lowering
enwask Jul 26, 2024
53462a3
misc: cleanup, docs and typing for half support
enwask Jul 29, 2024
8a8f24f
compiler: FindSymbols 'scalars' -> 'abstractsymbols'
enwask Jul 29, 2024
b0c9ee9
test: include scalar parameters in complex tests
enwask Jul 30, 2024
344c435
test: add test_dtypes with initial tests for float16 + complex
enwask Jul 30, 2024
4f5ee36
misc: more lower_dtypes cleanup + type hints
enwask Jul 30, 2024
49a9bec
api: use grid dtype for extent and origin, add test_grid
enwask Jul 31, 2024
e8c7fc0
test: clean up and add more half/complex tests
enwask Jul 31, 2024
12b936a
test: fix test_grid_objs, add test_grid_dtypes
enwask Jul 31, 2024
1981465
api: allow side for cross derivatives, fixes #2442
mloubout Aug 13, 2024
da2ea2a
compiler: process dtypes through printer
mloubout Jan 15, 2025
f76dbad
symbolics: specialize sizeof
mloubout Jan 16, 2025
d55931f
compiler: move dtype pass to top level operator iet pass
mloubout Jan 16, 2025
cb6fdc8
symbolics: fix SizeOf rebuild
mloubout Jan 16, 2025
11d6b32
symbolics: use std namespace for c++
mloubout Jan 16, 2025
ed5dbe6
compiler: fix std math func names
mloubout Jan 16, 2025
63e589d
symbolics: move printers rogether through registry
mloubout Jan 17, 2025
f781aa4
symbolics: rework Cast
mloubout Jan 17, 2025
05d4e3f
compiler: fix complex headers
mloubout Jan 17, 2025
6fc54e3
api: remove un-needed dtype reconstruction mode
mloubout Jan 17, 2025
0f17026
compiler: fix dtype for mpi routines
mloubout Jan 17, 2025
2ce9817
compiler: fix missing algorithm include for min/max
mloubout Jan 18, 2025
9307bf0
arch: switch sycl error to warning for no-compile codegen
mloubout Jan 18, 2025
2b2848b
symbolics: rework cast/sizeof for pickling
mloubout Jan 22, 2025
3e9e931
api: fix c_datatype hack
mloubout Jan 22, 2025
58e6310
compiler: make visitor language parametric
mloubout Jan 23, 2025
f1a082d
compiler: make sure complex ctype is handled properly for typedata
mloubout Jan 23, 2025
0c914f2
symbolics: cleaner repr of Cast
mloubout Jan 23, 2025
1c9ab2e
test: improve dtype tests log
mloubout Jan 24, 2025
461fd43
compiler: make sure cpp is used for c++ compilers
mloubout Jan 26, 2025
c515253
compiler: make printer part of the target and differentiate C and CXX
mloubout Jan 27, 2025
4899f11
compiler: add all cxx target to operator registry
mloubout Jan 27, 2025
f968a89
compiler: cleanup operator class names
mloubout Jan 28, 2025
d385957
compiler: switch cxx backend to static_cast
mloubout Jan 28, 2025
b419050
compiler: add switch for static_cast vs reinterpret_cast
mloubout Jan 28, 2025
ad7271f
compiler: handle plain text header
mloubout Jan 30, 2025
91f2018
compiler: convert all in visitors to f-string
mloubout Jan 30, 2025
049f17a
compiler: convert printer to f-string
mloubout Jan 31, 2025
8f4f221
arch: add intel gpu basic gpu_info support
mloubout Feb 13, 2025
d24e6e1
compiler: fix header order
mloubout Feb 20, 2025
8e0a2d3
compiler: add scalar type option
mloubout Feb 25, 2025
65f1b37
compiler: fix real dtype
mloubout Mar 2, 2025
1da795e
compiler: more robust safeinv
mloubout Mar 5, 2025
e9fa6ec
compiler: mssing substraction cxx def
mloubout Mar 6, 2025
83737b0
compiler: make dtype lowering more flexible
mloubout Mar 6, 2025
6a563a2
examples: add on the fly dft tutorial
mloubout Mar 6, 2025
601d99d
api: fix norm with complex numbers
mloubout Mar 6, 2025
c1a31a9
api: dix sympy assumptions for complex valued objects
mloubout Mar 6, 2025
43da5a2
misc: f-string formatting
mloubout Mar 11, 2025
8185a5e
compiler: cleanup default includes/header/namespaces
mloubout Mar 14, 2025
124b9ba
misc: fix typos and formatting
mloubout Mar 17, 2025
24faf28
compiler: rename lang option to langbb for clarity
mloubout Mar 17, 2025
5f89aed
api: enforce pow_to_mul to be un-evaluable
mloubout Mar 18, 2025
be87c17
compiler: rename lang to langbb throughout for clarity
mloubout Mar 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove dtypes lowering from IET layer
enwask authored and mloubout committed Mar 18, 2025
commit 425d56858a9c53a97e18a258a9d7823a6795e93c
10 changes: 5 additions & 5 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
FindNodes, FindSymbols, MapExprStmts, Transformer,
make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_dtypes
from devito.passes.iet.dtypes import include_complex
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
@@ -466,8 +466,8 @@ def place_casts(self, iet, **kwargs):
return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
def include_complex(self, iet):
iet, metadata = include_complex(iet, self.lang, self.compiler)
return iet, metadata

def process(self, graph):
@@ -476,7 +476,7 @@ def process(self, graph):
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)
self.include_complex(graph)


class DeviceAwareDataManager(DataManager):
@@ -618,7 +618,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)
self.include_complex(graph)


def make_zero_init(obj, rcompile, sregistry):
54 changes: 6 additions & 48 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,21 @@
import numpy as np
import ctypes

from devito.ir import FindSymbols, Uxreplace
from devito.ir.iet.nodes import Dereference
from devito.tools.utils import as_list
from devito.types.basic import Symbol
from devito.ir import FindSymbols

__all__ = ['lower_dtypes']
__all__ = ['include_complex']


def lower_dtypes(iet, lang, compiler, sregistry):
def include_complex(iet, lang, compiler):
"""
Lower language-specific dtypes and add headers for complex arithmetic
"""
# Include complex headers if needed (before we replace complex dtypes)
metadata = _complex_includes(iet, lang, compiler)

body_prefix = [] # Derefs to prepend to the body
body_mapper = {}
params_mapper = {}

# Lower scalar float16s to pointers and dereference them
if lang.get('half_types') is not None:
half, half_p = lang['half_types'] # dtype mappings for half float

for s in FindSymbols('scalars').visit(iet):
if s.dtype != np.float16 or s not in iet.parameters:
continue

ptr = s._rebuild(dtype=half_p, is_const=True)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half,
is_const=s.is_const)

params_mapper[s], body_mapper[s] = ptr, val
body_prefix.append(Dereference(val, ptr)) # val = *ptr

# Lower remaining language-specific dtypes
for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types'] and s not in params_mapper:
body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

# Apply the dtype replacements
body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body))
params = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=body, parameters=params)
return iet, metadata


def _complex_includes(iet, lang, compiler):
"""
Add headers for complex arithmetic
Include complex arithmetic headers for the given language, if needed.
"""
# Check if there is complex numbers that always take dtype precedence
types = {f.dtype for f in FindSymbols().visit(iet)
if not issubclass(f.dtype, ctypes._Pointer)}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this blank line

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return {}
return iet, {}

metadata = {}
lib = (lang['header-complex'],)
@@ -75,4 +33,4 @@ def _complex_includes(iet, lang, compiler):

metadata['includes'] = lib

return metadata
return iet, metadata