Skip to content

Commit 06244f3

Browse files
committed
Numba Scan: zero out unwritten buffers
1 parent f3a7d94 commit 06244f3

File tree

3 files changed

+39
-16
lines changed

3 files changed

+39
-16
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def add_output_storage_post_proc_stmt(
254254
"""
255255
).strip()
256256
)
257+
else:
258+
# And regular loops should zero out unused entries of the output buffer
259+
# These show up with truncated gradients of while loops
260+
output_storage_post_proc_stmts.append(
261+
dedent(
262+
f"""
263+
elif {storage_size} > (i + {max_offset}):
264+
{outer_in_name}[i + {max_offset}:] = 0
265+
"""
266+
).strip()
267+
)
257268

258269
# Special in-loop statements that create (nit-sot) storage arrays after a
259270
# single iteration is performed. This is necessary because we don't know

tests/link/numba/test_scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,7 @@ def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark
657657

658658
def test_higher_order_derivatives():
659659
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")
660+
661+
662+
def test_grad_until_and_truncate_sequence_taps():
663+
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA")

tests/scan/test_basic.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,22 +2621,7 @@ def test_grad_until_and_truncate(self):
26212621
utt.assert_allclose(pytensor_gradient, self.numpy_gradient)
26222622

26232623
def test_grad_until_and_truncate_sequence_taps(self):
2624-
n = 3
2625-
r = scan(
2626-
lambda x, y, u: (x * y, until(y > u)),
2627-
sequences=dict(input=self.x, taps=[-2, 0]),
2628-
non_sequences=[self.threshold],
2629-
truncate_gradient=n,
2630-
return_updates=False,
2631-
)
2632-
g = grad(r.sum(), self.x)
2633-
f = function([self.x, self.threshold], [r, g])
2634-
_pytensor_output, pytensor_gradient = f(self.seq, 6)
2635-
2636-
# Gradient computed by hand:
2637-
numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
2638-
numpy_grad = numpy_grad.astype(config.floatX)
2639-
utt.assert_allclose(pytensor_gradient, numpy_grad)
2624+
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode=None)
26402625

26412626

26422627
def test_mintap_onestep():
@@ -4431,3 +4416,26 @@ def check_higher_order_derivative(mode):
44314416
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
44324417
# FIXME: All implementations of Scan seem to get this one wrong!
44334418
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
4419+
4420+
4421+
@staticmethod
4422+
def check_grad_until_and_truncate_sequence_taps(mode):
4423+
"""Test case where we need special behavior of zeroing out sequences in Scan"""
4424+
x = pt.vector("x")
4425+
threshold = pt.scalar(name="threshold", dtype="int64")
4426+
4427+
r = scan(
4428+
lambda x, y, u: (x * y, until(y > u)),
4429+
sequences=dict(input=x, taps=[-2, 0]),
4430+
non_sequences=[threshold],
4431+
truncate_gradient=3,
4432+
return_updates=False,
4433+
)
4434+
g = grad(r.sum(), x)
4435+
f = function([x, threshold], [r, g], mode=mode)
4436+
_, grad_res = f(np.arange(15, dtype=x.dtype), 6)
4437+
4438+
# Gradient computed by hand:
4439+
grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
4440+
grad_expected = grad_expected.astype(config.floatX)
4441+
np.testing.assert_allclose(grad_res, grad_expected)

0 commit comments

Comments
 (0)