Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Introduce complex numbers support (np.complex64/128) #2375

Merged
merged 74 commits into from
Mar 19, 2025
Merged

Conversation

mloubout
Copy link
Contributor

Support complex float data type.

@mloubout mloubout added API api (symbolics, types, ...) feature-request labels May 26, 2024
Copy link

codecov bot commented May 26, 2024

Codecov Report

Attention: Patch coverage is 88.28981% with 139 lines in your changes missing coverage. Please review.

Project coverage is 91.99%. Comparing base (f67f2b0) to head (be87c17).
Report is 75 commits behind head on main.

Files with missing lines Patch % Lines
devito/arch/archinfo.py 0.00% 34 Missing ⚠️
devito/ir/iet/visitors.py 79.48% 22 Missing and 2 partials ⚠️
devito/ir/cgen/printer.py 89.78% 11 Missing and 3 partials ⚠️
devito/tools/dtypes_lowering.py 62.16% 8 Missing and 6 partials ⚠️
devito/operator/operator.py 82.75% 8 Missing and 2 partials ⚠️
devito/symbolics/extended_sympy.py 88.70% 6 Missing and 1 partial ⚠️
devito/arch/compiler.py 83.33% 6 Missing ⚠️
devito/ir/iet/nodes.py 60.00% 4 Missing and 2 partials ⚠️
devito/symbolics/inspection.py 73.91% 3 Missing and 3 partials ⚠️
devito/passes/iet/misc.py 88.23% 4 Missing ⚠️
... and 7 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2375      +/-   ##
==========================================
- Coverage   92.07%   91.99%   -0.08%     
==========================================
  Files         238      244       +6     
  Lines       47038    47690     +652     
  Branches     4143     4196      +53     
==========================================
+ Hits        43308    43874     +566     
- Misses       3078     3144      +66     
- Partials      652      672      +20     
Flag Coverage Δ
pytest-gpu-aomp-amdgpuX 72.64% <69.36%> (+0.04%) ⬆️
pytest-gpu-nvc-nvidiaX 73.71% <71.44%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 3195b9e to 6ac5b99 Compare May 27, 2024 16:52
"""
Add headers for complex arithmetic
"""
if configuration['language'] == 'cuda':
Copy link
Contributor

@EdCaunt EdCaunt May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could shorten this using:

headers = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
lib = headers.get(configuration['language'], 'complex.h')

dtype = self.dtype
if np.issubdtype(dtype, np.complexfloating):
func_name = 'c%s' % func_name
dtype = self.dtype(0).real.dtype
Copy link
Contributor

@EdCaunt EdCaunt May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newline between if blocks would improve readability

@@ -640,6 +640,25 @@ def test_tensor(self, func1):
op2 = Operator([Eq(f, f.dx) for f in f1.values()])
assert str(op1.ccode) == str(op2.ccode)

def test_complex(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test?

if exp not in parameters + boilerplate:
error("Missing parameter: %s" % exp)
assert exp in parameters + boilerplate
for expi in expected:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ex in expected?

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 0268781 to 2c80bf8 Compare May 28, 2024 17:16
@mloubout mloubout force-pushed the complex branch 3 times, most recently from e7a2791 to 05f8528 Compare May 30, 2024 18:19
@mloubout mloubout force-pushed the complex branch 3 times, most recently from a655632 to 7cee7fb Compare May 31, 2024 15:10
@@ -66,6 +67,23 @@ def test_maxpar_option(self):
assert trees[0][0] is trees[1][0]
assert trees[0][1] is not trees[1][1]

@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you try to take derivatives of an expression containing the imaginary unit? Something like (sympy.I*u).dx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sympy.I is just a sympy Atomic it's treated like any other symbol or number such as S.One or S.NegativeOne

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm just inclined to add tests to check that things work the way they 'should' since I've been tripped up in the past

@@ -270,3 +306,39 @@ def _rename_subdims(target, dimensions):
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}


