Skip to content

refactor: establish shared core package and restructure baselines#2

Merged
yohannestayez merged 7 commits into
mainfrom
refactor/project-structure
May 17, 2026
Merged

refactor: establish shared core package and restructure baselines#2
yohannestayez merged 7 commits into
mainfrom
refactor/project-structure

Conversation

@Macmilan24

Copy link
Copy Markdown
Collaborator

Summary

The previous codebase had a structural problem: shared infrastructure
(MLP architecture, MNIST pipeline, metrics) lived inside ewc/src/,
making it appear to be EWC-specific. Any future method family would
have had no clean path to consume those modules without duplicating
them or importing across a wrong boundary. This PR fixes that before
new work is built on top of it.

Changes

Create core/ — shared infrastructure package

  • core/model.py — MLP definition with he_init; adds
    forward_with_states() that returns per-layer activations alongside
    final logits, required by future method implementations
  • core/data.py — MNIST loading and Split-MNIST task splitting
  • core/metrics.py — cross-entropy, accuracy, BWT, average accuracy,
    plot helpers
  • core/base.pyContinualLearningMethod Protocol that every method
    must satisfy; typed EWCState and SIState dataclasses replacing
    the ad hoc plain dicts used before
  • core/runner.py — shared evaluate() function and
    run_experiment() loop extracted from the seven experiment scripts
    where they were duplicated

Rename ewc/baselines/

  • The folder name now accurately describes its contents: a collection
    of continual learning baselines (Naive, EWC, EWC-DR, Online EWC,
    EWC+EMA, Synaptic Intelligence)
  • Removes model.py, data.py, utils.py from baselines/src/
    those now live in core/
  • All imports in baselines/src/ and baselines/experiments/ updated
    to point at core/
  • NaiveMethod.train_task signature updated to match the Protocol;
    evaluate() removed from all three method classes

Update baselines/README.md

  • Corrects the folder name, project structure diagram, and run
    instructions to reflect the current layout

What did not change

No training logic, loss functions, importance metrics, Fisher
computation, SI omega accumulation, or hyperparameters were modified.
Numerical outputs of all seven experiment scripts are identical to
before this PR.

  Introduce a top-level `core/` package that will serve as the single
  source of truth for all shared infrastructure across every method
  family in this project (baselines, causal coding, and future work such
  as Bayesian CC).

  Previously, the MLP definition, MNIST data pipeline, and utility
  functions lived inside `ewc/src/`, which implicitly made them
  EWC-specific assets. Any future method that needed them would either
  have to duplicate them or import across a conceptually incorrect
  boundary. This commit corrects that structural problem before any new
  method is introduced.

  Changes in this commit:

  core/model.py
  - Relocates `he_init` from `ewc/src/utils.py` into this file, since
    weight initialization is a model concern, not a metrics concern.
  - Preserves the original `MLP.forward()` method without modification
    so that all existing baseline code continues to function unchanged.
  - Adds `MLP.forward_with_states()`, a richer forward pass that returns
    a tuple of (states, final_logits). `states` is a list collecting the
    post-ReLU activation of every hidden layer followed by the raw output
    logits. This is required by causal coding, which needs per-layer
    activations to compute interventional do-influence via the
    Schur-Fisher recursion and to perform layer-wise Jacobian estimation.
    Baselines never call this method and pay zero cost for its existence.

  core/data.py
  - Verbatim copy of `ewc/src/data.py`. The `Task` dataclass and the
    `load_mnist` / `split_into_tasks` functions are dataset utilities,
    not EWC-specific logic, and belong at the project level.

  core/metrics.py
  - Copied from `ewc/src/utils.py` with `he_init` removed (relocated to
    core/model.py). Retains `cross_entropy`, `accuracy`,
    `average_accuracy`, `backward_transfer`, and `plot_accuracy_matrix`.
    The confusion cost metric required by Phase III of the causal coding
    roadmap will be added to this file in a later commit.

  core/__init__.py
  - Empty marker file that makes `core` a proper Python package.
