Spin-Model Transformers

A non-equilibrium statistical mechanics perspective on transformers


  1. Introduction
  2. Mean-field theory of asymmetric Ising models with binary spins
    1. Setting the scene: the kinetic Ising model
    2. Mean-field theory and Kullback-Leibler divergence
    3. The Plefka expansion: interpolating distributions
    4. Naive mean-field and Thouless-Anderson-Palmer approximations
    5. A simple JAX implementation
  3. Mean-field theory of asymmetric Ising models with vector spins
    1. Vector spins: distributions on hyperspheres
    2. Magnetizations and limit of large vector dimension
    3. First-order naive mean-field approximation
    4. Second-order Thouless-Anderson-Palmer approximation
    5. A simple JAX implementation
  4. A family of transformer-like modules
    1. Connecting the dots
    2. Fast- and slow-moving parameters
    3. A simple JAX implementation
    4. Hunting for non-equilibrium steady states
  5. Conclusion
  6. Appendices
    1. Vector-spin distribution: normalization constant
    2. Vector-spin distribution: expected value (first moment)
    3. Vector-spin distribution: variance (second moment)
    4. Ratio of modified Bessel functions of the first kind
    5. General case: partial derivatives with respect to $\alpha$
  7. References & footnotes

1. Introduction

This blog post is a work in progress exploring connections between transformer neural networks and mean-field dynamics of (asymmetric) vector-spin models. Code in JAX and PyTorch of the content outlined in this post should eventually become available at mcbal/spin-model-transformers. Come join the fun.

In a series of previous blog posts, we have tried to connect the forward pass of a transformer neural-network module to computing mean magnetizations in disordered Ising-like vector-spin models with parameterized couplings and external magnetic fields. In this framework, the forward pass of a transformer module computes statistical observables given a specific realization of quenched couplings and external magnetic fields while the backward pass nudges the parameterized couplings and external magnetic fields. Physically, the transformer module represents an interacting many-body system modulating its behavior by learning to respond to being driven in all kinds of funny ways.

However, both the mean-field message-passing approach of Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms (2021) and the saddle-point free-energy approach of Transformers from Spin Models: Approximate Free Energy Minimization (2021) inherently rely on methods that are only well-defined for spin models with symmetric coupling matrices, whose stochastic dynamics obey detailed balance and converge to a steady-state equilibrium characterized by the Boltzmann distribution. Since the softmax attention matrix in transformer modules is famously asymmetric, we had better come up with a more convincing approach to establish a correspondence.

To capture spin models with asymmetric coupling matrices, we turn to non-equilibrium spin systems, whose dynamics can be pretty wild yet gentle enough to support regimes where relaxation to a non-equilibrium steady state can occur. In the past few decades, dynamical mean-field approaches have been developed for the binary kinetic Ising model, which exhibits non-equilibrium behavior when couplings are asymmetric or when parameters are subject to rapid changes. In this post, we generalize one particular dynamical mean-field approach from binary spins to vector spins with the aim of relating the resulting mean-field update equations for the magnetizations to the forward pass of a transformer module. We find that the spin-system structure is rich enough for the update equations to yield residual connections, attention terms, and feed-forward terms, motivating a physics-inspired class of transformer modules.

2. Mean-field theory of asymmetric Ising models with binary spins

In this section, we review known results on mean-field theory approaches to capturing the stochastic dynamics of binary kinetic Ising models. Readers familiar with this framework can skip to Section 3 where we develop a generalization to vector spins. We primarily follow the discussion outlined in A unifying framework for mean-field theories of asymmetric kinetic Ising systems (Aguilera et al., 2021). We implement the mean-field update equations for the mean magnetizations in JAX and run a few numerical experiments.

2.1. Setting the scene: the kinetic Ising model

Random Ising model configuration with binary spins

We consider a kinetic Ising model describing a system made up of $N$ interacting binary spins $s_{i,t} \in \{-1, 1\}$ that evolve in discrete time steps $t$ according to synchronous dynamics, i.e. all spins get updated at the same time in parallel. Given a spin configuration $\mathbf{s}_{t-1} = \{ s_{1,t-1}, s_{2,t-1}, \ldots, s_{N,t-1} \}$ at time $t-1$, we consider the spins $\mathbf{s}_{t}$ at time $t$ to be conditionally independent random variables captured by a discrete-time Markov chain transition probability

\begin{equation} P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}}}{\sum_{s_{i,t}} \mathrm{e}^{s_{i,t} h_{i,t}}} = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}}}{2 \cosh h_{i,t}}, \label{eq:pcond} \end{equation}

where the effective external field is given by

\begin{equation} h_{i,t} = x_{i,t} + \sum_{j=1}^{N} J_{ij} s_{j,t-1}. \end{equation}

Here, the parameters $\mathbf{x}$ represent the (possibly time-dependent) local external fields at each site while the coupling parameters $\mathbf{J}$ are a specific realization of quenched disorder encoding the interactions between pairs of spins. Using the probability mass function of the previous state $P( \mathbf{s}_{t-1} )$ we can write the distribution of the current state as

\begin{equation} P( \mathbf{s}_{t} ) = \sum_{\mathbf{s}_{t-1}} P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P( \mathbf{s}_{t-1} ), \label{eq:marginal} \end{equation}

which, when applied recursively, traces the evolution of the system starting from some initial distribution $P( \mathbf{s}_{0} )$. Unless we turn off the couplings by setting $\mathbf{J} = \mathbf{0}$, the marginal distribution $P( \mathbf{s}_{t} )$ is not factorized and tends to be quite complicated. Our goal is to compute statistical properties of the system, such as the mean magnetizations

\begin{equation} m_{i,t} = \sum_{\mathbf{s}_{t}} s_{i,t} P( \mathbf{s}_{t} ), \end{equation}

as well as correlations

\begin{equation} C_{ik,t} = \sum_{\mathbf{s}_{t}} s_{i,t} s_{k,t} P( \mathbf{s}_{t} ) - m_{i,t} m_{k,t}, \end{equation}

and delayed correlations

\begin{equation} D_{il,t} = \sum_{\mathbf{s}_{t},\mathbf{s}_{t-1}} s_{i,t} s_{l,t-1} P( \mathbf{s}_{t}, \mathbf{s}_{t-1} ) - m_{i,t} m_{l,t-1}. \end{equation}

Since the above expressions involve summing over a large amount of possible spin configurations, they are not very useful in practice. So we will try to approximate the tricky marginal distribution $P( \mathbf{s}_{t} )$ defined in Eq. \eqref{eq:marginal} using a mean-field theory approach.

2.2. Mean-field theory and Kullback-Leibler divergence

Mean-field theory tries to approximate a complicated object ${\color{red}P}$ by wiggling around the parameters of a simple, analytically tractable parameterized ansatz ${\color{green}Q_{\theta}}$ to get as close as possible to ${\color{red}P}$. At risk of inducing headaches in mathematicians by calling everything a manifold, we can picture what is going on geometrically as trying to approximate a target probability distribution $P( \mathbf{s}_{t} \vert \mathbf{x}, \mathbf{J})$ and its statistical properties $\mathbf{m}_{t}$, $\mathbf{C}_{t}$, and $\mathbf{D}_{t}$ by restricting ourselves to a submanifold of tractable probability distributions. A particularly convenient submanifold is that of factorized models, where each point on the submanifold corresponds to a distribution parameterized by a vector $\boldsymbol{\theta}_{t}$,

\begin{equation} Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} \theta_{i,t}}}{2 \cosh \theta_{i,t}}, \label{eq:q} \end{equation}

so that the mean magnetizations are simply given by

\begin{equation} m_{i,t} = \tanh \theta_{i,t} \label{eq:meanmagstanh} \end{equation}

as there are no couplings between spins. The factorized model $Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}^{*}_{t} )$ that minimizes the Kullback-Leibler (KL) divergence

\begin{equation} D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}}) = \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \log \frac{P( \mathbf{s}_{t})}{Q_{\theta}( \mathbf{s}_{t})} \label{eq:kl} \end{equation}

has mean magnetizations $\mathbf{m}_{t}$ identical to those of the target distribution $P( \mathbf{s}_{t})$ since, for all spins $i=1,2,\ldots,N$, we find that

\begin{align} \frac{\partial D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}}) }{\partial \theta_{i, t}} \Biggr\rvert_{\boldsymbol{\theta}_{t}=\boldsymbol{\theta}^{*}_{t}} &= - \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \frac{\partial \log Q_{\theta}( \mathbf{s}_{t}) }{\partial \theta_{i, t}} \Biggr\rvert_{\boldsymbol{\theta}_{t}=\boldsymbol{\theta}^{*}_{t}} \\ &= - \sum_{\mathbf{s}_{t}} s_{i,t} P( \mathbf{s}_{t}) + \tanh \theta^{*}_{i,t} \\ &= -m^{{\color{red}P}}_{i,t} + m^{{\color{green}Q_{\theta^{*}}}}_{i,t} = 0, \label{eq:klm} \end{align}

where $m^{{\color{red}P}}_{i,t}$ and $m^{{\color{green}Q_{\theta^{*}}}}_{i,t}$ respectively denote the expectation values of $s_{i,t}$ with respect to ${\color{red}P}$ and ${\color{green}Q_{\theta^{*}}}$. Indeed, minimizing $D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}})$ tries to cover the modes of ${\color{red}P}$ by moment matching since the expectation value in Eq. \eqref{eq:kl} is calculated with respect to ${\color{red}P}$.

2.3. The Plefka expansion: interpolating distributions

Great, but is it even possible to find the parameters

\begin{equation} \DeclareMathOperator*{\argmin}{arg\,min} \boldsymbol{\theta}^{*}_{t} = \argmin_{\boldsymbol{\theta}_{t}} \left( - \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \log Q_{\theta}( \mathbf{s}_{t}) \right) \end{equation}

that minimize the KL divergence? Well, that’s going to be hard, unless you already know the target distribution $P( \mathbf{s}_{t})$, or you have a clever way of approximately evaluating the expectation value of $\log {\color{green}Q_{\theta}}$ with respect to ${\color{red}P}$. So let’s introduce some more distributions to get around this issue. To apply the Plefka expansion to our problem, we introduce the conditional distribution

\begin{equation} P_{\alpha}( \mathbf{s}_{t}\vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}(\alpha) }}{2 \cosh h_{i,t}(\alpha)}, \label{eq:pcondalt} \end{equation}

\begin{equation} h_{i,t}(\alpha) = (1-\alpha) \theta_{i,t} + \alpha \left( x_{i,t} + \sum_{j=1}^{N} J_{ij} s_{j,t-1} \right), \label{eq:pcondalth} \end{equation}

parameterized by a scalar $\alpha$ interpolating between $P_{\alpha=0}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} )$ (Eq. \eqref{eq:q}) and $P_{\alpha=1}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} )$ (Eq. \eqref{eq:pcond}). Using Eq. \eqref{eq:pcondalt}, we can construct an approximate marginal distribution $P_{\alpha}( \mathbf{s}_{t})$, leading to $\alpha$-dependent statistical properties $\mathbf{m}_{t}(\alpha)$, $\mathbf{C}_{t}(\alpha)$, and $\mathbf{D}_{t}(\alpha)$ for the approximate system. The Plefka expansion then boils down to writing these properties as Taylor series expansions around the factorized model $\alpha=0$. For the mean magnetizations, the expansion up to $n$-th order looks like

\begin{equation} \mathbf{m}_{t}(\alpha) = \mathbf{m}_{t}(\alpha=0) + \sum_{k=1}^{n} \frac{\alpha^k}{k!} \frac{\partial^{k} \mathbf{m}_{t}(\alpha=0)}{\partial \alpha^{k}} + \mathcal{O}(\alpha^{n+1}), \label{eq:mtaylor} \end{equation}

where all coefficients in the expansion are functions of $\boldsymbol{\theta}_{t}$ via Eq. \eqref{eq:pcondalth}. The mean-field approximation is computed by setting $\alpha=1$ so that the original marginal distribution is recovered and Eq. \eqref{eq:klm} holds, which implies that $\mathbf{m}_{t}(\alpha=1) = \mathbf{m}_{t}(\alpha=0)$ and thus

\begin{equation} \sum_{k=1}^{n} \frac{1}{k!} \frac{\partial^{k} \mathbf{m}_{t}(\alpha=0)}{\partial \alpha^{k}} + \mathcal{O}(\alpha^{n+1}) = 0. \label{eq:mftheta} \end{equation}

Finally, we solve Eq. \eqref{eq:mftheta} for $\boldsymbol{\theta}_{t}$ to find the mean-field values $\boldsymbol{\theta}^{*}_{t}$ of the parameters of the distribution Eq. \eqref{eq:q}. Physically, we are tuning the effective external magnetic fields of the factorized ansatz to $\boldsymbol{\theta}^{*}_{t}$ so that its approximate mean magnetizations get as close as possible to the true ones.

2.4. Naive mean-field and Thouless-Anderson-Palmer approximations

We now consider first and second order approximations of the mean magnetizations Eq. \eqref{eq:mtaylor} to recover respectively the naive mean-field and Thouless-Anderson-Palmer (TAP) approximations for the binary kinetic Ising model. The starting point is a Plefka expansion around factorized models at times $t-1$ and $t$. From Eq. \eqref{eq:marginal} and Eq. \eqref{eq:pcondalt}, we construct a marginal probability distribution

\begin{equation} P^{[t-1:t]}_{\alpha}( \mathbf{s}_{t} ) = \sum_{\mathbf{s}_{t-1},\mathbf{s}_{t-2}} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ), \end{equation}

interpolating between $P^{[t-1:t]}_{\alpha=0}( \mathbf{s}_{t} ) = Q( \mathbf{s}_{t} )$ and $P^{[t-1:t]}_{\alpha=1}( \mathbf{s}_{t} ) = P( \mathbf{s}_{t} )$. The corresponding mean magnetizations are

\begin{align} m_{i,t}(\alpha) &= \sum_{\mathbf{s}_{t},\mathbf{s}_{t-1},\mathbf{s}_{t-2}} s_{i,t} \, P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \\ &= \sum_{\mathbf{s}_{t-1},\mathbf{s}_{t-2}} \tanh h_{i,t}(\alpha) \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \end{align}

Following Eq. \eqref{eq:mftheta}, the first-order approximation should satisfy

\begin{equation} \frac{\partial m_{i,t}(\alpha=0)}{\partial\alpha} = \left( 1-m^{2}_{i,t} \right) \left( -\theta_{i,t} + x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} \right) = 0, \end{equation}

so that $\theta^{*}_{i,t} = x_{i,t} + \sum_{j} J_{ij} m_{j,t-1}$ and we end up with the naive mean-field equations:

\begin{equation} \boxed{m_{i,t} = \tanh \left( x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} \right)} \label{eq:naivem} \end{equation}

Again following Eq. \eqref{eq:mftheta}, the second-order approximation should satisfy

