Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions benchmarks/python/recurrent_snn_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright © 2023-2024 Apple Inc.

import argparse
import csv
import json
import math
import time
from itertools import product
from pathlib import Path

import mlx.core as mx

DTYPE_MAP = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}


def parse_int_list(value):
return tuple(int(v.strip()) for v in value.split(",") if v.strip())


def parse_dtype_list(value):
values = [v.strip() for v in value.split(",") if v.strip()]
invalid = [v for v in values if v not in DTYPE_MAP]
if invalid:
allowed = ", ".join(sorted(DTYPE_MAP))
raise ValueError(f"Unsupported dtype(s): {invalid}. Allowed: {allowed}")
return tuple(values)


def benchmark_runtime(fn, warmup, iters):
for _ in range(warmup):
mx.eval(fn())

tic = time.perf_counter()
for _ in range(iters):
mx.eval(fn())
toc = time.perf_counter()
return (toc - tic) * 1000.0 / iters


def make_tanh_rnn_workload(x, w_in, w_rec, bias):
def run():
batch_size = x.shape[0]
hidden_size = w_rec.shape[0]
hidden = mx.zeros((batch_size, hidden_size), dtype=x.dtype)
outputs = []
for t in range(x.shape[1]):
hidden = mx.tanh(x[:, t, :] @ w_in + hidden @ w_rec + bias)
outputs.append(hidden)
return mx.stack(outputs, axis=1)

return run


def make_lif_workload(x, w_in, w_rec, leak, threshold):
leak_value = mx.array(leak, dtype=x.dtype)
threshold_value = mx.array(threshold, dtype=x.dtype)
one = mx.array(1.0, dtype=x.dtype)

def run():
batch_size = x.shape[0]
hidden_size = w_rec.shape[0]
membrane = mx.zeros((batch_size, hidden_size), dtype=x.dtype)
spikes = mx.zeros((batch_size, hidden_size), dtype=x.dtype)
outputs = []
for t in range(x.shape[1]):
current = x[:, t, :] @ w_in + spikes @ w_rec
membrane = leak_value * membrane + current
spikes = (membrane >= threshold_value).astype(x.dtype)
membrane = membrane * (one - spikes)
outputs.append(spikes)
return mx.stack(outputs, axis=1)

return run


def run_case(batch_size, sequence_length, input_size, hidden_size, dtype_name, args):
dtype = DTYPE_MAP[dtype_name]
scale = 1.0 / math.sqrt(hidden_size)

x = mx.random.normal((batch_size, sequence_length, input_size)).astype(dtype)
w_in = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, hidden_size)
).astype(dtype)
w_rec = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, hidden_size)
).astype(dtype)
bias = mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,)).astype(dtype)
mx.eval(x, w_in, w_rec, bias)

workloads = [
("rnn_tanh_unrolled", make_tanh_rnn_workload(x, w_in, w_rec, bias)),
(
"lif_hard_reset_unrolled",
make_lif_workload(x, w_in, w_rec, args.leak, args.threshold),
),
]

rows = []
for workload_name, workload_fn in workloads:
runtime_ms = benchmark_runtime(workload_fn, args.warmup, args.iters)
effective_steps_per_second = (batch_size * sequence_length) / (
runtime_ms / 1000.0
)
rows.append(
{
"workload": workload_name,
"batch_size": batch_size,
"sequence_length": sequence_length,
"input_size": input_size,
"hidden_size": hidden_size,
"dtype": dtype_name,
"runtime_ms": runtime_ms,
"effective_steps_per_second": effective_steps_per_second,
}
)

return rows


def write_json(path, rows):
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as f:
json.dump(rows, f, indent=2)


def write_csv(path, rows):
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
field_names = [
"workload",
"batch_size",
"sequence_length",
"input_size",
"hidden_size",
"dtype",
"runtime_ms",
"effective_steps_per_second",
]
with output_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=field_names)
writer.writeheader()
writer.writerows(rows)


def print_summary(rows):
print(
"workload,dtype,batch_size,sequence_length,input_size,hidden_size,runtime_ms,steps_per_second"
)
for row in rows:
print(
f"{row['workload']},{row['dtype']},{row['batch_size']},{row['sequence_length']},{row['input_size']},{row['hidden_size']},{row['runtime_ms']:.6f},{row['effective_steps_per_second']:.2f}"
)


def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark unrolled recurrent and LIF-style workloads in MLX."
)
parser.add_argument("--batch-sizes", default=(1, 8, 32), type=parse_int_list)
parser.add_argument(
"--sequence-lengths", default=(32, 128, 256), type=parse_int_list
)
parser.add_argument("--hidden-sizes", default=(64, 256, 512), type=parse_int_list)
parser.add_argument("--input-size", default=40, type=int)
parser.add_argument(
"--dtypes",
default=("float16", "float32"),
type=parse_dtype_list,
)
parser.add_argument("--warmup", default=10, type=int)
parser.add_argument("--iters", default=100, type=int)
parser.add_argument("--leak", default=0.95, type=float)
parser.add_argument("--threshold", default=1.0, type=float)
parser.add_argument("--json-output", default=None)
parser.add_argument("--csv-output", default=None)
return parser.parse_args()


def main():
args = parse_args()
rows = []

