Skip to content

Comments

Add gradient-based optimization (CG) and multi-start support#65

Open
Marius1311 wants to merge 1 commit intomainfrom
feature/gradient-optimization
Open

Add gradient-based optimization (CG) and multi-start support#65
Marius1311 wants to merge 1 commit intomainfrom
feature/gradient-optimization

Conversation

@Marius1311
Copy link
Collaborator

Summary

This adds gradient-based optimizers (CG, L-BFGS-B, BFGS) as alternatives to Nelder-Mead for the GPCCA rotation matrix optimization, along with a multi-start mechanism to improve solution quality.

Motivation

The existing Nelder-Mead optimizer works well for small numbers of macrostates but becomes very slow for m > 15 — it optimizes over O(m²) parameters without gradient information. In CellRank, users increasingly want to compute 20+ macrostates, which motivated exploring gradient-based alternatives.

What's in this PR

Analytical Jacobian (_jacobian): Computes the exact gradient of the crispness objective by backpropagating through _fill_matrix. The three transformations in _fill_matrix — row-sum constraint, max condition, rescaling — introduce dependencies between matrix entries that a naive gradient (differentiating the objective w.r.t. individual entries) would miss. The implementation backpropagates through all three steps correctly. Verified against scipy.optimize.approx_fprime in tests.

Gradient-based optimizers: _opt_soft now accepts a method parameter. For "CG", "L-BFGS-B", and "BFGS", it uses scipy.optimize.minimize with the analytical Jacobian. Nelder-Mead remains the default and uses the existing fmin path — behavior is unchanged when no method is specified.

Multi-start optimization (_perturb_rotation, _gpcca_core): For n_starts > 1, the first run uses the deterministic ISA initialization, and subsequent runs perturb the initial rotation matrix on the SO(k) manifold via expm(epsilon * S) with random skew-symmetric S. The result with the best crispness is kept. Degenerate solutions (where a cluster is never the argmax for any cell) are filtered out. Default is n_starts=1 (fully backward compatible).

_fill_matrix refactoring: Added an optional _return_intermediates parameter that returns the pre-scaling matrix, argmax indices, and scale factor needed by the Jacobian backward pass. This avoids duplicating the forward-pass logic between _fill_matrix and _jacobian.

Testing we did

We benchmarked on the CellRank pancreas dataset (2,531 cells, PseudotimeKernel) and bone marrow dataset (5,780 cells) across m = 3..50:

  • CG with 10 random restarts matches or exceeds Nelder-Mead crispness at m ≥ 8, while being significantly faster
  • CG scales to m = 50 (where Nelder-Mead is infeasible) in seconds
  • Multi-start (n_starts=10, perturbation_scale=0.1) reliably improves over single-start CG
  • High agreement between CG and NM solutions where both converge (mean Pearson r ≥ 0.94 on membership vectors)
  • All CG solutions were non-degenerate on both datasets

Benchmark scripts are in a separate analysis repo.

Tests

14 new tests covering:

  • Jacobian correctness vs finite differences (m=3, m=5)
  • CG vs Nelder-Mead crispness on the standard test matrices
  • Valid memberships from all 4 optimizer methods
  • Full GPCCA.optimize() pipeline with CG
  • _perturb_rotation preserves orthogonality (det of applied rotation = ±1)
  • Multi-start crispness ≥ single-start
  • Seed determinism (same seed → identical results)

API

All new parameters on GPCCA.optimize():

  • method: "Nelder-Mead" (default), "CG", "L-BFGS-B", "BFGS"
  • n_starts: number of optimization runs (default 1)
  • perturbation_scale: angular scale for SO(k) perturbation (default 0.1)
  • seed: random seed for reproducibility

Defaults reproduce the existing behavior exactly.

Add analytical Jacobian for the GPCCA rotation matrix objective,
enabling gradient-based optimizers (CG, L-BFGS-B, BFGS) as alternatives
to Nelder-Mead. Add multi-start optimization via SO(k) rotation
perturbation to escape local optima.

Key changes:
- _jacobian(): backpropagates through _fill_matrix (row-sum constraint,
  max condition, rescaling) for the correct gradient
- _perturb_rotation(): perturbs rotation matrix on the SO(k) manifold
  via expm(epsilon * S) with random skew-symmetric S
- _opt_soft(): dispatches to scipy.optimize.minimize for gradient-based
  methods, keeps fmin for Nelder-Mead
- _gpcca_core(): multi-start loop with degeneracy filtering
- _fill_matrix(): optional _return_intermediates for Jacobian backward
  pass, avoiding forward-pass duplication
- GPCCA.optimize(): new parameters method, n_starts, perturbation_scale,
  seed -- all backward compatible (defaults reproduce old behavior)

Tests:
- Jacobian vs finite differences (m=3, m=5)
- CG vs Nelder-Mead crispness comparison
- All 4 methods produce valid memberships
- Full GPCCA pipeline with CG
- _perturb_rotation preserves orthogonality
- Multi-start >= single-start crispness
- Seed determinism
@Marius1311 Marius1311 requested review from msmdev February 18, 2026 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant