-
Notifications
You must be signed in to change notification settings - Fork 149
Refactor advanced subtensor #1756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9ff8eea
3cfbd0d
c18b322
737b8cb
a3634dd
53adf9a
4b02064
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| in2out, | ||
| node_rewriter, | ||
| ) | ||
| from pytensor.graph.type import Type | ||
| from pytensor.raise_op import Assert | ||
| from pytensor.scalar import Add, ScalarConstant, ScalarType | ||
| from pytensor.scalar import constant as scalar_constant | ||
|
|
@@ -212,6 +213,20 @@ def get_advsubtensor_axis(indices): | |
| return axis | ||
|
|
||
|
|
||
| def reconstruct_indices(idx_list, tensor_inputs): | ||
| """Reconstruct indices from idx_list and tensor inputs.""" | ||
| indices = [] | ||
| input_idx = 0 | ||
| for entry in idx_list: | ||
| if isinstance(entry, slice): | ||
| indices.append(entry) | ||
| elif isinstance(entry, Type): | ||
| if input_idx < len(tensor_inputs): | ||
| indices.append(tensor_inputs[input_idx]) | ||
| input_idx += 1 | ||
| return indices | ||
|
|
||
|
|
||
| @register_specialize | ||
| @node_rewriter([AdvancedSubtensor]) | ||
| def local_replace_AdvancedSubtensor(fgraph, node): | ||
|
|
@@ -228,7 +243,10 @@ def local_replace_AdvancedSubtensor(fgraph, node): | |
| return | ||
|
|
||
| indexed_var = node.inputs[0] | ||
| indices = node.inputs[1:] | ||
| tensor_inputs = node.inputs[1:] | ||
|
|
||
| # Reconstruct indices from idx_list and tensor inputs | ||
| indices = reconstruct_indices(node.op.idx_list, tensor_inputs) | ||
|
|
||
| axis = get_advsubtensor_axis(indices) | ||
|
|
||
|
|
@@ -255,7 +273,10 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): | |
|
|
||
| res = node.inputs[0] | ||
| val = node.inputs[1] | ||
| indices = node.inputs[2:] | ||
| tensor_inputs = node.inputs[2:] | ||
|
|
||
| # Reconstruct indices from idx_list and tensor inputs | ||
| indices = reconstruct_indices(node.op.idx_list, tensor_inputs) | ||
|
|
||
| axis = get_advsubtensor_axis(indices) | ||
|
|
||
|
|
@@ -1090,6 +1111,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node): | |
| def local_inplace_AdvancedIncSubtensor(fgraph, node): | ||
| if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: | ||
| new_op = type(node.op)( | ||
| node.op.idx_list, | ||
| inplace=True, | ||
| set_instead_of_inc=node.op.set_instead_of_inc, | ||
| ignore_duplicates=node.op.ignore_duplicates, | ||
|
|
@@ -1354,6 +1376,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): | |
| z_broad[k] | ||
| and not same_shape(xi, y, dim_x=k, dim_y=k) | ||
| and shape_of[y][k] != 1 | ||
| and shape_of[xi][k] == 1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this fixes a bug, it will need a specific regression test, and be in a separate commit |
||
| ) | ||
| ] | ||
|
|
||
|
|
@@ -1751,9 +1774,14 @@ def ravel_multidimensional_bool_idx(fgraph, node): | |
| x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) | ||
| """ | ||
| if isinstance(node.op, AdvancedSubtensor): | ||
| x, *idxs = node.inputs | ||
| x = node.inputs[0] | ||
| tensor_inputs = node.inputs[1:] | ||
| else: | ||
| x, y, *idxs = node.inputs | ||
| x, y = node.inputs[0], node.inputs[1] | ||
| tensor_inputs = node.inputs[2:] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the name
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This applies elsewhere |
||
|
|
||
| # Reconstruct indices from idx_list and tensor inputs | ||
| idxs = reconstruct_indices(node.op.idx_list, tensor_inputs) | ||
|
|
||
| if any( | ||
| ( | ||
|
|
@@ -1791,12 +1819,41 @@ def ravel_multidimensional_bool_idx(fgraph, node): | |
| new_idxs[bool_idx_pos] = raveled_bool_idx | ||
|
|
||
| if isinstance(node.op, AdvancedSubtensor): | ||
| new_out = node.op(raveled_x, *new_idxs) | ||
| # Create new AdvancedSubtensor with updated idx_list | ||
| new_idx_list = list(node.op.idx_list) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use some helper to do this? Isn't there already something for when you do |
||
| new_tensor_inputs = list(tensor_inputs) | ||
|
|
||
| # Update the idx_list and tensor_inputs for the raveled boolean index | ||
| input_idx = 0 | ||
| for i, entry in enumerate(node.op.idx_list): | ||
| if isinstance(entry, Type): | ||
| if input_idx == bool_idx_pos: | ||
| new_tensor_inputs[input_idx] = raveled_bool_idx | ||
| input_idx += 1 | ||
|
|
||
| new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs) | ||
| else: | ||
| # Create new AdvancedIncSubtensor with updated idx_list | ||
| new_idx_list = list(node.op.idx_list) | ||
| new_tensor_inputs = list(tensor_inputs) | ||
|
|
||
| # Update the tensor_inputs for the raveled boolean index | ||
| input_idx = 0 | ||
| for i, entry in enumerate(node.op.idx_list): | ||
| if isinstance(entry, Type): | ||
| if input_idx == bool_idx_pos: | ||
| new_tensor_inputs[input_idx] = raveled_bool_idx | ||
| input_idx += 1 | ||
|
|
||
| # The dimensions of y that correspond to the boolean indices | ||
| # must already be raveled in the original graph, so we don't need to do anything to it | ||
| new_out = node.op(raveled_x, y, *new_idxs) | ||
| # But we must reshape the output to math the original shape | ||
| new_out = AdvancedIncSubtensor( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should use |
||
| new_idx_list, | ||
| inplace=node.op.inplace, | ||
| set_instead_of_inc=node.op.set_instead_of_inc, | ||
| ignore_duplicates=node.op.ignore_duplicates, | ||
| )(raveled_x, y, *new_tensor_inputs) | ||
| # But we must reshape the output to match the original shape | ||
| new_out = new_out.reshape(x_shape) | ||
|
|
||
| return [copy_stack_trace(node.outputs[0], new_out)] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this helper do? Sounds like it returns dummy slice (not with actual arguments) or tensor variable. Shouldn't the slice variables be placed inside the slices. If so, there's already a helper that does that IIRC