Commit 5055262
authored
Add
* Add JAX-based `find_MAP`
* add `better_optimize` to CI envs
* Fix relative import
* Remove `find_MAP` import from module-level `__init__.py`
* Update docstring
* Allow calling `find_MAP` inside model context without model argument
* Required patched better_optimize
* in-progress refactor
* More refactor
* Generalize code to use any pytensor backend
* Reconcile the two laplace approximation functions
* Use absolute import in doctest
* Fix imports
* Fix unrelated statespace test
* - Rename argument `use_jax_gradients` -> `gradient_backend`
- Rename function `laplace` -> `sample_laplace_posterior`
* Fix typo introduced by rename refactor
* use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP optimization
* Rename `test_jax_find_map.py` -> `test_find_map.py`
* Improve docstring for `fit_laplace`
* Update tests to match new signature
* Update docstringfind_MAP with close JAX integration and fix bug with Laplace fit (#385)1 parent 40714de commit 5055262
File tree
8 files changed
+1178
-163
lines changed- conda-envs
- pymc_experimental/inference
- tests
- statespace
8 files changed
+1178
-163
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
0 commit comments