diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 60926055..42e4e9dc 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,8 @@ Release Notes Upcoming Version ---------------- +* Add documentation about `LinearExpression.where` with `drop=True`. Add `BaseExpression.variable_names` property. + Version 0.6.3 -------------- @@ -11,7 +13,6 @@ Version 0.6.3 * Reinsert broadcasting logic of mask object to be fully compatible with performance improvements in version 0.6.2 using `np.where` instead of `xr.where`. - Version 0.6.2 -------------- diff --git a/examples/creating-expressions.ipynb b/examples/creating-expressions.ipynb index aafd8a09..4067018b 100644 --- a/examples/creating-expressions.ipynb +++ b/examples/creating-expressions.ipynb @@ -160,7 +160,11 @@ "cell_type": "markdown", "id": "f7578221", "metadata": {}, - "source": ".. important::\n\n\tWhen combining variables or expression with dimensions of the same name and size, the first object will determine the coordinates of the resulting expression. For example:" + "source": [ + ".. important::\n", + "\n", + "\tWhen combining variables or expression with dimensions of the same name and size, the first object will determine the coordinates of the resulting expression. For example:" + ] }, { "cell_type": "code", @@ -308,6 +312,102 @@ "(x + y).where(mask) + xr.DataArray(5, coords=[time]).where(~mask, 0)" ] }, + { + "cell_type": "markdown", + "id": "6741e69e", + "metadata": {}, + "source": [ + "Sometimes `.where` may lead to a situation where some of the variables are completely masked" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc32bdca", + "metadata": {}, + "outputs": [], + "source": [ + "mask_a = xr.DataArray(False, coords=[time])\n", + "mask_b = xr.DataArray(time > 2, coords=[time])\n", + "\n", + "z = (x.where(mask_a) + y).where(mask_b)\n", + "z" + ] + }, + { + "cell_type": "markdown", + "id": "25bf798c", + "metadata": {}, + "source": [ + "In this example you can see that many of the elements of the LinearExpression are None. If you want to remove all the None terms, you can use `.where(.., drop=True)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72c6b51b", + "metadata": {}, + "outputs": [], + "source": [ + "z = z.where(mask_b, drop=True)\n", + "z" + ] + }, + { + "cell_type": "markdown", + "id": "1c1e0b85", + "metadata": {}, + "source": [ + "That looks nicer!
" + ] + }, + { + "cell_type": "markdown", + "id": "d8530a08", + "metadata": {}, + "source": [ + "You may notice that the variable `x` is not used at all. The expression still contains two terms (one of them is unused) but it only has one variable `y`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c577863", + "metadata": {}, + "outputs": [], + "source": [ + "z.nterm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe43d47d", + "metadata": {}, + "outputs": [], + "source": [ + "z.variable_names" + ] + }, + { + "cell_type": "markdown", + "id": "a76d40b1", + "metadata": {}, + "source": [ + "You can get rid of the unused term with `.simplify()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc27341c", + "metadata": {}, + "outputs": [], + "source": [ + "z = z.simplify()\n", + "z.nterm" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/linopy/expressions.py b/linopy/expressions.py index 649989f7..dc919a72 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1078,6 +1078,30 @@ def nterm(self) -> int: """ return len(self.data._term) + @property + def variable_names(self) -> set[str]: + """ + The names of the unique variables present in the expression + """ + if self.nterm == 0: + return set() + + # Collect all unique labels from the expression (excluding -1) while preserving order + all_labels = self.vars.values.ravel() + valid_labels = all_labels[all_labels != -1] + + if len(valid_labels) == 0: + return set() + + # Get unique labels while preserving first occurrence order + unique_labels, first_indices = np.unique(valid_labels, return_index=True) + ordered_labels = unique_labels[np.argsort(first_indices)] + + # Batch lookup variable names for all labels + positions = self.model.variables.get_label_position(ordered_labels) + + return {p[0] for p in positions if p[0] is not None} + @property def shape(self) -> tuple[int, ...]: """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 0da9ec7f..e94a96b9 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1399,3 +1399,50 @@ def test_constant_only_expression_mul_linexpr_with_vars_and_const( assert not result_rev.is_constant assert (result_rev.coeffs == expected_coeffs).all() assert (result_rev.const == expected_const).all() + + +def test_variable_names() -> None: + m = Model() + time = pd.Index(range(3), name="time") + + a = m.add_variables(name="a", coords=[time]) + b = m.add_variables(name="b", coords=[time]) + + expr = a + b + assert expr.nterm == 2 + assert expr.variable_names == {"a", "b"} + + mask = xr.DataArray(False, coords=[time]) + expr = a + (b * 1).where(mask) + assert expr.nterm == 2 + assert expr.variable_names == {"a"} + + expr = (b * 1).where(mask) + assert expr.nterm == 1 + assert expr.variable_names == set() + + expr = LinearExpression.from_constant(model=m, constant=5) + assert expr.nterm == 0 + assert expr.variable_names == set() + + +def test_nterm() -> None: + m = Model() + time = pd.Index(range(3), name="time") + all_false = xr.DataArray(False, coords=[time]) + not_0 = xr.DataArray([False, True, True], coords=[time]) + not_1 = xr.DataArray([True, False, True], coords=[time]) + not_2 = xr.DataArray([True, True, False], coords=[time]) + + a = m.add_variables(name="a", coords=[time]) + b = m.add_variables(name="b", coords=[time]) + c = m.add_variables(name="c", coords=[time]) + + expr = (a.where(not_0) + b.where(not_1) + c.where(not_2)).densify_terms() + assert expr.nterm == 3 + + expr = a + b.where(all_false) + assert expr.nterm == 2 + + expr = expr.simplify() + assert expr.nterm == 1 diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fc1bb25f..b5d35865 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -360,3 +360,11 @@ def test_power_of_three(x: Variable) -> None: x**3 with pytest.raises(TypeError): (x * x) * (x * x) + + +def test_variable_names(x: Variable, y: Variable) -> None: + expr = 2 * (x * x) + 3 * y + 1 + assert expr.variable_names == {"x", "y"} + + expr = 2 * (x * x) + 1 + assert expr.variable_names == {"x"}