Skip to content

Commit 36a788d

Browse files
committed
Numba Scan: correct handling of signed mitmot taps
Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
1 parent 42e8490 commit 36a788d

File tree

5 files changed

+108
-46
lines changed

5 files changed

+108
-46
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def jax_args_to_inner_func_args(carry, x):
9090
chain.from_iterable(
9191
buffer[(i + np.array(taps))]
9292
for buffer, taps in zip(
93-
inner_mit_mot, info.mit_mot_in_slices, strict=True
93+
inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True
9494
)
9595
)
9696
)
@@ -140,7 +140,10 @@ def inner_func_outs_to_jax_outs(
140140
new_mit_mot = [
141141
buffer.at[i + np.array(taps)].set(new_vals)
142142
for buffer, new_vals, taps in zip(
143-
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
143+
old_mit_mot,
144+
new_mit_mot_vals,
145+
info.normalized_mit_mot_out_slices,
146+
strict=True,
144147
)
145148
]
146149
# Discard oldest MIT-SOT and append newest value

pytensor/link/numba/dispatch/scan.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def idx_to_str(
2727
idx_symbol: str = "i",
2828
allow_scalar=False,
2929
) -> str:
30-
if offset < 0:
31-
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
32-
elif offset > 0:
30+
assert offset >= 0
31+
if offset > 0:
3332
indices = f"{idx_symbol} + {offset}"
3433
else:
3534
indices = idx_symbol
@@ -226,33 +225,16 @@ def add_inner_in_expr(
226225
# storage array like a circular buffer, and that's why we need to track the
227226
# storage size along with the taps length/indexing offset.
228227
def add_output_storage_post_proc_stmt(
229-
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str
228+
outer_in_name: str, max_offset: int, storage_size: str
230229
):
231-
tap_size = max(tap_sizes)
232-
233-
if op.info.as_while:
234-
# While loops need to truncate the output storage to a length given
235-
# by the number of iterations performed.
236-
output_storage_post_proc_stmts.append(
237-
dedent(
238-
f"""
239-
if i + {tap_size} < {storage_size}:
240-
{storage_size} = i + {tap_size}
241-
{outer_in_name} = {outer_in_name}[:{storage_size}]
242-
"""
243-
).strip()
244-
)
245-
246-
# Rotate the storage so that the last computed value is at the end of
247-
# the storage array.
230+
# Rotate the storage so that the last computed value is at the end of the storage array.
248231
# This is needed when the output storage array does not have a length
249232
# equal to the number of taps plus `n_steps`.
250-
# If the storage size only allows one entry, there's nothing to rotate
251233
output_storage_post_proc_stmts.append(
252234
dedent(
253235
f"""
254-
if 1 < {storage_size} < (i + {tap_size}):
255-
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
236+
if 1 < {storage_size} < (i + {max_offset}):
237+
{outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
256238
if {outer_in_name}_shift > 0:
257239
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
258240
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
@@ -261,6 +243,18 @@ def add_output_storage_post_proc_stmt(
261243
).strip()
262244
)
263245

246+
if op.info.as_while:
247+
# While loops need to truncate the output storage to a length given
248+
# by the number of iterations performed.
249+
output_storage_post_proc_stmts.append(
250+
dedent(
251+
f"""
252+
elif {storage_size} > (i + {max_offset}):
253+
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
254+
"""
255+
).strip()
256+
)
257+
264258
# Special in-loop statements that create (nit-sot) storage arrays after a
265259
# single iteration is performed. This is necessary because we don't know
266260
# the exact shapes of the storage arrays that need to be allocated until
@@ -288,12 +282,11 @@ def add_output_storage_post_proc_stmt(
288282
storage_size_name = f"{outer_in_name}_len"
289283
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
290284
input_taps = inner_in_names_to_input_taps[outer_in_name]
291-
tap_storage_size = -min(input_taps)
292-
assert tap_storage_size >= 0
285+
max_lookback_inp_tap = -min(0, min(input_taps))
286+
assert max_lookback_inp_tap >= 0
293287

294288
for in_tap in input_taps:
295-
tap_offset = in_tap + tap_storage_size
296-
assert tap_offset >= 0
289+
tap_offset = max_lookback_inp_tap + in_tap
297290
is_vector = outer_in_var.ndim == 1
298291
add_inner_in_expr(
299292
outer_in_name,
@@ -302,22 +295,25 @@ def add_output_storage_post_proc_stmt(
302295
vector_slice_opt=is_vector,
303296
)
304297

305-
output_taps = inner_in_names_to_output_taps.get(
306-
outer_in_name, [tap_storage_size]
307-
)
308-
inner_out_to_outer_in_stmts.extend(
309-
idx_to_str(
310-
storage_name,
311-
out_tap,
312-
size=storage_size_name,
313-
allow_scalar=True,
298+
output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
299+
for out_tap in output_taps:
300+
tap_offset = max_lookback_inp_tap + out_tap
301+
assert tap_offset >= 0
302+
inner_out_to_outer_in_stmts.append(
303+
idx_to_str(
304+
storage_name,
305+
tap_offset,
306+
size=storage_size_name,
307+
allow_scalar=True,
308+
)
314309
)
315-
for out_tap in output_taps
316-
)
317310

318-
add_output_storage_post_proc_stmt(
319-
storage_name, output_taps, storage_size_name
320-
)
311+
if outer_in_name not in outer_in_mit_mot_names:
312+
# MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
313+
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
314+
add_output_storage_post_proc_stmt(
315+
storage_name, max_offset_out_tap, storage_size_name
316+
)
321317

322318
else:
323319
storage_size_stmt = ""
@@ -351,7 +347,7 @@ def add_output_storage_post_proc_stmt(
351347
inner_out_to_outer_in_stmts.append(
352348
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
353349
)
354-
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
350+
add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)
355351

356352
# In case of nit-sots we are provided the length of the array in
357353
# the iteration dimension instead of actual arrays, hence we

pytensor/scan/op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,26 @@ def n_outer_outputs(self):
287287
+ self.n_untraced_sit_sot_outs
288288
)
289289

290+
@property
291+
def normalized_mit_mot_in_slices(self) -> tuple[tuple[int, ...], ...]:
292+
"""Return mit_mot_in slices normalized as an offset from the oldest tap"""
293+
# TODO: Make this the canonical representation
294+
res = []
295+
for in_slice in self.mit_mot_in_slices:
296+
min_tap = -(min(0, min(in_slice)))
297+
res.append(tuple(tap + min_tap for tap in in_slice))
298+
return tuple(res)
299+
300+
@property
301+
def normalized_mit_mot_out_slices(self) -> tuple[tuple[int, ...], ...]:
302+
"""Return mit_mot_out slices normalized as an offset from the oldest tap"""
303+
# TODO: Make this the canonical representation
304+
res = []
305+
for out_slice in self.mit_mot_out_slices:
306+
min_tap = -(min(0, min(out_slice)))
307+
res.append(tuple(tap + min_tap for tap in out_slice))
308+
return tuple(res)
309+
290310

291311
TensorConstructorType = Callable[
292312
[Iterable[bool | int | None], str | np.generic], TensorType

tests/link/jax/test_scan.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import pytensor.tensor as pt
7-
from pytensor import function, ifelse, shared
7+
from pytensor import function, grad, ifelse, shared
88
from pytensor.compile import get_mode
99
from pytensor.configdefaults import config
1010
from pytensor.graph import Apply, Op
@@ -626,3 +626,25 @@ def block_until_ready(*inputs, jax_fn=jax_fn):
626626
block_until_ready(*test_input_vals) # Warmup
627627

628628
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)
629+
630+
631+
def test_higher_order_derivatives():
632+
"""This tests different mit-mot taps signs"""
633+
x = pt.scalar("x")
634+
635+
xs = scan(
636+
fn=lambda xtm1: xtm1**2,
637+
outputs_info=[x],
638+
n_steps=5,
639+
return_updates=False,
640+
)
641+
g = grad(xs[-1], x)
642+
gg = grad(g, x)
643+
ggg = grad(gg, x)
644+
645+
compare_jax_and_py(
646+
[x],
647+
[g, gg, ggg],
648+
[np.array(0.95)],
649+
jax_mode="JAX", # Needs full pipeline to compile
650+
)

tests/link/numba/test_scan.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,24 @@ def test_mit_sot_buffer(self, constant_n_steps, n_steps_val):
652652

653653
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
654654
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
655+
656+
657+
def test_higher_order_derivatives():
658+
"""This tests different mit-mot taps signs"""
659+
x = pt.scalar("x")
660+
661+
xs = scan(
662+
fn=lambda xtm1: xtm1**2,
663+
outputs_info=[x],
664+
n_steps=5,
665+
return_updates=False,
666+
)
667+
g = grad(xs[-1], x)
668+
gg = grad(g, x)
669+
ggg = grad(gg, x)
670+
671+
compare_numba_and_py(
672+
[x],
673+
[g, gg, ggg],
674+
[np.array(0.95)],
675+
)

0 commit comments

Comments
 (0)