@@ -2447,8 +2447,6 @@ def test_ftemps_option(self):
2447
2447
t = grid .stepping_dim
2448
2448
2449
2449
nthreads = 2
2450
- x0_blk0_size = 8
2451
- y0_blk0_size = 8
2452
2450
2453
2451
u = TimeFunction (name = 'u' , grid = grid , space_order = 3 )
2454
2452
u1 = TimeFunction (name = "u" , grid = grid , space_order = 3 )
@@ -2475,6 +2473,12 @@ def test_ftemps_option(self):
2475
2473
with pytest .raises (InvalidArgument ):
2476
2474
op1 (time_M = 1 , u = u1 )
2477
2475
2476
+ block_dims = [i for i in op1 .dimensions if i .is_Block and i ._depth == 1 ]
2477
+ assert len (block_dims ) == 2
2478
+ mapper = {d .root : d for d in block_dims }
2479
+ x0_blk0_size = mapper [x ]._arg_defaults ()[mapper [x ].step .name ]
2480
+ y0_blk0_size = mapper [y ]._arg_defaults ()[mapper [y ].step .name ]
2481
+
2478
2482
# Prepare to run op1
2479
2483
shape = [nthreads , x0_blk0_size , y0_blk0_size , grid .shape [- 1 ]]
2480
2484
ofuncs = [i .make (shape ) for i in op1 .temporaries ]
@@ -2845,7 +2849,7 @@ def test_fullopt(self):
2845
2849
2846
2850
# Check expected opcount/oi
2847
2851
assert summary [('section1' , None )].ops == 92
2848
- assert np .isclose (summary [('section1' , None )].oi , 2.072 , atol = 0.001 )
2852
+ assert np .isclose (summary [('section1' , None )].oi , 1.99 , atol = 0.001 )
2849
2853
2850
2854
# With optimizations enabled, there should be exactly four BlockDimensions
2851
2855
op = wavesolver .op_fwd ()
0 commit comments