From 14f0aafc22125f65f538568d945e7986e23ca592 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 28 Jan 2026 12:41:09 +0000 Subject: [PATCH] tests: Add an additional test for loop scheduling --- tests/test_operator.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_operator.py b/tests/test_operator.py index 49e5614968..2e98a7603d 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1444,6 +1444,42 @@ def test_fission_for_parallelism(self, exprs, fissioned, shared): # Fission happened assert i[exp_depth].dim is exp_dim + def test_fission_for_parallelism_b(self): + so = 2 + grid = Grid(shape=(10, 10, 10)) + x, y, z = grid.dimensions + + f0 = TimeFunction(name='f0', grid=grid, space_order=so, staggered=(x,)) + f1 = TimeFunction(name='f1', grid=grid, space_order=so, staggered=(y,)) + + f2 = TimeFunction(name='f2', grid=grid, space_order=so, staggered=(x, z)) + f3 = TimeFunction(name='f3', grid=grid, space_order=so, staggered=(y, z)) + + f4 = TimeFunction(name='f4', grid=grid, space_order=so, staggered=NODE) + + eq0 = Eq(f2, f0.dz) + eq1 = Eq(f3, f1.dz) + eq2 = Eq(f4, f2 + f3) + + op = Operator([eq0, eq1, eq2]) + + trees = retrieve_iteration_tree(op) + + # First two equations should be fused for parallelism, but the third should be + # fissioned + assert len(trees) == 2 + assert len(trees[0][-1].nodes[0].exprs) == 2 + assert len(trees[1][-1].nodes[0].exprs) == 1 + + def check_expr_contents(expr, expected): + assert all(f.base in expr.expr_symbols for f in expected) + + # Check expressions match equations + check_expr_contents(trees[0][-1].nodes[0].exprs[0], (f2, f0)) + check_expr_contents(trees[0][-1].nodes[0].exprs[1], (f3, f1)) + + check_expr_contents(trees[1][-1].nodes[0].exprs[0], (f4, f2, f3)) + @pytest.mark.parametrize('exprs', [ # 0) Storage related dependence ('Eq(u.forward, v)', 'Eq(v, u.dxl)'),