Skip to content

Commit 16cf29d

Browse files
committed
compiler: Relax intervals with upper from not mapped dimensions
1 parent dd31ca2 commit 16cf29d

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

devito/ir/clusters/cluster.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -358,22 +358,22 @@ def dspace(self):
358358
# Dimension-centric view of the data space
359359
intervals = IntervalGroup.generate('union', *parts.values())
360360

361+
# 'union' may consume intervals (values) from keys that have dimensions
362+
# not mapped to intervals e.g. issue #2235, resulting in reduced
363+
# iteration size. Here, we relax this mapped upper interval, by
364+
# intersecting intervals with matching only dimensions
365+
for f, v in parts.items():
366+
for i in v:
367+
# oobs check is not required but helps reduce
368+
# interval reconstruction
369+
if i.dim in oobs and i.dim in f.dimensions:
370+
ii = intervals[i.dim].intersection(v[i.dim])
371+
intervals = intervals.set_upper(i.dim, ii.upper)
372+
361373
# E.g., `db0 -> time`, but `xi NOT-> x`
362374
intervals = intervals.promote(lambda d: not d.is_Sub)
363375
intervals = intervals.zero(set(intervals.dimensions) - oobs)
364376

365-
# Upper bound of intervals including dimensions classified for
366-
# shifting should retain the "oobs" upper bound
367-
for f, v in parts.items():
368-
for i in v:
369-
if i.dim in oobs:
370-
try:
371-
if intervals[i.dim].upper > v[i.dim].upper and \
372-
bool(i.dim in f.dimensions):
373-
intervals = intervals.ceil(v[i.dim])
374-
except AttributeError:
375-
pass
376-
377377
return DataSpace(intervals, parts)
378378

379379
@cached_property

devito/ir/support/space.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,8 @@ def negate(self):
259259
def zero(self):
260260
return Interval(self.dim, 0, 0, self.stamp)
261261

262-
def ceil(self, o):
263-
if not self.is_compatible(o):
264-
return self._rebuild()
265-
return Interval(self.dim, self.lower, o.upper, self.stamp)
262+
def set_upper(self, v=0):
263+
return Interval(self.dim, self.lower, v, self.stamp)
266264

267265
def flip(self):
268266
return Interval(self.dim, self.upper, self.lower, self.stamp)
@@ -497,9 +495,9 @@ def zero(self, d=None):
497495

498496
return IntervalGroup(intervals, relations=self.relations, mode=self.mode)
499497

500-
def ceil(self, o=None):
501-
d = self.dimensions if o is None else as_tuple(o.dim)
502-
return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self],
498+
def set_upper(self, d, v=0):
499+
dims = as_tuple(d)
500+
return IntervalGroup([i.set_upper(v) if i.dim in dims else i for i in self],
503501
relations=self.relations, mode=self.mode)
504502

505503
def lift(self, d=None, v=None):

tests/test_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1993,7 +1993,7 @@ class TestInternals(object):
19931993

19941994
@pytest.mark.parametrize('nt, offset, epass',
19951995
([1, 1, True], [1, 2, False],
1996-
[5, 1, True], [3, 5, False],
1996+
[5, 3, True], [3, 5, False],
19971997
[4, 1, True], [5, 10, False]))
19981998
def test_indirection(self, nt, offset, epass):
19991999
grid = Grid(shape=(4, 4))

0 commit comments

Comments
 (0)