for batch_size, sequence_length, hidden_size, dtype_name in product(
args.batch_sizes,
args.sequence_lengths,
args.hidden_sizes,
args.dtypes,
):
rows.extend(
run_case(
batch_size=batch_size,
sequence_length=sequence_length,
input_size=args.input_size,
hidden_size=hidden_size,
dtype_name=dtype_name,
args=args,
)
)

print_summary(rows)

if args.json_output:
write_json(args.json_output, rows)
if args.csv_output:
write_csv(args.csv_output, rows)


if __name__ == "__main__":
main()
24 changes: 24 additions & 0 deletions docs/src/python/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@ Quick Start with Neural Networks
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)

Recurrent Patterns for Temporal Inputs
--------------------------------------

For temporal and event-like inputs, recurrent layers can be run in fixed
windows while carrying hidden state between windows. This pattern is useful for
streaming and SNN-style unrolled training loops where each chunk should keep
context from the previous chunk.

.. code-block:: python

import mlx.core as mx
import mlx.nn as nn

rnn = nn.RNN(input_size=40, hidden_size=128)
hidden = None

# stream_chunks yields arrays with shape [batch, time, features]
for chunk in stream_chunks:
y = rnn(chunk, hidden=hidden)
hidden = y[:, -1, :]

When using :class:`GRU` or :class:`LSTM`, the same chunked approach applies.
For :class:`LSTM`, carry both hidden and cell states between windows.

.. _module_class:

The Module Class
Expand Down
88 changes: 88 additions & 0 deletions python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,94 @@ def test_lstm(self):
self.assertEqual(h_out.shape, (44, 12))
self.assertEqual(c_out.shape, (44, 12))

def test_recurrent_dtype_propagation(self):
for dtype in (mx.float16, mx.bfloat16):
inp_batched = mx.random.normal((2, 32, 5)).astype(dtype)
inp_unbatched = mx.random.normal((32, 5)).astype(dtype)

rnn = nn.RNN(5, 12)
gru = nn.GRU(5, 12)
lstm = nn.LSTM(5, 12)
rnn.set_dtype(dtype)
gru.set_dtype(dtype)
lstm.set_dtype(dtype)

rnn_out = rnn(inp_batched)
self.assertEqual(rnn_out.dtype, dtype)
self.assertEqual(rnn(inp_unbatched).dtype, dtype)

gru_out = gru(inp_batched)
self.assertEqual(gru_out.dtype, dtype)
self.assertEqual(gru(inp_unbatched).dtype, dtype)

lstm_hidden, lstm_cell = lstm(inp_batched)
self.assertEqual(lstm_hidden.dtype, dtype)
self.assertEqual(lstm_cell.dtype, dtype)

lstm_hidden, lstm_cell = lstm(inp_unbatched)
self.assertEqual(lstm_hidden.dtype, dtype)
self.assertEqual(lstm_cell.dtype, dtype)

def test_recurrent_gradient_parity(self):
def assert_grads_close(layer, loss_fn, *loss_args):
_, module_grads = nn.value_and_grad(layer, loss_fn)(layer, *loss_args)

def pure_loss(params, *args):
layer.update(params)
return loss_fn(layer, *args)

pure_grads = mx.grad(pure_loss)(layer.trainable_parameters(), *loss_args)

module_flat = tree_flatten(module_grads, destination={})
pure_flat = tree_flatten(pure_grads, destination={})
self.assertEqual(module_flat.keys(), pure_flat.keys())
for key in module_flat:
self.assertTrue(
mx.allclose(module_flat[key], pure_flat[key], atol=1e-5, rtol=1e-5),
f"Gradient mismatch for {key}",
)

x = mx.random.normal((2, 24, 5))
h0 = mx.random.normal((2, 12))
c0 = mx.random.normal((2, 12))

def rnn_loss(model, x, hidden):
return model(x, hidden=hidden).sum()

def gru_loss(model, x, hidden):
return model(x, hidden=hidden).sum()

def lstm_loss(model, x, hidden, cell):
h_out, c_out = model(x, hidden=hidden, cell=cell)
return h_out.sum() + c_out.sum()

assert_grads_close(nn.RNN(5, 12), rnn_loss, x, h0)
assert_grads_close(nn.GRU(5, 12), gru_loss, x, h0)
assert_grads_close(nn.LSTM(5, 12), lstm_loss, x, h0, c0)

def test_recurrent_long_sequence_stability(self):
seq_len = 256
inp = mx.random.normal((2, seq_len, 5)).astype(mx.float32)
h0 = mx.random.normal((2, 12))
c0 = mx.random.normal((2, 12))

def assert_finite(name, arr):
arr_np = np.array(arr)
self.assertTrue(np.isfinite(arr_np).all(), f"{name} has non-finite values")

rnn = nn.RNN(5, 12)
gru = nn.GRU(5, 12)
lstm = nn.LSTM(5, 12)

rnn_out = rnn(inp, hidden=h0)
gru_out = gru(inp, hidden=h0)
lstm_hidden, lstm_cell = lstm(inp, hidden=h0, cell=c0)

assert_finite("rnn_out", rnn_out)
assert_finite("gru_out", gru_out)
assert_finite("lstm_hidden", lstm_hidden)
assert_finite("lstm_cell", lstm_cell)

def test_quantized_embedding(self):
emb = nn.Embedding(32, 256)
qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)
Expand Down