\begin{equation} \frac{\partial m_{i,t}(\alpha=0)}{\partial\alpha} + \frac{1}{2} \frac{\partial^{2} m_{i,t}(\alpha=0)}{\partial^{2}\alpha} = 0, \end{equation}

where the second-order derivative, neglecting terms higher than $\mathcal{O}(\alpha^2)$, is

\begin{equation} \frac{\partial^{2} m_{i,t}(\alpha=0)}{\partial^{2}\alpha} \approx -2 m_{i,t} \left( 1-m^{2}_{i,t} \right) \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right) \end{equation}

so that

\begin{equation} \theta^{*}_{i,t} = x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} - m_{i,t} \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right) \end{equation}

and we end up with the TAP mean-field equations:

\begin{equation} \boxed{m_{i,t} = \tanh \left( x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} - m_{i,t} \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right) \right)} \label{eq:tapm} \end{equation}

The mean-field equations obtained above can also be elegantly derived using a Legendre transformation of the generating functional of the set of trajectories of the model, as outlined in e.g. Dynamical TAP equations for non-equilibrium Ising spin glasses (2011). We can also derive second-order TAP approximations of the correlations

\begin{equation} C_{ik,t} = \begin{cases} 1 - m^{2}_{i,t} & i = k \\ \left( 1-m^{2}_{i,t} \right) \left( 1-m^{2}_{k,t} \right) \sum_{j} J_{ij} J_{kj} \left( 1-m^{2}_{j,t-1} \right) & i \neq k \label{eq:tapc} \end{cases} \end{equation}

and delayed correlations

\begin{equation} D_{il,t} = J_{il} \left( 1-m^{2}_{i,t} \right) \left( 1-m^{2}_{l,t-1} \right) \left( 1 + 2 J_{il} m_{i,t} m_{l,t-1} \right). \label{eq:tapd} \end{equation}

We refer to (Aguilera et al., 2020) for full derivations of the above mean-field results as well as variations based on different approximations of the marginal distribution $P( \mathbf{s}_{t} )$.

In summary, given the mean magnetizations $\mathbf{m}_{t-1}$ of the system at time $t-1$, we can use equations \eqref{eq:tapm} \eqref{eq:tapc} \eqref{eq:tapd} to compute a tuple $(\mathbf{m}_{t},\mathbf{C}_{t},\mathbf{D}_{t})$ of approximate statistical properties of the system at time $t$. The time evolution of the system can thus be captured at the mean-field level by recursively computing $\mathbf{m}_{t}$ starting from an initial state $\mathbf{m}_{0}$, although approximation errors accumulate over the iterations.

2.5. A simple JAX implementation

To get more insight into what is going on, let us turn the mean-field update equations \eqref{eq:naivem} and \eqref{eq:tapm} for the mean magnetizations into code. But before we show a few plots, we need to know a bit more background about the model we are about to simulate. In (Aguilera et al., 2020), the authors derive a solution of the asymmetric version of the kinetic Sherrington-Kirkpatrick mean-field spin-glass model using a generating functional or dynamical partition function approach to capture the distribution of trajectories. They consider the same kinetic Ising model as in Eq. \eqref{eq:pcond} but with an inverse temperature parameter $\beta$ in the exponentials:

\begin{equation} P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta s_{i,t} h_{i,t}}}{2 \cosh \beta h_{i,t}}. \label{eq:pcondwithbeta} \end{equation}

For Gaussian couplings $J_{ij} \sim \mathcal{N}\left( J_{\mu} / N, J^{2}_{\sigma} / N\right)$ and uniformly distributed external magnetic fields $x_{i} \sim \mathcal{U}(-X_{0}, X_{0})$, they show the existence of a ferromagnetic phase transition. In particular for $X_{0}=0.5$, $J_{\mu}=1.0$, and $J_{\sigma}=0.1$, a phase transition happens when tuning $\beta$ to a critical value $\beta_{c} \approx 1.1108$.

Simulating magnetization trajectories

We now turn to a JAX implementation of the mean-field time evolution of the magnetizations according to the model described above. We use lax.scan to implement the update steps and vmap to parallelize trajectories starting from a batch of initial magnetization configurations $\mathbf{m}_{0}$. For the second-order TAP equations, jaxopt’s Anderson acceleration is used to find the fixed point magnetizations $\mathbf{m}_{t}$ given $\mathbf{m}_{t-1}$.

from functools import partial

import jax
import jax.numpy as jnp

from jaxopt import AndersonAcceleration


def update_naive_mf(m0, _, x, J):
    m1 = jnp.tanh(x + jnp.einsum("... i j, ... j -> ... i", J, m0))
    return m1, m0


def update_tap_mf(m0, _, x, J):
    def tap(m):
        return jnp.tanh(
            x
            + jnp.einsum("... i j, ... j -> ... i", J, m0)
            - m * jnp.einsum("... i j, ... j -> ... i", J**2, (1.0 - m0**2))
        )

    m1 = AndersonAcceleration(fixed_point_fun=tap, tol=1e-3, maxiter=10).run(m0).params
    return m1, m0


def time_evolution(m0, steps, update_fun):
    final_carry, stacked_outputs = jax.lax.scan(update_fun, init=m0, xs=steps)
    return final_carry, stacked_outputs


def init_params(key, N, beta, X0, J_mu, J_sigma):
    x_key, J_key = jax.random.split(key)
    x = jax.random.uniform(x_key, shape=(N,), minval=-beta * X0, maxval=beta * X0)
    J = beta * J_mu * N**-1 + beta * J_sigma * N**-0.5 * jax.random.normal(
        J_key, shape=(N, N)
    )
    return x, J


def simulate(
    key, m0, steps, beta, X0=0.5, J_mu=1.0, J_sigma=0.1, update_fun=update_tap_mf
):
    x, J = init_params(key, m0.shape[-1], beta, X0, J_mu, J_sigma)
    wrapped_time_evolution = partial(
        time_evolution,
        steps=steps,
        update_fun=partial(update_fun, x=x, J=J),
    )
    final_carry, stacked_outputs = jax.vmap(wrapped_time_evolution)(m0)
    return final_carry, stacked_outputs

Naive mean-field vs. Thouless-Anderson-Palmer (TAP)

We fix the seed for the randomly initialized model parameters $\mathbf{x}$ and $\mathbf{J}$ and simulate $N=512$ spins at the critical temperature $\beta_{c}$ for $t=128$ time steps starting from an all-ones intial state. We first consider the naive mean-field update step.

The left axis shows the individual magnetization trajectories for each spin plotted horizontally while the red line associated to the right axis describes the average of the magnetizations across all spins for each time step. We observe convergence to what looks like a steady state.

Comparing the naive first-order mean-field update equations to the second-order Thouless-Anderson-Palmer (TAP) ones, we observe lower values for the mean magnetization across all spins, which (Aguilera et al., 2020) showed to be closer to ground truth values (not shown) obtained via sampling and averaging spin configurations.

Sampling trajectories

Let’s now consider 100 randomly-initialized initial states and simulate their associated trajectories in three different model regimes: far below the critical point ($\beta=\beta_c / 2 $), at the critical point ($\beta=\beta_c$), and far above the critical point ($\beta=2 \beta_c$).

We observe that the trajectories of randomly-initialized initial states converge to identical final states in each regime. These final states map to a simple ferromagnetic Ising phase diagram, where a high-temperature disordered phase $\langle m_{i,t} \rangle \to 0$ (left) is separated from a low-temperature locally-ordered phase $\langle m_{i,t} \rangle \to \pm 1$ (right) by a critical point (center). The behavior around $\beta=\beta_{c}$ is pretty interesting: the non-trivial non-equilibrium steady state looks like an attractor implicitly encoded in the dynamics of the model. If we were to parametrize the couplings, we could train the model to act as an associative memory.

Sampling model parameters

Let us now go back to considering just a single trajectory since we just saw that trajectories seem to converge to the same final steady-state magnetizations for fixed model parameters. To get a feel for the variation of these values across different realizations of model parameters, we plot the absolute value1 $| \langle m_{i} \rangle |$ of the final steady-state magnetizations across 100 samples of model parameters and a range of inverse temperatures. We are using JAX, so we can easily sample model parameters by vmap‘ing the random key fed into the simulate function followed by another vmap to sweep across $\beta$.

Every curve in the above plot describes the final steady-state value of the “order parameter” $| \langle m_{i} \rangle |$ for a fixed set of model parameters sweeping across $\beta$. We observe a greater spread of values near the critical point and hence an improved capacity to map input external fields to a range of output magnetizations. If we were to let the number of spins $N \to \infty$ and average over a large number of model parameter samples, the finite-size results above would probably transform into a sharp curve with zero magnetization below the critical point and a sudden non-zero magnetization emerging at the critical point.

3. Mean-field theory of asymmetric Ising models with vector spins

We now transpose the binary-spin results of the previous section to a setting where local spin degrees of freedom are $D$-dimensional vector spins restricted to wiggle around on $(D-1)$-dimensional spheres. We start by generalizing the conditional distribution Eq. \eqref{eq:pcondalt} to vector spins. Next, we motivate the limit of large vector dimension and derive first-order and second-order mean-field update equations for the mean magnetizations. We finish this section with a JAX implementation and some numerical experiments.

3.1. Vector spins: distributions on hyperspheres

Random Ising model configuration with vector spins

A vector-spin equivalent of Eq. \eqref{eq:pcondalt} looks like

\begin{equation} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \mathbf{h}_{i,t}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s}_{i,t} \; \mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \mathbf{h}_{i,t}(\alpha)} }, \label{eq:pcondaltvector} \end{equation}

where we immediately included an inverse temperature scaling $\beta$ like in Eq. \eqref{eq:pcondwithbeta}. A vector-spin equivalent of Eq. \eqref{eq:pcondalth} is

\begin{equation} \mathbf{h}_{i,t}(\alpha) = (1-\alpha) \boldsymbol{\theta}_{i,t} + \alpha \left( \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{s}_{j,t-1} \right) \equiv \boldsymbol{\theta}_{i,t} + \alpha \Delta \mathbf{h}_{i,t}, \label{eq:pcondalthvector} \end{equation}

where $S_{D-1}(R) = \{ x \in \mathbb{R}^{D} : \lVert x \rVert = R \}$ denotes the $(D-1)$-dimensional sphere with radius $R$ embedded in $D$ dimensions. Let us focus on the distribution for a single site and drop all subscripts and dependencies for clarity:

\begin{equation} p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }. \label{eq:pcondsinglesitevector} \end{equation}

The normalization constant in the denominator can be shown to be (see Appendix A.1)

\begin{equation} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } \equiv Z(\beta, R, \lVert \mathbf{h}\rVert) \label{eq:partfun} \end{equation}

where $I_{\nu}(z)$ denotes the modified Bessel function of the first kind and $\lVert \mathbf{h}\rVert = \sqrt{\mathbf{h} \cdot \mathbf{h}}$. Physically, we can think of this single-site distribution as measuring dot-product alignment to an effective external magnetic field $\mathbf{h}$ at inverse temperature $\beta$.

If we consider spins living on the unit sphere $R=1$ as well as unit vectors $\mathbf{h}$, the distribution boils down to a von Mises–Fisher distribution with mean direction $\boldsymbol{\mu} \equiv \mathbf{h}$ and concentration parameter $\kappa \equiv \beta$. This distribution is unimodal for $\kappa > 0$ and can be derived from restricting an isotropic multivariate Gaussian to the unit hypersphere. The greater the value of $\kappa$ (the inverse temperature $\beta$), the higher the concentration of the distribution around the mean direction $\boldsymbol{\mu}$ (the more the spin tends to align to the effective external field $\mathbf{h}$). Instead of a fixed parameter $\boldsymbol{\mu}$, we have a very funky parameter Eq. \eqref{eq:pcondalthvector} that depends on all other spins to spice things up.

3.2. Magnetizations and limit of large vector dimension

Before we derive mean-field approximations for the mean magnetizations of our vector-spin system, let us first consider the decoupled $\alpha \to 0$ limit of the distribution Eq. \eqref{eq:pcondaltvector},

\begin{equation} Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \boldsymbol{\theta}_{i,t}}}{Z_{i,t}\left(\beta, R, \lVert \boldsymbol{\theta}_{i,t} \rVert\right)}, \end{equation}

and find an expression for its mean magnetizations. For every decoupled site, the mean magnetization can be shown to be (see Appendix A.2)

\begin{equation} \mathbf{m}_{i,t} = \frac{I_{D/2}(\beta R \lVert \boldsymbol{\theta}_{i,t} \rVert)}{I_{D/2 - 1}(\beta R \lVert \boldsymbol{\theta}_{i,t} \rVert)} \frac{R \boldsymbol{\theta}_{i,t}}{\lVert \boldsymbol{\theta}_{i,t} \rVert} \equiv \boldsymbol{\varphi} \left(\boldsymbol{\theta}_{i,t}\right), \label{eq:meanmagsbessels} \end{equation}

which plays the role of $m_{i,t} = \tanh \theta_{i,t}$ in the binary setting, see Eq. \eqref{eq:meanmagstanh}. Looking ahead at turning the above equation into code, we note that there exist efficient algorithms to compute the ratio of modified Bessel functions of the first kind. We implement a fast JAX version in Appendix A.4 and show numerically how the ratio flattens out quickly for large values of the order $\nu = D/2 -1$, motivating some kind of large-order expansion.

As the vector dimension in practical transformer modules tends be somewhere between $\mathcal{O}(10^2)$ and $\mathcal{O}(10^5)$, it makes sense to focus on the limit of large vector dimension. A relevant uniform asymptotic expansion of the ratio of modified Bessel functions of the first kind can be found in (Kiefer & Weiss, 1972):

\begin{align} \frac{I_{\nu+\alpha}(\nu x)}{I_{\nu}(\nu x)} = \left( \frac{x}{1+\sqrt{1+x^2}} \right)^{\alpha} \left( 1 - \frac{1+\alpha\sqrt{1+x^2}}{2(1+x^2)} \frac{\alpha}{\nu} + \mathcal{O}\left( \frac{1}{\nu^2} \right) \right) \end{align}

Indeed, if we choose to tie the radius $R$ of our little spins to their vector dimension $D$ via

\begin{align} \nu=D/2-1=R^2, \end{align}

we can apply the leading order of the asymptotic expansion to \eqref{eq:meanmagsbessels} to find

\begin{equation} \mathbf{m}^{D \to \infty}_{i,t} \approx \frac{\beta}{1+\sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 }} \boldsymbol{\theta}_{i,t} \equiv \boldsymbol{\varphi}^{D \to \infty} \left(\boldsymbol{\theta}_{i,t}\right). \label{eq:largedevmag} \end{equation}

From here on, we will default to using the large-$D$ approximation because keeping track of (derivatives of) Bessel functions gets boring real quick. We refer to Appendix A.5 for some truly outrageous expressions pertaining to the general case valid for all $D>1$.

3.3. First-order naive mean-field approximation

Closely mimicking the binary case, we start from the following approximated marginal probability distribution

\begin{equation} P^{[t-1:t]}_{\alpha}( \mathbf{s}_{t} ) = \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ), \end{equation}

interpolating between $P^{[t-1:t]}_{\alpha=0}( \mathbf{s}_{t} ) = Q( \mathbf{s}_{t} )$ and $P^{[t-1:t]}_{\alpha=1}( \mathbf{s}_{t} ) = P( \mathbf{s}_{t} )$. Our lazy integral notation $\int \mathrm{d} \mathbf{s}_{t}$ should be understood as $\int \prod_{i=1}^{N} \mathrm{d}^{D} \mathbf{s}_{i, t}$, i.e. integrating over all the little spins at a fixed time $t$. The estimated mean magnetizations are

\begin{align} \mathbf{m}_{i,t}(\alpha) &= \int \mathrm{d} \mathbf{s}_{t} \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; \mathbf{s}_{i,t} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \nonumber\\ &= \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ). \end{align}

The first-order derivative with respect to $\alpha$ is given by

\begin{align} \frac{\partial \mathbf{m}_{i,t}(\alpha)}{\partial\alpha} = \int &\mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \Biggl( \frac{\partial\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha} \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) \nonumber\\ &+ \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, \frac{\partial P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha} \Biggr) P( \mathbf{s}_{t-2} ), \label{eq:mitfirstorderalpha} \end{align}

where

\begin{align} \frac{\partial \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} = - & \frac{\beta}{R^2} \frac{ \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) }{ \sqrt{1+\beta^2 \lVert \mathbf{h}_{i,t}(\alpha) \rVert^2 / R^2 } } \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\ &+ \frac{\beta}{1+\sqrt{1+\beta^2 \lVert \mathbf{h}_{i,t}(\alpha) \rVert^2 / R^2 }} \Delta \mathbf{h}_{i,t} \label{eq:firstorderphialpha} \end{align}

Evaluating \eqref{eq:mitfirstorderalpha} at $\alpha=0$, the second term drops out because the first-order derivative of $P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )$ becomes independent of $\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)$ and $\int \mathrm{d} \mathbf{s}_{t-1} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )=1$. We thus end up with

\begin{align} \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} = - &\frac{\beta}{R^2}\frac{\left( \mathbf{m}_{i,t} \cdot \boldsymbol{v}_{i,t} \right)}{\sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 }} \mathbf{m}_{i,t} \nonumber \\ &+ \frac{\beta}{1+\sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 }}\boldsymbol{v}_{i,t} \label{eq:mfirstorderalphazero} \end{align}

where

\begin{align} \boldsymbol{v}_{i,t} = -\boldsymbol{\theta}_{i,t} + \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{m}_{j,t-1} \label{eq:vmf} \end{align}

captures the result of integrating $\Delta \mathbf{h}_{i,t}$ over the spins $\mathbf{s}_{t-1}$. Following Eq. \eqref{eq:mftheta}, the first-order approximation should satisfy

\begin{equation} \left[ \alpha \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} \right]_{\alpha=1} = \mathbf{0} + \left[ \mathcal{O}\left(\alpha^2\right)\right]_{\alpha=1},\label{eq:firstorderapproxreqs} \end{equation}

so that we are encouraged to set $\boldsymbol{v}_{i,t}=0$ and hence $\boldsymbol{\theta}^{*}_{i,t} = \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1}$, leading to the naive mean-field equations:

\begin{equation} \boxed{ \mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \rVert^2 / R^2 }} } \label{eq:naivemvector} \end{equation}

Looking ahead at the transformer-module correspondence in Section 4, we squint our eyes and recognize a scaled sum of a residual connection and an attention term. No feed-forward terms though.

Before moving on to the second-order approximation, let us end this section with an interesting observation about Eq. \eqref{eq:mfirstorderalphazero}. In Appendix A.3, we show that the variance matrix of a single spin in the large-$D$ limit equals a rank-1 perturbation of a diagonal matrix

\begin{align} \mathrm{Var} [ \mathbf{s}_{i,t} ] &= \frac{\mathbb{1}}{1+\gamma(\mathbf{h}_{i,t})} - \frac{ \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \otimes \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) }{ R^2 \gamma(\mathbf{h}_{i,t}) }, \label{eq:spinvariance} \end{align}

where

\begin{align} \gamma(\mathbf{h}_{i,t}) = \sqrt{1+\beta^{2}\lVert\mathbf{h}_{i,t}\rVert^{2}/R^2}, \end{align}

Taking the $\alpha \to 0$ limit of the above expressions, we can reinterpret Eq. \eqref{eq:mfirstorderalphazero} as the matrix-vector multiplication of the decoupled spin’s variance matrix with $\boldsymbol{v}_{i,t}$,

\begin{align} \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} = \beta \mathrm{Var} [ \mathbf{s}_{i,t} ] \boldsymbol{v}_{i,t}. \end{align}

Let us now try to find out whether going to the second-order approximation generates additional feed-forward like terms in the update equations for the magnetizations.

3.4. Second-order Thouless-Anderson-Palmer approximation

Again following Eq. \eqref{eq:mftheta}, the second-order approximation should satisfy

\begin{equation} \left[ \alpha \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} \right]_{\alpha=1} + \left[ \frac{\alpha^2}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha}\right]_{\alpha=1} = \mathbf{0} + \left[ \mathcal{O}\left(\alpha^3\right)\right]_{\alpha=1}, \label{eq:secondorderconstraint} \end{equation}

where the second-order derivative is given by

\begin{align} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha)}{\partial^{2}\alpha} = \int &\mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \Biggl( \frac{\partial^{2}\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial^{2}\alpha} \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) \nonumber\\ &+ 2\frac{\partial\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha} \, \frac{\partial P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha} \nonumber \\ &+ \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, \frac{\partial^{2} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial^{2}\alpha} \Biggr) P( \mathbf{s}_{t-2} ). \label{eq:mhasecordder} \end{align}

Evaluated at $\alpha=0$, the third term in the expression above will drop out because the derivative becomes independent of $\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)$ and $\int \mathrm{d} \mathbf{s}_{t-1} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )=1$. Letting

\begin{align} \gamma (\alpha) \equiv \gamma(\mathbf{h}_{i,t}(\alpha)) = \sqrt{1+\beta^2 \lVert \mathbf{h}_{i,t}(\alpha) \rVert^2 / R^2 }, \end{align}

the first term in Eq. \eqref{eq:mhasecordder} can be shown to look something like