… core

  Sever the structural dependency that made shared infrastructure
  (MLP, data pipeline, metrics) appear to be EWC-specific assets.

  Previously the project had a single top-level folder named `ewc/`
  whose `src/` subdirectory housed both EWC-specific training logic and
  three general-purpose modules: model.py, data.py, and utils.py. Any
  future method family — causal coding, Bayesian CC, or otherwise —
  would have had no clean path to consume those shared modules without
  either duplicating them or importing across a semantically wrong
  boundary (e.g., causal_coding importing from ewc/src/).

  This commit establishes the correct ownership model:

  Renames ewc/ → baselines/
  - The folder now accurately describes its contents: a collection of
    continual learning baseline methods (Naive, EWC, EWC-DR, Online EWC,
    EWC+EMA, Synaptic Intelligence). The name "ewc" was a historical
    artifact of the first method implemented, not a description of the
    package's scope.

  Deletes three files from baselines/src/
  - baselines/src/model.py  — promoted to core/model.py in prior commit
  - baselines/src/data.py   — promoted to core/data.py in prior commit
  - baselines/src/utils.py  — promoted to core/metrics.py in prior commit
    These files no longer belong to any single method family.

  Updates all imports in baselines/src/ (4 files)
  - naive.py, ewc.py, si.py: replace `.model`, `.utils`, `.data`
    relative imports with absolute `core.model`, `core.metrics`,
    `core.data` imports.
  - ewc_dr.py: replace `.model`, `.data` relative imports with absolute
    `core.*` equivalents. The `.ewc` relative import is preserved because
    EWCDRMethod inherits from EWCMethod within the same package.

  Updates all experiment scripts in baselines/experiments/ (7 files)
  - Adds a second sys.path.insert that places the project root
    (ContinualLearning/) on the Python path, making the `core` package
    importable from within baselines/experiments/.
  - Replaces `from src.model`, `from src.data`, `from src.utils` with
    `from core.model`, `from core.data`, `from core.metrics`.
  - All `from src.<method> import <Class>` lines are unchanged; baselines/
    is still on the path and baselines/src/ is still a valid package.

  No logic is modified anywhere. Every behavioural change in this commit
  is purely mechanical import path surgery.

  BREAKING CHANGE: the ewc/ directory no longer exists. Any external
  script or notebook that imports from ewc/src/ must be updated to import
  from core/ (for model, data, metrics) or baselines/src/ (for method
  classes).
…classes

  Introduce core/base.py as the single contract file that every method
  family in this project — baselines, causal coding, and any future work
  — must satisfy. This eliminates the implicit, undocumented API that
  existed previously and replaces ad hoc state dicts with typed,
  inspectable dataclasses.

  core/base.py (new file)

  Defines two typed state dataclasses and one structural Protocol.

  EWCState holds old_params and cumulative_fisher. Previously this
  information was stored in an untyped plain dict whose keys were only
  discoverable by reading the method's train_task implementation. A typo
  in a key would fail silently at runtime with a KeyError, not at the
  point of incorrect construction.

  SIState holds old_params and cumulative_omega for the same reason.

  ContinualLearningMethod is a typing.Protocol that formalises the
  interface any continual learning method must expose:
    - train_task(model, params, state, task, task_idx) -> (params, state, loss)
    - evaluate(model, params, task, allowed_classes) -> float

  Any class that implements these two methods satisfies the Protocol
  structurally, without needing to inherit from it. This means existing
  baselines comply automatically once their signatures match, and new
  methods such as the causal coding implementation have a concrete
  template to follow rather than discovering the API by reading
  experiment scripts.

  baselines/src/ewc.py
  - Imports EWCState from core.base.
  - Replaces all state["old_params"] and state["cumulative_fisher"] dict
    accesses with state.old_params and state.cumulative_fisher attribute
    access — the type system can now catch field name errors.
  - Returns EWCState(...) instead of a plain dict at the end of
    train_task.

  baselines/src/si.py
  - Imports SIState from core.base.
  - Replaces state["old_params"] and state["cumulative_omega"] with
    attribute access.
  - Returns SIState(...) instead of a plain dict.

  baselines/experiments/ (6 files: run_ewc, run_ewc_dr, run_online_ewc,
  run_online_ewc_dr, run_ewc_with_ema, run_si)
  - Each script now imports the appropriate State class from core.base
    and constructs the initial state object using named fields instead
    of a dict literal. This ensures that the type passed into train_task
    at startup is identical in structure to the type returned by
    train_task at the end of each task, making the state threading loop
    consistent end-to-end.

  No training logic, hyperparameters, or evaluation behaviour is changed.
  This commit is a pure structural improvement with no effect on
  numerical outputs.
