-
Notifications
You must be signed in to change notification settings - Fork 49
Bug Fix: weight=0 for _broadcast_input='Node' fails
#2221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ Bug Fixes | |
| - Fixes a bug in `OmnigenousField.change_resolution` when changing `L_B`. | ||
| - Scaling a `ScaledProfile` or taking power of a `PowerProfile` now only updates the `scale`/`power` attributes instead of nesting the `ScaledProfile`/`PowerProfile`s. | ||
| - `jax.Array`s in `_static_attrs` will be automatically converted to `np.ndarray` to prevent stalling code. In general, jax arrays should be omitted in `_static_attrs`. | ||
| - Fixes a bug in `_CoilObjective` for objectives which are computed per-grid node when at least one entry of `weight` is zero. | ||
|
YigitElma marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We made a new release recently. Can you move this above and add a title below it |
||
|
|
||
| Performance Improvements | ||
|
|
||
|
|
@@ -35,6 +36,11 @@ Deprecations | |
|
|
||
| - `constants` argument of `compute`, `jvp`, `jac`, `grad` and `hess` methods (including all of their variants) to all objective classes (including `ObjectiveFunction` and wrappers) is deprecated and will be removed in a future release. This argument was not necessary, and the code will still work if user doesn't pass it. Users should update their custom objectives for this change. In addition, `constants` property of the `ObjectiveFunction` and all sub-classes of `_Objective` is deprecated. | ||
|
|
||
| Breaking Changes | ||
|
|
||
| - Name change in `_CoilObjective` replacing `coilset_mask` with `objective_mask`. Custom | ||
| subclasses with `_broadcast_input="Node"` that previously used `coilset_mask` should | ||
| switch to `objective_mask`. | ||
|
|
||
| v0.17.1 | ||
| ------- | ||
|
|
@@ -73,6 +79,10 @@ or if multiple things are being optimized, `x_scale` can be a list of dict, one | |
| - Changes the import paths for ``desc.external`` to require reference to the sub-modules. | ||
| - Adds a differentiable utility for finding constant offset toroidal surfaces inside of optimizations. See [PR](https://github.com/PlasmaControl/DESC/pull/2016) for more details. | ||
| - Add support for Python 3.14 | ||
| - Adds support for optimization targeting individual coils in a coilset. | ||
| - Coil objectives accept pytree inputs for `target`, `bounds`, and `weight`. | ||
| - Able to set weights to zero, excluding certain coils from the objective. | ||
|
|
||
|
|
||
| Bug Fixes | ||
|
|
||
|
|
@@ -108,10 +118,6 @@ New Features | |
| - `field_line_integrate` function doesn't accept additional keyword-arguments related to `diffrax`, if it is necessary, they must be given through `options` dictionary. | ||
| - ``poincare_plot`` and ``plot_field_lines`` functions can now plot partial results if the integration failed. Previously, user had to pass ``throw=False`` or change the integration parameters. Users can ignore the warnings that are caused by hitting the bounds (i.e. `Terminating differential equation solve because an event occurred.`). | ||
| - `chunk_size` argument is now used for chunking the number of field lines. For the chunking of Biot-Savart integration for the magnetic field, users can use `bs_chunk_size` instead. | ||
| - Adds support for optimization targeting individual coils in a coilset. | ||
| - Coil objectives accept pytree inputs for `target`, `bounds`, and `weight`. | ||
| - Able to set weights to zero, excluding certain coils from the objective. | ||
|
|
||
|
|
||
|
|
||
| Bug Fixes | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,7 +59,7 @@ | |
| "Coil" if the objective returns a single scalar per coil, and "Node" | ||
| if it returns a scalar at every grid point. To be compatible with | ||
| masking, compute function should apply the mask | ||
| self._coilset_tree["coilset_mask"] before returning data. | ||
| self._coilset_tree["objective_mask"] before returning data. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change, and needs to be documented in the changelog. But I like the new naming, which is easier to distinguish.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, documented. |
||
| """ | ||
|
|
||
| __doc__ = __doc__.rstrip() + collect_docs(coil=True) | ||
|
|
@@ -133,7 +133,9 @@ | |
| params_tree["coils"] contains a nested list of 0s representing | ||
| individual coils and the coilsets to which they belong. Similarly, | ||
| params_tree["nodes"] lists the grid nodes associated with each coil. | ||
| params_tree["coilset_mask"] contains the indices in [0,self._dim_f-1] | ||
| params_tree["coilset_mask"] contains the indices in | ||
| [0,self._num_coils-1] for which the corresponding weight is positive. | ||
| params_tree["objective_mask"] contains the indices in [0,self._dim_f-1] | ||
| for which the corresponding weight is positive. If all weights are | ||
| positive (i.e. no masking needed), contains default slice(None). | ||
| """ | ||
|
|
@@ -166,12 +168,16 @@ | |
| self._coilset_tree = { | ||
| "coils": tree[0], | ||
| "nodes": tree[1], | ||
| "coilset_mask": slice(None), | ||
| "coilset_mask": np.arange(self._num_coils), | ||
| "objective_mask": slice(None), | ||
| } | ||
| if np.any([w == 0 for w in tree_leaves(self._weight)]): | ||
| mask = self._coilset_broadcast(self._weight) | ||
| mask = np.nonzero(mask)[0] | ||
| self._coilset_tree["coilset_mask"] = mask | ||
| coilset_mask = self._coilset_broadcast(self._weight) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: can you pass |
||
| objective_mask = self._coilset_broadcast( | ||
| self._weight, self._broadcast_input | ||
| ) | ||
| self._coilset_tree["coilset_mask"] = np.nonzero(coilset_mask)[0] | ||
| self._coilset_tree["objective_mask"] = np.nonzero(objective_mask)[0] | ||
|
|
||
| coil = self.things[0] | ||
| grid = self._grid | ||
|
|
@@ -211,12 +217,12 @@ | |
|
|
||
| _build_coilset_tree() | ||
| quad_weights = np.concatenate([g.spacing[:, 2] for g in grid])[ | ||
| self._coilset_tree["coilset_mask"] | ||
| self._coilset_tree["objective_mask"] | ||
| ] | ||
|
|
||
| if self._broadcast_input == "Node": | ||
| grid_nodes_unmasked = [ | ||
| g.num_nodes for g in grid[self._coilset_tree["coilset_mask"]] | ||
| grid[i].num_nodes for i in self._coilset_tree["coilset_mask"] | ||
| ] | ||
| self._dim_f = np.sum(grid_nodes_unmasked) | ||
| else: | ||
|
|
@@ -230,14 +236,14 @@ | |
| grid = _prune_coilset_tree(grid) | ||
| coil = _prune_coilset_tree(coil) | ||
|
|
||
| self._weight = self._coilset_broadcast(self._weight) | ||
| self._weight = self._coilset_broadcast(self._weight, self._broadcast_input) | ||
| if self._bounds: | ||
| self._bounds = ( | ||
| self._coilset_broadcast(self._bounds[0]), | ||
| self._coilset_broadcast(self._bounds[1]), | ||
| self._coilset_broadcast(self._bounds[0], self._broadcast_input), | ||
| self._coilset_broadcast(self._bounds[1], self._broadcast_input), | ||
| ) | ||
| elif self._target: | ||
| self._target = self._coilset_broadcast(self._target) | ||
| self._target = self._coilset_broadcast(self._target, self._broadcast_input) | ||
|
|
||
| timer = Timer() | ||
| if verbose > 0: | ||
|
|
@@ -294,14 +300,18 @@ | |
| assert (bounds is None) or (isinstance(bounds, tuple) and len(bounds) == 2) | ||
| if bounds: | ||
| self._bounds = ( | ||
| self._coilset_broadcast(bounds[0]), | ||
| self._coilset_broadcast(bounds[1]), | ||
| self._coilset_broadcast(bounds[0], self._broadcast_input), | ||
| self._coilset_broadcast(bounds[1], self._broadcast_input), | ||
| ) | ||
| self._check_dimensions() | ||
|
|
||
| @_Objective.target.setter | ||
| def target(self, target): | ||
| self._target = self._coilset_broadcast(target) if target is not None else target | ||
| self._target = ( | ||
| self._coilset_broadcast(target, self._broadcast_input) | ||
| if target is not None | ||
| else target | ||
| ) | ||
| self._check_dimensions() | ||
|
|
||
| @_Objective.weight.setter | ||
|
|
@@ -311,30 +321,34 @@ | |
| # objective should be rebuilt to account for masking | ||
| self._built = False | ||
|
|
||
| def _coilset_broadcast(self, x): | ||
| """Expand an array in accordance with the attribute _broadcast_input. | ||
| def _coilset_broadcast(self, x, target="Coil"): | ||
| """Broadcast an array to dimensions consistent with "target". | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : float or list[float] | ||
| Must be broadcastable to the structure of self._things[0]. | ||
| target: str, optional | ||
| Optional string taking values "Coil" or "Node". Defaults to "Coil". | ||
|
YigitElma marked this conversation as resolved.
|
||
|
|
||
| Returns | ||
| ------- | ||
| arr: float or list[float] | ||
| Float inputs are returned unchanged, and list inputs are | ||
| expanded to size self._dim_f. | ||
| """ | ||
| assert target in ["Node", "Coil"] | ||
|
|
||
| # No need to broadcast if input is a scalar | ||
| arr_flat = tree_leaves(x) | ||
| if len(arr_flat) == 1: | ||
| return np.atleast_1d(arr_flat[0]) | ||
|
|
||
| arr = jax_tree_broadcast(x, self._coilset_tree["coils"]) | ||
| if self._broadcast_input == "Node": | ||
| if target == "Node": | ||
| arr = tree_map(lambda a, b: [a] * b, arr, self._coilset_tree["nodes"]) | ||
| arr, _ = tree_flatten(arr) | ||
| return np.asarray(arr)[self._coilset_tree["coilset_mask"]] | ||
| return np.asarray(arr)[self._coilset_tree["objective_mask"]] | ||
|
|
||
|
|
||
| class CoilLength(_CoilObjective): | ||
|
|
@@ -433,7 +447,7 @@ | |
| data = super().compute(params, constants=constants) | ||
| data = tree_leaves(data, is_leaf=lambda x: isinstance(x, dict)) | ||
| out = jnp.array([dat["length"] for dat in data]) | ||
| return out[self._coilset_tree["coilset_mask"]] | ||
| return out[self._coilset_tree["objective_mask"]] | ||
|
|
||
|
|
||
| class CoilCurvature(_CoilObjective): | ||
|
|
@@ -535,7 +549,7 @@ | |
| data = super().compute(params, constants=constants) | ||
| data = tree_leaves(data, is_leaf=lambda x: isinstance(x, dict)) | ||
| out = jnp.concatenate([dat["curvature"] for dat in data]) | ||
| return out[self._coilset_tree["coilset_mask"]] | ||
| return out[self._coilset_tree["objective_mask"]] | ||
|
|
||
|
|
||
| class CoilTorsion(_CoilObjective): | ||
|
|
@@ -635,7 +649,7 @@ | |
| data = super().compute(params, constants=constants) | ||
| data = tree_leaves(data, is_leaf=lambda x: isinstance(x, dict)) | ||
| out = jnp.concatenate([dat["torsion"] for dat in data]) | ||
| return out[self._coilset_tree["coilset_mask"]] | ||
| return out[self._coilset_tree["objective_mask"]] | ||
|
|
||
|
|
||
| class CoilCurrentLength(CoilLength): | ||
|
|
@@ -741,7 +755,7 @@ | |
| lengths = super().compute(params, constants=constants) | ||
| params = tree_leaves(params, is_leaf=lambda x: isinstance(x, dict)) | ||
| currents = jnp.concatenate([param["current"] for param in params]) | ||
| out = jnp.atleast_1d(lengths * currents[self._coilset_tree["coilset_mask"]]) | ||
| out = jnp.atleast_1d(lengths * currents[self._coilset_tree["objective_mask"]]) | ||
| return out | ||
|
|
||
|
|
||
|
|
@@ -848,7 +862,7 @@ | |
| for dat in data | ||
| ] | ||
| ) | ||
| return out[self._coilset_tree["coilset_mask"]] | ||
| return out[self._coilset_tree["objective_mask"]] | ||
|
|
||
|
|
||
| class CoilSetMinDistance(_Objective): | ||
|
|
@@ -1566,7 +1580,7 @@ | |
| constants = self._get_deprecated_constants(constants) | ||
| data = tree_leaves(data, is_leaf=lambda x: isinstance(x, dict)) | ||
| out = jnp.array([jnp.var(jnp.linalg.norm(dat["x_s"], axis=1)) for dat in data]) | ||
| return (out * constants["mask"])[self._coilset_tree["coilset_mask"]] | ||
| return (out * constants["mask"])[self._coilset_tree["objective_mask"]] | ||
|
|
||
|
|
||
| class QuadraticFlux(_Objective): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.