Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Add extraction utility functions #2554

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions devito/symbolics/extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from functools import singledispatch

import sympy

from devito.finite_differences.differentiable import Mul
from devito.finite_differences.derivative import Derivative
from devito.types import Eq, TimeFunction
from devito.operations.solve import eval_time_derivatives
from devito.symbolics import retrieve_functions


def separate_eqn(eqn, target):
"""
Separate the equation into two separate expressions,
where F(target) = b.
"""
zeroed_eqn = Eq(eqn.lhs - eqn.rhs, 0)
zeroed_eqn = eval_time_derivatives(zeroed_eqn.lhs)
target_funcs = set(generate_targets(zeroed_eqn, target))
b, F_target = remove_targets(zeroed_eqn, target_funcs)
return -b, F_target, target_funcs


def generate_targets(eq, target):
"""
Extract all the functions that share the same time index as the target
but may have different spatial indices.
"""
funcs = retrieve_functions(eq)
if isinstance(target, TimeFunction):
time_idx = target.indices[target.time_dim]
targets = [
f for f in funcs if f.function is target.function and time_idx
in f.indices
]
else:
targets = [f for f in funcs if f.function is target.function]
return targets


@singledispatch
def remove_targets(expr, targets):
return (0, expr) if expr in targets else (expr, 0)


@remove_targets.register(sympy.Add)
def _(expr, targets):
if not any(expr.has(t) for t in targets):
return (expr, 0)

args_b, args_F = zip(*(remove_targets(a, targets) for a in expr.args))
return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False))


@remove_targets.register(Mul)
def _(expr, targets):
if not any(expr.has(t) for t in targets):
return (expr, 0)

args_b, args_F = zip(*[remove_targets(a, targets) if any(a.has(t) for t in targets)
else (a, a) for a in expr.args])
return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False))


@remove_targets.register(Derivative)
def _(expr, targets):
return (0, expr) if any(expr.has(t) for t in targets) else (expr, 0)


@singledispatch
def centre_stencil(expr, target):
"""
Extract the centre stencil from an expression. Its coefficient is what
would appear on the diagonal of the matrix system if the matrix were
formed explicitly.
"""
return expr if expr == target else 0


@centre_stencil.register(sympy.Add)
def _(expr, target):
if not expr.has(target):
return 0

args = [centre_stencil(a, target) for a in expr.args]
return expr.func(*args, evaluate=False)


@centre_stencil.register(Mul)
def _(expr, target):
if not expr.has(target):
return 0

args = []
for a in expr.args:
if not a.has(target):
args.append(a)
else:
args.append(centre_stencil(a, target))

return expr.func(*args, evaluate=False)


@centre_stencil.register(Derivative)
def _(expr, target):
if not expr.has(target):
return 0
args = [centre_stencil(a, target) for a in expr.evaluate.args]
return expr.evaluate.func(*args)
239 changes: 239 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
ReservedWord, ListInitializer, ccode, uxreplace,
retrieve_derivatives)
from devito.symbolics.extraction import separate_eqn, centre_stencil
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
ComponentAccess, StencilDimension, Symbol as dSymbol)
Expand Down Expand Up @@ -805,3 +806,241 @@ def test_assumptions(self, op, expr, assumptions, expected):
assumptions = eval(assumptions)
expected = eval(expected)
assert evalrel(op, eqn, assumptions) == expected


