Skip to content

Bug Fix: weight=0 for _broadcast_input='Node' fails#2221

Open
singh-jaydeep wants to merge 3 commits into
PlasmaControl:masterfrom
singh-jaydeep:bug-fix-pytree-coils
Open

Bug Fix: weight=0 for _broadcast_input='Node' fails#2221
singh-jaydeep wants to merge 3 commits into
PlasmaControl:masterfrom
singh-jaydeep:bug-fix-pytree-coils

Conversation

@singh-jaydeep
Copy link
Copy Markdown
Collaborator

@singh-jaydeep singh-jaydeep commented May 19, 2026

PR #1921 added support for coil optimization with Pytree inputs for weight, target, and bounds. It relies on broadcasting and flattening the input Pytrees to have dimensions consistent with the objective's output shape. When weight=0 for a given coil or coilset, the corresponding coils (or grid nodes) are dropped during the broadcasting process. However, this wasn't being done correctly for objectives like CoilCurvature that output a single value per grid node.

Fix was suggested in original issue, #2219. We keep track of which coils are masked (weight=0) in the attribute _coilset_tree["coilset_mask"]. This always has length num_coils. The mask _coilset_tree["objective_mask"] matches the output shape of the objective, and in the case of per-node objectives, tracks which individual grid nodes are masked. Keeping both masks makes it easier to index things correctly when building the objective.

Also added some tests to catch bugs of this form.

Addresses Issue #2219.

-Handles case when objective is computed per-grid node and at least one weight is zero.
@singh-jaydeep singh-jaydeep force-pushed the bug-fix-pytree-coils branch from efdf6f0 to 0534979 Compare May 19, 2026 23:42
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 20, 2026

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |   -2.94 %    |     4.008e+03      |     3.890e+03      |   -117.91    |       31.03        |       28.13        |
  test_proximal_jac_w7x_with_eq_update   |   -0.35 %    |     6.583e+03      |     6.560e+03      |    -23.01    |       148.30       |       148.52       |
  test_proximal_freeb_jac                |    0.31 %    |     1.331e+04      |     1.335e+04      |    41.03     |       80.11        |       77.80        |
  test_proximal_freeb_jac_blocked        |    0.12 %    |     7.665e+03      |     7.674e+03      |     9.45     |       67.17        |       67.36        |
  test_proximal_freeb_jac_batched        |   -0.25 %    |     7.657e+03      |     7.638e+03      |    -19.36    |       66.87        |       67.06        |
  test_proximal_jac_ripple               |    0.23 %    |     3.556e+03      |     3.564e+03      |     8.24     |       52.52        |       52.63        |
  test_proximal_jac_ripple_bounce1d      |    0.47 %    |     3.781e+03      |     3.798e+03      |    17.89     |       65.78        |       66.00        |
  test_eq_solve                          |    0.74 %    |     2.039e+03      |     2.054e+03      |    15.16     |       85.32        |       85.29        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 20, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 94.33%. Comparing base (c119da0) to head (7475c7e).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2221      +/-   ##
==========================================
- Coverage   94.34%   94.33%   -0.01%     
==========================================
  Files         101      101              
  Lines       28845    28847       +2     
==========================================
  Hits        27213    27213              
- Misses       1632     1634       +2     
Files with missing lines Coverage Δ
desc/objectives/_coils.py 99.47% <100.00%> (+<0.01%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@singh-jaydeep singh-jaydeep marked this pull request as ready for review May 20, 2026 01:29
Copy link
Copy Markdown
Collaborator

@YigitElma YigitElma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will test it myself too, but it looks good. I am personally fine with the breaking change; this was a pretty new addition, and going forward, I like the new name better.

Comment thread desc/objectives/_coils.py
Comment thread tests/test_objective_funs.py
Comment thread tests/test_objective_funs.py
Comment thread CHANGELOG.md
- Fixes a bug in `OmnigenousField.change_resolution` when changing `L_B`.
- Scaling a `ScaledProfile` or taking power of a `PowerProfile` now only updates the `scale`/`power` attributes instead of nesting the `ScaledProfile`/`PowerProfile`s.
- `jax.Array`s in `_static_attrs` will be automatically converted to `np.ndarray` to prevent stalling code. In general, jax arrays should be omitted in `_static_attrs`.
- Fixes a bug in `_CoilObjective` for objectives which are computed per-grid node when at least one entry of `weight` is zero.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is also a breaking change here. If someone had a custom coil objective, they need to update the [self._coilset_tree["coil_mask"] to [self._coilset_tree["objective_mask"]. Can you also document that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. If the wording isn't clear, can change.

Comment thread desc/objectives/_coils.py
if it returns a scalar at every grid point. To be compatible with
masking, compute function should apply the mask
self._coilset_tree["coilset_mask"] before returning data.
self._coilset_tree["objective_mask"] before returning data.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, and needs to be documented in the changelog. But I like the new naming, which is easier to distinguish.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, documented.

Comment thread desc/objectives/_coils.py Outdated
individual coils and the coilsets to which they belong. Similarly,
params_tree["nodes"] lists the grid nodes associated with each coil.
params_tree["coilset_mask"] contains the indices in [0,self._dim_f-1]
params_tree["coil_mask"] contains the indices in [0,self._num_coils]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep the dict key name the same as before coilset_mask? The meaning is different but at least prevents breaking some code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched back to coilset_mask.

Comment thread tests/test_objective_funs.py
@YigitElma YigitElma requested review from a team, YigitElma, ddudt, dpanici, f0uriest, rahulgaur104 and unalmis and removed request for a team May 20, 2026 02:54
Comment thread CHANGELOG.md
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants