diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 7f09afd1c0..4223e789e7 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -121,6 +121,13 @@ class Schedule(Queue): Dimension in both Clusters. """ + FISSION_THRESHOLD = 2 + """ + The maximum number of iteration Dimensions such that we consider fissioning + a sequence of Clusters to increase parallelism. IOW, if there are more than + this number of iteration Dimensions, we do not even try to fission. + """ + @timed_pass(name='schedule') def process(self, clusters): return self._process_fatd(clusters, 1) @@ -134,7 +141,8 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): # Take the innermost Dimension -- no other Clusters other than those in # `clusters` are supposed to share it - candidates = prefix[-1].dim._defines + dim = prefix[-1].dim + candidates = dim._defines scope = Scope(flatten(c.exprs for c in clusters)) @@ -157,7 +165,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): # Schedule Clusters over different IterationSpaces if this increases # parallelism for i in range(1, len(clusters)): - if self._break_for_parallelism(scope, candidates, i): + if self._break_for_parallelism(scope, dim, i): return self.callback(clusters[:i], prefix, clusters[i:] + backlog, candidates | known_break) @@ -189,7 +197,19 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): return processed + self.callback(backlog, prefix) - def _break_for_parallelism(self, scope, candidates, timestamp): + def _break_for_parallelism(self, scope, dim, timestamp): + candidates = dim._defines + + # Do not fission for data locality reasons if there's enough potential + # parallelism in the inner Dimensions + try: + ispace, = {e.ispace for e in scope.exprs[:timestamp]} + _, ispace1 = ispace.split(dim) + if len(ispace1.itdims) > self.FISSION_THRESHOLD: + return False + except ValueError: + pass + # `test` will be True if there's at least one data-dependence that would # break parallelism test = False