Skip to content

Commit f5db549

Browse files
committed
compiler: Relax intervals with upper from not mapped dimensions
1 parent 4a2b155 commit f5db549

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

devito/ir/clusters/cluster.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,22 @@ def dspace(self):
379379
# Dimension-centric view of the data space
380380
intervals = IntervalGroup.generate('union', *parts.values())
381381

382+
# 'union' may consume intervals (values) from keys that have dimensions
383+
# not mapped to intervals e.g. issue #2235, resulting in reduced
384+
# iteration size. Here, we relax this mapped upper interval, by
385+
# intersecting intervals with matching only dimensions
386+
for f, v in parts.items():
387+
for i in v:
388+
# oobs check is not required but helps reduce
389+
# interval reconstruction
390+
if i.dim in oobs and i.dim in f.dimensions:
391+
ii = intervals[i.dim].intersection(v[i.dim])
392+
intervals = intervals.set_upper(i.dim, ii.upper)
393+
382394
# E.g., `db0 -> time`, but `xi NOT-> x`
383395
intervals = intervals.promote(lambda d: not d.is_Sub)
384396
intervals = intervals.zero(set(intervals.dimensions) - oobs)
385397

386-
# Upper bound of intervals including dimensions classified for
387-
# shifting should retain the "oobs" upper bound
388-
for f, v in parts.items():
389-
for i in v:
390-
if i.dim in oobs:
391-
try:
392-
if intervals[i.dim].upper > v[i.dim].upper and \
393-
bool(i.dim in f.dimensions):
394-
intervals = intervals.ceil(v[i.dim])
395-
except AttributeError:
396-
pass
397-
398398
return DataSpace(intervals, parts)
399399

400400
@cached_property

devito/ir/support/space.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,8 @@ def negate(self):
258258
def zero(self):
259259
return Interval(self.dim, 0, 0, self.stamp)
260260

261-
def ceil(self, o):
262-
if not self.is_compatible(o):
263-
return self._rebuild()
264-
return Interval(self.dim, self.lower, o.upper, self.stamp)
261+
def set_upper(self, v=0):
262+
return Interval(self.dim, self.lower, v, self.stamp)
265263

266264
def flip(self):
267265
return Interval(self.dim, self.upper, self.lower, self.stamp)
@@ -496,9 +494,9 @@ def zero(self, d=None):
496494

497495
return IntervalGroup(intervals, relations=self.relations, mode=self.mode)
498496

499-
def ceil(self, o=None):
500-
d = self.dimensions if o is None else as_tuple(o.dim)
501-
return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self],
497+
def set_upper(self, d, v=0):
498+
dims = as_tuple(d)
499+
return IntervalGroup([i.set_upper(v) if i.dim in dims else i for i in self],
502500
relations=self.relations, mode=self.mode)
503501

504502
def lift(self, d=None, v=None):

tests/test_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,7 @@ class TestInternals:
19921992

19931993
@pytest.mark.parametrize('nt, offset, epass',
19941994
([1, 1, True], [1, 2, False],
1995-
[5, 1, True], [3, 5, False],
1995+
[5, 3, True], [3, 5, False],
19961996
[4, 1, True], [5, 10, False]))
19971997
def test_indirection(self, nt, offset, epass):
19981998
grid = Grid(shape=(4, 4))

0 commit comments

Comments
 (0)