1414 in2out ,
1515 node_rewriter ,
1616)
17+ from pytensor .graph .type import Type
1718from pytensor .raise_op import Assert
1819from pytensor .scalar import Add , ScalarConstant , ScalarType
1920from pytensor .scalar import constant as scalar_constant
@@ -212,6 +213,20 @@ def get_advsubtensor_axis(indices):
212213 return axis
213214
214215
216+ def reconstruct_indices (idx_list , tensor_inputs ):
217+ """Reconstruct indices from idx_list and tensor inputs."""
218+ indices = []
219+ input_idx = 0
220+ for entry in idx_list :
221+ if isinstance (entry , slice ):
222+ indices .append (entry )
223+ elif isinstance (entry , Type ):
224+ if input_idx < len (tensor_inputs ):
225+ indices .append (tensor_inputs [input_idx ])
226+ input_idx += 1
227+ return indices
228+
229+
215230@register_specialize
216231@node_rewriter ([AdvancedSubtensor ])
217232def local_replace_AdvancedSubtensor (fgraph , node ):
@@ -229,17 +244,9 @@ def local_replace_AdvancedSubtensor(fgraph, node):
229244
230245 indexed_var = node .inputs [0 ]
231246 tensor_inputs = node .inputs [1 :]
232-
247+
233248 # Reconstruct indices from idx_list and tensor inputs
234- indices = []
235- input_idx = 0
236- for entry in node .op .idx_list :
237- if isinstance (entry , slice ):
238- indices .append (entry )
239- elif isinstance (entry , Type ):
240- if input_idx < len (tensor_inputs ):
241- indices .append (tensor_inputs [input_idx ])
242- input_idx += 1
249+ indices = reconstruct_indices (node .op .idx_list , tensor_inputs )
243250
244251 axis = get_advsubtensor_axis (indices )
245252
@@ -267,17 +274,9 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
267274 res = node .inputs [0 ]
268275 val = node .inputs [1 ]
269276 tensor_inputs = node .inputs [2 :]
270-
277+
271278 # Reconstruct indices from idx_list and tensor inputs
272- indices = []
273- input_idx = 0
274- for entry in node .op .idx_list :
275- if isinstance (entry , slice ):
276- indices .append (entry )
277- elif isinstance (entry , Type ):
278- if input_idx < len (tensor_inputs ):
279- indices .append (tensor_inputs [input_idx ])
280- input_idx += 1
279+ indices = reconstruct_indices (node .op .idx_list , tensor_inputs )
281280
282281 axis = get_advsubtensor_axis (indices )
283282
@@ -1112,6 +1111,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
11121111def local_inplace_AdvancedIncSubtensor (fgraph , node ):
11131112 if isinstance (node .op , AdvancedIncSubtensor ) and not node .op .inplace :
11141113 new_op = type (node .op )(
1114+ node .op .idx_list ,
11151115 inplace = True ,
11161116 set_instead_of_inc = node .op .set_instead_of_inc ,
11171117 ignore_duplicates = node .op .ignore_duplicates ,
@@ -1376,6 +1376,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13761376 z_broad [k ]
13771377 and not same_shape (xi , y , dim_x = k , dim_y = k )
13781378 and shape_of [y ][k ] != 1
1379+ and shape_of [xi ][k ] == 1
13791380 )
13801381 ]
13811382
@@ -1778,17 +1779,9 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17781779 else :
17791780 x , y = node .inputs [0 ], node .inputs [1 ]
17801781 tensor_inputs = node .inputs [2 :]
1781-
1782+
17821783 # Reconstruct indices from idx_list and tensor inputs
1783- idxs = []
1784- input_idx = 0
1785- for entry in node .op .idx_list :
1786- if isinstance (entry , slice ):
1787- idxs .append (entry )
1788- elif isinstance (entry , Type ):
1789- if input_idx < len (tensor_inputs ):
1790- idxs .append (tensor_inputs [input_idx ])
1791- input_idx += 1
1784+ idxs = reconstruct_indices (node .op .idx_list , tensor_inputs )
17921785
17931786 if any (
17941787 (
@@ -1829,36 +1822,36 @@ def ravel_multidimensional_bool_idx(fgraph, node):
18291822 # Create new AdvancedSubtensor with updated idx_list
18301823 new_idx_list = list (node .op .idx_list )
18311824 new_tensor_inputs = list (tensor_inputs )
1832-
1825+
18331826 # Update the idx_list and tensor_inputs for the raveled boolean index
18341827 input_idx = 0
18351828 for i , entry in enumerate (node .op .idx_list ):
18361829 if isinstance (entry , Type ):
18371830 if input_idx == bool_idx_pos :
18381831 new_tensor_inputs [input_idx ] = raveled_bool_idx
18391832 input_idx += 1
1840-
1833+
18411834 new_out = AdvancedSubtensor (new_idx_list )(raveled_x , * new_tensor_inputs )
18421835 else :
18431836 # Create new AdvancedIncSubtensor with updated idx_list
18441837 new_idx_list = list (node .op .idx_list )
18451838 new_tensor_inputs = list (tensor_inputs )
1846-
1839+
18471840 # Update the tensor_inputs for the raveled boolean index
18481841 input_idx = 0
18491842 for i , entry in enumerate (node .op .idx_list ):
18501843 if isinstance (entry , Type ):
18511844 if input_idx == bool_idx_pos :
18521845 new_tensor_inputs [input_idx ] = raveled_bool_idx
18531846 input_idx += 1
1854-
1847+
18551848 # The dimensions of y that correspond to the boolean indices
18561849 # must already be raveled in the original graph, so we don't need to do anything to it
18571850 new_out = AdvancedIncSubtensor (
18581851 new_idx_list ,
18591852 inplace = node .op .inplace ,
18601853 set_instead_of_inc = node .op .set_instead_of_inc ,
1861- ignore_duplicates = node .op .ignore_duplicates
1854+ ignore_duplicates = node .op .ignore_duplicates ,
18621855 )(raveled_x , y , * new_tensor_inputs )
18631856 # But we must reshape the output to match the original shape
18641857 new_out = new_out .reshape (x_shape )
0 commit comments