\begin{align} \frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha^2} = & \frac{\beta^2}{R^4} \frac{ 1+\gamma(\alpha) }{ \gamma(\alpha)^3 } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right)^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\ &- \frac{\beta}{R^2} \frac{1}{\gamma(\alpha)} \left( \frac{\partial\boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} \cdot \Delta \mathbf{h}_{i,t} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\ &- \frac{\beta}{R^2} \frac{1}{\gamma(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \frac{\partial\boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} \nonumber \\ &- \frac{\beta^2}{R^2} \frac{1}{\gamma(\alpha)^2 + \gamma(\alpha) } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \Delta \mathbf{h}_{i,t}, \end{align}

which, after substituting the first-order derivative Eq. \eqref{eq:firstorderphialpha}, simplifies to

\begin{align} \frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha^2} = & \frac{\beta^2}{R^4} \frac{ 1+3\gamma(\alpha) }{ \gamma(\alpha)^3 } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right)^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\ &- \frac{\beta^2}{R^2} \frac{1}{\gamma(\alpha)^2 + \gamma(\alpha)} \left( \Delta \mathbf{h}_{i,t} \cdot \Delta \mathbf{h}_{i,t} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\ &- \frac{\beta^2}{R^2} \frac{2}{\gamma(\alpha)^2 + \gamma(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \Delta \mathbf{h}_{i,t} . \label{eq:secondorderphialpha} \end{align}

The second term in Eq. \eqref{eq:mhasecordder} contains non-vanishing contributions in the $\alpha \to 0$ limit coming from the $\sum_{j=1}^{N} J_{ij} \mathbf{s}_{j, t-1}$ terms in $\Delta \mathbf{h}_{i,t}$. The surviving terms in the integrand are proportional to

\begin{align} \sum_{j} J_{ij} \Biggl( &\frac{2 \beta^2}{1+\gamma(\alpha)} \frac{\partial\mathbf{m}_{j, t-1}(\alpha)}{\partial\alpha} \nonumber \\ &- \frac{2 \beta^2}{R^2 \gamma(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \frac{\partial\mathbf{m}_{j, t-1}(\alpha)}{\partial\alpha} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \Biggr), \end{align}

which we can ignore since they are $\mathcal{O}(\alpha)$ on their own, and thus $\mathcal{O}(\alpha^3)$ when multiplied with $\alpha^2$ in the second-order approximation.

Before taking the $\alpha \to 0$ limit of whatever is left in Eq. \eqref{eq:mhasecordder}, we list a few useful tricks to make the evaluation easier. First of all, we use Eq. \eqref{eq:vmf} to introduce the following sneaky substitution

\begin{align} \Delta \mathbf{h}_{i,t} = -\boldsymbol{\theta}_{i,t} + \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{s}_{j,t-1} = \boldsymbol{v}_{i,t} + \sum_{j=1}^{N} J_{ij} \left( \mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1} \right), \end{align}

which conveniently separates terms with fluctuating spin variables from magnetizations that can be pulled out of the integrals. Secondly, all terms that contain only one spin variable with a dependence looking like $\mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1}$ drop out because, schematically,

\begin{align} \mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1} \overset{\int \mathrm{d} \mathbf{s}_{t-1}}{\to} \boldsymbol{\varphi}(\mathbf{h}_{j,t}(\alpha)) - \mathbf{m}_{j,t-1} \overset{\alpha \to 0}{\to} \mathbf{0}. \end{align}

Thirdly, since the $\alpha \to 0$ limit decouples all spins $\mathbf{s}_{t-1}$, any term containing dot products $(\mathbf{s}_{j,t-1}-\mathbf{m}_{j,t-1}) \cdot (\mathbf{s}_{k,t-1}-\mathbf{m}_{k,t-1})$ of two spin variables is zero for $j \neq k$ and equal to $R^2 - \mathbf{m}^2_{j,t-1}$ for $j=k$. We will also encounter terms containing (tensor contractions with) outer products $(\mathbf{s}_{j,t-1}-\mathbf{m}_{j,t-1}) \otimes (\mathbf{s}_{k,t-1}-\mathbf{m}_{k,t-1})$, which we can think of as projection operators. For $j \neq k$, these and similar terms again evaluate to zero, while, for $j=k$, we get the variance contributions we mentioned previously in Eq. \eqref{eq:spinvariance}.

Finally, we take the $\alpha \to 0$ limit of Eq. \eqref{eq:mhasecordder} only to end up with the following mess:

\begin{align} &\qquad \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha} = \nonumber \\ &\frac{\beta^2}{R^4} \frac{1+3\gamma(0)}{\gamma(0)^3} \left( \left( \mathbf{m}_{i,t} \cdot \mathbf{v}_{i,t} \right)^2 + \sum_{j} J_{ij} \left( \frac{\mathbf{m}_{i,t}^2}{1+\gamma(0)} - \frac{\left(\mathbf{m}_{i,t}\cdot\mathbf{m}_{j,t-1}\right)^2}{R^2 \gamma(0)} \right) \right) \mathbf{m}_{i,t} \nonumber \\ &- \frac{\beta^2}{R^2} \frac{1}{\gamma^2 (0) + \gamma(0)} \left( \mathbf{v}_{i,t}^2 + \sum_{j} J_{ij} \left( R^2 - \mathbf{m}_{j,t-1}^2 \right) \right) \mathbf{m}_{i,t} \nonumber \\ &- \frac{\beta^2}{R^2} \frac{2}{\gamma^2 (0) + \gamma(0)} \left( \mathbf{v}_{i,t} \otimes \mathbf{v}_{i,t} + \sum_{j} J_{ij} \left( \frac{\mathbb{1}}{1+\gamma(0)} - \frac{\mathbf{m}_{j,t-1}\otimes\mathbf{m}_{j,t-1}}{R^2 \gamma(0)} \right) \right) \mathbf{m}_{i,t} \nonumber \end{align}

At this point, it is already too late. We should have remembered that the second-order approximation lives in the neighborhood of the first-order approximation. We probably ended up doing too much work by taking terms into account that are of higher order than $\mathcal{O}(\alpha^2)$. To arrive at the second-order mean-field equations for the magnetizations, we now only have to solve Eq. \eqref{eq:secondorderconstraint} for $\boldsymbol{\theta}^{*}_{i,t}$,

\begin{equation} \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} + \frac{1}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha} = \mathbf{0} + \mathcal{O}\left(\alpha^3\right). \end{equation}

Let us substitute $\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha}$ from Eq. \eqref{eq:mfirstorderalphazero} but keep $\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha}$ for generality,

\begin{align} \beta \left( \frac{\mathbb{1}}{1+\gamma(0)} - \frac{\mathbf{m}_{i,t}\otimes\mathbf{m}_{i,t}}{R^2 \gamma(0)} \right) \mathbf{v}_{i,t} + \frac{1}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha} = \mathbf{0} + \mathcal{O}\left(\alpha^3\right), \end{align}

so that we can then isolate $\boldsymbol{\theta}_{i,t}$ in $\mathbf{v}_{i,t}$ to find

\begin{align} \boldsymbol{\theta}_{i,t} = \mathbf{x}_{i,t} &+ \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \nonumber \\ &+ \frac{1+\gamma(0)}{2\beta} \left( \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha} + \frac{\mathbf{m}_{i,t} \cdot \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial^{2}\alpha}}{\frac{R^2 \gamma(0)}{1+\gamma(0)} - \mathbf{m}_{i,t}^2} \mathbf{m}_{i,t} \right), \end{align}

where we have used the Sherman–Morrison formula to compute the inverse of the variance matrix. Since the expression on the right-hand side also depends on $\boldsymbol{\theta}_{i,t}$, we seem to have stumbled upon a set of fixed-point equations

\begin{align} \boldsymbol{\theta}_{i,t} = \mathbf{f} (\boldsymbol{\theta}_{i,t}, \mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1}), \label{eq:thetafp} \end{align}

which we should solve to find $\boldsymbol{\theta}^{*}_{i,t}$. The second-order mean-field equations then become yet another set of fixed-point equations

\begin{equation} \boxed{\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left(\boldsymbol{\theta}^{*}_{i,t}(\mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1})\right) } \label{eq:tapmvector} \end{equation}

Like in the TAP approximation Eq. \eqref{eq:tapm} for the binary case, the dependence of $\boldsymbol{\theta}^{*}_{i,t}$ on $\mathbf{m}_{i,t}$ indicates that we should solve for fixed-point magnetizations $\mathbf{m}^{*}_{i,t}$. However, in contrast to the binary case, the dependence here turns out to be implicit since $\boldsymbol{\theta}^{*}_{i,t}$ is obtained from solving the fixed-point equations Eq. \eqref{eq:thetafp} which depend on $\mathbf{m}_{i,t}$ as well. Numerically, we should thus try to jointly optimize for the fixed-point tuples $(\mathbf{m}^{*}_{i,t}, \boldsymbol{\theta}^{*}_{i,t})$ such that the following pair of equations holds for every site:

