Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@
"""


@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_AdvancedSubtensor1(op, node, **kwargs):
def advanced_subtensor1(x, ilist):
return x[ilist]

return advanced_subtensor1


@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
Expand All @@ -47,10 +54,24 @@ def subtensor(x, *ilists):
return subtensor


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, y, ilist):
return x.at[ilist].set(y)

else:

def jax_fn(x, y, ilist):
return x.at[ilist].add(y)

return jax_fn


@jax_funcify.register(IncSubtensor)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

if getattr(op, "set_instead_of_inc", False):

Expand All @@ -77,6 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):

@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list

if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
Expand All @@ -87,8 +110,11 @@ def jax_fn(x, indices, y):
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return jax_fn(x, indices, y)

return advancedincsubtensor

Expand Down
49 changes: 26 additions & 23 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -29,7 +28,7 @@
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
from pytensor.tensor.type_other import MakeSlice


def slice_new(self, start, stop, step):
Expand Down Expand Up @@ -239,28 +238,32 @@ def {function_name}({", ".join(input_names)}):
@register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor):
_x, _y, idxs = node.inputs[0], None, node.inputs[1:]
tensor_inputs = node.inputs[1:]
else:
_x, _y, *idxs = node.inputs

basic_idxs = [
idx
for idx in idxs
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]
tensor_inputs = node.inputs[2:]

# Reconstruct indexing information from idx_list and tensor inputs
basic_idxs = []
adv_idxs = []
input_idx = 0

for i, entry in enumerate(op.idx_list):
if isinstance(entry, slice):
# Basic slice index
basic_idxs.append(entry)
elif isinstance(entry, Type):
# Advanced tensor index
if input_idx < len(tensor_inputs):
idx_input = tensor_inputs[input_idx]
adv_idxs.append(
{
"axis": i,
"dtype": idx_input.type.dtype,
"bcast": idx_input.type.broadcastable,
"ndim": idx_input.type.ndim,
}
)
input_idx += 1

# Special implementation for consecutive integer vector indices
if (
Expand Down
25 changes: 18 additions & 7 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
from pytensor.tensor.type_other import MakeSlice


def check_negative_steps(indices):
Expand Down Expand Up @@ -63,7 +63,10 @@ def makeslice(start, stop, step):
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
idx_list = op.idx_list

def advsubtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

Expand Down Expand Up @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices):

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -132,13 +138,18 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
# Check if we have slice indexing in idx_list
has_slice_indexing = (
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
)
if has_slice_indexing:
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
Expand Down
27 changes: 27 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,33 @@ def do_constant_folding(self, fgraph, node):
return True


@_vectorize_node.register(Alloc)
def vectorize_alloc(op: Alloc, node: Apply, batch_val, *batch_shapes):
# batch_shapes are usually not batched (they are scalars for the shape)
# batch_val is the value being allocated.

# If shapes are batched, we fall back (complex case)
if any(
b_shp.type.ndim > shp.type.ndim
for b_shp, shp in zip(batch_shapes, node.inputs[1:], strict=True)
):
return vectorize_node_fallback(op, node, batch_val, *batch_shapes)

# If value is batched, we need to prepend batch dims to the output shape
val = node.inputs[0]
batch_ndim = batch_val.type.ndim - val.type.ndim

if batch_ndim == 0:
return op.make_node(batch_val, *batch_shapes)

# We need the size of the batch dimensions
# batch_val has shape (B1, B2, ..., val_dims...)
batch_dims = [batch_val.shape[i] for i in range(batch_ndim)]

new_shapes = batch_dims + list(batch_shapes)
return op.make_node(batch_val, *new_shapes)


alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))

Expand Down
71 changes: 64 additions & 7 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
in2out,
node_rewriter,
)
from pytensor.graph.type import Type
from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import constant as scalar_constant
Expand Down Expand Up @@ -212,6 +213,20 @@ def get_advsubtensor_axis(indices):
return axis


def reconstruct_indices(idx_list, tensor_inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this helper do? Sounds like it returns dummy slice (not with actual arguments) or tensor variable. Shouldn't the slice variables be placed inside the slices. If so, there's already a helper that does that IIRC

"""Reconstruct indices from idx_list and tensor inputs."""
indices = []
input_idx = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(entry)
elif isinstance(entry, Type):
if input_idx < len(tensor_inputs):
indices.append(tensor_inputs[input_idx])
input_idx += 1
return indices


@register_specialize
@node_rewriter([AdvancedSubtensor])
def local_replace_AdvancedSubtensor(fgraph, node):
Expand All @@ -228,7 +243,10 @@ def local_replace_AdvancedSubtensor(fgraph, node):
return

indexed_var = node.inputs[0]
indices = node.inputs[1:]
tensor_inputs = node.inputs[1:]

# Reconstruct indices from idx_list and tensor inputs
indices = reconstruct_indices(node.op.idx_list, tensor_inputs)

axis = get_advsubtensor_axis(indices)

Expand All @@ -255,7 +273,10 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):

res = node.inputs[0]
val = node.inputs[1]
indices = node.inputs[2:]
tensor_inputs = node.inputs[2:]

# Reconstruct indices from idx_list and tensor inputs
indices = reconstruct_indices(node.op.idx_list, tensor_inputs)

axis = get_advsubtensor_axis(indices)

Expand Down Expand Up @@ -1090,6 +1111,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)(
node.op.idx_list,
inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=node.op.ignore_duplicates,
Expand Down Expand Up @@ -1354,6 +1376,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
z_broad[k]
and not same_shape(xi, y, dim_x=k, dim_y=k)
and shape_of[y][k] != 1
and shape_of[xi][k] == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this fixes a bug, it will need a specific regression test, and be in a separate commit

)
]

Expand Down Expand Up @@ -1751,9 +1774,14 @@ def ravel_multidimensional_bool_idx(fgraph, node):
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
"""
if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
x = node.inputs[0]
tensor_inputs = node.inputs[1:]
else:
x, y, *idxs = node.inputs
x, y = node.inputs[0], node.inputs[1]
tensor_inputs = node.inputs[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the name tensor_inputs, x, y are also tensor and inputs. Use index_variables?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This applies elsewhere


# Reconstruct indices from idx_list and tensor inputs
idxs = reconstruct_indices(node.op.idx_list, tensor_inputs)

if any(
(
Expand Down Expand Up @@ -1791,12 +1819,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
new_idxs[bool_idx_pos] = raveled_bool_idx

if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
# Create new AdvancedSubtensor with updated idx_list
new_idx_list = list(node.op.idx_list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use some helper to do this? Isn't there already something for when you do x.__getitem__?

new_tensor_inputs = list(tensor_inputs)

# Update the idx_list and tensor_inputs for the raveled boolean index
input_idx = 0
for i, entry in enumerate(node.op.idx_list):
if isinstance(entry, Type):
if input_idx == bool_idx_pos:
new_tensor_inputs[input_idx] = raveled_bool_idx
input_idx += 1

new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs)
else:
# Create new AdvancedIncSubtensor with updated idx_list
new_idx_list = list(node.op.idx_list)
new_tensor_inputs = list(tensor_inputs)

# Update the tensor_inputs for the raveled boolean index
input_idx = 0
for i, entry in enumerate(node.op.idx_list):
if isinstance(entry, Type):
if input_idx == bool_idx_pos:
new_tensor_inputs[input_idx] = raveled_bool_idx
input_idx += 1

# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = AdvancedIncSubtensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use type(op) so that subclasses are respected. It may also make sense to add a method to these indexing Ops like op.with_new_indices() that clones itself with a new idx_list. Maybe that will be the one that handles creating the new idx_list, instead of having to be here in the rewrite.

new_idx_list,
inplace=node.op.inplace,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=node.op.ignore_duplicates,
)(raveled_x, y, *new_tensor_inputs)
# But we must reshape the output to match the original shape
new_out = new_out.reshape(x_shape)

return [copy_stack_trace(node.outputs[0], new_out)]
Expand Down
Loading