Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,79 @@ on:
pull_request:

jobs:
tests:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Set up Python
run: uv python install 3.13

- name: Install dev dependencies
run: uv sync --extra dev

- name: Ruff check
run: uv run ruff check

- name: Ruff format check
run: uv run ruff format --check

typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Set up Python
run: uv python install 3.13

- name: Install dev dependencies
run: uv sync --extra dev

- name: Mypy
run: uv run mypy src/devol

test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]

python-version: ["3.11", "3.12", "3.13"]
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Set up Python
run: uv python install ${{ matrix.python-version }}

- name: Install dependencies
- name: Install dev dependencies
run: uv sync --extra dev

- name: Run tests
env:
MPLBACKEND: Agg
run: uv run pytest
run: uv run pytest

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Set up Python
run: uv python install 3.13

- name: Build sdist and wheel
run: uv build

- name: Validate package metadata
run: uv run --with twine twine check dist/*
8 changes: 2 additions & 6 deletions benchmark/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def display_results(results: list[BenchmarkMetrics]) -> None:
display_parameter_sensitivity(results, console)


def display_schedule_summary(
results: list[BenchmarkMetrics], console: Console
) -> None:
def display_schedule_summary(results: list[BenchmarkMetrics], console: Console) -> None:
"""Display summary statistics for each schedule type."""
table = Table(title="Schedule Performance Summary")

Expand Down Expand Up @@ -92,9 +90,7 @@ def display_best_configs(results: list[BenchmarkMetrics], console: Console) -> N
console.print(table)


def display_parameter_sensitivity(
results: list[BenchmarkMetrics], console: Console
) -> None:
def display_parameter_sensitivity(results: list[BenchmarkMetrics], console: Console) -> None:
"""Display how performance varies with each parameter."""
params = ["population_size", "num_steps", "param_dim", "sigma_m"]

Expand Down
4 changes: 1 addition & 3 deletions benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def calculate_population_diversity(fitness: NDArray) -> float:
return float(np.std(fitness))


def evaluate_population_fitness(
population: NDArray, fitness_fn: FitnessFunction
) -> NDArray:
def evaluate_population_fitness(population: NDArray, fitness_fn: FitnessFunction) -> NDArray:
"""Evaluate fitness for entire population.

Args:
Expand Down
23 changes: 5 additions & 18 deletions benchmark/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ def create_heatmaps(results: list[BenchmarkMetrics], output_dir: str) -> None:
global_max = max(global_max, valid_data.max())

# Second pass: plot with consistent scale
for ax, sched, (heatmap_data, pop_sizes, step_counts) in zip(
axes, schedules, all_heatmaps
):
im = ax.imshow(
heatmap_data, cmap="RdYlGn_r", aspect="auto", vmin=global_min, vmax=global_max
)
for ax, sched, (heatmap_data, pop_sizes, step_counts) in zip(axes, schedules, all_heatmaps):
im = ax.imshow(heatmap_data, cmap="RdYlGn_r", aspect="auto", vmin=global_min, vmax=global_max)
ax.set_title(f"{sched} Schedule")
ax.set_xlabel("Population Size")
ax.set_ylabel("Num Steps")
Expand Down Expand Up @@ -111,11 +107,7 @@ def create_line_plots(results: list[BenchmarkMetrics], output_dir: str) -> None:
stds = []

for val in param_vals:
distances = [
r.distance_from_origin
for r in sched_results
if getattr(r, param_name) == val
]
distances = [r.distance_from_origin for r in sched_results if getattr(r, param_name) == val]
means.append(np.mean(distances))
stds.append(np.std(distances))

Expand All @@ -139,9 +131,7 @@ def create_line_plots(results: list[BenchmarkMetrics], output_dir: str) -> None:
ax.set_xscale("log")

plt.tight_layout()
plt.savefig(
f"{output_dir}/lineplot_{param_name}.png", dpi=150, bbox_inches="tight"
)
plt.savefig(f"{output_dir}/lineplot_{param_name}.png", dpi=150, bbox_inches="tight")
plt.close()


Expand All @@ -155,10 +145,7 @@ def create_boxplots(results: list[BenchmarkMetrics], output_dir: str) -> None:
fig, ax = plt.subplots(figsize=(10, 6))

schedules = sorted(set(r.schedule_type for r in results))
data = [
[r.distance_from_origin for r in results if r.schedule_type == sched]
for sched in schedules
]
data = [[r.distance_from_origin for r in results if r.schedule_type == sched] for sched in schedules]

bp = ax.boxplot(data, labels=schedules, patch_artist=True)

Expand Down
4 changes: 1 addition & 3 deletions benchmark/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def __init__(
n_workers: Number of parallel workers (default: CPU count)
"""
self.fitness_fn = fitness_fn
self.schedule_types = [
ScheduleType(s) if isinstance(s, str) else s for s in schedule_types
]
self.schedule_types = [ScheduleType(s) if isinstance(s, str) else s for s in schedule_types]
self.population_sizes = population_sizes
self.num_steps_list = num_steps_list
self.param_dims = param_dims
Expand Down
2 changes: 1 addition & 1 deletion examples/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def run_cartpole() -> None:
if test_reward >= 475:
print("✓ Task solved! (reward >= 475)")
else:
print(f"Task not yet solved. Try more steps or larger population.")
print("Task not yet solved. Try more steps or larger population.")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/cartpole_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def run_cartpole(show_progress: bool = False, compare: bool = True) -> None:
seed=42,
)

print(f"\nRunning evolution:")
print("\nRunning evolution:")
print(f" Population size: {config.population_size}")
print(f" Evolution steps: {config.num_steps}")
# print(f" Distance metric: latent space (dim={config.distance.latent_dim})")
Expand All @@ -170,7 +170,7 @@ def run_cartpole(show_progress: bool = False, compare: bool = True) -> None:
final_population = algo.run()
best_individual, best_fitness = algo.get_best_individual()

print(f"\n" + "=" * 60)
print("\n" + "=" * 60)
print("EVOLUTION COMPLETE")
print("=" * 60)
print(f"Best fitness achieved: {best_fitness:.2f}")
Expand All @@ -182,7 +182,7 @@ def run_cartpole(show_progress: bool = False, compare: bool = True) -> None:
if test_reward >= 475:
print("✓ Task SOLVED! (reward >= 475)")
else:
print(f"✗ Task not solved yet (need >= 475)")
print("✗ Task not solved yet (need >= 475)")

print("\nDemonstrating best evolved controller:")
demonstrate_controller(best_individual, num_demos=3)
Expand Down
5 changes: 2 additions & 3 deletions examples/mnist/fitness.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""GPU-batched fitness evaluation for MNIST."""

import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from numpy.typing import NDArray

from examples.mnist.serialization import deserialize_model

Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Parameter serialization and deserialization for LeNet models."""

import numpy as np
import torch
import torch.nn as nn
import numpy as np
from numpy.typing import NDArray

from examples.mnist.lenet5 import LeNet5, LeNetMini, create_lenet5, create_lenet_mini
Expand Down
5 changes: 1 addition & 4 deletions examples/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import json
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand All @@ -12,8 +11,6 @@
from devol.algorithm import DiffusionEvolution
from examples.mnist.config import MNISTConfig
from examples.mnist.fitness import MNISTFitnessEvaluator
from examples.mnist.lenet5 import count_parameters, create_lenet5, create_lenet_mini
from examples.mnist.serialization import create_random_individual, serialize_model


class MNISTEvolution(DiffusionEvolution):
Expand Down Expand Up @@ -99,7 +96,7 @@ def run(self) -> NDArray:
break

if self.patience_counter >= self.mnist_config.early_stopping_patience:
print(f"\n✗ Early stopping triggered (patience exhausted)")
print("\n✗ Early stopping triggered (patience exhausted)")
break

self.population = population
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ packages = ["src/devol", "examples"]
[tool.ruff]
line-length = 120
target-version = "py311"
# Demo scripts are linted separately; they use argparse globals, numpy meshgrid
# conventions, and half-finished code that doesn't belong in the published package.
exclude = ["examples"]

[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
Expand All @@ -84,3 +87,4 @@ strict = true
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
files = ["src/devol"]
24 changes: 10 additions & 14 deletions src/devol/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
from collections.abc import Callable

import numpy as np
from numpy.typing import NDArray

from devol.config import DiffusionConfig
from devol.distance import create_distance_computer
from devol.distance import FloatArray, create_distance_computer
from devol.evolution import compute_epsilon_hat, estimate_x0, evolution_step
from devol.fitness import create_fitness_mapper, create_fitness_normalizer
from devol.schedules import create_alpha_schedule, create_sigma_schedule


class DiffusionEvolution:
def __init__(self, config: DiffusionConfig, fitness_fn: Callable[[NDArray], float]) -> None:
def __init__(self, config: DiffusionConfig, fitness_fn: Callable[[FloatArray], float]) -> None:
self.config = config
self.fitness_fn = fitness_fn
self.rng = np.random.default_rng(config.seed)
Expand All @@ -34,19 +33,16 @@ def __init__(self, config: DiffusionConfig, fitness_fn: Callable[[NDArray], floa
config.fitness.temperature,
)

self.population: NDArray | None = None
self.population: FloatArray | None = None

# TODO: Is this init optimal? do we want to abstract it?
# Make it a docstring
# Explain how the noising op is shifting the original pdf to a ~N(0, 1)
def initialize_population(self) -> NDArray: # TODO: maybe make it of type Population
def initialize_population(self) -> FloatArray:
self.population = self.rng.standard_normal((self.config.population_size, self.config.param_dim))
return self.population

def evaluate_fitness(self, population: NDArray) -> NDArray:
def evaluate_fitness(self, population: FloatArray) -> FloatArray:
return np.array([self.fitness_fn(ind) for ind in population])

def step(self, timestamp: int, population: NDArray) -> NDArray:
def step(self, timestamp: int, population: FloatArray) -> FloatArray:
fitness = self.evaluate_fitness(population)
normalized_fitness = self.fitness_normalizer(fitness)

Expand All @@ -66,7 +62,7 @@ def step(self, timestamp: int, population: NDArray) -> NDArray:

return new_population

def run(self, initial_population: NDArray | None) -> NDArray:
def run(self, initial_population: FloatArray | None = None) -> FloatArray:
population = initial_population
if population is None:
population = self.initialize_population()
Expand All @@ -77,9 +73,9 @@ def run(self, initial_population: NDArray | None) -> NDArray:
self.population = population
return population

def get_best_individual(self) -> tuple[NDArray, float]:
def get_best_individual(self) -> tuple[FloatArray, float]:
if self.population is None:
raise ValueError("Algorithm has not been run yet")
fitness = self.evaluate_fitness(self.population)
best_idx = np.argmax(fitness)
return self.population[best_idx], fitness[best_idx]
best_idx = int(np.argmax(fitness))
return self.population[best_idx], float(fitness[best_idx])
11 changes: 6 additions & 5 deletions src/devol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,36 @@

import argparse
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any

import numpy as np
from pydantic_yaml import parse_yaml_file_as

from devol.algorithm import DiffusionEvolution
from devol.config import DiffusionConfig
from devol.distance import FloatArray


def load_config(config_path: str) -> DiffusionConfig:
"""Load configuration from YAML file."""
return parse_yaml_file_as(DiffusionConfig, Path(config_path))


def sphere_function(x: np.ndarray) -> float:
def sphere_function(x: FloatArray) -> float:
"""Sphere function: maximize -(x^2)."""
return -np.sum(x**2)
return float(-np.sum(x**2))


def rosenbrock_function(x: np.ndarray) -> float:
def rosenbrock_function(x: FloatArray) -> float:
"""Rosenbrock function (minimization converted to maximization)."""
result = 0.0
for i in range(len(x) - 1):
result += 100 * (x[i + 1] - x[i] ** 2) ** 2 + (1 - x[i]) ** 2
return -result


BUILTIN_FUNCTIONS: dict[str, Any] = {
BUILTIN_FUNCTIONS: dict[str, Callable[[FloatArray], float]] = {
"sphere": sphere_function,
"rosenbrock": rosenbrock_function,
}
Expand Down
Loading
Loading