@pytest.mark.parametrize('eqn, target, expected', [
('Eq(f1.laplace, g1)',
'f1', ('g1(x, y)', 'Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))')),
('Eq(g1, f1.laplace)',
'f1', ('-g1(x, y)', '-Derivative(f1(x, y), (x, 2)) - Derivative(f1(x, y), (y, 2))')),
('Eq(g1, f1.laplace)', 'g1',
('Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))', 'g1(x, y)')),
('Eq(f1 + f1.laplace, g1)', 'f1', ('g1(x, y)',
'f1(x, y) + Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))')),
('Eq(g1.dx + f1.dx, g1)', 'f1',
('g1(x, y) - Derivative(g1(x, y), x)', 'Derivative(f1(x, y), x)')),
('Eq(g1.dx + f1.dx, g1)', 'g1',
('-Derivative(f1(x, y), x)', '-g1(x, y) + Derivative(g1(x, y), x)')),
('Eq(f1 * g1.dx, g1)', 'g1', ('0', 'f1(x, y)*Derivative(g1(x, y), x) - g1(x, y)')),
('Eq(f1 * g1.dx, g1)', 'f1', ('g1(x, y)', 'f1(x, y)*Derivative(g1(x, y), x)')),
('Eq((f1 * g1.dx).dy, f1)', 'f1',
('0', '-f1(x, y) + Derivative(f1(x, y)*Derivative(g1(x, y), x), y)')),
('Eq((f1 * g1.dx).dy, f1)', 'g1',
('f1(x, y)', 'Derivative(f1(x, y)*Derivative(g1(x, y), x), y)')),
('Eq(f2.laplace, g2)', 'g2',
('-Derivative(f2(t, x, y), (x, 2)) - Derivative(f2(t, x, y), (y, 2))',
'-g2(t, x, y)')),
('Eq(f2.laplace, g2)', 'f2', ('g2(t, x, y)',
'Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2))')),
('Eq(f2.laplace, f2)', 'f2', ('0',
'-f2(t, x, y) + Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2))')),
('Eq(f2*g2, f2)', 'f2', ('0', 'f2(t, x, y)*g2(t, x, y) - f2(t, x, y)')),
('Eq(f2*g2, f2)', 'g2', ('f2(t, x, y)', 'f2(t, x, y)*g2(t, x, y)')),
('Eq(g2*f2.laplace, f2)', 'g2', ('f2(t, x, y)',
'(Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2)))*g2(t, x, y)')),
('Eq(f2.forward, f2)', 'f2.forward', ('f2(t, x, y)', 'f2(t + dt, x, y)')),
('Eq(f2.forward, f2)', 'f2', ('-f2(t + dt, x, y)', '-f2(t, x, y)')),
('Eq(f2.forward.laplace, f2)', 'f2.forward', ('f2(t, x, y)',
'Derivative(f2(t + dt, x, y), (x, 2)) + Derivative(f2(t + dt, x, y), (y, 2))')),
('Eq(f2.forward.laplace, f2)', 'f2',
('-Derivative(f2(t + dt, x, y), (x, 2)) - Derivative(f2(t + dt, x, y), (y, 2))',
'-f2(t, x, y)')),
('Eq(f2.laplace + f2.forward.laplace, g2)', 'f2.forward',
('g2(t, x, y) - Derivative(f2(t, x, y), (x, 2)) - Derivative(f2(t, x, y), (y, 2))',
'Derivative(f2(t + dt, x, y), (x, 2)) + Derivative(f2(t + dt, x, y), (y, 2))')),
('Eq(g2.laplace, f2 + g2.forward)', 'g2.forward',
('f2(t, x, y) - Derivative(g2(t, x, y), (x, 2)) - Derivative(g2(t, x, y), (y, 2))',
'-g2(t + dt, x, y)'))
])
def test_separate_eqn(eqn, target, expected):
"""
Test the separate_eqn function.
"""
grid = Grid((2, 2))

so = 2

f1 = Function(name='f1', grid=grid, space_order=so) # noqa
g1 = Function(name='g1', grid=grid, space_order=so) # noqa

f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa
g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa

b, F, _ = separate_eqn(eval(eqn), eval(target))
expected_b, expected_F = expected

assert str(b) == expected_b
assert str(F) == expected_F


@pytest.mark.parametrize('eqn, target, expected', [
('Eq(f1.laplace, g1).evaluate', 'f1',
(
'g1(x, y)',
'-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2 '
'- 2.0*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2'
)),
('Eq(g1, f1.laplace).evaluate', 'f1',
(
'-g1(x, y)',
'-(-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2) '
'- (-2.0*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2)'
)),
('Eq(g1, f1.laplace).evaluate', 'g1',
(
'-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2 '
'- 2.0*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2',
'g1(x, y)'
)),
('Eq(f1 + f1.laplace, g1).evaluate', 'f1',
(
'g1(x, y)',
'-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2 - 2.0'
'*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2 + f1(x, y)'
)),
('Eq(g1.dx + f1.dx, g1).evaluate', 'f1',
(
'-(-g1(x, y)/h_x + g1(x + h_x, y)/h_x) + g1(x, y)',
'-f1(x, y)/h_x + f1(x + h_x, y)/h_x'
)),
('Eq(g1.dx + f1.dx, g1).evaluate', 'g1',
(
'-(-f1(x, y)/h_x + f1(x + h_x, y)/h_x)',
'-g1(x, y)/h_x + g1(x + h_x, y)/h_x - g1(x, y)'
)),
('Eq(f1 * g1.dx, g1).evaluate', 'g1',
(
'0', '(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y) - g1(x, y)'
)),
('Eq(f1 * g1.dx, g1).evaluate', 'f1',
(
'g1(x, y)', '(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y)'
)),
('Eq((f1 * g1.dx).dy, f1).evaluate', 'f1',
(
'0', '(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y) '
'+ (-g1(x, y + h_y)/h_x + g1(x + h_x, y + h_y)/h_x)*f1(x, y + h_y)/h_y '
'- f1(x, y)'
)),
('Eq((f1 * g1.dx).dy, f1).evaluate', 'g1',
(
'f1(x, y)', '(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y) + '
'(-g1(x, y + h_y)/h_x + g1(x + h_x, y + h_y)/h_x)*f1(x, y + h_y)/h_y'
)),
('Eq(f2.laplace, g2).evaluate', 'g2',
(
'-(-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + f2(t, x + h_x, y)'
'/h_x**2) - (-2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2 + '
'f2(t, x, y + h_y)/h_y**2)', '-g2(t, x, y)'
)),
('Eq(f2.laplace, g2).evaluate', 'f2',
(
'g2(t, x, y)', '-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + '
'f2(t, x + h_x, y)/h_x**2 - 2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)'
'/h_y**2 + f2(t, x, y + h_y)/h_y**2'
)),
('Eq(f2.laplace, f2).evaluate', 'f2',
(
'0', '-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + '
'f2(t, x + h_x, y)/h_x**2 - 2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2'
' + f2(t, x, y + h_y)/h_y**2 - f2(t, x, y)'
)),
('Eq(g2*f2.laplace, f2).evaluate', 'g2',
(
'f2(t, x, y)', '(-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + '
'f2(t, x + h_x, y)/h_x**2 - 2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2'
' + f2(t, x, y + h_y)/h_y**2)*g2(t, x, y)'
)),
('Eq(f2.forward.laplace, f2).evaluate', 'f2.forward',
(
'f2(t, x, y)', '-2.0*f2(t + dt, x, y)/h_x**2 + f2(t + dt, x - h_x, y)/h_x**2'
' + f2(t + dt, x + h_x, y)/h_x**2 - 2.0*f2(t + dt, x, y)/h_y**2 + '
'f2(t + dt, x, y - h_y)/h_y**2 + f2(t + dt, x, y + h_y)/h_y**2'
)),
('Eq(f2.forward.laplace, f2).evaluate', 'f2',
(
'-(-2.0*f2(t + dt, x, y)/h_x**2 + f2(t + dt, x - h_x, y)/h_x**2 + '
'f2(t + dt, x + h_x, y)/h_x**2) - (-2.0*f2(t + dt, x, y)/h_y**2 + '
'f2(t + dt, x, y - h_y)/h_y**2 + f2(t + dt, x, y + h_y)/h_y**2)',
'-f2(t, x, y)'
)),
('Eq(f2.laplace + f2.forward.laplace, g2).evaluate', 'f2.forward',
(
'-(-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + f2(t, x + h_x, y)/'
'h_x**2) - (-2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2 + '
'f2(t, x, y + h_y)/h_y**2) + g2(t, x, y)', '-2.0*f2(t + dt, x, y)/h_x**2 + '
'f2(t + dt, x - h_x, y)/h_x**2 + f2(t + dt, x + h_x, y)/h_x**2 - 2.0*'
'f2(t + dt, x, y)/h_y**2 + f2(t + dt, x, y - h_y)/h_y**2 + '
'f2(t + dt, x, y + h_y)/h_y**2'
)),
('Eq(g2.laplace, f2 + g2.forward).evaluate', 'g2.forward',
(
'-(-2.0*g2(t, x, y)/h_x**2 + g2(t, x - h_x, y)/h_x**2 + '
'g2(t, x + h_x, y)/h_x**2) - (-2.0*g2(t, x, y)/h_y**2 + g2(t, x, y - h_y)'
'/h_y**2 + g2(t, x, y + h_y)/h_y**2) + f2(t, x, y)', '-g2(t + dt, x, y)'
))
])
def test_separate_eval_eqn(eqn, target, expected):
"""
Test the separate_eqn function on pre-evaluated equations.
"""
grid = Grid((2, 2))