\begin{align} \boldsymbol{\theta}^{*}_{i,t} &= \mathbf{f} (\boldsymbol{\theta}^{*}_{i,t}, \mathbf{x}_{i,t}, \mathbf{m}^{*}_{i,t}, \mathbf{m}_{t-1}) \\ \mathbf{m}^{*}_{i,t} &= \boldsymbol{\varphi} \left( \boldsymbol{\theta}^{*}_{i,t}(\mathbf{x}_{i,t}, \mathbf{m}^{*}_{i,t}, \mathbf{m}_{t-1}) \right) \end{align}

Looking ahead at the transformer-module correspondence in Section 4, we recognize a scaled sum of a residual connection, an attention term, and a self-consistent expression in terms of magnetizations and couplings playing the role of the feed-forward network. These feed-forward-like correction terms are determined by the structure of the vector-spin model’s mean-field approximation and require no additional free parameters.

3.5. A simple JAX implementation

We now turn to a JAX implementation of the mean-field time evolution of the magnetizations according to the vector-spin model introduced in the previous sections.

$\ldots$

import jax

$\ldots$

4. A family of transformer-like modules

$\ldots$

4.1. Connecting the dots

$\ldots$

4.2. Fast- and slow-moving parameters

As mentioned previously in Transformers Are Secretly Collectives of Spin Systems (2021), each example in a batch of sequential data can be thought of as probing the spin system in a particular way. Physically, the fast-moving parameterized couplings $\mathbf{J}(\mathbf{x})$ are determined by the fast-moving parameterized external fields $\mathbf{x}$, which in turn depend on the magnetizations of the previous layer and ultimately on the input data. The external fields act as an environment of patterns that gets transformed instantly into the values of the coupling matrix, effectively inducing some kind of state of quenched disorder. The slow-moving parameters are those receiving gradient updates during training, e.g. the query-key matrices in the softmax couplings. On the level of the spin-model transformer module, training can be understood as shaping the input-dependent distribution of coupling parameters by amassing information from a huge amount of quenched disorder realizations, sculpting a spin glass with data.

4.3. A simple JAX implementation

$\ldots$

4.4. Hunting for non-equilibrium steady states

$\ldots$

5. Conclusion

$\ldots$

Appendices

A.1. Vector-spin distribution: normalization constant

We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:

\begin{equation} p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }. \end{equation}

Let $Z(\beta, R, \mathbf{h})=\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}$. We switch to $D$-dimensional spherical coordinates to make our life easier and use rotational symmetry to choose the polar axis parallel to $\mathbf{h}$,

\begin{equation} Z(\beta, R, h) = R^{D-1} \int_{\Omega} \int_{0}^{\pi} \mathrm{d}^{D-2} \Omega \;\mathrm{d}\theta \; \mathrm{e}^{\beta R h \cos \theta } \sin^{D-2} \theta , \end{equation}

where $h=\lVert\mathbf{h}\rVert$ and where $\int_{\Omega} \mathrm{d}^{D-2} \Omega$ represents the integral over all other spherical angles, which coincides with the surface area of the unit sphere in $D-1$ dimensions,

\begin{equation} S_{D-1} = \frac{2\pi^{\frac{D-1}{2}}}{\Gamma\left( \frac{D-1}{2} \right)}, \end{equation}

so that

\begin{equation} Z(\beta, R, h) = \frac{2 \pi^{\frac{D-1}{2}} R^{D-1}}{\Gamma\left( \frac{D-1}{2} \right)} \int_{0}^{\pi} \mathrm{d}\theta \; \mathrm{e}^{\beta R h \cos \theta } \sin^{D-2} \theta . \end{equation}

If we now let $u = \cos \theta$, then

\begin{equation} Z(\beta, R, h) = \frac{2 \pi^{\frac{D-1}{2}} R^{D-1}}{\Gamma\left( \frac{D-1}{2} \right)} \int_{-1}^{1} \mathrm{d}u \; \mathrm{e}^{\beta R h u } \left(1 - u^2\right)^{(D-3)/2} . \end{equation}

Recognizing an integral representation of the modified Bessel function of the first kind,

\begin{equation} I_{\nu}(z) = \frac{2^{-\nu}}{\sqrt{\pi}\, \Gamma\left(\nu+\frac{1}{2}\right)} z^{\nu} \int_{-1}^{1} \mathrm{d}t \; \mathrm{e}^{\pm zt} \left(1-t^2\right)^{\nu-\frac{1}{2}}, \end{equation}

we identify $\nu = D/2 - 1$ and $z = \beta R h$ to find

\begin{equation} Z(\beta, R, h) = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R h) }{ \left(\beta h\right)^{D/2-1} }. \end{equation}

A.2. Vector-spin distribution: expected value (first moment)

We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:

\begin{equation} p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }. \end{equation}

Starting from the expression of the normalization constant Eq. \eqref{eq:partfun},

\begin{equation} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } = Z(\beta, R, \lVert \mathbf{h}\rVert) , \end{equation}

we write the expected value as

\begin{equation} \mathbb{E}_{p} [ \mathbf{s} ] = \frac{1}{Z} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathbf{s} \, \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} \end{equation}

so that

\begin{align} \mathbb{E}_{p} [ \mathbf{s} ] = \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \left( \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert\mathbf{h} \rVert) }{ \left(\beta \lVert\mathbf{h}\rVert \right)^{D/2-1} } \right) \end{align}

which evaluates to

\begin{align} \mathbb{E}_{p} [ \mathbf{s} ] = \left( \frac{I’_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert)}{I_{D/2 - 1}(\beta R \lVert\mathbf{h}\rVert)} - \frac{ D/2-1 }{ \beta R \lVert\mathbf{h}\rVert} \right) \frac{R \mathbf{h}}{\lVert\mathbf{h}\rVert}. \end{align}

Using the modified Bessel function recurrence relations,

\begin{align} I_{\nu-1}(z) - I_{\nu+1}(z) &= \frac{2\nu}{z} I_{\nu}(z), \label{eq:irecurr}\\ I_{\nu-1}(z) + I_{\nu+1}(z) &= 2 I'_{\nu}(z), \label{eq:irecurrderiv} \end{align}

we end up with

\begin{align} \mathbb{E}_{p} [ \mathbf{s} ] = \frac{I_{D/2}(\beta R \lVert \mathbf{h}\rVert)}{I_{D/2 - 1}(\beta R \lVert\mathbf{h}\rVert)} \frac{R \mathbf{h}}{\lVert\mathbf{h}\rVert}\equiv \boldsymbol{\varphi} (\mathbf{h}). \label{eq:app:expectedvalue} \end{align}

A.3. Vector-spin distribution: variance (second moment)

TODO: Add variance for general case.

We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:

\begin{equation} p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }. \end{equation}

Using the expression of the normalization constant Eq. \eqref{eq:partfun},

\begin{equation} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } = Z(\beta, R, \lVert \mathbf{h}\rVert) , \end{equation}

we write the symmetric outer-product variance matrix as

\begin{align} \mathrm{Var}_{p} [ \mathbf{s} ] &= \frac{1}{Z} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} \, ( \mathbf{s} - \mathbb{E}_{p} [ \mathbf{s} ])( \mathbf{s} - \mathbb{E}_{p} [ \mathbf{s} ])^{T} \\ &= \frac{1}{\beta^2 Z} \frac{ \partial^2 }{ \partial \mathbf{h} \partial \mathbf{h}^{T} } \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} - \mathbb{E}_{p} [ \mathbf{s} ] \mathbb{E}_{p} [ \mathbf{s} ]^{T}, \end{align}

