Skip to content

08 — Training on MNIST with darnax using DynamicalTrainer (JAX-first, CPU)

We employ the same network/topology as the previous tutorial—but we employ an object called DynamicalTrainer that encapsulates the training and validation logic in a pytorch-lightning style. Many trainers are available.

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import jax
import equinox as eqx
import jax.numpy as jnp
import optax
from datasets import load_dataset

from darnax.layer_maps.sparse import LayerMap
from darnax.modules.fully_connected import FrozenFullyConnected, FullyConnected
from darnax.modules.input_output import OutputLayer
from darnax.modules.recurrent import RecurrentDiscrete
from darnax.orchestrators.sequential import SequentialOrchestrator
from darnax.states.sequential import SequentialState
from darnax.trainers.dynamical import DynamicalTrainer

if TYPE_CHECKING:
    from collections.abc import Iterator

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Accuracy helper (±1 OVA labels)

def batch_accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> float:
    """Compute the accuracy for a given batch."""
    y_true_idx = jnp.argmax(y_true, axis=-1)
    y_pred_idx = jnp.argmax(y_pred, axis=-1)
    return jnp.mean((y_true_idx == y_pred_idx).astype(jnp.float32))

Minimal MNIST data (projection + optional sign)

class MNISTData:
    """Simple implementation of MNIST dataset."""

    NUM_CLASSES = 10
    FLAT_DIM = 28 * 28
    TOTAL_SIZE_PER_CLASS = 5900
    TEST_SIZE_PER_CLASS = 1000

    def __init__(
        self,
        key: jax.Array,
        batch_size: int = 64,
        linear_projection: int | None = 100,
        apply_sign_transform: bool = True,
        num_images_per_class: int = TOTAL_SIZE_PER_CLASS,
    ):
        """Initialize MNIST."""
        if batch_size <= 1:
            raise ValueError("batch_size must be > 1")
        if not (0 < num_images_per_class <= self.TOTAL_SIZE_PER_CLASS):
            raise ValueError("num_images_per_class out of range")
        self.linear_projection = linear_projection
        self.apply_sign_transform = bool(apply_sign_transform)
        self.num_data = int(num_images_per_class) * self.NUM_CLASSES
        self.batch_size = int(batch_size)
        self.num_batches = -(-self.num_data // self.batch_size)
        self._build(key)
        self._train_bounds = [
            (i * self.batch_size, min((i + 1) * self.batch_size, self.num_data))
            for i in range(self.num_batches)
        ]
        self.num_eval_data = int(self.x_eval.shape[0])
        self.num_eval_batches = -(-self.num_eval_data // self.batch_size)
        self._eval_bounds = [
            (i * self.batch_size, min((i + 1) * self.batch_size, self.num_eval_data))
            for i in range(self.num_eval_batches)
        ]

    # public API
    def __iter__(self) -> Iterator[tuple[jax.Array, jax.Array]]:
        for lo, hi in self._train_bounds:
            yield self.x[lo:hi], self.y[lo:hi]

    def iter_eval(self) -> Iterator[tuple[jax.Array, jax.Array]]:
        for lo, hi in self._eval_bounds:
            yield self.x_eval[lo:hi], self.y_eval[lo:hi]

    def __len__(self) -> int:
        return self.num_batches

    # internals
    @staticmethod
    def _load(split: str) -> tuple[jax.Array, jax.Array]:
        ds = load_dataset("mnist")
        x = jnp.asarray([jnp.array(im) for im in ds[split]["image"]], dtype=jnp.float32)
        y = jnp.asarray(ds[split]["label"], dtype=jnp.int32)
        x = x.reshape(x.shape[0], -1) / 255.0
        return x, y

    @staticmethod
    def _labels_to_pm1_scaled(y_scalar: jax.Array, C: int) -> jax.Array:
        one_hot = jax.nn.one_hot(y_scalar, C, dtype=jnp.float32)
        return one_hot * (C**0.5 / 2.0) - 0.5

    @staticmethod
    def _rand_proj(key: jax.Array, out_dim: int, in_dim: int) -> jax.Array:
        return jax.random.normal(key, (out_dim, in_dim), dtype=jnp.float32) / jnp.sqrt(in_dim)

    @staticmethod
    def _take_per_class(key: jax.Array, x: jax.Array, y: jax.Array, k: int):
        xs, ys = [], []
        for cls in range(MNISTData.NUM_CLASSES):
            key, sub = jax.random.split(key)
            idx = jnp.where(y == cls)[0]
            perm = jax.random.permutation(sub, idx.shape[0])
            take = idx[perm[:k]]
            xs.append(x[take])
            ys.append(y[take])
        return jnp.concatenate(xs, 0), jnp.concatenate(ys, 0)

    def _maybe_proj_sign(self, w: jax.Array | None, x: jax.Array) -> jax.Array:
        if w is not None:
            x = (x @ w.T).astype(jnp.float32)
        if self.apply_sign_transform:
            sgn = jnp.sign(x)
            x = jnp.where(sgn == 0, jnp.array(-1.0, dtype=sgn.dtype), sgn)
        return x

    def _build(self, key: jax.Array) -> None:
        key_sample, key_proj, key_shuffle = jax.random.split(key, 3)
        x_tr_all, y_tr_all = self._load("train")
        x_ev_all, y_ev_scalar = self._load("test")
        k_train = self.num_data // self.NUM_CLASSES
        x_tr, y_tr_scalar = self._take_per_class(key_sample, x_tr_all, y_tr_all, k_train)
        w = (
            self._rand_proj(key_proj, int(self.linear_projection), x_tr.shape[-1])
            if self.linear_projection
            else None
        )
        x_tr = self._maybe_proj_sign(w, x_tr)
        x_ev = self._maybe_proj_sign(w, x_ev_all)
        y_tr = self._labels_to_pm1_scaled(y_tr_scalar, self.NUM_CLASSES)
        perm_tr = jax.random.permutation(key_shuffle, x_tr.shape[0])
        self.x, self.y = x_tr[perm_tr], y_tr[perm_tr]
        self.x_eval = x_ev
        self.y_eval = self._labels_to_pm1_scaled(y_ev_scalar, self.NUM_CLASSES)
        self.input_dim = int(self.x.shape[1])

Model topology (LayerMap + Orchestrator)

One hidden layer: - 0→1 forward (FullyConnected) - 1→1 recurrent (RecurrentDiscrete) - 2→1 teacher/clamp (FrozenFullyConnected) - 1→2 readout (FullyConnected) - 2→2 aggregator (OutputLayer)

DIM_DATA = 100
NUM_LABELS = MNISTData.NUM_CLASSES
DIM_HIDDEN = 300

THRESHOLD_OUT = 1.0
THRESHOLD_IN = 1.0
THRESHOLD_J = 1.0
STRENGTH_BACK = 0.5
STRENGTH_FORTH = 5.0
J_D = 0.5

state = SequentialState((DIM_DATA, DIM_HIDDEN, NUM_LABELS))

master_key = jax.random.key(44)
keys = jax.random.split(master_key, 5)

layer_map = {
    1: {
        0: FullyConnected(
            in_features=DIM_DATA,
            out_features=DIM_HIDDEN,
            strength=STRENGTH_FORTH,
            threshold=THRESHOLD_IN,
            key=keys[0],
        ),
        1: RecurrentDiscrete(
            features=DIM_HIDDEN,
            j_d=J_D,
            threshold=THRESHOLD_J,
            key=keys[1],
        ),
        2: FrozenFullyConnected(
            in_features=NUM_LABELS,
            out_features=DIM_HIDDEN,
            strength=STRENGTH_BACK,
            threshold=0.0,
            key=keys[2],
        ),
    },
    2: {
        1: FullyConnected(
            in_features=DIM_HIDDEN,
            out_features=NUM_LABELS,
            strength=1.0,
            threshold=THRESHOLD_OUT,
            key=keys[3],
        ),
        2: OutputLayer(),
    },
}
layer_map = LayerMap.from_dict(layer_map)
orchestrator = SequentialOrchestrator(layers=layer_map)
logger.info("Model initialized with SequentialOrchestrator.")

Optimizer & DynamicalTrainer

The trainer owns Optax state and RNG and exposes fit_epoch / eval_epoch.

# Experiment knobs
NUM_IMAGES_PER_CLASS = 5400
APPLY_SIGN_TRANSFORM = True
BATCH_SIZE = 16
EPOCHS = 5
T_WARMUP = 1  # warmup phase
T_FREE = 7  # free phase iterations
T_CLAMPED = 7  # clamped phase iterations
T_EVAL = 14  # validation phase

# RNGs
run_key = jax.random.key(59)
run_key, data_key = jax.random.split(run_key)

# Data
data = MNISTData(
    key=data_key,
    batch_size=BATCH_SIZE,
    linear_projection=DIM_DATA,
    apply_sign_transform=APPLY_SIGN_TRANSFORM,
    num_images_per_class=NUM_IMAGES_PER_CLASS,
)
print(f"Dataset — x.shape={tuple(data.x.shape)}, y.shape={tuple(data.y.shape)}")

# Optimizer
optimizer = optax.adam(2e-3)
optimizer_state = optimizer.init(eqx.filter(orchestrator, eqx.is_inexact_array))

# Trainer
trainer = DynamicalTrainer(
    orchestrator=orchestrator,
    state=state,
    optimizer=optimizer,
    optimizer_state=optimizer_state,
    warmup_n_iter=T_WARMUP,
    train_clamped_n_iter=T_CLAMPED,
    train_free_n_iter=T_FREE,
    eval_n_iter=T_EVAL,
)

Fit (compact epoch loop)

logs = []
run_key, rng = jax.random.split(run_key)
logger.info("Starting training...")
for epoch in range(1, EPOCHS + 1):
    logger.info(f"{epoch=}")
    for x, y in data:
        rng = trainer.train_step(x, y, rng)
    eval_accuracies = []
    for x, y in data.iter_eval():
        rng, metrics = trainer.eval_step(x, y, rng)
        eval_accuracies.append(metrics["accuracy"])
    print(f"Eval accuracy: {sum(eval_accuracies) / len(eval_accuracies)}")

Final evaluation

print("\n=== Final evaluation ===")
for x, y in data.iter_eval():
    rng, metrics = trainer.eval_step(x, y, rng)
    eval_accuracies.append(metrics["accuracy"])
print(f"Eval accuracy: {sum(eval_accuracies) / len(eval_accuracies)}")

Recap

You kept the LayerMap/State/Orchestrator design and local-plasticity updates, but moved outer-loop mechanics into a thin, reusable DynamicalTrainer: a stable API (train_step, eval_step).