Skip to content

Commit 26f81d3

Browse files
committed
Numba Scan: zero out unwritten buffers
1 parent 36a788d commit 26f81d3

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def add_output_storage_post_proc_stmt(
254254
"""
255255
).strip()
256256
)
257+
else:
258+
# And regular loops should zero out unused entries of the output buffer
259+
# These show up with truncated gradients of while loops
260+
output_storage_post_proc_stmts.append(
261+
dedent(
262+
f"""
263+
elif {storage_size} > (i + {max_offset}):
264+
{outer_in_name}[i + {max_offset}:] = 0
265+
"""
266+
).strip()
267+
)
257268

258269
# Special in-loop statements that create (nit-sot) storage arrays after a
259270
# single iteration is performed. This is necessary because we don't know

tests/link/numba/test_scan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,3 +673,26 @@ def test_higher_order_derivatives():
673673
[g, gg, ggg],
674674
[np.array(0.95)],
675675
)
676+
677+
678+
def test_grad_until_and_truncate_sequence_taps():
679+
# This is a case where we need special zero out behavior in Scan
680+
# Copied from tests.scan.basic.py::TestGradUntil::test_grad_until_and_truncate_sequence_taps
681+
x = pt.vector("x")
682+
threshold = pt.scalar(name="threshold", dtype="int64")
683+
684+
r = scan(
685+
lambda x, y, u: (x * y, until(y > u)),
686+
sequences=dict(input=x, taps=[-2, 0]),
687+
outputs_info=[None],
688+
non_sequences=[threshold],
689+
truncate_gradient=3,
690+
return_updates=False,
691+
)
692+
g = grad(r.sum(), x)
693+
694+
compare_numba_and_py(
695+
[x, threshold],
696+
[r, g],
697+
[np.arange(15, dtype=x.dtype), 6],
698+
)

0 commit comments

Comments
 (0)