so that

\begin{align} \mathrm{Var}_{p} [ \mathbf{s} ] &= \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \left( Z \mathbb{E}_{p} [ \mathbf{s} ]^{T} \right) - \mathbb{E}_{p} [ \mathbf{s} ] \mathbb{E}_{p} [ \mathbf{s} ]^{T}, \\ &= \frac{1}{\beta} \frac{ \partial }{ \partial \mathbf{h} } \mathbb{E}_{p} [ \mathbf{s} ]^{T}, \end{align}

which evaluates to

\begin{align} \mathrm{Var}_{p} [ \mathbf{s} ] &= \ldots \label{eq:app:var} \end{align}

for the general case with the expected value given by Eq. \eqref{eq:app:expectedvalue} and to

\begin{align} \mathrm{Var}_{p} [ \mathbf{s} ] &= \frac{\mathbb{1}}{1+\gamma(\mathbf{h})} - \frac{\beta^2\mathbf{h} \otimes \mathbf{h}}{R^2\gamma(\mathbf{h})\left(1+\gamma(\mathbf{h})\right)^2}\\ &= \frac{\mathbb{1}}{1+\gamma(\mathbf{h})} - \frac{\boldsymbol{\varphi} (\mathbf{h}) \otimes \boldsymbol{\varphi}(\mathbf{h})}{R^2\gamma(\mathbf{h})} \end{align}

for the large-$D$ limit with the expected value given by Eq. \eqref{eq:largedevmag}, where

\begin{align} \gamma(\mathbf{h}) = \sqrt{1+\beta^{2}\lVert\mathbf{h}\rVert^{2}/R^2} \end{align}

A.4. Ratio of modified Bessel functions of the first kind

To compute the ratio $I_{\nu+1}(x) / I_{\nu}(x)$ of modified Bessel functions of the first kind for $\nu \geq 0$ and $x \geq 0$, we implement a JAX version of the algorithm described in (Amos, 1974). A pseudocode implementation can be found in (Kurz et al., 2013). We compare our implementation against explicitly calculating the ratio using scipy.special.ive across a range of orders $\nu$ for several different values of $x$ to get a feel for its behavior.

We observe a satisfying agreement between the two approaches. For $x=\sqrt{\nu}$, the ratio takes on very small values for large orders. For $x=\nu^2$, the oppositive happens and we see saturation. The case $x=\nu$ seems to sit in between, which suggests it might be opportune to fix the radius of our little spins to $R=\sqrt{D}$ so that with $\lVert\mathbf{h}\rVert \sim \mathcal{O}(\sqrt{D})$ we might maximize the “sensitivity” of the expected value. In this regime, we can get away with known asymptotic expansions for large $\nu$ given that the ratio flattens out quickly.

A.5. General case: partial derivatives with respect to $\alpha$

TODO: Clean up and verify.

We are interested in computing the first-order and second-order derivative with respect to $\alpha$ of the function

\begin{equation} \boldsymbol{\varphi}(\mathbf{h}(\alpha)) = \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \mathbf{h}(\alpha)}{\lVert \mathbf{h}(\alpha) \rVert}, \end{equation}

where $\mathbf{h}(\alpha) = \boldsymbol{\theta} + \alpha \Delta \mathbf{h}$. Using

\begin{equation} \frac{\partial \lVert \mathbf{h}(\alpha) \rVert}{\partial\alpha} = \frac{\mathbf{h}(\alpha) \cdot \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \end{equation}

and Eqs. \eqref{eq:irecurr}-\eqref{eq:irecurrderiv}, we find

\begin{align} \frac{\partial \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} = \beta &\lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\ &+ \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \label{eq:generalgradalphafirstorder} \end{align}

where

\begin{equation} \lambda_{D} (x) = \frac{I^2_{D/2-1}(x)}{I^2_{D/2}(x)} - \frac{D}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)} - 1. \label{eq:app:lambda} \end{equation}

For the second-order derivative, we need to slog through even more tedious algebra,

\begin{align} \frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial^2\alpha} = \beta &\frac{\partial}{\partial\alpha}\biggl( \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \biggr) \nonumber \\ &+ \frac{\partial}{\partial\alpha}\biggl( \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \biggr) , \end{align}

which eventually leads to something like

\begin{align} \frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial^2\alpha} = -2\beta^2 & \, \kappa_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right)^{2} \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\ &+ \beta \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \frac{\partial\boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\ &+ \beta \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \frac{\partial\boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} \nonumber \\ &- \frac{D}{\lVert \mathbf{h}(\alpha) \rVert^2} \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \Delta \mathbf{h} , \label{eq:generalgradalphasecondorder} \end{align}

where

\begin{align} \kappa_{D} (x) = \lambda^2_{D} (x) + \left( 1 + \frac{D/2 + 1}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)} \right) \lambda_{D} (x) + \frac{1}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)}. \end{align}

Equation \eqref{eq:generalgradalphasecondorder} can be further simplified by substituting the first-order derivative Eq. \eqref{eq:generalgradalphafirstorder} and further simplifying the resulting expression. The derivation of the mean-field equations proceeds in a similar fashion as in the main text, but uses \eqref{eq:generalgradalphafirstorder} and \eqref{eq:generalgradalphasecondorder} as expressions for the partial derivatives instead of their large-$D$ approximations.

Another useful derivative is that of the single-site probability distribution \eqref{eq:pcondsinglesitevector},

\begin{align} \frac{\partial}{\partial\alpha} \left( \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \right) = \frac{\partial}{\partial\mathbf{h}(\alpha)} \left( \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \right) \cdot \Delta \mathbf{h}, \end{align}

which evaluates to

\begin{align} \beta \left( \mathbf{s} - \boldsymbol{\varphi}\left(\mathbf{h}(\alpha)\right) \right) \cdot \Delta \mathbf{h} \frac{ \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} }{ \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \end{align}

and can be used to calculate derivatives of the conditional distribution \eqref{eq:pcondaltvector}.

References & footnotes

A non-exhaustive list of references and inspiration includes:

  • M. Aguilera, S.A. Moosavi, and H. Shimazaki, A unifying framework for mean-field theories of asymmetric kinetic Ising systems, Nat Commun 12, 1197 (2021) https://arxiv.org/abs/2002.04309

  • Y. Roudi and J. Hertz, Dynamical TAP equations for non-equilibrium Ising spin glasses, J. Stat. Mech., P03031 (2011) https://arxiv.org/abs/1103.1044

  • H.J. Kappen and J.J. Spanjers, Mean field theory for asymmetric neural networks, Phys. Rev. E 61, 5658 (2000)

  • G. Parisi, Asymmetric neural networks and the process of learning, J. Phys. A: Math. Gen. 19 L675 (1986)


If you happen to find this work useful, please consider citing it as:

@article{bal2023spinmodeltransformers,
  title   = {Spin-Model Transformers},
  author  = {Bal, Matthias},
  year    = {2023},
  month   = {?},
  url     = {https://mcbal.github.io/post/spin-model-transformers}
}

  1. We plot the absolute value to get rid of artificial “jumps” between the two branches. These occur because all models are simulated independently when sweeping across $\beta$ and the some combinations of initial state and model parameters might just happen to bounce to the other branch when $\beta$ changes in the $\beta > \beta_c$ regime. ↩︎

Previous

Related