@@ -2621,22 +2621,7 @@ def test_grad_until_and_truncate(self):
26212621 utt .assert_allclose (pytensor_gradient , self .numpy_gradient )
26222622
26232623 def test_grad_until_and_truncate_sequence_taps (self ):
2624- n = 3
2625- r = scan (
2626- lambda x , y , u : (x * y , until (y > u )),
2627- sequences = dict (input = self .x , taps = [- 2 , 0 ]),
2628- non_sequences = [self .threshold ],
2629- truncate_gradient = n ,
2630- return_updates = False ,
2631- )
2632- g = grad (r .sum (), self .x )
2633- f = function ([self .x , self .threshold ], [r , g ])
2634- _pytensor_output , pytensor_gradient = f (self .seq , 6 )
2635-
2636- # Gradient computed by hand:
2637- numpy_grad = np .array ([0 , 0 , 0 , 5 , 6 , 10 , 4 , 5 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
2638- numpy_grad = numpy_grad .astype (config .floatX )
2639- utt .assert_allclose (pytensor_gradient , numpy_grad )
2624+ ScanCompatibilityTests .check_grad_until_and_truncate_sequence_taps (mode = None )
26402625
26412626
26422627def test_mintap_onestep ():
@@ -4431,3 +4416,26 @@ def check_higher_order_derivative(mode):
44314416 np .testing .assert_allclose (gg_res , (16 * 15 ) * x_test ** 14 )
44324417 # FIXME: All implementations of Scan seem to get this one wrong!
44334418 # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
4419+
4420+
4421+ @staticmethod
4422+ def check_grad_until_and_truncate_sequence_taps (mode ):
4423+ """Test case where we need special behavior of zeroing out sequences in Scan"""
4424+ x = pt .vector ("x" )
4425+ threshold = pt .scalar (name = "threshold" , dtype = "int64" )
4426+
4427+ r = scan (
4428+ lambda x , y , u : (x * y , until (y > u )),
4429+ sequences = dict (input = x , taps = [- 2 , 0 ]),
4430+ non_sequences = [threshold ],
4431+ truncate_gradient = 3 ,
4432+ return_updates = False ,
4433+ )
4434+ g = grad (r .sum (), x )
4435+ f = function ([x , threshold ], [r , g ], mode = mode )
4436+ _ , grad_res = f (np .arange (15 , dtype = x .dtype ), 6 )
4437+
4438+ # Gradient computed by hand:
4439+ grad_expected = np .array ([0 , 0 , 0 , 5 , 6 , 10 , 4 , 5 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
4440+ grad_expected = grad_expected .astype (config .floatX )
4441+ np .testing .assert_allclose (grad_res , grad_expected )
0 commit comments