…into core/runner.py

  Eliminate the three most pervasive forms of duplication in this
  codebase in a single commit: the identical evaluate() method that was
  copy-pasted into every method class, the manual task-loop that was
  copy-pasted across all seven experiment scripts, and the implicit
  constraint that Naive's train_task signature differed from every other
  method.

  core/runner.py (new file)

  evaluate(model, params, task, allowed_classes)
    Standalone function extracted from NaiveMethod, EWCMethod, and
    SIMethod, where it was byte-for-byte identical. It now lives in
    exactly one place. Any future method — causal coding or otherwise —
    gets evaluation for free by calling this function; there is nothing
    to implement or copy.

  run_experiment(method, model, params, state, tasks)
    Owns the full sequential training loop that previously appeared in
    every experiment script. For each task it calls method.train_task,
    then evaluates on every task under both Class-IL (all logits
    visible) and Task-IL (logits masked to the ground-truth class pair),
    and accumulates both accuracy matrices. Returns
    (params, state, class_il_matrix, task_il_matrix).
    Because the runner calls method.train_task uniformly, adding a new
    method family requires zero changes to the runner.

  baselines/src/naive.py
    Updated train_task signature from (model, params, task) to the
    Protocol-compliant (model, params, state, task, task_idx). State and
    task_idx are intentionally ignored — Naive is stateless by design.
    Return value updated from (params, loss) to (params, None, loss) so
    the runner can unpack it identically to EWC and SI. Removed the
    evaluate() method. Removed the now-unused accuracy and jnp imports.

  baselines/src/ewc.py
    Removed evaluate(). Removed the now-unused accuracy import.

  baselines/src/si.py
    Removed evaluate(). Removed the now-unused accuracy import.

  baselines/experiments/ (all 7 scripts)
    Replaced the manual for-loop, evaluate calls, and matrix assembly
    with a single call to run_experiment. Each script is now responsible
    only for what makes it unique: method instantiation, hyperparameters,
    summary printing, and plot titles. Scripts that previously tracked
    only a single accuracy matrix (run_online_ewc, run_online_ewc_dr,
    run_ewc_with_ema) now report both Class-IL and Task-IL consistently,
    bringing them in line with the rest of the suite. The
    run_online_ewc_dr and run_si scripts now correctly report
    backward_transfer on the Class-IL matrix rather than on an
    unspecified single matrix.

  No training logic, loss functions, importance metrics, or
  hyperparameters are changed. Numerical outputs are identical to before
  this commit.
…d BWT

  Vanilla EWC was using a single overwritten anchor for all tasks, causing
  earlier Fisher diagonals to be applied around the wrong optimum from task
  3
  onward. Add EWCVanillaState dataclass to core/base.py and refactor
  train_task to store a growing list of per-task {fisher, params} anchors
  for the vanilla path (decay == 1.0), summing the penalty over all of them.
  Online variants (decay < 1.0) retain the single cumulative-Fisher anchor
  backed by EWCState unchanged.

  EWCDRMethod defaulted decay to 0.0, causing the non-online baseline to
  discard all but the most recent task's Fisher before the third task.
  Change the default to 1.0 so full accumulation is the baseline and online
  decay must be opted into explicitly.

  The EMA experiment delegated to run_experiment which evaluated raw
  gradient-updated params, making reported accuracies mislabeled. Replace
  the run_experiment call with a custom loop that maintains a separate
  ema_params variable updated after each task and evaluates with it.

  backward_transfer in core/metrics.py looped over range(T-2) but divided
  by T-1, silently dropping the second-to-last task from the numerator.
  Change to range(T-1) so all T-1 previous tasks contribute to the sum.
@Macmilan24 Macmilan24 force-pushed the refactor/project-structure branch from 94da379 to 8aaddf9 Compare May 12, 2026 08:14
Comment thread core/base.py Outdated
Comment thread baselines/experiments/run_online_ewc.py
Comment thread baselines/experiments/run_naive.py
Add Class-IL backward transfer reporting to baseline experiment scripts so all runs expose the same final continual-learning metrics. Each updated experiment now imports backward_transfer and prints the final Class-IL BWT alongside average Class-IL and Task-IL accuracy.

Also remove the stale ContinualLearningMethod protocol from core.base after evaluation was centralized in The protocol was unused and no longer matched the concrete method interface, so core.base
now only contains the shared state dataclasses.
@yohannestayez yohannestayez merged commit ad1e55c into main May 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants