@@ -3181,40 +3181,9 @@ def onestep(x, x_tm4):
31813181 f = function ([seq ], results [1 ])
31823182 assert np .all (exp_out == f (inp ))
31833183
3184- def test_shared_borrow (self ):
3185- """
3186- This tests two things. The first is a bug occurring when scan wrongly
3187- used the borrow flag. The second thing it that Scan's infer_shape()
3188- method will be able to remove the Scan node from the graph in this
3189- case.
3190- """
3191-
3192- inp = np .arange (10 ).reshape (- 1 , 1 ).astype (config .floatX )
3193- exp_out = np .zeros ((10 , 1 )).astype (config .floatX )
3194- exp_out [4 :] = inp [:- 4 ]
3195-
3196- def onestep (x , x_tm4 ):
3197- return x , x_tm4
3198-
3199- seq = matrix ()
3200- initial_value = shared (np .zeros ((4 , 1 ), dtype = config .floatX ))
3201- outputs_info = [{"initial" : initial_value , "taps" : [- 4 ]}, None ]
3202- results = scan (
3203- fn = onestep , sequences = seq , outputs_info = outputs_info , return_updates = False
3204- )
3205- sharedvar = shared (np .zeros ((1 , 1 ), dtype = config .floatX ))
3206- updates = {sharedvar : results [0 ][- 1 :]}
3207-
3208- f = function ([seq ], results [1 ], updates = updates )
3209-
3210- # This fails if scan uses wrongly the borrow flag
3211- assert np .all (exp_out == f (inp ))
3212-
3213- # This fails if Scan's infer_shape() is unable to remove the Scan
3214- # node from the graph.
3215- f_infershape = function ([seq ], results [1 ].shape , mode = "FAST_RUN" )
3216- scan_nodes_infershape = scan_nodes_from_fct (f_infershape )
3217- assert len (scan_nodes_infershape ) == 0
3184+ @pytest .mark .parametrize ("static_shape" , (True , False ))
3185+ def test_aliased_inner_outputs (self , static_shape ):
3186+ ScanCompatibilityTests .check_aliased_inner_outputs (static_shape , mode = None )
32183187
32193188 def test_memory_reuse_with_outputs_as_inputs (self ):
32203189 """
@@ -4417,7 +4386,6 @@ def check_higher_order_derivative(mode):
44174386 # FIXME: All implementations of Scan seem to get this one wrong!
44184387 # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
44194388
4420-
44214389 @staticmethod
44224390 def check_grad_until_and_truncate_sequence_taps (mode ):
44234391 """Test case where we need special behavior of zeroing out sequences in Scan"""
@@ -4439,3 +4407,66 @@ def check_grad_until_and_truncate_sequence_taps(mode):
44394407 grad_expected = np .array ([0 , 0 , 0 , 5 , 6 , 10 , 4 , 5 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
44404408 grad_expected = grad_expected .astype (config .floatX )
44414409 np .testing .assert_allclose (grad_res , grad_expected )
4410+
4411+ @staticmethod
4412+ def check_aliased_inner_outputs (static_shape , mode ):
4413+ """
4414+ This tests two things. The first is a bug occurring when scan wrongly
4415+ used the borrow flag. The second thing it that Scan's infer_shape()
4416+ method will be able to remove the Scan node from the graph in this
4417+ case.
4418+
4419+ Here is pure python equivalent of the problem we want to avoid:
4420+ ```python
4421+ def scan(seq, initval):
4422+ # Due to memory optimization we override values of mitsot as we iterate
4423+ # That's why mitsot has shape (4, 1) and not (14, 1)
4424+ mitsot = np.zeros((4, 1))
4425+ mitsot[:4] = initval
4426+ nitsot = np.zeros((10, 1))
4427+ for i, s in enumerate(seq):
4428+ # Incorrect results
4429+ mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
4430+ # Correct results
4431+ # mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
4432+
4433+ return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
4434+
4435+ scan(np.arange(10), np.zeros((4, 1)))
4436+ ```
4437+ """
4438+
4439+ def onestep (seq , seq_tm4 ):
4440+ # Recurring output is just each value of seq
4441+ # And we further map the tap -4 as a new output
4442+ return seq , seq_tm4
4443+
4444+ # Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
4445+ # Otherwise we would be working with scalars and memory alias wouldn't be a concern
4446+ seq = matrix (shape = (10 , 1 ) if static_shape else (None , None ), name = "seq" )
4447+ init = matrix (shape = (4 , 1 ) if static_shape else (None , None ), name = "init" )
4448+ outputs_info = [{"initial" : init , "taps" : [- 4 ]}, None ]
4449+ [out_seq , out_seq_tm4 ] = scan (
4450+ fn = onestep ,
4451+ sequences = seq ,
4452+ outputs_info = outputs_info ,
4453+ return_updates = False ,
4454+ )
4455+
4456+ f = function ([seq , init ], [out_seq [- 1 ].ravel (), out_seq_tm4 .ravel ()], mode = mode )
4457+
4458+ seq_test_val = np .arange (10 , dtype = config .floatX )[:, None ]
4459+ init_test_val = np .zeros ((4 , 1 ), dtype = config .floatX )
4460+
4461+ res0 , res1 = f (seq_test_val , init_test_val )
4462+ expected_res0 = np .array ([9 ], dtype = config .floatX )
4463+ expected_res1 = np .zeros (10 , dtype = config .floatX )
4464+ expected_res1 [4 :] = np .arange (6 )
4465+ np .testing .assert_array_equal (res0 , expected_res0 )
4466+ np .testing .assert_array_equal (res1 , expected_res1 )
4467+
4468+ # This fails if Scan's infer_shape() is unable to remove the Scan
4469+ # node from the graph.
4470+ f_infershape = function ([seq , init ], out_seq_tm4 [1 ].shape )
4471+ scan_nodes_infershape = scan_nodes_from_fct (f_infershape )
4472+ assert len (scan_nodes_infershape ) == 0
0 commit comments