07 — Training on MNIST with darnax (JAX-first, CPU)¶
This tutorial is for readers who want to understand how darnax is used in practice: how a network is assembled from modules, how the orchestrator runs recurrent dynamics, and how local plasticity integrates with Equinox/Optax to update parameters without backpropagation.
We’ll implement a compact, JAX-friendly training loop:
- Only the outer steps are
jit-compiled:train_step,eval_step. - Recurrent dynamics use
jax.lax.scan(no Python loops inside compiled code). - The dataset iterator is slice-based (avoid
array_split), better for accelerators and CPUs. - We stay on CPU to keep the focus on design; the code is accelerator-ready.
# --- Imports ---------------------------------------------------------------
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
from datasets import load_dataset
from darnax.layer_maps.sparse import LayerMap
from darnax.modules.fully_connected import FrozenFullyConnected, FullyConnected
from darnax.modules.input_output import OutputLayer
from darnax.modules.recurrent import RecurrentDiscrete
from darnax.orchestrators.sequential import SequentialOrchestrator
from darnax.states.sequential import SequentialState
if TYPE_CHECKING:
from collections.abc import Iterator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
1) What darnax gives you¶
darnax decomposes a model into modules connected by a LayerMap. At runtime, an orchestrator applies a message-passing schedule over a state vector.
- Modules (edges/diagonals) consume a sender buffer and emit a message to a receiver. Some modules are trainable; some are frozen; some are “diagonal” (operate on a buffer itself).
- The LayerMap defines the fixed topology: for each receiver row
i, which sendersjcontribute, and with which module. - The SequentialOrchestrator drives the update order (left→right, recurrent self, right→left as needed) and exposes:
step: full dynamics (all messages allowed).step_inference: inference dynamics (typically suppress “backward” messages).backward: compute local parameter deltas from the current state (no backprop).predict: produce output scores in the final buffer.- The State is a fixed-shape tuple of buffers
(input, hidden, output). You clamp inputs (and possibly labels) by writing them into the state, then run dynamics to a fixed point.
2) A tiny metric helper¶
Labels are One-Vs-All (OVA) in ±1. We decode predictions via argmax.
def batch_accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> float:
"""Accuracy with ±1 OVA labels (class = argmax along last dim)."""
y_true_idx = jnp.argmax(y_true, axis=-1)
y_pred_idx = jnp.argmax(y_pred, axis=-1)
return jnp.mean((y_true_idx == y_pred_idx).astype(jnp.float32))
3) Dataset object designed for JAX¶
We keep the data pipeline deliberately simple to highlight the model mechanics:
- Materialize once: training subset and full test split live as device arrays.
- Shared projection: an optional linear projection (same matrix for train/test) reduces
dimensionality and can be followed by a
signtransform (Entangled-MNIST style). - Slice-based batches: iterators yield contiguous chunks; no list materialization or
array_split. - Label scaling (your original choice):
- true class →
+√C / 2 - others →
−0.5
This scaling biases the output field to favor the target class during clamped dynamics.
class MNISTData:
"""MNIST dataset with optional linear projection and sign; slice-based iterators.
Design:
- Single in-memory materialization (train subset, full test).
- Shared projection across splits.
- Deterministic batch slicing (precomputed ranges).
"""
TOTAL_SIZE_PER_CLASS = 5900 # train split
TEST_SIZE_PER_CLASS = 1000 # test split
NUM_CLASSES = 10
FLAT_DIM = 28 * 28
def __init__(
self,
key: jax.Array,
batch_size: int = 64,
linear_projection: int | None = 100,
apply_sign_transform: bool = True,
num_images_per_class: int = TOTAL_SIZE_PER_CLASS,
):
"""Initialize the dataset object."""
# Lightweight validation; fail fast on easy mistakes.
if not (linear_projection is None or isinstance(linear_projection, int)):
raise TypeError("`linear_projection` must be `None` or `int`.")
if batch_size <= 1:
raise ValueError(f"Invalid batch_size={batch_size!r}; must be > 1.")
if not (0 < num_images_per_class <= self.TOTAL_SIZE_PER_CLASS):
raise ValueError(f"`num_images_per_class` must be in [1, {self.TOTAL_SIZE_PER_CLASS}]")
self.linear_projection = linear_projection
self.apply_sign_transform = bool(apply_sign_transform)
self.num_data = int(num_images_per_class) * self.NUM_CLASSES
self.batch_size = int(batch_size)
self.num_batches = -(-self.num_data // self.batch_size) # ceil div
# Build arrays once.
self._create_dataset(key)
# Precompute slicing ranges for train/eval.
self._train_bounds = [
(i * self.batch_size, min((i + 1) * self.batch_size, self.num_data))
for i in range(self.num_batches)
]
self.num_eval_data = int(self.x_eval.shape[0])
self.num_eval_batches = -(-self.num_eval_data // self.batch_size)
self._eval_bounds = [
(i * self.batch_size, min((i + 1) * self.batch_size, self.num_eval_data))
for i in range(self.num_eval_batches)
]
# ------------------------------- Public API ------------------------------- #
def __iter__(self) -> Iterator[tuple[jax.Array, jax.Array]]:
"""Yield `(x, y)` training batches by contiguous slicing."""
for lo, hi in self._train_bounds:
yield self.x[lo:hi], self.y[lo:hi]
def iter_eval(self) -> Iterator[tuple[jax.Array, jax.Array]]:
"""Yield `(x_eval, y_eval)` validation batches (full test split)."""
for lo, hi in self._eval_bounds:
yield self.x_eval[lo:hi], self.y_eval[lo:hi]
def __len__(self) -> int:
"""Return the number of batches."""
return self.num_batches
# ------------------------------ Internals ------------------------------ #
@staticmethod
def _load_mnist_split(split: str) -> tuple[jax.Array, jax.Array]:
"""Load MNIST split and flatten to (N, 784)."""
assert split in ["train", "test"]
ds = load_dataset("mnist")
x = jnp.asarray([jnp.array(im) for im in ds[split]["image"]], dtype=jnp.float32)
y = jnp.asarray(ds[split]["label"], dtype=jnp.int32)
x = x.reshape(x.shape[0], -1) / 255.0
return x, y
@staticmethod
def _labels_to_pm1_scaled(y_scalar: jax.Array, num_classes: int) -> jax.Array:
"""Original scaling: +√C/2 at the true class, −0.5 elsewhere."""
one_hot = jax.nn.one_hot(y_scalar, num_classes, dtype=jnp.float32)
return one_hot * (num_classes**0.5 / 2.0) - 0.5
@staticmethod
def _random_projection_matrix(key: jax.Array, out_dim: int, in_dim: int) -> jax.Array:
"""Gaussian projection with variance 1/in_dim to keep outputs ~unit variance."""
return jax.random.normal(key, (out_dim, in_dim), dtype=jnp.float32) / jnp.sqrt(in_dim)
@staticmethod
def _take_per_class(
key: jax.Array, x: jax.Array, y: jax.Array, k_per_class: int
) -> tuple[jax.Array, jax.Array]:
"""Uniformly sample `k_per_class` examples for each class 0..9."""
xs, ys = [], []
for cls in range(MNISTData.NUM_CLASSES):
key, sub = jax.random.split(key)
idx = jnp.where(y == cls)[0]
if k_per_class > idx.shape[0]:
raise ValueError(
f"Requested {k_per_class} for class {cls}, but only {idx.shape[0]} available."
)
perm = jax.random.permutation(sub, idx.shape[0])
take = idx[perm[:k_per_class]]
xs.append(x[take])
ys.append(y[take])
return jnp.concatenate(xs, axis=0), jnp.concatenate(ys, axis=0)
def _maybe_project_and_sign(self, w: jax.Array | None, x: jax.Array) -> jax.Array:
"""Apply optional linear projection + optional sign nonlinearity (Entangled-MNIST)."""
if w is not None:
x = (x @ w.T).astype(jnp.float32)
if self.apply_sign_transform:
sgn = jnp.sign(x)
x = jnp.where(sgn == 0, jnp.array(-1.0, dtype=sgn.dtype), sgn)
return x
def _create_dataset(self, key: jax.Array) -> None:
"""Materialize train subset and full test split with consistent preprocessing."""
key_sample, key_proj, key_shuf_tr = jax.random.split(key, 3)
# Load raw splits.
x_tr_all, y_tr_all = self._load_mnist_split("train")
x_ev_all, y_ev_scalar = self._load_mnist_split("test")
# Uniform per-class sampling.
k_train = self.num_data // self.NUM_CLASSES
x_tr, y_tr_scalar = self._take_per_class(key_sample, x_tr_all, y_tr_all, k_train)
# Shared projection/sign across splits.
w = (
self._random_projection_matrix(key_proj, int(self.linear_projection), x_tr.shape[-1])
if self.linear_projection is not None
else None
)
x_tr = self._maybe_project_and_sign(w, x_tr)
x_ev = self._maybe_project_and_sign(w, x_ev_all)
# Labels with original scaling; shuffle train only.
y_tr = self._labels_to_pm1_scaled(y_tr_scalar, self.NUM_CLASSES)
perm_tr = jax.random.permutation(key_shuf_tr, x_tr.shape[0])
self.x, self.y = x_tr[perm_tr], y_tr[perm_tr]
self.x_eval = x_ev
self.y_eval = self._labels_to_pm1_scaled(y_ev_scalar, self.NUM_CLASSES)
# Convenience metadata.
self.input_dim = int(self.x.shape[1])
4) Model topology as a LayerMap¶
We build a minimal network with one hidden layer and an output sink:
- Receiver row 1 (hidden) gets messages from:
- 0 (input) via
FullyConnected(forward path) - 1 (itself) via
RecurrentDiscrete(internal recurrency) - 2 (labels) via
FrozenFullyConnected(backward/clamping path) - Receiver row 2 (output) gets:
- 1 (hidden) via
FullyConnected(readout) - 2 (itself) via
OutputLayer(diagonal sink/aggregator, returns zeros)
The SequentialState is (input, hidden, output) with fixed sizes.
The SequentialOrchestrator knows how to:
- aggregate edge messages for each receiver,
- apply diagonal modules,
- and run the chosen schedule (step, step_inference, predict, backward).
DIM_DATA = 100
NUM_LABELS = MNISTData.NUM_CLASSES
DIM_HIDDEN = 300
THRESHOLD_OUT = 1.0
THRESHOLD_IN = 1.0
THRESHOLD_J = 1.0
STRENGTH_BACK = 0.5
STRENGTH_FORTH = 5.0
J_D = 0.5
# Global state with three buffers: input (0), hidden (1), output/labels (2)
state = SequentialState((DIM_DATA, DIM_HIDDEN, NUM_LABELS))
# Independent keys for each module (avoid accidental correlations).
master_key = jax.random.key(seed=44)
keys = jax.random.split(master_key, num=5)
layer_map = {
1: { # Hidden row receives from input, itself, and labels
0: FullyConnected(
in_features=DIM_DATA,
out_features=DIM_HIDDEN,
strength=STRENGTH_FORTH,
threshold=THRESHOLD_IN,
key=keys[0],
),
1: RecurrentDiscrete(
features=DIM_HIDDEN,
j_d=J_D,
threshold=THRESHOLD_J,
key=keys[1],
),
2: FrozenFullyConnected( # clamping/teaching signal, not trainable
in_features=NUM_LABELS,
out_features=DIM_HIDDEN,
strength=STRENGTH_BACK,
threshold=0.0,
key=keys[2],
),
},
2: { # Output row receives from hidden and aggregates
1: FullyConnected(
in_features=DIM_HIDDEN,
out_features=NUM_LABELS,
strength=1.0,
threshold=THRESHOLD_OUT,
key=keys[3],
),
2: OutputLayer(), # diagonal sink: produces zeros; acts as aggregator anchor
},
}
layer_map = LayerMap.from_dict(layer_map)
# Trainable orchestrator built from the fixed topology.
orchestrator = SequentialOrchestrator(layers=layer_map)
logger.info("Model initialized with SequentialOrchestrator.")
5) Optimizer and the “no-backprop” update¶
darnax does not use backpropagation here. Instead:
- Run recurrent dynamics with the current batch clamped (inputs + labels in the state).
- Call
orchestrator.backward(state, rng)to get local deltas for every trainable module. - Apply those deltas using Optax—this gives you the familiar optimizer ergonomics.
Notes for JAX compilation:
- We pass the optimizer object as an argument to the jitted functions.
Under eqx.filter_jit, non-array args are static. Reusing the same instance prevents retracing.
- Only the optimizer state (arrays) flows through the jitted code.
optimizer = optax.adam(2e-3)
opt_state = optimizer.init(eqx.filter(orchestrator, eqx.is_inexact_array))
def _apply_update(
orch: SequentialOrchestrator,
s: SequentialState,
opt_state,
rng: jax.Array,
optimizer,
):
"""Compute local deltas via .backward, then apply Optax updates.
Why separate this helper?
- Clear separation of concerns (dynamics vs parameter updates).
- Easier to unit-test and profile independently.
"""
grads = orch.backward(s, rng=rng) # local deltas, tree-shaped like `orch`
params = eqx.filter(orch, eqx.is_inexact_array) # trainable leaves
grads = eqx.filter(grads, eqx.is_inexact_array) # drop non-arrays from grads
updates, opt_state = optimizer.update(grads, opt_state, params=params)
orch = eqx.apply_updates(orch, updates)
return orch, opt_state
6) Dynamics with lax.scan (not jitted directly)¶
The orchestrator exposes one-step transitions:
- step(s, rng) → (s’, rng’): full dynamics (includes backward/label messages).
- step_inference(s, rng) → (s’, rng’): inference-only dynamics (suppress backward messages).
We wrap those into scans. These helpers are not jitted on their own; they are traced as part of the outer jitted steps. That keeps the code modular and the compiled graph clean.
def _scan_steps(fn, s: SequentialState, rng: jax.Array, steps: int):
"""Scan `steps` times a (s, rng)->(s, rng) transition."""
def body(carry, _):
s, rng = carry
s, rng = fn(s, rng=rng)
return (s, rng), None
(s, rng), _ = jax.lax.scan(body, (s, rng), xs=None, length=steps)
return s, rng
def run_dynamics_training(
orch: SequentialOrchestrator,
s: SequentialState,
rng: jax.Array,
steps: int,
):
"""Clamped phase (full dynamics) followed by a short free relaxation (inference)."""
s, rng = _scan_steps(orch.step, s, rng, steps) # clamped
s, rng = _scan_steps(orch.step_inference, s, rng, steps) # free
return s, rng
def run_dynamics_inference(
orch: SequentialOrchestrator,
s: SequentialState,
rng: jax.Array,
steps: int,
):
"""Inference-only relaxation to a fixed point."""
s, rng = _scan_steps(orch.step_inference, s, rng, 2 * steps)
return s, rng
7) Outer steps (the only jit-compiled functions)¶
We jit only the functions that are called many times and represent the outer boundary of our computation:
-
train_step(per batch): 1) write(x, y)into the state (clamp), 2) run clamped + free dynamics, 3) compute local deltas and apply the Optax update. -
eval_step(per batch): 1) writexonly, 2) run free dynamics, 3)predictand compute accuracy.
JIT boundary discipline:
- Static args (optimizer object, Python ints like t_train) trigger retraces only if they
change. Keep them fixed during a run.
@eqx.filter_jit
def train_step(
orch: SequentialOrchestrator,
s: SequentialState,
x: jnp.ndarray,
y: jnp.ndarray,
rng: jax.Array,
*,
opt_state,
optimizer,
t_train: int = 3,
):
"""Perform a train step in a single batch."""
# 1) Clamp inputs + labels into the global state.
s = s.init(x, y)
# 2) Recurrent dynamics: clamped phase then free relaxation.
s, rng = run_dynamics_training(orch, s, rng, steps=t_train)
# 3) Local deltas + Optax update.
rng, update_key = jax.random.split(rng)
orch, opt_state = _apply_update(orch, s, opt_state, update_key, optimizer)
return orch, rng, opt_state
@eqx.filter_jit
def eval_step(
orch: SequentialOrchestrator,
s: SequentialState,
x: jnp.ndarray,
y: jnp.ndarray,
rng: jax.Array,
*,
t_eval: int = 5,
) -> tuple[float, jax.Array]:
"""Perform a validation step on a single batch."""
# 1) Clamp inputs only (labels aren't used by dynamics here).
s = s.init(x, None)
# 2) Free relaxation to a fixed point.
s, rng = run_dynamics_inference(orch, s, rng, steps=t_eval)
# 3) Predict scores from the settled state and measure accuracy.
s, rng = orch.predict(s, rng)
y_pred = s[-1]
acc = batch_accuracy(y, y_pred)
return acc, rng
8) Training loop (CPU)¶
The Python epoch loop shepherds data and RNG. All heavy lifting happens inside the two jitted steps above. Practical guidance:
- Keep array shapes/dtypes and the pytrees’ structures stable across calls.
- Reuse the same optimizer instance; pass its state through the jitted code.
- If you change
t_train/t_evalbetween calls, expect a retrace (they are static).
# Experiment knobs
NUM_IMAGES_PER_CLASS = 5400
APPLY_SIGN_TRANSFORM = True
BATCH_SIZE = 16
EPOCHS = 5
T_TRAIN = 10 # clamped + free steps per batch
T_EVAL = 10 # inference steps multiplier (2*T_EVAL iterations)
# RNGs
master_key = jax.random.key(59)
master_key, data_key = jax.random.split(master_key)
# Data
data = MNISTData(
key=data_key,
batch_size=BATCH_SIZE,
linear_projection=DIM_DATA,
apply_sign_transform=APPLY_SIGN_TRANSFORM,
num_images_per_class=NUM_IMAGES_PER_CLASS,
)
print(f"Dataset ready — x.shape={tuple(data.x.shape)}, y.shape={tuple(data.y.shape)}")
# Train & evaluate
for epoch in range(1, EPOCHS + 1):
t0 = time.time()
print(f"\n=== Epoch {epoch}/{EPOCHS} ===")
# Training epoch
for x_batch, y_batch in data:
master_key, step_key = jax.random.split(master_key)
orchestrator, master_key, opt_state = train_step(
orchestrator,
state,
x_batch,
y_batch,
rng=step_key,
opt_state=opt_state,
optimizer=optimizer, # static in the JIT sense; same instance every call
t_train=T_TRAIN,
)
# Evaluation epoch (full test split)
accs = []
for x_b, y_b in data.iter_eval():
master_key, step_key = jax.random.split(master_key)
acc, master_key = eval_step(
orchestrator,
state,
x_b.astype(jnp.float32),
y_b.astype(jnp.float32),
rng=step_key,
t_eval=T_EVAL,
)
accs.append(acc)
acc_epoch = float(jnp.mean(jnp.array(accs)))
print(f"Eval Accuracy = {acc_epoch:.3f} | epoch time: {time.time() - t0:.2f}s")
9) One-line final report¶
This is just to have a single scalar you can grep from logs or compare across runs.
final_accs = []
for x_b, y_b in data.iter_eval():
master_key, step_key = jax.random.split(master_key)
acc, master_key = eval_step(
orchestrator,
state,
x_b.astype(jnp.float32),
y_b.astype(jnp.float32),
rng=step_key,
t_eval=T_EVAL,
)
final_accs.append(acc)
print("\n=== Final evaluation summary ===")
print(f"Accuracy = {float(jnp.mean(jnp.array(final_accs))):.3f}")
10) Recap & next steps¶
You just trained a recurrent, locally-plastic network on MNIST using darnax:
- You declared topology with a
LayerMap, not a layer stack. - A state of fixed buffers
(input, hidden, output)was clamped and then relaxed to a fixed point by the orchestrator. - You updated parameters using local deltas (
orchestrator.backward) funneled through Optax. - You JIT-compiled the outer loop only, using
lax.scanfor inner dynamics.
If you’re serious about scaling this:
- Parallel orchestrators: swap
SequentialOrchestratorfor a parallel flavor when your graphs grow (careful with data dependencies). - Topology as data: generate
LayerMapprogrammatically (e.g., blocks, conv-like bands). - Per-block scalings: match initialization and LR magnitudes to each path’s fan-in/out.
- Profiling: dump HLO for
train_step/eval_step, sanity-check fusion and shape stability.
Don’t just accept the defaults—pressure-test the schedule and the rules. If a path doesn’t pull its weight (e.g., backward clamp too weak/strong), instrument it and fix it.