Bug Fix: weight=0 for _broadcast_input='Node' fails#2221
Conversation
-Handles case when objective is computed per-grid node and at least one weight is zero.
efdf6f0 to
0534979
Compare
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | -2.94 % | 4.008e+03 | 3.890e+03 | -117.91 | 31.03 | 28.13 |
test_proximal_jac_w7x_with_eq_update | -0.35 % | 6.583e+03 | 6.560e+03 | -23.01 | 148.30 | 148.52 |
test_proximal_freeb_jac | 0.31 % | 1.331e+04 | 1.335e+04 | 41.03 | 80.11 | 77.80 |
test_proximal_freeb_jac_blocked | 0.12 % | 7.665e+03 | 7.674e+03 | 9.45 | 67.17 | 67.36 |
test_proximal_freeb_jac_batched | -0.25 % | 7.657e+03 | 7.638e+03 | -19.36 | 66.87 | 67.06 |
test_proximal_jac_ripple | 0.23 % | 3.556e+03 | 3.564e+03 | 8.24 | 52.52 | 52.63 |
test_proximal_jac_ripple_bounce1d | 0.47 % | 3.781e+03 | 3.798e+03 | 17.89 | 65.78 | 66.00 |
test_eq_solve | 0.74 % | 2.039e+03 | 2.054e+03 | 15.16 | 85.32 | 85.29 |For the memory plots, go to the summary of |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #2221 +/- ##
==========================================
- Coverage 94.34% 94.33% -0.01%
==========================================
Files 101 101
Lines 28845 28847 +2
==========================================
Hits 27213 27213
- Misses 1632 1634 +2
🚀 New features to boost your workflow:
|
YigitElma
left a comment
There was a problem hiding this comment.
I will test it myself too, but it looks good. I am personally fine with the breaking change; this was a pretty new addition, and going forward, I like the new name better.
| - Fixes a bug in `OmnigenousField.change_resolution` when changing `L_B`. | ||
| - Scaling a `ScaledProfile` or taking power of a `PowerProfile` now only updates the `scale`/`power` attributes instead of nesting the `ScaledProfile`/`PowerProfile`s. | ||
| - `jax.Array`s in `_static_attrs` will be automatically converted to `np.ndarray` to prevent stalling code. In general, jax arrays should be omitted in `_static_attrs`. | ||
| - Fixes a bug in `_CoilObjective` for objectives which are computed per-grid node when at least one entry of `weight` is zero. |
There was a problem hiding this comment.
I think there is also a breaking change here. If someone had a custom coil objective, they need to update the [self._coilset_tree["coil_mask"] to [self._coilset_tree["objective_mask"]. Can you also document that?
There was a problem hiding this comment.
Updated. If the wording isn't clear, can change.
| if it returns a scalar at every grid point. To be compatible with | ||
| masking, compute function should apply the mask | ||
| self._coilset_tree["coilset_mask"] before returning data. | ||
| self._coilset_tree["objective_mask"] before returning data. |
There was a problem hiding this comment.
This is a breaking change, and needs to be documented in the changelog. But I like the new naming, which is easier to distinguish.
There was a problem hiding this comment.
Sounds good, documented.
| individual coils and the coilsets to which they belong. Similarly, | ||
| params_tree["nodes"] lists the grid nodes associated with each coil. | ||
| params_tree["coilset_mask"] contains the indices in [0,self._dim_f-1] | ||
| params_tree["coil_mask"] contains the indices in [0,self._num_coils] |
There was a problem hiding this comment.
Should we keep the dict key name the same as before coilset_mask? The meaning is different but at least prevents breaking some code.
There was a problem hiding this comment.
I switched back to coilset_mask.
PR #1921 added support for coil optimization with Pytree inputs for
weight,target, andbounds. It relies on broadcasting and flattening the input Pytrees to have dimensions consistent with the objective's output shape. Whenweight=0for a given coil or coilset, the corresponding coils (or grid nodes) are dropped during the broadcasting process. However, this wasn't being done correctly for objectives likeCoilCurvaturethat output a single value per grid node.Fix was suggested in original issue, #2219. We keep track of which coils are masked (
weight=0) in the attribute_coilset_tree["coilset_mask"]. This always has lengthnum_coils. The mask_coilset_tree["objective_mask"]matches the output shape of the objective, and in the case of per-node objectives, tracks which individual grid nodes are masked. Keeping both masks makes it easier to index things correctly when building the objective.Also added some tests to catch bugs of this form.
Addresses Issue #2219.