Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ class Schedule(Queue):
Dimension in both Clusters.
"""

FISSION_THRESHOLD = 2
"""
The maximum number of iteration Dimensions such that we consider fissioning
a sequence of Clusters to increase parallelism. IOW, if there are more than
this number of iteration Dimensions, we do not even try to fission.
"""

@timed_pass(name='schedule')
def process(self, clusters):
return self._process_fatd(clusters, 1)
Expand All @@ -134,7 +141,8 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):

# Take the innermost Dimension -- no other Clusters other than those in
# `clusters` are supposed to share it
candidates = prefix[-1].dim._defines
dim = prefix[-1].dim
candidates = dim._defines

scope = Scope(flatten(c.exprs for c in clusters))

Expand All @@ -157,7 +165,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
# Schedule Clusters over different IterationSpaces if this increases
# parallelism
for i in range(1, len(clusters)):
if self._break_for_parallelism(scope, candidates, i):
if self._break_for_parallelism(scope, dim, i):
return self.callback(clusters[:i], prefix, clusters[i:] + backlog,
candidates | known_break)

Expand Down Expand Up @@ -189,7 +197,19 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):

return processed + self.callback(backlog, prefix)

def _break_for_parallelism(self, scope, candidates, timestamp):
def _break_for_parallelism(self, scope, dim, timestamp):
candidates = dim._defines

# Do not fission for data locality reasons if there's enough potential
# parallelism in the inner Dimensions
try:
ispace, = {e.ispace for e in scope.exprs[:timestamp]}
_, ispace1 = ispace.split(dim)
if len(ispace1.itdims) > self.FISSION_THRESHOLD:
return False
except ValueError:
pass

# `test` will be True if there's at least one data-dependence that would
# break parallelism
test = False
Expand Down
Loading