Skip to content

Commit d9dccb0

Browse files
committed
Add README visual of diffusion evolution on Rastrigin
- Add scripts/generate_readme_figure.py producing a 4-panel static snapshot plus animated GIF of a Rastrigin run (seeded, rerunnable). - Tune fitness mapping/temperature/sigma_m so the final population spreads across multiple Rastrigin peaks instead of collapsing to the global max alone, highlighting the algorithm's multi-modal behavior. - Embed the static figure and link the GIF from the README's "What Is This?" section. - Whitelist docs/images/*.png and *.gif in .gitignore so README assets are tracked while stray plots remain ignored. - Note the addition in CHANGELOG under [Unreleased].
1 parent 362fe85 commit d9dccb0

6 files changed

Lines changed: 216 additions & 0 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ __marimo__/
217217

218218
*.png
219219
*.gif
220+
# Documentation assets are tracked; everything else (stray plots, benchmark
221+
# output) stays ignored by the rules above.
222+
!docs/images/*.png
223+
!docs/images/*.gif
220224

221225
data/
222226
mnist_checkpoints/

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ the set of symbols exported from `devol.__all__`; anything else is internal.
2121
- `examples`, `benchmark`, `dev`, and `all` optional dependency groups.
2222
- `src/devol` is now strictly typed end-to-end (`mypy --strict` clean).
2323
- Installation section in the README explaining the new extras.
24+
- README hero visual: 4-panel static figure and animated GIF showing diffusion evolution collapsing noise onto the Rastrigin fitness landscape. Reproducible via `scripts/generate_readme_figure.py`.
2425

2526
### Changed
2627

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ This reframing gives us an algorithm that naturally transitions from broad explo
1212

1313
**The intuition**: Imagine you're in a foggy room full of people, each standing at a different elevation. You can only see your immediate neighbors through the fog. To find the highest point, you don't just copy the person next to you - you look at everyone nearby, weight them by height, and move toward the weighted average. As the fog clears (denoising), your steps become smaller and more precise.
1414

15+
![Diffusion Evolution on Rastrigin: pure noise collapses onto a constellation of fitness peaks](docs/images/denoising-trajectory.png)
16+
17+
The population starts as pure noise spread across the search space (left). As denoising proceeds, the cloud organizes around the Rastrigin landscape's fitness peaks (center). At the end (right), individuals cluster on the global maximum at the origin *and* on neighbouring high-fitness modes — no explicit niching required. A full animation of the trajectory is [here](docs/images/denoising-trajectory.gif).
18+
1519
## Installation
1620

1721
```bash
5.35 MB
Loading
739 KB
Loading

scripts/generate_readme_figure.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""Generate the Rastrigin denoising trajectory figures used in the README.
2+
3+
Produces:
4+
docs/images/denoising-trajectory.png – 4-panel static snapshot
5+
docs/images/denoising-trajectory.gif – animated version across every step
6+
7+
The script is deterministic (seeded) and depends only on devol + matplotlib.
8+
Rerun with: `uv run scripts/generate_readme_figure.py`.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from pathlib import Path
14+
from typing import Any
15+
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
from matplotlib.animation import FuncAnimation, PillowWriter
19+
from matplotlib.artist import Artist
20+
from matplotlib.axes import Axes
21+
from matplotlib.figure import Figure
22+
from numpy.typing import NDArray
23+
24+
from devol import DiffusionConfig, DiffusionEvolution
25+
from devol.config import FitnessConfig, FitnessMapping, NormalType
26+
from devol.distance import FloatArray
27+
28+
# --- Configuration knobs ---------------------------------------------------
29+
30+
SEED = 42
31+
POPULATION_SIZE = 1024
32+
NUM_STEPS = 120
33+
PARAM_DIM = 2
34+
SIGMA_M = 0.5
35+
36+
# How far to stretch the initial N(0,1) noise. Pushes the starting population
37+
# close to the plot edges so the "noise → clusters" collapse is visually strong.
38+
INIT_SCALE = 4.6
39+
40+
# Exponential fitness mapping with a moderate temperature keeps enough selection
41+
# pressure to find peaks without collapsing the whole population to the global
42+
# max, so the final population visibly spreads across several Rastrigin peaks.
43+
FITNESS_CONFIG = FitnessConfig(
44+
mapping=FitnessMapping.EXPONENTIAL,
45+
temperature=2.0,
46+
normalize=NormalType.IDENTITY,
47+
)
48+
49+
BOUNDS = (-5.12, 5.12) # standard Rastrigin search region
50+
GRID_RESOLUTION = 200
51+
52+
OUTPUT_DIR = Path(__file__).resolve().parent.parent / "docs" / "images"
53+
STATIC_PATH = OUTPUT_DIR / "denoising-trajectory.png"
54+
GIF_PATH = OUTPUT_DIR / "denoising-trajectory.gif"
55+
56+
57+
def rastrigin(x: FloatArray) -> float:
58+
"""Rastrigin in 2D, converted to a maximization problem.
59+
60+
Global maximum at the origin; many regular local maxima surround it.
61+
"""
62+
a = 10.0
63+
n = x.shape[0]
64+
return float(-(a * n + np.sum(x**2 - a * np.cos(2 * np.pi * x))))
65+
66+
67+
class RecordingEvolution(DiffusionEvolution):
68+
"""DiffusionEvolution that stores a copy of the population after every step."""
69+
70+
def __init__(self, *args: object, **kwargs: object) -> None:
71+
super().__init__(*args, **kwargs) # type: ignore[arg-type]
72+
self.trajectory: list[NDArray[np.float64]] = []
73+
74+
def step(self, timestamp: int, population: NDArray[np.float64]) -> NDArray[np.float64]:
75+
new_population = super().step(timestamp, population)
76+
self.trajectory.append(new_population.copy())
77+
return new_population
78+
79+
80+
def build_landscape_grid() -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
81+
"""Evaluate Rastrigin over a regular grid for contour plotting."""
82+
axis = np.linspace(BOUNDS[0], BOUNDS[1], GRID_RESOLUTION)
83+
xx, yy = np.meshgrid(axis, axis)
84+
stacked = np.stack([xx.ravel(), yy.ravel()], axis=1)
85+
zz = np.array([rastrigin(point) for point in stacked]).reshape(xx.shape)
86+
return xx, yy, zz
87+
88+
89+
def run_evolution() -> tuple[list[NDArray[np.float64]], NDArray[np.float64]]:
90+
"""Run the seeded evolution and return the trajectory (initial + every step)."""
91+
config = DiffusionConfig(
92+
population_size=POPULATION_SIZE,
93+
num_steps=NUM_STEPS,
94+
param_dim=PARAM_DIM,
95+
sigma_m=SIGMA_M,
96+
seed=SEED,
97+
fitness=FITNESS_CONFIG,
98+
)
99+
algo = RecordingEvolution(config, rastrigin)
100+
101+
# Scale initial noise to cover the landscape. devol's default init is N(0,1); we
102+
# rescale once so the starting cloud fills the Rastrigin bounds for a stronger
103+
# "noise → structure" visual.
104+
initial_population = algo.initialize_population() * INIT_SCALE
105+
106+
algo.run(initial_population)
107+
trajectory = [initial_population.copy(), *algo.trajectory]
108+
return trajectory, initial_population
109+
110+
111+
def draw_landscape(ax: Axes, xx: NDArray[np.float64], yy: NDArray[np.float64], zz: NDArray[np.float64]) -> None:
112+
ax.contourf(xx, yy, zz, levels=30, cmap="Greys_r", alpha=0.55)
113+
ax.set_xlim(BOUNDS)
114+
ax.set_ylim(BOUNDS)
115+
ax.set_xticks([])
116+
ax.set_yticks([])
117+
ax.set_aspect("equal")
118+
119+
120+
SCATTER_KW: dict[str, Any] = dict(s=18, c="#FF3366", edgecolors="white", linewidths=0.6, alpha=0.95)
121+
122+
123+
def make_static_figure(
124+
trajectory: list[NDArray[np.float64]],
125+
xx: NDArray[np.float64],
126+
yy: NDArray[np.float64],
127+
zz: NDArray[np.float64],
128+
) -> Figure:
129+
"""Four-panel snapshot showing noise → convergence.
130+
131+
The interesting part of the denoising happens early (by t ~= T/4 the
132+
population has collapsed onto the basin), so panels are front-loaded
133+
rather than evenly spaced.
134+
"""
135+
num_frames = len(trajectory)
136+
last = num_frames - 1
137+
panel_indices = [0, max(1, num_frames // 6), max(1, num_frames // 3), last]
138+
panel_titles = [
139+
f"t = {panel_indices[0]} (pure noise)",
140+
f"t = {panel_indices[1]}",
141+
f"t = {panel_indices[2]}",
142+
f"t = {panel_indices[3]} (converged)",
143+
]
144+
145+
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
146+
for ax, idx, title in zip(axes, panel_indices, panel_titles):
147+
draw_landscape(ax, xx, yy, zz)
148+
population = trajectory[idx]
149+
ax.scatter(population[:, 0], population[:, 1], **SCATTER_KW)
150+
ax.set_title(title, fontsize=12, pad=8)
151+
152+
fig.suptitle(
153+
"Diffusion Evolution on Rastrigin (2D): noise → convergence",
154+
fontsize=14,
155+
y=1.02,
156+
)
157+
fig.tight_layout()
158+
return fig
159+
160+
161+
def make_gif(
162+
trajectory: list[NDArray[np.float64]],
163+
xx: NDArray[np.float64],
164+
yy: NDArray[np.float64],
165+
zz: NDArray[np.float64],
166+
out_path: Path,
167+
) -> None:
168+
"""Animate every recorded step."""
169+
fig, ax = plt.subplots(figsize=(5.5, 5.5))
170+
draw_landscape(ax, xx, yy, zz)
171+
scatter = ax.scatter([], [], **SCATTER_KW)
172+
title = ax.set_title("t = 0", fontsize=12, pad=8)
173+
174+
def update(frame: int) -> list[Artist]:
175+
population = trajectory[frame]
176+
scatter.set_offsets(population)
177+
title.set_text(f"t = {frame}")
178+
return [scatter, title]
179+
180+
anim = FuncAnimation(fig, update, frames=len(trajectory), interval=120, blit=False)
181+
anim.save(out_path, writer=PillowWriter(fps=12))
182+
plt.close(fig)
183+
184+
185+
def main() -> None:
186+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
187+
188+
print(f"Running evolution (seed={SEED}, steps={NUM_STEPS}, population={POPULATION_SIZE})...")
189+
trajectory, _ = run_evolution()
190+
print(f"Captured {len(trajectory)} frames.")
191+
192+
print("Building landscape grid...")
193+
xx, yy, zz = build_landscape_grid()
194+
195+
print(f"Writing static figure to {STATIC_PATH}")
196+
fig = make_static_figure(trajectory, xx, yy, zz)
197+
fig.savefig(STATIC_PATH, dpi=160, bbox_inches="tight")
198+
plt.close(fig)
199+
200+
print(f"Writing animated figure to {GIF_PATH}")
201+
make_gif(trajectory, xx, yy, zz, GIF_PATH)
202+
203+
print("Done.")
204+
205+
206+
if __name__ == "__main__":
207+
main()

0 commit comments

Comments
 (0)