diff --git a/benchmarks/python/recurrent_snn_bench.py b/benchmarks/python/recurrent_snn_bench.py new file mode 100644 index 0000000000..7f41ecbf04 --- /dev/null +++ b/benchmarks/python/recurrent_snn_bench.py @@ -0,0 +1,214 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import csv +import json +import math +import time +from itertools import product +from pathlib import Path + +import mlx.core as mx + +DTYPE_MAP = { + "float16": mx.float16, + "float32": mx.float32, + "bfloat16": mx.bfloat16, +} + + +def parse_int_list(value): + return tuple(int(v.strip()) for v in value.split(",") if v.strip()) + + +def parse_dtype_list(value): + values = [v.strip() for v in value.split(",") if v.strip()] + invalid = [v for v in values if v not in DTYPE_MAP] + if invalid: + allowed = ", ".join(sorted(DTYPE_MAP)) + raise ValueError(f"Unsupported dtype(s): {invalid}. Allowed: {allowed}") + return tuple(values) + + +def benchmark_runtime(fn, warmup, iters): + for _ in range(warmup): + mx.eval(fn()) + + tic = time.perf_counter() + for _ in range(iters): + mx.eval(fn()) + toc = time.perf_counter() + return (toc - tic) * 1000.0 / iters + + +def make_tanh_rnn_workload(x, w_in, w_rec, bias): + def run(): + batch_size = x.shape[0] + hidden_size = w_rec.shape[0] + hidden = mx.zeros((batch_size, hidden_size), dtype=x.dtype) + outputs = [] + for t in range(x.shape[1]): + hidden = mx.tanh(x[:, t, :] @ w_in + hidden @ w_rec + bias) + outputs.append(hidden) + return mx.stack(outputs, axis=1) + + return run + + +def make_lif_workload(x, w_in, w_rec, leak, threshold): + leak_value = mx.array(leak, dtype=x.dtype) + threshold_value = mx.array(threshold, dtype=x.dtype) + one = mx.array(1.0, dtype=x.dtype) + + def run(): + batch_size = x.shape[0] + hidden_size = w_rec.shape[0] + membrane = mx.zeros((batch_size, hidden_size), dtype=x.dtype) + spikes = mx.zeros((batch_size, hidden_size), dtype=x.dtype) + outputs = [] + for t in range(x.shape[1]): + current = x[:, t, :] @ w_in + spikes @ w_rec + membrane = leak_value * membrane + current + spikes = (membrane >= threshold_value).astype(x.dtype) + membrane = membrane * (one - spikes) + outputs.append(spikes) + return mx.stack(outputs, axis=1) + + return run + + +def run_case(batch_size, sequence_length, input_size, hidden_size, dtype_name, args): + dtype = DTYPE_MAP[dtype_name] + scale = 1.0 / math.sqrt(hidden_size) + + x = mx.random.normal((batch_size, sequence_length, input_size)).astype(dtype) + w_in = mx.random.uniform( + low=-scale, high=scale, shape=(input_size, hidden_size) + ).astype(dtype) + w_rec = mx.random.uniform( + low=-scale, high=scale, shape=(hidden_size, hidden_size) + ).astype(dtype) + bias = mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,)).astype(dtype) + mx.eval(x, w_in, w_rec, bias) + + workloads = [ + ("rnn_tanh_unrolled", make_tanh_rnn_workload(x, w_in, w_rec, bias)), + ( + "lif_hard_reset_unrolled", + make_lif_workload(x, w_in, w_rec, args.leak, args.threshold), + ), + ] + + rows = [] + for workload_name, workload_fn in workloads: + runtime_ms = benchmark_runtime(workload_fn, args.warmup, args.iters) + effective_steps_per_second = (batch_size * sequence_length) / ( + runtime_ms / 1000.0 + ) + rows.append( + { + "workload": workload_name, + "batch_size": batch_size, + "sequence_length": sequence_length, + "input_size": input_size, + "hidden_size": hidden_size, + "dtype": dtype_name, + "runtime_ms": runtime_ms, + "effective_steps_per_second": effective_steps_per_second, + } + ) + + return rows + + +def write_json(path, rows): + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(rows, f, indent=2) + + +def write_csv(path, rows): + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + field_names = [ + "workload", + "batch_size", + "sequence_length", + "input_size", + "hidden_size", + "dtype", + "runtime_ms", + "effective_steps_per_second", + ] + with output_path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=field_names) + writer.writeheader() + writer.writerows(rows) + + +def print_summary(rows): + print( + "workload,dtype,batch_size,sequence_length,input_size,hidden_size,runtime_ms,steps_per_second" + ) + for row in rows: + print( + f"{row['workload']},{row['dtype']},{row['batch_size']},{row['sequence_length']},{row['input_size']},{row['hidden_size']},{row['runtime_ms']:.6f},{row['effective_steps_per_second']:.2f}" + ) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark unrolled recurrent and LIF-style workloads in MLX." + ) + parser.add_argument("--batch-sizes", default=(1, 8, 32), type=parse_int_list) + parser.add_argument( + "--sequence-lengths", default=(32, 128, 256), type=parse_int_list + ) + parser.add_argument("--hidden-sizes", default=(64, 256, 512), type=parse_int_list) + parser.add_argument("--input-size", default=40, type=int) + parser.add_argument( + "--dtypes", + default=("float16", "float32"), + type=parse_dtype_list, + ) + parser.add_argument("--warmup", default=10, type=int) + parser.add_argument("--iters", default=100, type=int) + parser.add_argument("--leak", default=0.95, type=float) + parser.add_argument("--threshold", default=1.0, type=float) + parser.add_argument("--json-output", default=None) + parser.add_argument("--csv-output", default=None) + return parser.parse_args() + + +def main(): + args = parse_args() + rows = [] + + for batch_size, sequence_length, hidden_size, dtype_name in product( + args.batch_sizes, + args.sequence_lengths, + args.hidden_sizes, + args.dtypes, + ): + rows.extend( + run_case( + batch_size=batch_size, + sequence_length=sequence_length, + input_size=args.input_size, + hidden_size=hidden_size, + dtype_name=dtype_name, + args=args, + ) + ) + + print_summary(rows) + + if args.json_output: + write_json(args.json_output, rows) + if args.csv_output: + write_csv(args.csv_output, rows) + + +if __name__ == "__main__": + main() diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 00f7f95456..286a95e734 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -64,6 +64,30 @@ Quick Start with Neural Networks # gradient with respect to `mlp.trainable_parameters()` loss_and_grad = nn.value_and_grad(mlp, l2_loss) +Recurrent Patterns for Temporal Inputs +-------------------------------------- + +For temporal and event-like inputs, recurrent layers can be run in fixed +windows while carrying hidden state between windows. This pattern is useful for +streaming and SNN-style unrolled training loops where each chunk should keep +context from the previous chunk. + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + + rnn = nn.RNN(input_size=40, hidden_size=128) + hidden = None + + # stream_chunks yields arrays with shape [batch, time, features] + for chunk in stream_chunks: + y = rnn(chunk, hidden=hidden) + hidden = y[:, -1, :] + +When using :class:`GRU` or :class:`LSTM`, the same chunked approach applies. +For :class:`LSTM`, carry both hidden and cell states between windows. + .. _module_class: The Module Class diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ebdfe580ec..b993f4642c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1946,6 +1946,94 @@ def test_lstm(self): self.assertEqual(h_out.shape, (44, 12)) self.assertEqual(c_out.shape, (44, 12)) + def test_recurrent_dtype_propagation(self): + for dtype in (mx.float16, mx.bfloat16): + inp_batched = mx.random.normal((2, 32, 5)).astype(dtype) + inp_unbatched = mx.random.normal((32, 5)).astype(dtype) + + rnn = nn.RNN(5, 12) + gru = nn.GRU(5, 12) + lstm = nn.LSTM(5, 12) + rnn.set_dtype(dtype) + gru.set_dtype(dtype) + lstm.set_dtype(dtype) + + rnn_out = rnn(inp_batched) + self.assertEqual(rnn_out.dtype, dtype) + self.assertEqual(rnn(inp_unbatched).dtype, dtype) + + gru_out = gru(inp_batched) + self.assertEqual(gru_out.dtype, dtype) + self.assertEqual(gru(inp_unbatched).dtype, dtype) + + lstm_hidden, lstm_cell = lstm(inp_batched) + self.assertEqual(lstm_hidden.dtype, dtype) + self.assertEqual(lstm_cell.dtype, dtype) + + lstm_hidden, lstm_cell = lstm(inp_unbatched) + self.assertEqual(lstm_hidden.dtype, dtype) + self.assertEqual(lstm_cell.dtype, dtype) + + def test_recurrent_gradient_parity(self): + def assert_grads_close(layer, loss_fn, *loss_args): + _, module_grads = nn.value_and_grad(layer, loss_fn)(layer, *loss_args) + + def pure_loss(params, *args): + layer.update(params) + return loss_fn(layer, *args) + + pure_grads = mx.grad(pure_loss)(layer.trainable_parameters(), *loss_args) + + module_flat = tree_flatten(module_grads, destination={}) + pure_flat = tree_flatten(pure_grads, destination={}) + self.assertEqual(module_flat.keys(), pure_flat.keys()) + for key in module_flat: + self.assertTrue( + mx.allclose(module_flat[key], pure_flat[key], atol=1e-5, rtol=1e-5), + f"Gradient mismatch for {key}", + ) + + x = mx.random.normal((2, 24, 5)) + h0 = mx.random.normal((2, 12)) + c0 = mx.random.normal((2, 12)) + + def rnn_loss(model, x, hidden): + return model(x, hidden=hidden).sum() + + def gru_loss(model, x, hidden): + return model(x, hidden=hidden).sum() + + def lstm_loss(model, x, hidden, cell): + h_out, c_out = model(x, hidden=hidden, cell=cell) + return h_out.sum() + c_out.sum() + + assert_grads_close(nn.RNN(5, 12), rnn_loss, x, h0) + assert_grads_close(nn.GRU(5, 12), gru_loss, x, h0) + assert_grads_close(nn.LSTM(5, 12), lstm_loss, x, h0, c0) + + def test_recurrent_long_sequence_stability(self): + seq_len = 256 + inp = mx.random.normal((2, seq_len, 5)).astype(mx.float32) + h0 = mx.random.normal((2, 12)) + c0 = mx.random.normal((2, 12)) + + def assert_finite(name, arr): + arr_np = np.array(arr) + self.assertTrue(np.isfinite(arr_np).all(), f"{name} has non-finite values") + + rnn = nn.RNN(5, 12) + gru = nn.GRU(5, 12) + lstm = nn.LSTM(5, 12) + + rnn_out = rnn(inp, hidden=h0) + gru_out = gru(inp, hidden=h0) + lstm_hidden, lstm_cell = lstm(inp, hidden=h0, cell=c0) + + assert_finite("rnn_out", rnn_out) + assert_finite("gru_out", gru_out) + assert_finite("lstm_hidden", lstm_hidden) + assert_finite("lstm_cell", lstm_cell) + def test_quantized_embedding(self): emb = nn.Embedding(32, 256) qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)