_stdcomplex_defs = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this belongs to a complex.h

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, given all the other comments (this is the final one I'm writing), you may as well move the entire complex number lowering machinery to a separate python module such as complex.py within passes/iet/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this belongs to a complex.h

This is a lot more robust to generate the header in the same dir as the the generated code and avoid having to infer path from the devito dir.

complex.py

That's fine

@@ -192,6 +192,42 @@ def minimize_symbols(iet):
return iet, {}


@iet_pass
def complex_include(iet, language, compiler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include_complex

@iet_pass
def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full stop

@@ -243,6 +245,20 @@ def version(self):

return version

@property
def _complex_ctype(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we definitely don't want code-generation-related machinery in our Compiler classes.

The right thing to do is, instead, single-dispatching the Compiler class within our own compilation pass, which is responsible for the lowering of complex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as such, I don't think we need to add a custom name to Compiler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite agree here, the complex type are defined by the actual compiler and their standard, i.e gnu has _Complex and cpp has std::complex, adding complicated dispatch is overkill for something that is standardized at the language level

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree, singledispatch would achieve the same exact objective by dispatching based on the type.

Instead, what you've done here violates a crucial principle of the OO paradigm, that is, classes should have a well defined purpose. These classes are for jit-compiling a given string. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information

adding complicated dispatch

I don't think it's complicated at all. An Iet_pass receives the compiler and all you have to do is creating a series of functions based on single dispatch doing the same exact thing it's being done here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

g. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information

But that's not what it is, this defines the standard associated with the compiler which is c99->_Complex, c++11->std:complex

adding a pass that move the standard out of the compiler doesn't really make sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say (I hope) to add a pass. A single-dispatch function to retrieve some sort of type-specific information doesn't have to be a compiler pass.

But obviously our compiler pass would use it to get the code it needs

@@ -92,8 +92,12 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
datasize = int(reduce(mul, shape))
ctype = dtype_to_ctype(dtype)
# For complex number, allocate double the size of its real/imaginary part
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially useful elsewhere, so I'd move it into a function inside devito/tools/dtypes_lowering maybe?

@@ -460,6 +460,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Complex header if needed. Needs to be done before specialization
# as some specific cases require complex to be loaded first
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FFTW requires complex.h to be loaded first so that it's the type used rather than fftw_complex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "loaded first" you mean that the header file should stay at the very top of the includes list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really matter right now but might later


# For (cpp), need to define constant _Complex_I and missing mix-type
# std::complex arithmetic
if compiler._cpp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not ... : return

if np.issubdtype(dtype, np.complexfloating):
rtype = dtype(0).real.__class__
from devito import configuration
make = configuration['compiler']._complex_ctype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you really can't use global information here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not global because this is called within the switchconfig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from that code path above yes, but from another one?

we shouldn't access configuration in these remote places


@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('float')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't use global objects here

What we should do instead: leave CFLOAT generic

Extend the existing compiler pass to lower CFLOAT into something more specific such as CFLOAT_GCC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, with the cgen visitor now always using the local config from Oeprator this is never called with a global config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't matter, it's conceptually wrong

you're assuming you only go through _base_typ via that visitor, but who/what imposes that?

this is basically just a workaround to avoid a more graceful lowering, which you can do as explained in my first message

@mloubout mloubout force-pushed the complex branch 2 times, most recently from 3f0b8e2 to fdf1b36 Compare June 21, 2024 18:21
mloubout added 13 commits March 18, 2025 08:50

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin
@@ -67,6 +69,11 @@ class BasicOperator(Operator):
intensity of the generated kernel.
"""

SCALAR_MIN_TYPE = np.float16
"""
Minimum datatype for a scalar alias for a common sub-expression or CIRE temp.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"alias" applies just to CIRE. So more correct would be "... for a scalar arising from ..."

try:
eps = np.finfo(expr.base.dtype).resolution**2
except ValueError:
warning(f"Warning: dtype not recognized in SafeInv for {expr.base}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may drop "Warning:" inside

else:
return ()
return (), as_tuple(langbb.get('header-algorithm'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u see for example here how this should really be called header-math, as algorithm is kinda misleading since it's all but obvious to think immediately about CXX's algorithm lib

Copy link
Contributor

@FabioLuporini FabioLuporini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just left a few more final comments, but it's mostly nitpicks, so this now looks GTG to me.

Amazing revamp!

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 7279c43 to 5052b62 Compare March 18, 2025 15:43
Copy link
Contributor

@georgebisbas georgebisbas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! GTG

Comment on lines +464 to +465
header = c.Comment('Begin of %s+MPI setup' % self.langbb['name'])
footer = c.Comment('End of %s+MPI setup' % self.langbb['name'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f-string

Comment on lines +472 to +473
header = c.Comment('Begin of %s setup' % self.langbb['name'])
footer = c.Comment('End of %s setup' % self.langbb['name'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f-string

Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like my comments have all been addressed. G2g for me

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin
@mloubout mloubout merged commit 2247707 into main Mar 19, 2025
31 checks passed
@mloubout mloubout deleted the complex branch March 19, 2025 12:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API api (symbolics, types, ...) compiler feature-request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants