diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 083eb36354..7f09afd1c0 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -189,19 +189,27 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): return processed + self.callback(backlog, prefix) - def _break_for_parallelism(self, scope, candidates, i): + def _break_for_parallelism(self, scope, candidates, timestamp): # `test` will be True if there's at least one data-dependence that would # break parallelism test = False - for d in scope.d_from_access_gen(scope.a_query(i)): - if d.is_local or d.is_storage_related(candidates): + for dep in scope.d_all_gen(): + if dep.timestamp > timestamp: + continue + + if dep.is_local or dep.is_storage_related(candidates): # Would break a dependence on storage return False - if any(d.is_carried(i) for i in candidates): # noqa: SIM102 - if (d.is_flow and d.is_lex_negative) or (d.is_anti and d.is_lex_positive): + + if any(dep.is_carried(i) for i in candidates): + test0 = dep.is_flow and dep.is_lex_negative + test1 = dep.is_anti and dep.is_lex_positive + if test0 or test1: # Would break a data dependence return False - test = test or (bool(d.cause & candidates) and not d.is_lex_equal) + + test = test or (bool(dep.cause & candidates) and not dep.is_lex_equal) + return test diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index abc817ff73..3bad983426 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -558,6 +558,10 @@ def function(self): def findices(self): return self.source.findices + @property + def timestamp(self): + return max(self.source.timestamp, self.sink.timestamp) + @cached_property def distance(self): return self.source.distance(self.sink) diff --git a/tests/test_operator.py b/tests/test_operator.py index 3e417d74b8..49e5614968 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1485,7 +1485,7 @@ def test_no_fission_as_illegal(self, exprs): (('Eq(ti0[x,y,z], ti0[x,y,z] + ti1[x,y,z])', 'Eq(ti1[x,y,z], ti3[x,y,z])', 'Eq(ti3[x,y,z], ti1[x,y,z+1] + 1.)'), - '+++++', ['xyz', 'xyz', 'xyz'], 'xyzzz'), + '++++', ['xyz', 'xyz'], 'xyzz'), # 1) WAR 1->2, 2->3 (('Eq(ti0[x,y,z], ti0[x,y,z] + ti1[x,y,z])', 'Eq(ti1[x,y,z], ti0[x,y,z+1])', @@ -1533,7 +1533,7 @@ def test_no_fission_as_illegal(self, exprs): (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z-2])', 'Eq(tw[t,x,y,z], tv[t,x,y+1,z] + 1.)'), - '++++++++', ['txyz', 'txyz', 'txyz'], 'txyzyzyz'), + '+++++++', ['txyz', 'txyz', 'txyz'], 'txyzzyz'), # 10) WAR 1->2; WAW 1->3 (('Eq(tu[t-1,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])', @@ -1593,6 +1593,15 @@ def test_no_fission_as_illegal(self, exprs): 'Eq(tu[t+1,xi,yi,zi], tv[t+1,xi,yi,zi] + tv[t+1,xi+1,yi,zi])', 'Eq(tw[t+1,x,y,z], tv[t+1,x,y,z] + tv[t+1,x+1,y,z])'), '++++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzxyz'), + # 20) RAW 1->3, WAR 2->3; expected=2 + # It's important the split occurs after the second equation, since the + # first two can safely be fused together (previously, instead, + # due to an issue in `break_for_parallelism`, the eqns were split over + # three loop nests) + (('Eq(tu[t+1,x,y,z], tu[t,x,y,z] + tu[t,x+1,y,z])', + 'Eq(tv[t+1,x,y,z], tv[t,x,y,z] + 1)', + 'Eq(tw[t+1,x,y,z], tu[t+1,x+1,y,z] + tw[t,x+1,y,z] + tv[t+1,x+1,y,z])'), + '+++++++', ['txyz', 'txyz'], 'txyzxyz'), ]) def test_consistency_anti_dependences(self, exprs, directions, expected, visit): """