Skip to content

Commit dca7f5d

Browse files
committed
Numba Scan: prevent alias of outputs
Also simplified test. Shared variables aren't needed for the test and clobber it
1 parent 06244f3 commit dca7f5d

File tree

3 files changed

+76
-41
lines changed

3 files changed

+76
-41
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
from numba import types
66
from numba.extending import overload
77

8-
from pytensor import In
9-
from pytensor.compile.function.types import add_supervisor_to_fgraph
8+
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
9+
from pytensor.compile.io import In, Out
1010
from pytensor.compile.mode import NUMBA, get_mode
1111
from pytensor.link.numba.cache import compile_numba_function_src
1212
from pytensor.link.numba.dispatch import basic as numba_basic
1313
from pytensor.link.numba.dispatch.basic import (
14-
create_arg_string,
1514
create_tuple_string,
1615
numba_funcify_and_cache_key,
1716
register_funcify_and_cache_key,
@@ -89,14 +88,15 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
8988
if outer_mitsot.type.shape[0] == abs(min(taps))
9089
]
9190
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
91+
input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs]
9292
add_supervisor_to_fgraph(
9393
fgraph=fgraph,
94-
input_specs=[
95-
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
96-
],
94+
input_specs=input_specs,
9795
accept_inplace=True,
9896
)
9997
rewriter(fgraph)
98+
output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
99+
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
100100

101101
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
102102

tests/link/numba/test_scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,7 @@ def test_higher_order_derivatives():
661661

662662
def test_grad_until_and_truncate_sequence_taps():
663663
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA")
664+
665+
666+
def test_aliased_inner_outputs():
667+
ScanCompatibilityTests.check_aliased_inner_outputs(static_shape=True, mode="NUMBA")

tests/scan/test_basic.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3181,40 +3181,9 @@ def onestep(x, x_tm4):
31813181
f = function([seq], results[1])
31823182
assert np.all(exp_out == f(inp))
31833183

3184-
def test_shared_borrow(self):
3185-
"""
3186-
This tests two things. The first is a bug occurring when scan wrongly
3187-
used the borrow flag. The second thing it that Scan's infer_shape()
3188-
method will be able to remove the Scan node from the graph in this
3189-
case.
3190-
"""
3191-
3192-
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
3193-
exp_out = np.zeros((10, 1)).astype(config.floatX)
3194-
exp_out[4:] = inp[:-4]
3195-
3196-
def onestep(x, x_tm4):
3197-
return x, x_tm4
3198-
3199-
seq = matrix()
3200-
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
3201-
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
3202-
results = scan(
3203-
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
3204-
)
3205-
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
3206-
updates = {sharedvar: results[0][-1:]}
3207-
3208-
f = function([seq], results[1], updates=updates)
3209-
3210-
# This fails if scan uses wrongly the borrow flag
3211-
assert np.all(exp_out == f(inp))
3212-
3213-
# This fails if Scan's infer_shape() is unable to remove the Scan
3214-
# node from the graph.
3215-
f_infershape = function([seq], results[1].shape, mode="FAST_RUN")
3216-
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
3217-
assert len(scan_nodes_infershape) == 0
3184+
@pytest.mark.parametrize("static_shape", (True, False))
3185+
def test_aliased_inner_outputs(self, static_shape):
3186+
ScanCompatibilityTests.check_aliased_inner_outputs(static_shape, mode=None)
32183187

32193188
def test_memory_reuse_with_outputs_as_inputs(self):
32203189
"""
@@ -4417,7 +4386,6 @@ def check_higher_order_derivative(mode):
44174386
# FIXME: All implementations of Scan seem to get this one wrong!
44184387
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
44194388

4420-
44214389
@staticmethod
44224390
def check_grad_until_and_truncate_sequence_taps(mode):
44234391
"""Test case where we need special behavior of zeroing out sequences in Scan"""
@@ -4439,3 +4407,66 @@ def check_grad_until_and_truncate_sequence_taps(mode):
44394407
grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
44404408
grad_expected = grad_expected.astype(config.floatX)
44414409
np.testing.assert_allclose(grad_res, grad_expected)
4410+
4411+
@staticmethod
4412+
def check_aliased_inner_outputs(static_shape, mode):
4413+
"""
4414+
This tests two things. The first is a bug occurring when scan wrongly
4415+
used the borrow flag. The second thing it that Scan's infer_shape()
4416+
method will be able to remove the Scan node from the graph in this
4417+
case.
4418+
4419+
Here is pure python equivalent of the problem we want to avoid:
4420+
```python
4421+
def scan(seq, initval):
4422+
# Due to memory optimization we override values of mitsot as we iterate
4423+
# That's why mitsot has shape (4, 1) and not (14, 1)
4424+
mitsot = np.zeros((4, 1))
4425+
mitsot[:4] = initval
4426+
nitsot = np.zeros((10, 1))
4427+
for i, s in enumerate(seq):
4428+
# Incorrect results
4429+
mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
4430+
# Correct results
4431+
# mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
4432+
4433+
return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
4434+
4435+
scan(np.arange(10), np.zeros((4, 1)))
4436+
```
4437+
"""
4438+
4439+
def onestep(seq, seq_tm4):
4440+
# Recurring output is just each value of seq
4441+
# And we further map the tap -4 as a new output
4442+
return seq, seq_tm4
4443+
4444+
# Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
4445+
# Otherwise we would be working with scalars and memory alias wouldn't be a concern
4446+
seq = matrix(shape=(10, 1) if static_shape else (None, None), name="seq")
4447+
init = matrix(shape=(4, 1) if static_shape else (None, None), name="init")
4448+
outputs_info = [{"initial": init, "taps": [-4]}, None]
4449+
[out_seq, out_seq_tm4] = scan(
4450+
fn=onestep,
4451+
sequences=seq,
4452+
outputs_info=outputs_info,
4453+
return_updates=False,
4454+
)
4455+
4456+
f = function([seq, init], [out_seq[-1].ravel(), out_seq_tm4.ravel()], mode=mode)
4457+
4458+
seq_test_val = np.arange(10, dtype=config.floatX)[:, None]
4459+
init_test_val = np.zeros((4, 1), dtype=config.floatX)
4460+
4461+
res0, res1 = f(seq_test_val, init_test_val)
4462+
expected_res0 = np.array([9], dtype=config.floatX)
4463+
expected_res1 = np.zeros(10, dtype=config.floatX)
4464+
expected_res1[4:] = np.arange(6)
4465+
np.testing.assert_array_equal(res0, expected_res0)
4466+
np.testing.assert_array_equal(res1, expected_res1)
4467+
4468+
# This fails if Scan's infer_shape() is unable to remove the Scan
4469+
# node from the graph.
4470+
f_infershape = function([seq, init], out_seq_tm4[1].shape)
4471+
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
4472+
assert len(scan_nodes_infershape) == 0

0 commit comments

Comments
 (0)