so = 2

f1 = Function(name='f1', grid=grid, space_order=so) # noqa
g1 = Function(name='g1', grid=grid, space_order=so) # noqa

f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa
g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa

b, F, _ = separate_eqn(eval(eqn), eval(target))
expected_b, expected_F = expected

assert str(b) == expected_b
assert str(F) == expected_F


@pytest.mark.parametrize('expr, so, target, expected', [
('f1.laplace', 2, 'f1', '-2.0*f1(x, y)/h_y**2 - 2.0*f1(x, y)/h_x**2'),
('f1 + f1.laplace', 2, 'f1',
'f1(x, y) - 2.0*f1(x, y)/h_y**2 - 2.0*f1(x, y)/h_x**2'),
('g1.dx + f1.dx', 2, 'f1', '-f1(x, y)/h_x'),
('10 + f1.dx2', 2, 'g1', '0'),
('(f1 * g1.dx).dy', 2, 'f1',
'(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y)'),
('(f1 * g1.dx).dy', 2, 'g1', '-(-1/h_y)*f1(x, y)*g1(x, y)/h_x'),
('f2.laplace', 2, 'f2', '-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2'),
('f2*g2', 2, 'f2', 'f2(t, x, y)*g2(t, x, y)'),
('g2*f2.laplace', 2, 'f2',
'(-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2)*g2(t, x, y)'),
('f2.forward', 2, 'f2.forward', 'f2(t + dt, x, y)'),
('f2.forward.laplace', 2, 'f2.forward',
'-2.0*f2(t + dt, x, y)/h_y**2 - 2.0*f2(t + dt, x, y)/h_x**2'),
('f2.laplace + f2.forward.laplace', 2, 'f2.forward',
'-2.0*f2(t + dt, x, y)/h_y**2 - 2.0*f2(t + dt, x, y)/h_x**2'),
('f2.laplace + f2.forward.laplace', 2,
'f2', '-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2'),
('f2.laplace', 4, 'f2', '-2.5*f2(t, x, y)/h_y**2 - 2.5*f2(t, x, y)/h_x**2'),
('f2.laplace + f2.forward.laplace', 4, 'f2.forward',
'-2.5*f2(t + dt, x, y)/h_y**2 - 2.5*f2(t + dt, x, y)/h_x**2'),
('f2.laplace + f2.forward.laplace', 4, 'f2',
'-2.5*f2(t, x, y)/h_y**2 - 2.5*f2(t, x, y)/h_x**2'),
('f2.forward*f2.forward.laplace', 4, 'f2.forward',
'(-2.5*f2(t + dt, x, y)/h_y**2 - 2.5*f2(t + dt, x, y)/h_x**2)*f2(t + dt, x, y)')
])
def test_centre_stencil(expr, so, target, expected):
"""
Test extraction of centre stencil from an equation.
"""
grid = Grid((2, 2))

f1 = Function(name='f1', grid=grid, space_order=so) # noqa
g1 = Function(name='g1', grid=grid, space_order=so) # noqa

f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa
g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa

centre = centre_stencil(eval(expr), eval(target))

assert str(centre) == expected
Loading