04 — LayerMaps¶
A LayerMap is a PyTree wrapper that organizes the full set of layers and adapters in your network. It provides a consistent way to index them, guarantees immutability of the structure, and integrates seamlessly with JAX/Equinox/Optax.
1. Matrix view¶
A LayerMap is conceptually a square matrix indexed by layer IDs:
- Diagonal
(i, i)
: the i-th layer. Each layer is stateful and sends a message to itself (its recurrent/self term). - Off-diagonal
(i, j)
withi ≠ j
: an adapter. It converts the j-th state into a message for the i-th layer.
This gives the following interpretation:
- Lower triangle (
i > j
): forward adapters, messages flowing left → right. - Upper triangle (
i < j
): backward adapters, messages flowing right → left. - Row
j
: everything that contributes into layerj
. - Column
i
: everything that originates from the state of layeri
. - Row 0: all connections going from layers to the input (unused for now).
- Row L: all connections going from layers to the output, meaning every layer whose state is used for prediction.
- Column 0: Forward skip connections from the input to each layer.
- Column L: Backward skip connections from the output to each layer.
This matches the convention in SequentialState
: s[0]
is the input state, and s[L]
is the output state.
This structure, being in the end a dictionary of dictionaries is well suited for sparsity. For example, if layer j
is not connected
to layer i
, there is no adapter in the (i, j)
position. The element is simply not present in the structure, only the relevant modules
are present.
2. API overview¶
lm = LayerMap.from_dict({...})
lm[i]: dict # read-only mapping of neighbors for row i
lm[i, j]: Module # single module at (i, j)
(i, j) in lm: bool # check if edge exists
lm.rows(): tuple[int, ...] # all row indices
lm.cols_of(i): tuple[int, ...] # all column indices in row i
lm.neighbors(i): dict[int, Module] # read-only mapping {j: module} for row i
lm.row_items(): Iterator[int, Module] # iterate (row, neighbors)
lm.edge_items(): Iterator[tuple[int , int], Module] # iterate ((i, j), module)
lm.to_dict(): dict[int, dict[int, Module]] # copy as a dict-of-dicts
The API is intentionally dict-like but read-only: once built, the structure cannot be mutated. You cannot add layers or edges later.
3. Immutability of structure¶
A LayerMap is frozen once created:
- Row/column indices (the “shape” of the map) are part of the static treedef.
- Modules on each edge can change their parameters (via updates), but the adjacency cannot change.
This immutability is not arbitrary. It is a basic requirement in JAX:
- The shape and structure of PyTrees must be static across JIT-compiled functions.
- If you were allowed to add or remove layers after creation, JIT cache keys would break and the compiled computation graph would need to be rebuilt every time.
- By freezing the structure, we ensure stability of compiled functions and allow the optimizer (Optax) to work on the entire network consistently.
Thus, in JAX, data changes are dynamic, but structure is static.
4. Example: building a simple LayerMap¶
import jax
from darnax.modules.recurrent import RecurrentDiscrete
from darnax.modules.adapters import Ferromagnetic
from darnax.layer_maps.sparse import LayerMap
key0, key1 = jax.random.split(jax.random.PRNGKey(0))
F = 8
# Define two layers
layer0 = RecurrentDiscrete(features=F, j_d=0.0, threshold=0.0, key=key0)
layer1 = RecurrentDiscrete(features=F, j_d=0.0, threshold=0.0, key=key1)
# Define adapters
fwd_10 = Ferromagnetic(features=F, strength=0.5) # forward (0 -> 1)
bwd_01 = Ferromagnetic(features=F, strength=0.2) # backward (1 -> 0)
# Build dict-of-dicts
raw = {
0: {0: layer0, 1: bwd_01},
1: {0: fwd_10, 1: layer1},
}
lm = LayerMap.from_dict(raw, require_diagonal=True)
# Access
print(lm[1, 1]) # layer1
print(lm[1, 0]) # forward adapter
print(lm[1].keys()) # neighbors of row 1: {0, 1}
5. A LayerMap as a PyTree¶
Because LayerMap
is registered as a PyTree:
- The keys (rows, columns) are static.
- The modules (layers/adapters) are leaves.
- Arrays inside those modules are visible to JAX and Optax.
This means you can treat the entire network as a single object:
import equinox as eqx, optax
opt = optax.adam(1e-2)
opt_state = opt.init(eqx.filter(lm, eqx.is_inexact_array))
# Later in training
updates, opt_state = opt.update(grads, opt_state, params=lm)
lm = eqx.apply_updates(lm, updates)
All parameters inside all layers/adapters are updated in one go.
6. Summary¶
- LayerMap = a collection of layers (diagonal) and adapters (off-diagonal) with integer keys.
- Matrix view: rows = inputs to a layer, columns = outputs from a layer.
- Input/output rows and columns handle special roles.
- Immutable structure: you cannot add or remove layers once built. This ensures JAX stability (PyTree structure must be static under JIT).
- PyTree integration: treat the whole network as one object, pass it to Equinox/Optax, and every parameter is handled correctly.
This design makes LayerMap a central abstraction: a static graph of modules whose parameters evolve dynamically during training, while its topology remains fixed.
7. An ascii art¶
LayerMap (rows = receivers, columns = senders)
columns (senders: states/messages from j) →
0 1 2 ... L-1 L
┌──────── ──────── ──────── ──────── ──────── ────────┐
r 0 │[L00] [A01] [A02↑] … [A0,L-1] [A0L↑] │
o │layer0 back back back back │
w │(input) adapters adapters adapters adapters│
s ├─────────────────────────────────────────────────────────────┤
( 1 │[A10↓] [L11] [A12↑] … [A1,L-1] [A1L] │
r │fwd→ layer1 ↑back back back │
e │adapters adapters adapters adapters│
c ├─────────────────────────────────────────────────────────────┤
e 2 │[A20↓] [A21↓] [L22] … [A2,L-1↑] [A2L↑] │
i │fwd→ fwd→ layer2 ↑ back ↑ back │
v │adapters adapters adapters adapters│
e ├─────────────────────────────────────────────────────────────┤
r … │… … … … … … │
s ├─────────────────────────────────────────────────────────────┤
L │[AL0↓] [AL1↓] [AL2↓] … [AL,L-1↓] [LL] │
│fwd→ fwd→ fwd→ fwd→ ↑layerL │
│adapters adapters adapters adapters (output)│
└─────────────────────────────────────────────────────────────┘
↑
rows (receivers: layer i to be updated)
Legend:
- Lii : layer on the diagonal (stateful). L00 is the input-layer slot; LL is the output-layer slot.
- Aij↓ : adapter at (i,j) with i > j (lower triangle) — forward message (from j → i).
- Aij↑ : adapter at (i,j) with i < j (upper triangle) — backward message (from j → i).
Row/Column intuition:
- Row i collects everything needed to update layer i: the diagonal Lii (self-message) plus all Aij that transform state j into a message for i.
- Column j lists everything that uses state j as a source: the diagonal Ljj plus all Aij that send j’s state to other layers.
Input/Output:
- First column (·,0): forward skip connections from the input state to every layer.
- Last column (·,L): backward skip connections from the output state to earlier layers.
- Last row (L,·): all contributors that feed directly into the output layer (final prediction).
Structure:
- Diagonal = layers; off-diagonal = adapters.
- The LayerMap’s structure (rows/cols and which edges exist) is immutable after creation.