<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Non-Equilibrium Dynamics | mcbal</title><link>https://mcbal.github.io/tags/non-equilibrium-dynamics/</link><atom:link href="https://mcbal.github.io/tags/non-equilibrium-dynamics/index.xml" rel="self" type="application/rss+xml"/><description>Non-Equilibrium Dynamics</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-gb</language><lastBuildDate>Mon, 02 Feb 2026 09:28:17 +0100</lastBuildDate><image><url>https://mcbal.github.io/media/icon.svg</url><title>Non-Equilibrium Dynamics</title><link>https://mcbal.github.io/tags/non-equilibrium-dynamics/</link></image><item><title>Entropy Production in Non-Equilibrium Neural Networks</title><link>https://mcbal.github.io/post/entropy-production-in-non-equilibrium-neural-networks/</link><pubDate>Mon, 02 Feb 2026 09:28:17 +0100</pubDate><guid>https://mcbal.github.io/post/entropy-production-in-non-equilibrium-neural-networks/</guid><description>&lt;p&gt;&lt;a title="Walter Baxter / A murmuration of starlings at Gretna" href="https://commons.wikimedia.org/wiki/File:Starling_murmuration.jpg"&gt;&lt;img width="512" alt="A murmuration of starlings at Gretna" src="https://upload.wikimedia.org/wikipedia/commons/8/8d/Starling_murmuration.jpg?20150218191823"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;&lt;p align="center"&gt;This project is a work in progress (open research)&lt;/p&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 id="introduction"&gt;Introduction&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Modern large-scale autoregressive language models are impressive system engineering artifacts. Yet they are frozen, with no apparent notion of dynamics unfolding over time. Surfacing in-context learning at inference time through prompt and environment engineering mitigates the fact that these models are temporal only in so far as information inside their context windows matches patterns observed during consecutive offline training stages. Time, and its dynamic memory affordances, is in a sense amortized or compressed away, incentivizing models to overrely on storing relevant patterns into parametric memory instead of sculpting latent low-dimensional shapes supporting stable dynamic computation. This has implications for online continual learning, adaptive model deployment, and real-time closed-loop interaction with live systems.&lt;/p&gt;
&lt;p&gt;In this post, we take the notion of treating neural networks as non-equilibrium thermodynamic systems seriously. We design a physics-inspired transformer module with adaptable couplings and memory parameters based on the naive mean-field dynamics of a class of vector-spin models introduced in
. The underlying mean-field spin-model interpretation enables us to write down an expression for
, a thermodynamic quantity measuring &amp;ldquo;instantaneous&amp;rdquo; irreversibility by quantifying the asymmetry between forward and backward time steps.&lt;/p&gt;
&lt;p&gt;Since every operation in our spin-transformer module is differentiable, entropy production can be made into a loss function. For example, maximizing entropy production incentivizes the system to &lt;em&gt;lean into the external drive&lt;/em&gt; by nudging its parameters to dump entropy as fast as possible in a way that maximizes uncertainty given constraints. Internally, we imagine the system reshaping itself into ordered structures to enable more efficient dissipation of the internal tension caused by the incoming data stream.&lt;/p&gt;
&lt;h2 id="background-and-intuitions"&gt;Background and intuitions&lt;/h2&gt;
&lt;p&gt;We
consider transformer modules as differentiable driven disordered vector-spin systems whose mean-field collective behavior we can control through training, and refer to
going back to
for earlier instantiations of this intuition. According to our correspondence, the forward pass of a transformer module implements a spin system&amp;rsquo;s response to getting probed, where &lt;em&gt;inputs&lt;/em&gt; map to time-varying applied external fields, &lt;em&gt;asymmetric, sparse attention matrices&lt;/em&gt; can be identified with fully-connected spin-spin interactions, and &lt;em&gt;outputs&lt;/em&gt; map to spin expectation values or magnetizations. Practically, the forward pass of a spin-transformer module can be designed to mimic that of a vanilla transformer module.&lt;/p&gt;
&lt;p&gt;In contrast to physics-oriented literature, we do not specify explicit probability distributions for the external fields and couplings of the disordered many-body system, nor are we interested in Nobel-prize-winning ways to average out the disorder. We instead focus on the very specific quenched disorder realizations induced by a dataset or environment of interest (encoded as sequences of vector embeddings), whose examples we use to drive the system. In this framing, training a transformer module corresponds to sculpting the underlying system&amp;rsquo;s collective response by tuning the parametrized distributions of its external fields and couplings.&lt;/p&gt;
&lt;img src="spin_transformer_module_fwd_bwd.png" alt="Forward and backward pass illustration" width="500px"/&gt;
&lt;p&gt;In
, we observed that these systems tend to settle into non-equilibrium steady states as dynamic sweet spots where the &amp;ldquo;continuous kicking&amp;rdquo; of the inputs (applied external fields) &amp;ldquo;sustains&amp;rdquo; the outputs (magnetizations). This negotiation process tends to happen after just a few iterations. The first iteration already gives a decent guess, which might explain why (1) transformers can get away with just stacking modules whose forward passes take just one time step, and (2) why doing a few time steps can improve performance, as done in recursive reasoning approaches. Indeed, repeating the same module can be seen as allowing the underlying non-equilibrium system to settle more snuggly into its steady state for that particular inputs/parameters configuration. However, as soon as the input sequence changes or the parameters are updated, the system has to renegotiate a different steady state compatible with what its current configuration dictates the response should be.&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="non-equilibrium-neural-networks"&gt;Non-equilibrium neural networks&lt;/h2&gt;
&lt;h3 id="example-model"&gt;Example model&lt;/h3&gt;
&lt;p&gt;When designing neural networks around mean-field vector-spin models, there is a lot of freedom. First of all, we must decide on what mean-field approximation to use for our spin system. Projecting the dynamics to different ansatz distributions leads to different mean-field equations, whick take into account more or less correlations at different time steps. In this post, we choose the simplest option: a first-order &lt;code&gt;Plefka[t-1,t]&lt;/code&gt; approximation. From
, we remember&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \rVert^2 / R^2 }},
\end{equation}&lt;p&gt;where $\mathbf{m}_{i,t} \in \mathbb{R}^{D}$ denote the magnetizations (outputs) at time $t$, $\mathbf{x}_{i,t} \in \mathbb{R}^{D}$ denote the applied external fields (inputs) at time $t$, $J_{ij}$ are the couplings, $\beta$ is an inverse temperature, and $R=\sqrt{D/2 -1}$ is a natural length scale resulting from the large-$D$ approximation we used to get rid of dealing with Bessel functions.&lt;/p&gt;
&lt;p&gt;If we now consider &lt;em&gt;parametrized input-dependent couplings&lt;/em&gt;&lt;/p&gt;
\begin{equation}
\mathbf{J} (\mathbf{x}) = \mathrm{softmax}\left( \mathbf{x} \boldsymbol{Q} \boldsymbol{K}^{T} \mathbf{x}^{T} \right), \label{eq:softmax}
\end{equation}&lt;p&gt;and augment the applied external fields with a &lt;em&gt;parametrized input-dependent memory&lt;/em&gt;,&lt;/p&gt;
\begin{equation}
\mathbf{x}_{i,t} \to \mathbf{x}_{i,t} + \mathrm{FFN}\left( \mathbf{x}_{i,t} \right),
\end{equation}&lt;p&gt;then our forward pass looks like&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \mathrm{FFN}\left( \mathbf{x}_{i,t} \right) + \sum_{j} J_{ij} (\mathbf{x}_{t}) \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \mathrm{FFN}\left( \mathbf{x}_{i,t} \right) + \sum_{j} J_{ij} (\mathbf{x}_{t}) \mathbf{m}_{j,t-1} \rVert^2 / R^2 }},
\end{equation}&lt;p&gt;which resembles a parallel transformer block as introduced in GPT-J and used in PaLM, with the notable difference that the &amp;ldquo;values&amp;rdquo; here correspond to the outputs (magnetizations) of the previous time step instead of some linear transformation applied to the inputs at the current time step. Making the applied external fields as well as the couplings input-dependent leads to a &lt;em&gt;highly adaptive system&lt;/em&gt; where the interaction landscape itself is dynamically shaped by the inputs.&lt;/p&gt;
&lt;p&gt;We can choose to have our module keep track of the previous state so that one forward pass corresponds to taking a single time step. If we care more about the steady state, we can also immediately compute the fixed point of the time evolution using a differentiable fixed-point solver, in which case one forward pass corresponds to jumping to the time-evolution fixed point. The latter approach is reminiscent of deep equilibrium models and certain recursive reasoning approaches.&lt;/p&gt;
&lt;h3 id="entropy-production"&gt;Entropy production&lt;/h3&gt;
&lt;p&gt;Following
, the entropy production for the kinetic Ising model, assuming a non-equilibrium steady state, is given by&lt;/p&gt;
\begin{equation}
\sigma_{t} = \sum_{ij} \left(J_{ij} - J_{ji}\right) D_{ij,t} \geq 0,
\end{equation}&lt;p&gt;where $J_{ij}$ corresponds to the couplings and $D_{ij,t}$ denotes the time-delayed correlations. If we write this down for the vector-spin case,&lt;/p&gt;
\begin{equation}
D_{ij,t} = \int \mathrm{d} \mathbf{s}_{t} \int \mathrm{d} \mathbf{s}_{t-1} \; \left( \mathbf{s}_{i,t} - \mathbf{m}_{i,t} \right) \cdot \left( \mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1}\right) \; P( \mathbf{s}_{t}, \mathbf{s}_{t-1} ),
\end{equation}&lt;p&gt;we can compute a first-order &lt;code&gt;Plefka[t-1,t]&lt;/code&gt; mean-field approximation for the time-delayed correlations, similar to the computations we did previously for the magnetizations in
, leading to something like&lt;/p&gt;
\begin{align}
D_{ij,t} = &amp;\frac{\beta J_{ij}}{1+\gamma_{i,t}} \left(R^2 - \mathbf{m}_{j,t-1}^2 \right) \nonumber\\\\
&amp;- \frac{\beta J_{ij}}{R^2 \gamma_{i,t} \left( 1 + \gamma_{j,t-1} \right)} \mathbf{m}_{i,t}^2 \nonumber\\\\
&amp;+ \frac{\beta J_{ij}}{R^4 \gamma_{i,t} \gamma_{j,t-1}} \left( \mathbf{m}_{i,t} \cdot \mathbf{m}_{j,t-1} \right)^2,
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\gamma_{i,t} &amp;= \sqrt{1 + \beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 } \\\\
\boldsymbol{\theta}_{i,t} &amp;= \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1}
\end{align}&lt;h3 id="vibe-check"&gt;Vibe check&lt;/h3&gt;
&lt;p&gt;Let us try to get a feel for what the entropy production looks like for vector-spin models using some rough back-of-the-envelope estimations. Assume both vectors $\mathbf{m}_{i,t}$ and $\mathbf{m}_{j,t-1}$ have a norm $\mathcal{O}(R)$, then the time-delayed correlations behave approximately like&lt;/p&gt;
\begin{align}
D_{ij,t} \sim J_{ij} \cos^2 \alpha_{(i,t)(j,t-1)},
\end{align}&lt;p&gt;where $\alpha_{(i,t)(j,t-1)}$ denotes the angle between the magnetization vectors. So the entropy production looks approximately like&lt;/p&gt;
\begin{equation}
\sigma_{t} \sim \sum_{ij} \left(J_{ij}^2 - J_{ij} J_{ji}\right) \cos^2 \alpha_{(i,t)(j,t-1)},
\end{equation}&lt;p&gt;which, in general, is minimized for symmetric coupling matrices or orthogonal embeddings and maximized for fully-asymmetric couplings or (anti-)parallel embeddings.&lt;/p&gt;
&lt;p&gt;But for the softmax attention matrix Eq. \eqref{eq:softmax}, we have additional constraints $J_{ij} \geq 0$ as well as a Frobenius norm of $\mathcal{O}(\sqrt{N})$ preventing unbounded growth under maximization. Additionally, imposing a causal mask on the couplings to do autoregressive modeling leads to even more constraints since then the upper triangular part of $J_{ij}$ is fixed to zero. So it feels like maximizing entropy production for causal softmax couplings promotes some kind of compromise between &lt;em&gt;sparse attention&lt;/em&gt; (intuitively, if the upper-triangular part is zero then it is favorable to push the lower-triangular elements close to zero as well) and &lt;em&gt;clustering of embeddings&lt;/em&gt; (weighted maximization of cosine similarity).&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="experiments"&gt;Experiments&lt;/h2&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h3 id="model-behavior-in-a-noisy-environment"&gt;Model behavior in a noisy environment&lt;/h3&gt;
&lt;p&gt;Interfaces, sensors and effectors.&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h3 id="global-coherence-from-local-backpropagation"&gt;Global coherence from local backpropagation&lt;/h3&gt;
&lt;p&gt;We test a stack of spin-transformer modules in a toy femtoscale online learning setup and try to see if we can make
when maximizing per-layer entropy-production losses &lt;em&gt;independently&lt;/em&gt;. If we detach module outputs after applying each layer, we end up with systems communicating via their input/output interfaces but without gradients backpropagating through the whole stack. (Pretty unlikely that the entropy-production losses on their own provide enough signal though.)&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h3 id="growing-network-topologies"&gt;Growing network topologies&lt;/h3&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="discussion-and-related-work"&gt;Discussion and related work&lt;/h2&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="references"&gt;References&lt;/h2&gt;
&lt;p&gt;A non-exhaustive list of references and inspiration includes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
by
Miguel Aguilera, S. Amin Moosavi, and Hideaki Shimazaki&lt;/li&gt;
&lt;li&gt;
by Jacob Mitchell Gold&lt;/li&gt;
&lt;li&gt;
by Giovanni Pezzulo and Michael Levin&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2026,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Entropy Production in Non-Equilibrium Neural Networks},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2026},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {?},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/entropy-production-in-non-equilibrium-neural-networks/}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h1 id="footnotes"&gt;Footnotes&lt;/h1&gt;</description></item><item><title>Spin-Model Transformers</title><link>https://mcbal.github.io/post/spin-model-transformers/</link><pubDate>Sun, 19 Jun 2022 09:28:17 +0100</pubDate><guid>https://mcbal.github.io/post/spin-model-transformers/</guid><description>&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TL;DR:&lt;/strong&gt; &lt;em&gt;We interpret and implement transformer modules as driven, disordered vector-spin models whose response behavior can be shaped by learning parameterized interactions, gradually steering a cascade of near-equilibrium steady-state magnetizations towards solving a given objective. Using dynamical mean-field theory, we show that a first-order approximation of the update equations for the magnetizations reproduces residual and attention terms. Going to second-order adds explicit expressions for feed-forward-like correction terms that are fully determined by the mean-field structure of the underlying spin model. By blending ideas from deep learning and statistical mechanics, we hope our work can help open up broader interdisciplinary bridges to improve our understanding of learning and generalization in transformer neural networks.&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;In a series of previous
, we have tried to connect the forward pass of a transformer neural-network module to computing mean magnetizations in disordered Ising-like vector-spin models with parameterized couplings and external magnetic fields. According to this perspective, the forward pass of a transformer module can be understood as computing statistical observables given a specific realization of quenched couplings and external magnetic fields while the backward pass nudges the parameterized couplings and external magnetic fields. Physically, the transformer module represents an interacting many-body system modulating its behavior by learning to respond to being probed and driven in all kinds of ways.&lt;/p&gt;
&lt;p&gt;However, both the mean-field message-passing approach of
and the saddle-point free-energy approach of
inherently rely on methods that are only well-defined for spin models with symmetric coupling matrices, whose stochastic dynamics obey detailed balance and converge to a steady-state equilibrium characterized by the Boltzmann distribution. The softmax attention matrix in transformers is famously asymmetric though, so we had better come up with a more convincing approach to establish a correspondence.&lt;/p&gt;
&lt;p&gt;To capture spin models with asymmetric coupling matrices, we turn to non-equilibrium spin systems, whose dynamics can be pretty wild yet gentle enough to support regimes where relaxation to a non-equilibrium or near-equilibrium steady state can occur. In the past few decades, dynamical mean-field approaches have been developed for the binary kinetic Ising model, which exhibits non-equilibrium behavior for asymmetric couplings or when parameters are subject to rapid changes.&lt;/p&gt;
&lt;p&gt;In this post, we generalize a particular dynamical mean-field approach from binary spins to vector spins and relate the resulting mean-field update equations for the magnetizations to the forward pass of a transformer module. We find that the spin-model structure is rich enough for the update equations to yield residual connections, attention terms, and feed-forward correction terms, motivating a family of physics-inspired transformers.&lt;/p&gt;
&lt;h1 id="mean-field-theory-of-asymmetric-ising-models-with-binary-spins"&gt;Mean-field theory of asymmetric Ising models with binary spins&lt;/h1&gt;
&lt;p&gt;In this preliminary section, we review known results on mean-field theory approaches capturing the stochastic dynamics of binary kinetic Ising models. Readers familiar with this framework can skip ahead to
where we develop a generalization to vector spins. We primarily follow the discussion outlined in
. At the end of the section, we implement the mean-field update equations for the mean magnetizations in JAX and run a few numerical experiments.&lt;/p&gt;
&lt;h2 id="setting-the-scene-the-kinetic-ising-model"&gt;Setting the scene: the kinetic Ising model&lt;/h2&gt;
&lt;img src="binary_spins.png" alt="Random Ising model configuration with binary spins" width="250px"/&gt;
&lt;p&gt;We consider a kinetic Ising model describing a system made up of $N$ interacting binary spins $s_{i,t} \in \{-1, 1\}$ that evolve in discrete time steps $t$ according to synchronous dynamics, i.e. all spins get updated at the same time in parallel. Given a configuration $\mathbf{s}_{t-1} = \{ s_{1,t-1}, s_{2,t-1}, \ldots, s_{N,t-1} \}$ at time $t-1$, we consider the spins $\mathbf{s}_{t}$ at time $t$ to be conditionally independent random variables captured by a discrete-time Markov chain transition probability&lt;/p&gt;
\begin{equation}
P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}}}{\sum_{s_{i,t}} \mathrm{e}^{s_{i,t} h_{i,t}}} = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}}}{2 \cosh h_{i,t}}, \label{eq:pcond}
\end{equation}&lt;p&gt;where the effective external field is given by&lt;/p&gt;
\begin{equation}
h_{i,t} = x_{i,t} + \sum_{j=1}^{N} J_{ij} s_{j,t-1}.
\end{equation}&lt;p&gt;Here, the parameters $\mathbf{x}$ represent the (possibly time-dependent) local external fields at each site while the coupling parameters $\mathbf{J}$ are a specific realization of quenched disorder encoding the interactions between pairs of spins. Using the probability mass function of the previous state $P( \mathbf{s}_{t-1} )$ we can write the distribution of the current state as&lt;/p&gt;
\begin{equation}
P( \mathbf{s}_{t} ) = \sum_{\mathbf{s}_{t-1}} P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P( \mathbf{s}_{t-1} ), \label{eq:marginal}
\end{equation}&lt;p&gt;which, when applied recursively, traces the evolution of the system starting from some initial distribution $P( \mathbf{s}_{0} )$. Unless we turn off the couplings by setting $\mathbf{J} = \mathbf{0}$, the marginal distribution $P( \mathbf{s}_{t} )$ is not factorized and tends to be quite complicated. Our goal is to compute statistical properties of the system, such as the mean magnetizations&lt;/p&gt;
\begin{equation}
m_{i,t} = \sum_{\mathbf{s}_{t}} s_{i,t} P( \mathbf{s}_{t} ),
\end{equation}&lt;p&gt;as well as correlations&lt;/p&gt;
\begin{equation}
C_{ik,t} = \sum_{\mathbf{s}_{t}} s_{i,t} s_{k,t} P( \mathbf{s}_{t} ) - m_{i,t} m_{k,t},
\end{equation}&lt;p&gt;and delayed correlations&lt;/p&gt;
\begin{equation}
D_{il,t} = \sum_{\mathbf{s}_{t},\mathbf{s}_{t-1}} s_{i,t} s_{l,t-1} P( \mathbf{s}_{t}, \mathbf{s}_{t-1} ) - m_{i,t} m_{l,t-1}.
\end{equation}&lt;p&gt;Since the above expressions involve summing over a large amount of possible spin configurations, they are not very useful in practice. So we will try to approximate the tricky marginal distribution $P( \mathbf{s}_{t} )$ defined in Eq. \eqref{eq:marginal} using a mean-field theory approach.&lt;/p&gt;
&lt;h2 id="mean-field-theory-and-kullback-leibler-divergence"&gt;Mean-field theory and Kullback-Leibler divergence&lt;/h2&gt;
&lt;p&gt;Mean-field theory tries to approximate a complicated object ${\color{red}P}$ by wiggling around the parameters of a simple, analytically tractable parameterized ansatz ${\color{green}Q_{\theta}}$ to get as close as possible to ${\color{red}P}$. At risk of inducing headaches in mathematicians by calling everything a manifold, we can picture what is going on geometrically as trying to approximate a target probability distribution $P( \mathbf{s}_{t} \vert \mathbf{x}, \mathbf{J})$ and its statistical properties $\mathbf{m}_{t}$, $\mathbf{C}_{t}$, and $\mathbf{D}_{t}$ by restricting ourselves to a submanifold of tractable probability distributions. A particularly convenient submanifold is that of factorized models, where each point on the submanifold corresponds to a distribution parameterized by a vector $\boldsymbol{\theta}_{t}$,&lt;/p&gt;
\begin{equation}
Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} \theta_{i,t}}}{2 \cosh \theta_{i,t}}, \label{eq:q}
\end{equation}&lt;p&gt;so that the mean magnetizations are simply given by&lt;/p&gt;
\begin{equation}
m_{i,t} = \tanh \theta_{i,t} \label{eq:meanmagstanh}
\end{equation}&lt;p&gt;as there are no couplings between spins. The factorized model $Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}^{*}_{t} )$ that minimizes the Kullback-Leibler (KL) divergence&lt;/p&gt;
\begin{equation}
D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}}) = \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \log \frac{P( \mathbf{s}_{t})}{Q_{\theta}( \mathbf{s}_{t})} \label{eq:kl}
\end{equation}&lt;p&gt;has mean magnetizations $\mathbf{m}_{t}$ identical to those of the target distribution $P( \mathbf{s}_{t})$ since, for all spins $i=1,2,\ldots,N$, we find that&lt;/p&gt;
\begin{align}
\frac{\partial D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}}) }{\partial \theta_{i, t}} \Biggr\rvert_{\boldsymbol{\theta}_{t}=\boldsymbol{\theta}^{*}_{t}} &amp;= - \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \frac{\partial \log Q_{\theta}( \mathbf{s}_{t}) }{\partial \theta_{i, t}} \Biggr\rvert_{\boldsymbol{\theta}_{t}=\boldsymbol{\theta}^{*}_{t}} \\\\
&amp;= - \sum_{\mathbf{s}_{t}} s_{i,t} P( \mathbf{s}_{t}) + \tanh \theta^{*}_{i,t} \\\\
&amp;= -m^{{\color{red}P}}_{i,t} + m^{{\color{green}Q_{\theta^{*}}}}_{i,t} = 0, \label{eq:klm}
\end{align}&lt;p&gt;where $m^{{\color{red}P}}_{i,t}$ and $m^{{\color{green}Q_{\theta^{*}}}}_{i,t}$ respectively denote the expectation values of $s_{i,t}$ with respect to ${\color{red}P}$ and ${\color{green}Q_{\theta^{*}}}$. Indeed, minimizing $D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}})$ tries to cover the modes of ${\color{red}P}$ by moment matching since the expectation value in Eq. \eqref{eq:kl} is calculated with respect to ${\color{red}P}$.&lt;/p&gt;
&lt;h2 id="the-plefka-expansion-interpolating-distributions"&gt;The Plefka expansion: interpolating distributions&lt;/h2&gt;
&lt;p&gt;Great, but is it even possible to find the parameters&lt;/p&gt;
\begin{equation}
\boldsymbol{\theta}^{*}_{t} = \operatorname*{arg\,min}_{\boldsymbol{\theta}_{t}} \left( - \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \log Q_{\theta}( \mathbf{s}_{t}) \right)
\end{equation}&lt;p&gt;that minimize the KL divergence? Well, that&amp;rsquo;s going to be hard, unless you already know the target distribution $P( \mathbf{s}_{t})$, or you have a clever way of approximately evaluating the expectation value of $\log {\color{green}Q_{\theta}}$ with respect to ${\color{red}P}$. So let us introduce some more distributions to get around this issue. To apply the Plefka expansion to our problem, we introduce the conditional distribution&lt;/p&gt;
\begin{equation}
P_{\alpha}( \mathbf{s}_{t}\vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}(\alpha) }}{2 \cosh h_{i,t}(\alpha)}, \label{eq:pcondalt}
\end{equation}\begin{equation}
h_{i,t}(\alpha) = (1-\alpha) \theta_{i,t} + \alpha \left( x_{i,t} + \sum_{j=1}^{N} J_{ij} s_{j,t-1} \right), \label{eq:pcondalth}
\end{equation}&lt;p&gt;parameterized by a scalar $\alpha$ interpolating between $P_{\alpha=0}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} )$ (Eq. \eqref{eq:q}) and $P_{\alpha=1}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} )$ (Eq. \eqref{eq:pcond}). Using Eq. \eqref{eq:pcondalt}, we can construct an approximate marginal distribution $P_{\alpha}( \mathbf{s}_{t})$, leading to $\alpha$-dependent statistical properties $\mathbf{m}_{t}(\alpha)$, $\mathbf{C}_{t}(\alpha)$, and $\mathbf{D}_{t}(\alpha)$ for the approximate system. The Plefka expansion then boils down to writing these properties as Taylor series expansions around the factorized model $\alpha=0$. For the mean magnetizations, the expansion up to $n$-th order looks like&lt;/p&gt;
\begin{equation}
\mathbf{m}_{t}(\alpha) = \mathbf{m}_{t}(\alpha=0) + \sum_{k=1}^{n} \frac{\alpha^k}{k!} \frac{\partial^{k} \mathbf{m}_{t}(\alpha=0)}{\partial \alpha^{k}} + \mathcal{O}(\alpha^{n+1}), \label{eq:mtaylor}
\end{equation}&lt;p&gt;where all coefficients in the expansion are functions of $\boldsymbol{\theta}_{t}$ via Eq. \eqref{eq:pcondalth}. The mean-field approximation is computed by setting $\alpha=1$ so that the original marginal distribution is recovered and Eq. \eqref{eq:klm} holds, which implies that $\mathbf{m}_{t}(\alpha=1) = \mathbf{m}_{t}(\alpha=0)$ and thus&lt;/p&gt;
\begin{equation}
\sum_{k=1}^{n} \frac{1}{k!} \frac{\partial^{k} \mathbf{m}_{t}(\alpha=0)}{\partial \alpha^{k}} + \mathcal{O}(\alpha^{n+1}) = 0. \label{eq:mftheta}
\end{equation}&lt;p&gt;Finally, we solve Eq. \eqref{eq:mftheta} for $\boldsymbol{\theta}_{t}$ to find the mean-field values $\boldsymbol{\theta}^{*}_{t}$ of the parameters of the distribution Eq. \eqref{eq:q}. Physically, we are tuning the effective external magnetic fields of the factorized ansatz to $\boldsymbol{\theta}^{*}_{t}$ so that its approximate mean magnetizations get as close as possible to the true ones.&lt;/p&gt;
&lt;h2 id="naive-mean-field-and-thouless-anderson-palmer-approximations"&gt;Naive mean-field and Thouless-Anderson-Palmer approximations&lt;/h2&gt;
&lt;p&gt;We now consider first and second order approximations of the mean magnetizations Eq. \eqref{eq:mtaylor} to recover respectively the naive mean-field and Thouless-Anderson-Palmer (TAP) approximations for the binary kinetic Ising model. The starting point is a Plefka expansion around factorized models at times $t-1$ and $t$. From Eq. \eqref{eq:marginal} and Eq. \eqref{eq:pcondalt}, we construct a marginal probability distribution&lt;/p&gt;
\begin{equation}
P^{[t-1:t]}_{\alpha}( \mathbf{s}_{t} ) = \sum_{\mathbf{s}_{t-1},\mathbf{s}_{t-2}} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ),
\end{equation}&lt;p&gt;interpolating between $P^{[t-1:t]}_{\alpha=0}( \mathbf{s}_{t} ) = Q( \mathbf{s}_{t} )$ and $P^{[t-1:t]}_{\alpha=1}( \mathbf{s}_{t} ) = P( \mathbf{s}_{t} )$. The corresponding mean magnetizations are&lt;/p&gt;
\begin{align}
m_{i,t}(\alpha) &amp;= \sum_{\mathbf{s}_{t},\mathbf{s}_{t-1},\mathbf{s}_{t-2}} s_{i,t} \, P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \\\\
&amp;= \sum_{\mathbf{s}_{t-1},\mathbf{s}_{t-2}} \tanh h_{i,t}(\alpha) \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} )
\end{align}&lt;p&gt;Following Eq. \eqref{eq:mftheta}, the first-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\frac{\partial m_{i,t}(\alpha=0)}{\partial\alpha} = \left( 1-m^{2}_{i,t} \right) \left( -\theta_{i,t} + x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} \right) = 0,
\end{equation}&lt;p&gt;so that $\theta^{*}_{i,t} = x_{i,t} + \sum_{j} J_{ij} m_{j,t-1}$ and we end up with the naive mean-field equations:&lt;/p&gt;
\begin{equation}
\boxed{m_{i,t} = \tanh \left( x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} \right)} \label{eq:naivem}
\end{equation}&lt;p&gt;Again following Eq. \eqref{eq:mftheta}, the second-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\frac{\partial m_{i,t}(\alpha=0)}{\partial\alpha} + \frac{1}{2} \frac{\partial^{2} m_{i,t}(\alpha=0)}{\partial\alpha^2} = 0,
\end{equation}&lt;p&gt;where the second-order derivative, neglecting terms higher than $\mathcal{O}(\alpha^2)$, is&lt;/p&gt;
\begin{equation}
\frac{\partial^{2} m_{i,t}(\alpha=0)}{\partial\alpha^2} \approx -2 m_{i,t} \left( 1-m^{2}_{i,t} \right) \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right)
\end{equation}&lt;p&gt;so that&lt;/p&gt;
\begin{equation}
\theta^{*}_{i,t} = x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} - m_{i,t} \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right)
\end{equation}&lt;p&gt;and we end up with the TAP mean-field equations:&lt;/p&gt;
\begin{equation}
\boxed{m_{i,t} = \tanh \left( x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} - m_{i,t} \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right) \right)} \label{eq:tapm}
\end{equation}&lt;p&gt;which includes the so-called Onsager correction term. The mean-field equations obtained above can also be elegantly derived using a Legendre transformation of the generating functional of the set of trajectories of the model, as outlined in e.g.
. We can also derive second-order TAP approximations of the correlations&lt;/p&gt;
\begin{equation}
C_{ik,t} = \begin{cases}
1 - m^{2}_{i,t} &amp; i = k \\\\
\left( 1-m^{2}_{i,t} \right) \left( 1-m^{2}_{k,t} \right) \sum_{j} J_{ij} J_{kj} \left( 1-m^{2}_{j,t-1} \right) &amp; i \neq k \label{eq:tapc}
\end{cases}
\end{equation}&lt;p&gt;and delayed correlations&lt;/p&gt;
\begin{equation}
D_{il,t} = J_{il} \left( 1-m^{2}_{i,t} \right) \left( 1-m^{2}_{l,t-1} \right) \left( 1 + 2 J_{il} m_{i,t} m_{l,t-1} \right). \label{eq:tapd}
\end{equation}&lt;p&gt;We refer to
for full derivations of the above mean-field results as well as variations based on different approximations of the marginal distribution $P( \mathbf{s}_{t} )$.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;In summary, given the mean magnetizations $\mathbf{m}_{t-1}$ of the system at time $t-1$, we can use equations \eqref{eq:tapm} \eqref{eq:tapc} \eqref{eq:tapd} to compute a tuple $(\mathbf{m}_{t},\mathbf{C}_{t},\mathbf{D}_{t})$ of approximate statistical properties of the system at time $t$. The time evolution of the system can be captured at the mean-field level by recursively computing $\mathbf{m}_{t}$ starting from an initial state $\mathbf{m}_{0}$ (with approximation errors likely accumulating over the course of the time evolution).&lt;/p&gt;
&lt;h2 id="a-simple-jax-implementation"&gt;A simple JAX implementation&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;To get more insight into what is going on, let us turn the mean-field update equations \eqref{eq:naivem} and \eqref{eq:tapm} for the mean magnetizations into code. But before we show a few plots, we need to know a bit more background about the model we are about to simulate. In
, the authors derive a solution of the asymmetric version of the kinetic
using a generating functional or dynamical partition function approach to capture the distribution of trajectories. They consider the same kinetic Ising model as in Eq. \eqref{eq:pcond} but with an inverse temperature parameter $\beta$ in the exponentials:&lt;/p&gt;
\begin{equation}
P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta s_{i,t} h_{i,t}}}{2 \cosh \beta h_{i,t}}. \label{eq:pcondwithbeta}
\end{equation}&lt;p&gt;For Gaussian couplings $J_{ij} \sim \mathcal{N}\left( J_{\mu} / N, J^{2}_{\sigma} / N\right)$ and uniformly distributed external magnetic fields $x_{i} \sim \mathcal{U}(-X_{0}, X_{0})$, they show the existence of a ferromagnetic phase transition. In particular for $X_{0}=0.5$, $J_{\mu}=1.0$, and $J_{\sigma}=0.1$, a phase transition happens when tuning $\beta$ to a critical value $\beta_{c} \approx 1.1108$.&lt;/p&gt;
&lt;h3 id="simulating-magnetization-trajectories"&gt;Simulating magnetization trajectories&lt;/h3&gt;
&lt;p&gt;We first present a JAX implementation of the mean-field time evolution of the magnetizations according to the model described above. We use &lt;code&gt;jax.lax.scan&lt;/code&gt; to implement the time evolution and &lt;code&gt;jax.vmap&lt;/code&gt; to parallelize trajectories starting from a batch of initial magnetization configurations $\mathbf{m}_{0}$. For the second-order TAP equations, &lt;code&gt;jaxopt&lt;/code&gt;&amp;rsquo;s Anderson acceleration is used to find the fixed point magnetizations $\mathbf{m}_{t}$ given $\mathbf{m}_{t-1}$.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax.numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;jnp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;jaxopt&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_naive_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (22).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_tap_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (26).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fixed_point_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scan&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;init_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,),&lt;/span&gt; &lt;span class="n"&gt;minval&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxval&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;**-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;**-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;simulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;update_tap_mf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;init_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="naive-mean-field-vs-thouless-anderson-palmer-tap"&gt;Naive mean-field vs. Thouless-Anderson-Palmer (TAP)&lt;/h3&gt;
&lt;p&gt;We fix the seed and randomly initialize model parameters $\mathbf{x}$ and $\mathbf{J}$ to simulate $N=512$ spins at the critical temperature $\beta_{c}$ for $t=128$ time steps starting from an all-ones initial state. We first consider the naive mean-field update step.&lt;/p&gt;
&lt;img src="binary_plot_1.png" width="600px"/&gt;
&lt;p&gt;The left axis shows the individual magnetization trajectories for each spin plotted horizontally while the red line associated to the right axis describes the average of the magnetizations across all spins for each time step. We observe convergence to what looks like a &lt;em&gt;non-equilibrium / near-equilibrium steady state&lt;/em&gt; (NESS).&lt;/p&gt;
&lt;img src="binary_plot_2.png" width="550px"/&gt;
&lt;p&gt;Comparing the naive first-order mean-field update equations to the second-order Thouless-Anderson-Palmer (TAP) ones, we observe lower values for the mean magnetization across all spins, which
showed to be closer to ground truth values (not shown) obtained via sampling and averaging spin configurations.&lt;/p&gt;
&lt;h3 id="sampling-trajectories"&gt;Sampling trajectories&lt;/h3&gt;
&lt;p&gt;Let us consider 100 randomly-initialized initial states and simulate their associated trajectories in three different model regimes: below the critical point ($\beta=\beta_c / 2 $), at the critical point ($\beta=\beta_c$), and above the critical point ($\beta=2 \beta_c$).&lt;/p&gt;
&lt;img src="binary_plot_3.png" width="650px"/&gt;
&lt;p&gt;We observe that the trajectories of randomly-initialized initial states converge to identical final states in each regime. These final states map to a simple ferromagnetic Ising phase diagram, where a high-temperature disordered phase $\langle m_{i,t} \rangle \to 0$ (left) is separated from a low-temperature locally-ordered phase $\langle m_{i,t} \rangle \to \pm 1$ (right) by a critical point (center). The behavior around $\beta=\beta_{c}$ is pretty interesting: &lt;em&gt;the non-trivial steady state looks like an attractor implicitly encoded in the dynamics of the model&lt;/em&gt;. If we were to parameterize the couplings, we could train the system to act as an associative memory.&lt;/p&gt;
&lt;h3 id="sampling-model-parameters"&gt;Sampling model parameters&lt;/h3&gt;
&lt;p&gt;We now go back to considering just a single trajectory since we just saw that trajectories seem to converge to the same final steady-state magnetizations for fixed model parameters. To get a feel for the variation of these values across different realizations of model parameters, we plot the absolute value&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; $\| \langle m_{i} \rangle \|$ of the final steady-state magnetizations across 100 samples of model parameters and a range of inverse temperatures. We are using JAX, so we can easily sample model parameters by &lt;code&gt;vmap&lt;/code&gt;&amp;lsquo;ing the random key fed into the &lt;code&gt;simulate&lt;/code&gt; function followed by another &lt;code&gt;vmap&lt;/code&gt; to sweep across $\beta$.&lt;/p&gt;
&lt;img src="binary_plot_4.png" width="500px"/&gt;
&lt;p&gt;Every curve in the above plot describes the final steady-state value of the &amp;ldquo;order parameter&amp;rdquo; $\| \langle m_{i} \rangle \|$ for a fixed set of model parameters sweeping across $\beta$. We observe a greater spread of values near the critical point and hence an improved capacity to map input external fields to a range of output magnetizations. If we were to let the number of spins $N \to \infty$ and average over a large number of model parameter samples, the finite-size results above would probably transform into a sharp curve with zero magnetization below the critical point and a sudden non-zero magnetization emerging at the critical point.&lt;/p&gt;
&lt;h1 id="mean-field-theory-of-asymmetric-ising-models-with-vector-spins"&gt;Mean-field theory of asymmetric Ising models with vector spins&lt;/h1&gt;
&lt;p&gt;We now transpose the binary-spin results of the previous section to a setting where local spin degrees of freedom are $D$-dimensional vector spins restricted to wiggle around on $(D-1)$-dimensional spheres. We start by generalizing the conditional distribution Eq. \eqref{eq:pcondalt} to vector spins. Next, we motivate the limit of large vector dimension and derive first-order and second-order mean-field update equations for the mean magnetizations. We finish this section with a JAX implementation and some toy numerical experiments.&lt;/p&gt;
&lt;h2 id="vector-spins-distributions-on-hyperspheres"&gt;Vector spins: distributions on hyperspheres&lt;/h2&gt;
&lt;img src="vector_spins.png" alt="Random Ising model configuration with vector spins" width="250px"/&gt;
&lt;p&gt;A vector-spin equivalent of Eq. \eqref{eq:pcondalt} looks something like&lt;/p&gt;
\begin{equation}
P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \mathbf{h}_{i,t}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s}_{i,t} \; \mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \mathbf{h}_{i,t}(\alpha)} }, \label{eq:pcondaltvector}
\end{equation}&lt;p&gt;where we immediately included an inverse temperature $\beta$ like in Eq. \eqref{eq:pcondwithbeta}. A vector-spin equivalent of Eq. \eqref{eq:pcondalth} is&lt;/p&gt;
\begin{equation}
\mathbf{h}_{i,t}(\alpha) = (1-\alpha) \boldsymbol{\theta}_{i,t} + \alpha \left( \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{s}_{j,t-1} \right) \equiv \boldsymbol{\theta}_{i,t} + \alpha \Delta \mathbf{h}_{i,t}, \label{eq:pcondalthvector}
\end{equation}&lt;p&gt;where $S_{D-1}(R) = \{ x \in \mathbb{R}^{D} : \lVert x \rVert = R \}$ denotes the $(D-1)$-dimensional sphere with radius $R$ embedded in $D$ dimensions. Let us focus on the distribution for a single site and drop all subscripts and dependencies for clarity:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }. \label{eq:pcondsinglesitevector}
\end{equation}&lt;p&gt;The normalization constant in the denominator can be shown to be (see
)&lt;/p&gt;
\begin{equation}
\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } \equiv Z(\beta, R, \lVert \mathbf{h}\rVert) \label{eq:partfun}
\end{equation}&lt;p&gt;where $I_{\nu}(z)$ denotes the modified Bessel function of the first kind and $\lVert \mathbf{h} \rVert = \sqrt{\mathbf{h} \cdot \mathbf{h}}$. Physically, we can think of this single-site distribution as measuring dot-product alignment to an effective external magnetic field $\mathbf{h}$ at inverse temperature $\beta$.&lt;/p&gt;
&lt;p&gt;If we consider spins living on the unit sphere $R=1$ as well as unit vectors $\mathbf{h}$, the distribution boils down to a
with mean direction $\boldsymbol{\mu} \equiv \mathbf{h}$ and
$\kappa \equiv \beta$. This distribution is unimodal for $\kappa &gt; 0$ and can be derived from restricting an isotropic multivariate Gaussian to the unit hypersphere. The greater the value of $\kappa$ (the inverse temperature $\beta$), the higher the concentration of the distribution around the mean direction $\boldsymbol{\mu}$ (the more the spin tends to align to the effective external field $\mathbf{h}$). Though instead of a fixed parameter $\boldsymbol{\mu}$, we have a very funky parameter Eq. \eqref{eq:pcondalthvector} that depends on all other spins to spice things up.&lt;/p&gt;
&lt;h2 id="magnetizations-and-limit-of-large-vector-dimension"&gt;Magnetizations and limit of large vector dimension&lt;/h2&gt;
&lt;p&gt;Before we derive mean-field approximations for the mean magnetizations of our vector-spin system, let us first consider the decoupled $\alpha \to 0$ limit of the distribution Eq. \eqref{eq:pcondaltvector},&lt;/p&gt;
\begin{equation}
Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \boldsymbol{\theta}_{i,t}}}{Z_{i,t}\left(\beta, R, \lVert \boldsymbol{\theta}_{i,t} \rVert\right)},
\end{equation}&lt;p&gt;and find an expression for its mean magnetizations. For every decoupled site, the mean magnetization can be shown to be (see
)&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \frac{I_{D/2}(\beta R \lVert \boldsymbol{\theta}_{i,t} \rVert)}{I_{D/2 - 1}(\beta R \lVert \boldsymbol{\theta}_{i,t} \rVert)} \frac{R \boldsymbol{\theta}_{i,t}}{\lVert \boldsymbol{\theta}_{i,t} \rVert} \equiv \boldsymbol{\varphi} \left(\boldsymbol{\theta}_{i,t}\right), \label{eq:meanmagsbessels}
\end{equation}&lt;p&gt;which plays the role of $m_{i,t} = \tanh \theta_{i,t}$ in the binary setting, see Eq. \eqref{eq:meanmagstanh}. Looking ahead at turning the above equation into code, we note that there exist
to compute the ratio of modified Bessel functions of the first kind. We implement a fast JAX version in
and show numerically how the ratio flattens out quickly for large values of the order $\nu = D/2 -1$, motivating some kind of large-order expansion.&lt;/p&gt;
&lt;p&gt;Remember that our goal is to make a connection to transformer neural networks. Since the vector dimension in dense transformer modules tends be somewhere between $\mathcal{O}(10^2)$ and $\mathcal{O}(10^5)$, it is not nonsensical to focus on the large vector dimension limit. A relevant uniform asymptotic expansion of the ratio of modified Bessel functions of the first kind is
:&lt;/p&gt;
\begin{align}
\frac{I_{\nu+\alpha}(\nu x)}{I_{\nu}(\nu x)} = \left( \frac{x}{1+\sqrt{1+x^2}} \right)^{\alpha} \left( 1 - \frac{1+\alpha\sqrt{1+x^2}}{2(1+x^2)} \frac{\alpha}{\nu} + \mathcal{O}\left( \frac{1}{\nu^2} \right) \right)
\end{align}&lt;p&gt;Indeed, if we choose to tie the radius $R$ of our little spins to their vector dimension $D$ via&lt;/p&gt;
\begin{align}
\nu=D/2-1=R^2,
\end{align}&lt;p&gt;we can apply the leading order of the asymptotic expansion for $\alpha=1$ to \eqref{eq:meanmagsbessels} to find&lt;/p&gt;
\begin{equation}
\mathbf{m}^{D \to \infty}_{i,t} \approx \frac{\beta}{1+\gamma( \lVert \boldsymbol{\theta}_{i,t} \rVert )} \boldsymbol{\theta}_{i,t} \equiv \boldsymbol{\varphi}^{D \to \infty} \left(\boldsymbol{\theta}_{i,t}\right). \label{eq:largedevmag}
\end{equation}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\gamma \left(\lVert \boldsymbol{\theta}_{i,t} \rVert\right) = \sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 },
\end{align}&lt;p&gt;From here on, we will default to using the large-$D$ approximation because keeping track of (derivatives of) Bessel functions gets boring real quick. We refer to
for some truly outrageous expressions pertaining to the general case valid for all $D&gt;1$.&lt;/p&gt;
&lt;h2 id="first-order-naive-mean-field-approximation"&gt;First-order naive mean-field approximation&lt;/h2&gt;
&lt;p&gt;All right, let&amp;rsquo;s go. Closely mimicking the binary case, we start from the following approximated marginal probability distribution&lt;/p&gt;
\begin{equation}
P^{[t-1:t]}_{\alpha}( \mathbf{s}_{t} ) = \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ),
\end{equation}&lt;p&gt;interpolating between $P^{[t-1:t]}_{\alpha=0}( \mathbf{s}_{t} ) = Q( \mathbf{s}_{t} )$ and $P^{[t-1:t]}_{\alpha=1}( \mathbf{s}_{t} ) = P( \mathbf{s}_{t} )$. Our lazy integral notation $\int \mathrm{d} \mathbf{s}_{t}$ should be understood as $\int \prod_{i=1}^{N} \mathrm{d}^{D} \mathbf{s}_{i, t}$, i.e. integrating over all the little spins at a fixed time $t$. The estimated mean magnetizations are&lt;/p&gt;
\begin{align}
\mathbf{m}_{i,t}(\alpha) &amp;= \int \mathrm{d} \mathbf{s}_{t} \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; \mathbf{s}_{i,t} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \nonumber\\\\
&amp;= \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ).
\end{align}&lt;p&gt;The first-order derivative with respect to $\alpha$ is then given by&lt;/p&gt;
\begin{align}
\frac{\partial \mathbf{m}_{i,t}(\alpha)}{\partial\alpha} = \int &amp;\mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \Biggl( \frac{\partial\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha} \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) \nonumber\\\\
&amp;+ \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, \frac{\partial P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha} \Biggr) P( \mathbf{s}_{t-2} ), \label{eq:mitfirstorderalpha}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\frac{\partial \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} = \frac{\beta}{1+\gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right)} \Delta \mathbf{h}_{i,t} - \frac{\beta}{R^2} \frac{ \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) }{ \gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right) } \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \label{eq:firstorderphialpha}
\end{align}&lt;p&gt;Evaluating \eqref{eq:mitfirstorderalpha} at $\alpha=0$, the second term drops out because the first-order derivative of $P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )$ becomes independent of $\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)$ and $\int \mathrm{d} \mathbf{s}_{t-1} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )=1$. We thus end up with&lt;/p&gt;
\begin{align}
\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} = \frac{\beta}{1+\gamma \left(\lVert \boldsymbol{\theta}_{i,t} \rVert\right)}\boldsymbol{v}_{i,t} - \frac{\beta}{R^2}\frac{\left( \mathbf{m}_{i,t} \cdot \boldsymbol{v}_{i,t} \right)}{\gamma \left(\lVert \boldsymbol{\theta}_{i,t} \rVert\right)} \mathbf{m}_{i,t} \label{eq:mfirstorderalphazero}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\boldsymbol{v}_{i,t} = -\boldsymbol{\theta}_{i,t} + \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{m}_{j,t-1} \label{eq:vmf}
\end{align}&lt;p&gt;captures the result of integrating $\Delta \mathbf{h}_{i,t}$ over the spins $\mathbf{s}_{t-1}$. Following Eq. \eqref{eq:mftheta}, the first-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\left[ \alpha \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} \right]_{\alpha=1} = \mathbf{0} + \left[ \mathcal{O}\left(\alpha^2\right)\right]_{\alpha=1},\label{eq:firstorderapproxreqs}
\end{equation}&lt;p&gt;so that we are encouraged to set $\boldsymbol{v}_{i,t}=0$ and hence $\boldsymbol{\theta}^{*}_{i,t} = \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1}$, leading to the naive mean-field equations:&lt;/p&gt;
\begin{equation}
\boxed{ \mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \rVert^2 / R^2 }} } \label{eq:naivemvector}
\end{equation}&lt;p&gt;Looking ahead at the transformer-module correspondence in
, we squint our eyes and recognize a scaled sum of a residual connection and an attention term. No feed-forward terms though.&lt;/p&gt;
&lt;p&gt;Before moving on to the second-order approximation, let us end this section with an interesting observation about Eq. \eqref{eq:mfirstorderalphazero}. In
, we show that the variance matrix of a single spin in the large-$D$ limit equals a rank-1 perturbation of a diagonal matrix&lt;/p&gt;
\begin{align}
\mathrm{Var} [ \mathbf{s}_{i,t} ] &amp;= \frac{\mathbb{1}}{1+\gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right)} - \frac{ \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \otimes \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) }{ R^2 \gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right) }, \label{eq:spinvariance}
\end{align}&lt;p&gt;Taking the $\alpha \to 0$ limit of the above expressions, we can reinterpret Eq. \eqref{eq:mfirstorderalphazero} as the matrix-vector multiplication of the decoupled spin&amp;rsquo;s variance matrix with $\boldsymbol{v}_{i,t}$,&lt;/p&gt;
\begin{align}
\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} = \beta \mathrm{Var} [ \mathbf{s}_{i,t} ] \boldsymbol{v}_{i,t}.
\end{align}&lt;h2 id="second-order-thouless-anderson-palmer-approximation"&gt;Second-order Thouless-Anderson-Palmer approximation&lt;/h2&gt;
&lt;p&gt;Let us now try to find out whether going to the second-order approximation spits out additional Onsager feed-forward like correction terms in the update equations for the magnetizations.&lt;/p&gt;
&lt;p&gt;Again following Eq. \eqref{eq:mftheta}, the second-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\left[ \alpha \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} \right]_{\alpha=1} + \left[ \frac{\alpha^2}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2}\right]_{\alpha=1} = \mathbf{0} + \left[ \mathcal{O}\left(\alpha^3\right)\right]_{\alpha=1}, \label{eq:secondorderconstraint}
\end{equation}&lt;p&gt;where the second-order derivative is given by&lt;/p&gt;
\begin{align}
\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha)}{\partial\alpha^2} = \int &amp;\mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \Biggl( \frac{\partial^{2}\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha^2} \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) \nonumber\\\\
&amp;+ 2\frac{\partial\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha} \, \frac{\partial P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha} \nonumber \\\\
&amp;+ \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, \frac{\partial^{2} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha^2} \Biggr) P( \mathbf{s}_{t-2} ). \label{eq:mhasecordder}
\end{align}&lt;p&gt;Evaluated at $\alpha=0$, the third term in the expression above will drop out because the derivative becomes independent of $\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)$ and $\int \mathrm{d} \mathbf{s}_{t-1} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )=1$.&lt;/p&gt;
&lt;p&gt;The first term in Eq. \eqref{eq:mhasecordder} can be shown to look something like&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha^2} = &amp; \frac{\beta^2}{R^4} \frac{ 1+\gamma_{i,t}(\alpha) }{ \gamma_{i,t}(\alpha)^3 } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right)^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta}{R^2} \frac{1}{\gamma_{i,t}(\alpha)} \left( \frac{\partial\boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} \cdot \Delta \mathbf{h}_{i,t} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta}{R^2} \frac{1}{\gamma_{i,t}(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \frac{\partial\boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} \nonumber \\\\
&amp;- \frac{\beta^2}{R^2} \frac{1}{\gamma_{i,t}(\alpha)^2 + \gamma_{i,t}(\alpha) } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \Delta \mathbf{h}_{i,t},
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\gamma_{i,t} (\alpha) \equiv \gamma\left( \lVert \mathbf{h}_{i,t}(\alpha) \rVert \right) = \sqrt{1+\beta^2 \lVert \mathbf{h}_{i,t}(\alpha) \rVert^2 / R^2 },
\end{align}&lt;p&gt;which, after substituting the first-order derivative Eq. \eqref{eq:firstorderphialpha}, simplifies to&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha^2} = &amp; \frac{\beta^2}{R^4} \frac{ 1+3\gamma_{i,t}(\alpha) }{ \gamma_{i,t}(\alpha)^3 } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right)^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta^2}{R^2} \frac{1}{\gamma_{i,t}(\alpha)^2 + \gamma_{i,t}(\alpha)} \left( \Delta \mathbf{h}_{i,t} \cdot \Delta \mathbf{h}_{i,t} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta^2}{R^2} \frac{2}{\gamma_{i,t}(\alpha)^2 + \gamma_{i,t}(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \Delta \mathbf{h}_{i,t} . \label{eq:secondorderphialpha}
\end{align}&lt;p&gt;The second term in Eq. \eqref{eq:mhasecordder} contains non-vanishing contributions in the $\alpha \to 0$ limit coming from the $\sum_{j=1}^{N} J_{ij} \mathbf{s}_{j, t-1}$ terms in $\Delta \mathbf{h}_{i,t}$. One can show that the surviving terms in the integrand are proportional to&lt;/p&gt;
\begin{align}
\sum_{j} J_{ij} \Biggl( &amp;\frac{2 \beta^2}{1+\gamma_{i,t}(\alpha)} \frac{\partial\mathbf{m}_{j, t-1}(\alpha)}{\partial\alpha} \nonumber \\\\
&amp;- \frac{2 \beta^2}{R^2 \gamma_{i,t}(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \frac{\partial\mathbf{m}_{j, t-1}(\alpha)}{\partial\alpha} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \Biggr),
\end{align}&lt;p&gt;which we can ignore since they are $\mathcal{O}(\alpha)$ on their own, and thus $\mathcal{O}(\alpha^3)$ when multiplied with $\alpha^2$ in the second-order approximation.&lt;/p&gt;
&lt;p&gt;Before taking the $\alpha \to 0$ limit of whatever is left in Eq. \eqref{eq:mhasecordder}, we list a few useful tricks to make the evaluation easier. First of all, we use Eq. \eqref{eq:vmf} to introduce the following sneaky substitution&lt;/p&gt;
\begin{align}
\Delta \mathbf{h}_{i,t} = -\boldsymbol{\theta}_{i,t} + \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{s}_{j,t-1} = \boldsymbol{v}_{i,t} + \sum_{j=1}^{N} J_{ij} \left( \mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1} \right),
\end{align}&lt;p&gt;which conveniently separates terms with fluctuating spin variables from magnetizations that can be pulled out of the integrals. Secondly, all terms that contain only one spin variable with a dependence looking like $\mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1}$ drop out because, schematically,&lt;/p&gt;
\begin{align}
\mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1} \overset{\int \mathrm{d} \mathbf{s}_{t-1}}{\to} \boldsymbol{\varphi}(\mathbf{h}_{j,t}(\alpha)) - \mathbf{m}_{j,t-1} \overset{\alpha \to 0}{\to} \mathbf{0}.
\end{align}&lt;p&gt;Thirdly, since the $\alpha \to 0$ limit decouples all spins $\mathbf{s}_{t-1}$, any term containing dot products $(\mathbf{s}_{j,t-1}-\mathbf{m}_{j,t-1}) \cdot (\mathbf{s}_{k,t-1}-\mathbf{m}_{k,t-1})$ of two spin variables is zero for $j \neq k$ and equal to $R^2 - \mathbf{m}^2_{j,t-1}$ for $j=k$. We will also encounter terms containing (tensor contractions with) outer products $(\mathbf{s}_{j,t-1}-\mathbf{m}_{j,t-1}) \otimes (\mathbf{s}_{k,t-1}-\mathbf{m}_{k,t-1})$, which we can think of as projection operators. For $j \neq k$, these and similar terms again evaluate to zero, while, for $j=k$, we get the variance contributions we mentioned previously in Eq. \eqref{eq:spinvariance} at the end of the previous section.&lt;/p&gt;
&lt;p&gt;Finally, we take the $\alpha \to 0$ limit of Eq. \eqref{eq:mhasecordder} only to end up with the following mess:&lt;/p&gt;
\begin{align}
&amp;\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} = \label{eq:secondordercorrections} \\\\
\end{align}&lt;p&gt;
&lt;/p&gt;
\begin{align}
&amp;\hspace{-1em}\frac{\beta^2}{R^4} \frac{1+3\gamma_{i,t}(0)}{\gamma_{i,t}(0)^3} \left( \left( \mathbf{m}_{i,t} \cdot \mathbf{v}_{i,t} \right)^2 + \sum_{j} J_{ij}^2 \left( \frac{\mathbf{m}_{i,t}^2}{1+\gamma_{i,t-1}(0)} - \frac{\left(\mathbf{m}_{i,t}\cdot\mathbf{m}_{j,t-1}\right)^2}{R^2 \gamma_{i,t-1}(0)} \right) \right) \mathbf{m}_{i,t} \nonumber \\\\
&amp;\hspace{-1em}- \frac{\beta^2}{R^2} \frac{1}{\gamma_{i,t}^2 (0) + \gamma_{i,t}(0)} \left( \mathbf{v}_{i,t}^2 + \sum_{j} J_{ij}^2 \left( R^2 - \mathbf{m}_{j,t-1}^2 \right) \right) \mathbf{m}_{i,t} \nonumber \\\\
&amp;\hspace{-1em}- \frac{\beta^2}{R^2} \frac{2}{\gamma_{i,t}^2 (0) + \gamma_{i,t}(0)} \Biggr( \mathbf{v}_{i,t} \otimes \mathbf{v}_{i,t} + \sum_{j} J_{ij}^2 \left( \frac{\mathbb{1}}{1+\gamma_{i,t-1}(0)} - \frac{\mathbf{m}_{j,t-1}\otimes\mathbf{m}_{j,t-1}}{R^2 \gamma_{i,t-1}(0)} \right) \Biggr) \mathbf{m}_{i,t} \nonumber
\end{align}&lt;p&gt;At this point, it is too late. We should have remembered that the second-order approximation lives in the neighborhood of the first-order approximation. We probably ended up doing too much work by taking terms into account that are of higher order in $\alpha$. We can always drop terms later on if it turns out they are neglible at $\mathcal{O}(\alpha^2)$.&lt;/p&gt;
&lt;p&gt;To get to the second-order mean-field equations for the magnetizations, we have to solve Eq. \eqref{eq:secondorderconstraint} for the optimal parameters $\boldsymbol{\theta}^{*}_{i,t}$, i.e.,&lt;/p&gt;
\begin{equation}
\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} + \frac{1}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} = \mathbf{0} + \mathcal{O}\left(\alpha^3\right).
\end{equation}&lt;p&gt;Let us substitute $\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha}$ from Eq. \eqref{eq:mfirstorderalphazero} but keep $\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2}$ for generality,&lt;/p&gt;
\begin{align}
\beta \left( \frac{\mathbb{1}}{1+\gamma_{i,t}(0)} - \frac{\mathbf{m}_{i,t}\otimes\mathbf{m}_{i,t}}{R^2 \gamma_{i,t}(0)} \right) \mathbf{v}_{i,t} + \frac{1}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} = \mathbf{0} + \mathcal{O}\left(\alpha^3\right),
\end{align}&lt;p&gt;so that we can then isolate $\boldsymbol{\theta}_{i,t}$ in $\mathbf{v}_{i,t}$ to find&lt;/p&gt;
\begin{align}
\boldsymbol{\theta}_{i,t} = \mathbf{x}_{i,t} &amp;+ \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \nonumber \\\\
&amp;+ \frac{1+\gamma_{i,t}(0)}{2\beta} \left( \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} + \frac{\mathbf{m}_{i,t} \cdot \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2}}{\frac{R^2 \gamma_{i,t}(0)}{1+\gamma_{i,t}(0)} - \mathbf{m}_{i,t}^2} \mathbf{m}_{i,t} \right),\label{eq:ftheta}
\end{align}&lt;p&gt;where we have used the
to compute the inverse of the variance matrix. Since the expression on the right-hand side &lt;em&gt;also&lt;/em&gt; depends on $\boldsymbol{\theta}_{i,t}$, we seem to have stumbled upon a set of fixed-point equations which we should solve for $\boldsymbol{\theta}^{*}_{i,t}$,&lt;/p&gt;
\begin{align}
\boldsymbol{\theta}_{i,t} = \mathbf{f} (\boldsymbol{\theta}_{i,t}, \mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1}), \label{eq:thetafp}
\end{align}&lt;p&gt;where the function $\mathbf{f}$ is given by the right-hand side of Eq. \eqref{eq:ftheta}. The second-order mean-field equations then become &lt;em&gt;yet another&lt;/em&gt; set of fixed-point equations&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left(\boldsymbol{\theta}^{*}_{i,t}(\mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1})\right)
\end{equation}&lt;p&gt;because of the dependence of $\boldsymbol{\theta}^{*}_{i,t}$ on $\mathbf{m}_{i,t}$. Similar to the binary TAP approximation Eq. \eqref{eq:tapm}, this dependency suggests that we should solve for fixed-point magnetizations $\mathbf{m}^{*}_{i,t}$. However, in contrast to the binary case, the dependence here is &lt;em&gt;implicit&lt;/em&gt; since $\boldsymbol{\theta}^{*}_{i,t}$ is itself obtained from solving fixed-point equations Eq. \eqref{eq:thetafp}, which, in turn, also depend on $\mathbf{m}_{i,t}$.&lt;/p&gt;
&lt;p&gt;The problem setup looks like a
, where the solutions to the inner-level fixed-point equations are fed as parameters to the outer-level fixed-point equations. Because of the hierarchical relationship and the implicit dependence of the outer solution on the inner problem&amp;rsquo;s parameters, bi-level optimization can be potentially computationally demanding and unstable. Let us try to sidestep this dreadfulness by writing all instances of $\boldsymbol{\theta}_{i,t}$ in Eq. \eqref{eq:ftheta} in terms of $\mathbf{m}_{i,t}$ by inverting Eq \eqref{eq:largedevmag} so that, for $\mathbf{m}^2_{i,t} &lt; R^2$,&lt;/p&gt;
\begin{equation}
\boldsymbol{\theta}_{i,t} = \frac{2 R^2}{\beta \left( R^2 - \mathbf{m}^2_{i,t} \right)} \mathbf{m}_{i,t},\label{eq:invphi}
\end{equation}&lt;p&gt;leading to a set of fixed-point equations in terms of only $\mathbf{m}_{i,t}$,&lt;/p&gt;
\begin{equation}
\boxed{\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left( \mathbf{f} (\mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1})\right) } \label{eq:tapmvector}
\end{equation}&lt;p&gt;Looking ahead at the transformer-module correspondence in
, we recognize a scaled sum of a residual connection, an attention term, and a self-consistent expression in terms of magnetizations and couplings taking on the role of the feed-forward network. Interestingly, these second-order correction terms require &lt;em&gt;no additional free parameters&lt;/em&gt; since they are &lt;em&gt;fully determined by the mean-field structure&lt;/em&gt; of the underlying spin model.&lt;/p&gt;
&lt;h2 id="a-simple-jax-implementation-1"&gt;A simple JAX implementation&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We now turn to a JAX implementation of the mean-field time evolution of the magnetizations according to the vector-spin model introduced in the previous sections. Compared to the binary-spin simulations of
, we will not attempt to precisely tune the vector-spin model since computing its critical temperature and quirky phase-diagram properties is well beyond the scope of this work. We will instead take an empirical approach and play around with a numerical implementation to figure out what works. Along the way, we provide some physical intuition.&lt;/p&gt;
&lt;h3 id="simulating-magnetization-trajectories-1"&gt;Simulating magnetization trajectories&lt;/h3&gt;
&lt;p&gt;The JAX reference implementation looks very similar to the binary-spin case. Essentially, we have to keep track of an additional vector dimension and replace the update equations with the vector equivalents introduced in the previous sections. We deliberately do not fiddle with the hyperparameters of the fixed-point solver &lt;code&gt;AndersonAcceleration&lt;/code&gt; to ensure robustness of exploratory results.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax.numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;jnp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;jaxopt&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (39).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (38).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_naive_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (47).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (64).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_d2_m_d_alpha_2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (58).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;g0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;g1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i d, i d -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, i d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, i d, j d, i e, j e -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i d, i d, i f -&amp;gt; i f&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, i d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, i d, j d, j f -&amp;gt; i f&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (61).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;g1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;d2_m_d_alpha_2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_d2_m_d_alpha_2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ff&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;d2_m_d_alpha_2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i d, i d -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d2_m_d_alpha_2&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;ff&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_tap_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (65).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fixed_point_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scan&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;simulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;update_tap_mf&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="playing-with-parameter-scales-an-exploration"&gt;Playing with parameter scales: an exploration&lt;/h3&gt;
&lt;p&gt;To get a feel for the complexity, let us visualize a $N=64$ sample of a coupling matrix $\mathbf{J} \in \mathbb{R}^{N \times N}$ drawn from $\mathcal{N}\left( 0, 1/N \right)$ using a visually appealing yet utterly pointless ball-of-yarn plot:&lt;/p&gt;
&lt;img src="vector_plot_1.png" width="500px"/&gt;
&lt;p&gt;We randomly initialize the external magnetic fields $\mathbf{x} \in \mathbb{R}^{N \times D}$ and coupling matrix $\mathbf{J} \in \mathbb{R}^{N \times N}$ by drawing from respectively $\mathcal{N}\left( 0, 1\right)$ and $\mathcal{N}\left( 0, 1/N \right)$ and simulate $N=1024$ $(D=512)-$dimensional vector spins at inverse temperature $\beta=1.0$ for $t=20$ time steps starting from an intial state $\mathbf{m}_{0} \in \mathbb{R}^{N \times D}$ of all-ones vectors. We choose to normalize all $\mathbf{x}$ vectors to lie on the spherical shell at radius $R$, so that $\mathbf{x}_{i} \to R \mathbf{x}_{i} / \lVert\mathbf{x}_{i}\rVert$. We apply the same external magnetic fields at all time steps ($\mathbf{x}_{t} \equiv \mathbf{x}$, $\forall t \geq 0$) so that the probing of the system is time-independent and relentless.&lt;/p&gt;
&lt;p&gt;We first consider the first-order naive mean-field update equations. To visualize a set of vectors evolving in time, we track their directionalities with respect to reference states using cosine similarities and their magnitudes using Euclidean norms.&lt;/p&gt;
&lt;img src="vector_plot_2.png" width="300px"/&gt;
&lt;p&gt;The top plot shows the cosine-similarity alignments of individual magnetization trajectories $\mathbf{m}_{i,t}$ compared to respectively $\mathbf{m}_{i,t-1}$ (green, magnetizations at previous time step to track convergence), $\mathbf{m}_{0}$ (yellow, magnetizations at initial time step to track drift from initial conditions), and $\mathbf{x}_{i}$ (blue, time-independent external magnetic fields to track alignment with the &amp;ldquo;residual stream&amp;rdquo;). The bottom plot tracks the evolution of the norms of $\mathbf{m}_{i,t}$ during time evolution. From the tracked metrics, we observe convergence to what looks like a &lt;em&gt;non-equilibrium / near-equilibrium steady state&lt;/em&gt; (NESS) with magnetizations remaining dynamically stable at the mean-field level.&lt;/p&gt;
&lt;p&gt;To compare the naive first-order mean-field update equations to the second-order Thouless-Anderson-Palmer (TAP) ones, we plot the mean magnetization trajectories across all sites and add shading to denote the spread of maximum and minimum values.&lt;/p&gt;
&lt;img src="vector_plot_3.png" width="500px"/&gt;
&lt;p&gt;We observe that the final TAP magnetizations are slightly different for our particular choice of parameters. The Onsager correction term seems to account for at least some correlations, lowering the local effective mean field and hence the magnitude of the magnetizations. If we lower the temperature to $\beta = 2.0$ while keeping all other parameters fixed, the difference becomes more pronounced:&lt;/p&gt;
&lt;img src="vector_plot_4.png" width="500px"/&gt;
&lt;p&gt;Lowering the temperature further while keeping all other parameters fixed starts leading to convergence issues for the TAP equations. If we go back to $\beta=1.0$ but (1) increase the random interaction strengths by doubling the elements of the coupling matrix and (2) reduce the influence of the random external magnetic fields by normalizing all $\mathbf{x}$ vectors to lie on the unit sphere, we end up in a regime where we observe that the naive mean-field equations have trouble converging whereas the TAP magnetizations quickly settle into a small-norm fixed-point solution:&lt;/p&gt;
&lt;img src="vector_plot_5.png" width="500px"/&gt;
&lt;h3 id="playing-with-parameter-scales-an-explanation"&gt;Playing with parameter scales: an explanation&lt;/h3&gt;
&lt;p&gt;To better understand the behavior of the system, we focus on the inverse temperature $\beta$, the magnitudes $\lVert\mathbf{x}_{i,t}\rVert$ of the external magnetic fields, the scale of the coupling matrix elements $J_{ij}$, and the vector-spin radius $R=\sqrt{D/2-1}$. The latter is fixed for fixed dimension $D$ and provides a natural length scale. In spin-glas mean-field theory, the random coupling matrix is usually chosen to have a variance of $1/N$ to ensure the existence of a proper thermodynamic limit. The magnitudes of the external magnetic fields determine to what extent the vector spins will try to align with their imposed external environment or yield to the influence of their neighbours. The relation between the scales of the couplings and the fields should be such that meaningful competition between the external magnetic fields and the intrinsic spin-spin interactions can occur. Finally, the system&amp;rsquo;s overall behavior is further governed by the thermal noise introduced via the inverse temperature $\beta$.&lt;/p&gt;
&lt;p&gt;Revisiting the magnetization equation Eq. \eqref{eq:largedevmag},&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left(\boldsymbol{\theta}_{i,t}\right) = \frac{\beta}{1+\sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 }} \boldsymbol{\theta}_{i,t},
\end{equation}&lt;p&gt;we observe that the infinite-temperature limit $\beta \to 0$ pushes the magnitude of the magnitization to $0$ whereas the zero-temperature limit $\beta \to \infty$ snaps to the spherical shell at radius $R$. We plot the norm of this equation for different values of $\beta$ as a function of $\lVert\boldsymbol{\theta}\rVert$ in $D=512$ dimensions below. The dashed horizontal and vertical lines indicate the value of $R=\sqrt{D/2-1}\approx 15.9687$.&lt;/p&gt;
&lt;img src="vector_plot_6.png" width="450px"/&gt;
&lt;p&gt;This plot partly explains why the TAP equations start showing convergence issues at lower temperatures. Large values of $\beta$ push the norm of the magnetizations towards $R$, but that in turn leads to $\boldsymbol{\theta}$ blowing up because of the $R^2-\mathbf{m}^2_{i,t}$ factors in the denominators of Eq. \eqref{eq:ftheta} and Eq. \eqref{eq:invphi}. This is no surprise since the Plefka expansion is in fact a high-temperature expansion. Indeed, if we write out the mean-field update equations, we find that the first-order terms scale as $\beta$ and the second-order terms as $\beta^2$. Additionally, we know from mean-field theory of binary spin glasses that the TAP equations break down when crossing the so-called de Almeida-Thouless line (AT line) in the $(\beta, x)$ phase diagram. Assuming an
, it might be worth rederiving the Onsager term like was done for binary spins in
to make sure its time indices are more geared towards convergence. But even then we would still not be able to cross the AT line and find mean-field solutions at lower temperatures.&lt;/p&gt;
&lt;p&gt;But we have to ask ourselves whether we actually care about this low-temperature failure mode for our purposes. Do we want a spin-transformer module to inhabit a complex spin-glass phase full of local minima containing frozen disordered spins that cannot respond to external magnetic fields? No. We would like our system to be able to fluidly and adaptively respond to its environment.&lt;/p&gt;
&lt;h1 id="spin-transformer-modules-a-family-of-transformer-like-modules"&gt;Spin-transformer modules: a family of transformer-like modules&lt;/h1&gt;
&lt;p&gt;In this final section, we propose a physics-inspired class of transformer modules based on the mean-field update equations for the vector-spin magnetizations derived in the previous section. We highlight conceptual similarities, physical interpretations, and potential benefits of exploiting spin-model structure to reduce parameter count.&lt;/p&gt;
&lt;h2 id="connecting-the-dots"&gt;Connecting the dots&lt;/h2&gt;
&lt;p&gt;Following
and
, we interpret a transformer module as a differentiable vector-spin system that is driven by data and whose collective behavior can be shaped through training. Intuitively, there is little difference here compared to the work mentioned above: we still probe a spin system and observe its response. But, technically and conceptually, the shift to dynamical mean-field expressions enables us to solidify the correspondence by moving past symmetric coupling matrices and equilibrium free energies.&lt;/p&gt;
&lt;p&gt;We define a &lt;em&gt;spin-transformer module&lt;/em&gt; as a wrapper around a vector-spin model where module inputs $\mathbf{x} \in \mathbb{R}^{N \times D}$ get routed to external magnetic fields. Inside the module, we evolve a set of initial magnetizations in time using either the first-order (Eq. \eqref{eq:naivemvector}) or the second-order (Eq. \eqref{eq:tapmvector}) mean-field update equations. Only the second-order update equations exhibit feed-forward-like corrections. We choose to relentlessly apply the same external magnetic fields at all time steps ($\mathbf{x}_{t} \equiv \mathbf{x}$, $\forall t \geq 0$) and construct input-dependent couplings using the row-stochastic attention matrix,&lt;/p&gt;
\begin{equation}
\mathbf{J}(\mathbf{x}) = \mathrm{softmax}\left( \frac{\boldsymbol{x} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{x}^{T}}{\sqrt{D}} \right). \label{eq:softmaxcouplings}
\end{equation}&lt;p&gt;where $\boldsymbol{W}_{\boldsymbol{Q}}$ and $\boldsymbol{W}_{\boldsymbol{K}}$ denote linear query- and key-mappings. Adding bias terms to these linear transformations would introduce intrinsic interactions between the spins that persist even in the absence of the external magnetic fields. Essentially, we recognize the softmax attention matrix as a parametrized flavor of the (asymmetric) coupling matrix of a vector-spin model. The external magnetic fields thus not only affect the vector spins directly, but also indirectly by altering the interaction strengths between them. This setup leads to a highly adaptive system where the interaction landscape itself is dynamically shaped by the inputs.&lt;/p&gt;
&lt;img src="arch_comparison.png" alt="Comparison between vanilla transformer module and spin-transformer module" width="550px"/&gt;
&lt;p&gt;What does the spin-transformer module return? The within-module time evolution is said to converge when the mean magnetizations collectively reach some kind of &lt;em&gt;non-equilibrium / near-equilibrium steady-state&lt;/em&gt; (NESS), which is not guaranteed a priori and requires us to make sure the couplings, inverse temperature, and normalizations are sensibly chosen. In fact, it might very well be the case that, for the parameter regimes we would want to consider, the behavior of the vector-spin model is quite equilibrium-like, and this is probably what we want to aim for anyway given that oscillations, instabilities, and divergences are always lurking close by in the perilous phase spaces of these systems. If the within-module time evolution converges, we return the magnetizations $\mathbf{m}_{\mathrm{NESS}} \in \mathbb{R}^{N \times D}$ as module outputs. Instead of time evolving for a number of steps until convergence, we could also try hunting for the NESS directly by assuming it exists and solving for it as if it were a fixed point of the time evolution.&lt;/p&gt;
&lt;p&gt;To wrap up this section, we list a few conceptual similarities and features below to close the gap between vector-spin models and transformer modules:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Attention heads:&lt;/strong&gt; Multiple attention heads can be implemented by embedding $N_{h}$ coupling matrices into a head-block-diagonal coupling tensor. Effectively, this operation stacks $N_{h}$ smaller-dimensional spin models where each submodel processes a disjoint $D_{h}-$dimensional piece of the full $D-$dimensional vector space. Mixing between subspaces can occur because (1) each individual coupling matrix is still constructed from query and key mappings $\mathbb{R}^{D} \to \mathbb{R}^{N_{h} \times D_{h}}$ acting on the full input space, and (2) the dot products in the second-order correction terms Eq. \eqref{eq:secondordercorrections} naturally mix channels.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Causal masks:&lt;/strong&gt; Since we identify the attention matrix with the spin model&amp;rsquo;s couplings, autoregressive modeling can be done by applying the appropriate triangular mask to the coupling matrix instead. The causal structure is preserved during the within-module time evolution. More generally, we expect any kind of masking that can be done on the level of the attention matrix to transfer to the coupling matrix.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Cross-attention:&lt;/strong&gt; The framework described above implements self-attention by constructing both queries and keys from the inputs $x$ according to Eq. \eqref{eq:softmaxcouplings}. Decoder layers in encoder-decoder models, however, rely on cross-attention, where keys (and values) from the encoder output are sent to the decoder input as context. We can accommodate this scenario by feeding the spin-transformer module an additional set of context vectors $\mathbf{c}$ to build the coupling matrix, i.e.,&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
\begin{equation}
\mathbf{J}(\mathbf{x}, \mathbf{c}) = \mathrm{softmax}\left( \frac{\boldsymbol{x} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{c}^{T}}{\sqrt{D}} \right). \label{eq:crosssoftmaxcouplings}
\end{equation}&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Normalization:&lt;/strong&gt; A flavor of
(RMSNorm) naturally appears in expression Eq. \eqref{eq:largedevmag} for the magnetization in the limit of large vector dimension as well as in all the mean-field update equations derived from it.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Queries, keys, and values:&lt;/strong&gt; The &lt;em&gt;queries&lt;/em&gt; and &lt;em&gt;keys&lt;/em&gt; are used to define the interactions between the spins from the external magnetic fields via Eq. \eqref{eq:softmaxcouplings}. In a sense, these linear transformations remain quite arbitrary since our framework is agnostic to the nature of the coupling matrix. But the &lt;em&gt;values&lt;/em&gt; do have an interpretation as the magnetizations $\mathbf{m}_{t-1}$ at the previous time step, or, in case of convergence, the steady-state magnetizations $\mathbf{m}^{\mathrm{NESS}}$.&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="fast--and-slow-moving-parameters"&gt;Fast- and slow-moving parameters&lt;/h2&gt;
&lt;p&gt;We now provide some additional physical intuition. As mentioned ad nauseam in
and
, each example in a batch of sequential data can be thought of as probing a spin-transformer module in a particular way. The response of the many-body system depends on the context provided by the applied external fields. We can tune the collective response behavior by parametrizing the couplings and making sure the whole probe-response stack is differentiable.&lt;/p&gt;
&lt;img src="spin_model_transformer_module.png" alt="Focus on a spin-transformer module in a stack of layers" width="450px"/&gt;
&lt;p&gt;Physically, the &lt;em&gt;fast-moving&lt;/em&gt; parameterized couplings $\mathbf{J}(\mathbf{x})$ are determined by the &lt;em&gt;fast-moving&lt;/em&gt; parameterized external fields $\mathbf{x}$, which, in a stack of transformer modules, depend on the magnetizations of the previous layer and ultimately on the input data. The external fields act as an environment of contextual patterns that gets transformed instantly into the values of the coupling matrix, effectively inducing some kind of state of quenched disorder. The &lt;em&gt;slow-moving&lt;/em&gt; parameters are those receiving gradient updates during training, e.g., the query-key matrices in the softmax couplings. On the level of a spin-transformer module, training can be understood as &lt;em&gt;shaping the input-dependent distribution of coupling parameters&lt;/em&gt; by amassing information from a huge amount of quenched disorder realizations, sculpting a spin glass with data.&lt;/p&gt;
&lt;h2 id="a-simple-jax-implementation-2"&gt;A simple JAX implementation&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Let us wrap up this post with some code showing how one could implement a spin-transformer module based on the recipe described above. We choose to normalize input vectors to have norm $R$, and, because of this choice, we set the softmax temperature in the couplings Eq. \eqref{eq:softmaxcouplings} to $1$ instead of $\sqrt{D}$ to make sure the scale of the matrix elements is similar as in scaled dot-product attention. As we have seen in
, lowering the norm of the input vectors decreases the strength the applied magnetic fields and increases the influence of the spin-spin interactions. Other normalization conventions might turn out to work better in actual training scenarios. Additionally, since different flavors of mean-field approximations lead to different update equations for the magnetizations, we want to stress that the approach we took in this post is just one possible option, which might not be the most useful one in practice.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;We use
to implement our neural network modules. We could replace the fixed-step &lt;code&gt;lax.scan&lt;/code&gt; time evolution of
with an &lt;code&gt;equinox.internal.while_loop&lt;/code&gt; to implement early-stopping when convergence occurs in a way that supports reverse-mode autodifferentiation. But then we would have to make sure to stop gradients so that only the values of the final iteration, corresponding to the steady-state magnetizations $\mathbf{m}^{\mathrm{NESS}}$, contribute to the gradient computation. To make things easier in the implementation below, we are going to assume the NESS exists and solve for it as if it were a fixed point of the time evolution. Implicit differentation of the fixed-point solver then takes care of the (near-)equilibrium gradients. So we only need the following function:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Find fixed-point vector magnetizations of second-order mean-field update equations.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_m_ness&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;fixed_point_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_m_ness&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We implement a spin-transformer module by wrapping a little boilerplate around the &lt;code&gt;vector_tap_fp&lt;/code&gt; function. We construct the spin-model couplings from the input vectors and mimic multi-head attention by &lt;code&gt;vmap&lt;/code&gt;&amp;lsquo;ing the magnetizations&amp;rsquo; fixed-point solving across &lt;code&gt;num_heads&lt;/code&gt; spin models where each one acts on an equal-size subspace of the full vector dimension.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TODO:&lt;/strong&gt; Fix multi-head case (it&amp;rsquo;s not just &lt;code&gt;vmap&lt;/code&gt;&amp;lsquo;ing the full thing).&lt;/p&gt;
&lt;/blockquote&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;typing&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Callable&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;equinox&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;eqx&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;einops&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SpinTransformerModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;to_qk&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Callable&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_qk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vector_tap_fp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;... h n d -&amp;gt; ... n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_qk&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;... n (h d) -&amp;gt; ... h n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;... i d, ... j d -&amp;gt; ... i j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;finfo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;min&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;... n (h d) -&amp;gt; ... h n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;in_axes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;... h n d -&amp;gt; ... n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Let&amp;rsquo;s run a forward pass of the spin-transformer module&amp;hellip;&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PRNGKey&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2666&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mod_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;transformer_module&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;SpinTransformerModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mod_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;transformer_module&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;[[[&lt;/span&gt; &lt;span class="mf"&gt;0.46483648&lt;/span&gt; &lt;span class="mf"&gt;0.3805422&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.44913006&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.02650307&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.36570293&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.23443604&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.37061682&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.42315483&lt;/span&gt; &lt;span class="mf"&gt;0.1197958&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.6265602&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.61598897&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.5583689&lt;/span&gt; &lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.21803643&lt;/span&gt; &lt;span class="mf"&gt;0.17418407&lt;/span&gt; &lt;span class="mf"&gt;0.22512378&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.82831764&lt;/span&gt; &lt;span class="mf"&gt;0.13957487&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.17361565&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.03738704&lt;/span&gt; &lt;span class="mf"&gt;0.10310851&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.12114237&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.17507279&lt;/span&gt; &lt;span class="mf"&gt;0.30361462&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.09653477&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.4211655&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.20545821&lt;/span&gt; &lt;span class="mf"&gt;0.12954816&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.74708706&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.35752055&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5818469&lt;/span&gt; &lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;1.149747&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.6245326&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.28383803&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.31866318&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.13622926&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.52548647&lt;/span&gt;&lt;span class="p"&gt;]]]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;&amp;hellip; and a backward pass.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nd"&gt;@eqx.filter_jit&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;filter_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;transformer_module&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_qk&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;weight&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;[[&lt;/span&gt; &lt;span class="mf"&gt;6.84143470e-06&lt;/span&gt; &lt;span class="mf"&gt;1.26781670e-04&lt;/span&gt; &lt;span class="mf"&gt;3.00350985e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;2.42774186e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;6.56897682e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.09572255e-04&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;2.77053477e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.62737968e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;9.00395680e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;8.95370322e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.99462512e-05&lt;/span&gt; &lt;span class="mf"&gt;5.35702784e-05&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.52689070e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.44067290e-05&lt;/span&gt; &lt;span class="mf"&gt;1.77498405e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.35530383e-04&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;7.19401141e-05&lt;/span&gt; &lt;span class="mf"&gt;1.22722937e-04&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.90037055e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.04181963e-04&lt;/span&gt; &lt;span class="mf"&gt;4.73747787e-06&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;8.87275892e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;5.93782897e-06&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.02471051e-05&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.34355170e-05&lt;/span&gt; &lt;span class="mf"&gt;3.30054972e-05&lt;/span&gt; &lt;span class="mf"&gt;1.77152877e-04&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.20974844e-04&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.17946729e-04&lt;/span&gt; &lt;span class="mf"&gt;4.90189996e-06&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;3.79099110e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.06873820e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;8.71618904e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;4.89293416e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;8.51267905e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.46996666e-04&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Going beyond a single spin-transformer module, we can stack modules sequentially to create a spin-transformer model using the
:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SpinTransformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;modules&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;SpinTransformerModule&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;depth&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;depth&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;make_modules&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;SpinTransformerModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;modules&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;filter_vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;make_modules&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dynamic_modules&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;static_modules&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;partition&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;modules&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_array&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_dynamic_module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;module&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;combine&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_dynamic_module&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;static_modules&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;module&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scan&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dynamic_modules&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;transformer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;SpinTransformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;depth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mod_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;transformer&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;[[[&lt;/span&gt; &lt;span class="mf"&gt;0.20396525&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.06002701&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.24426042&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.25347382&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.01503923&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.15146086&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.3552067&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.4154298&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.2159235&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.68296695&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.18692644&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.20893992&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.03525298&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.11836862&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.13671912&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.22646151&lt;/span&gt; &lt;span class="mf"&gt;0.18905625&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.05829766&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.11216182&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.26305646&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.31211302&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.27817503&lt;/span&gt; &lt;span class="mf"&gt;0.25123474&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.11120855&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.17170963&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.33360714&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.12762357&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.70538384&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.04229175&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5447842&lt;/span&gt; &lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.5191558&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5662918&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.33646253&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.4568781&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.04439414&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.18843232&lt;/span&gt;&lt;span class="p"&gt;]]]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;In this post, we have shown how
can be generalized to capture asymmetric coupling matrices like softmax attention. We observed that dynamical mean-field descriptions of vector-spin models exhibit structure capable of yielding residual connections, attention terms, and feed-forward-like correction terms, motivating a physics-inspired class of spin-transformer modules. By blending ideas from deep learning and statistical mechanics, we hope our work can help open up broader interdisciplinary bridges to improve our understanding of learning and generalization in transformer neural networks.&lt;/p&gt;
&lt;p&gt;From a theoretical point of view, it would be interesting to further explore and develop connections to the physics of vector spin glasses and properly study transformers as statistical-mechanical systems. Computationally, we look forward to experiments at scale to get more insight into potential benefits and bottlenecks of spin-transformer models in terms of
, representational power, and scaling behavior. In any case, it is fun to think about transformers as a collective of driven, disordered vector-spin models whose response behavior can be shaped by learning parameterized interactions, gradually steering a cascade of near-equilibrium steady-state magnetizations towards solving a given objective.&lt;/p&gt;
&lt;h1 id="references"&gt;References&lt;/h1&gt;
&lt;p&gt;A non-exhaustive list of references and inspiration includes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;F. Nicoletti, Low energy excitations of vector spin glasses, PhD thesis (2023)
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;M. Aguilera, S.A. Moosavi, and H. Shimazaki, A unifying framework for mean-field theories of asymmetric kinetic Ising systems, &lt;em&gt;Nat Commun&lt;/em&gt; &lt;strong&gt;12&lt;/strong&gt;, 1197 (2021)
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Y. Roudi and J. Hertz, Dynamical TAP equations for non-equilibrium Ising spin glasses, &lt;em&gt;J. Stat. Mech.&lt;/em&gt;, P03031 (2011)
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;H.J. Kappen and J.J. Spanjers, Mean field theory for asymmetric neural networks, &lt;em&gt;Phys. Rev. E&lt;/em&gt; &lt;strong&gt;61&lt;/strong&gt;, 5658 (2000)&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;G. Parisi, Asymmetric neural networks and the process of learning, &lt;em&gt;J. Phys. A: Math. Gen.&lt;/em&gt; &lt;strong&gt;19&lt;/strong&gt; L675 (1986)&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2023spinmodeltransformers,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Spin-Model Transformers},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2023},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {December},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/spin-model-transformers}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h1 id="appendices"&gt;Appendices&lt;/h1&gt;
&lt;h2 id="a1-vector-spin-distribution-normalization-constant"&gt;A.1. Vector-spin distribution: normalization constant&lt;/h2&gt;
&lt;p&gt;We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }.
\end{equation}&lt;p&gt;Let $Z(\beta, R, \mathbf{h})=\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}$. We switch to $D$-dimensional spherical coordinates to make our life easier and use rotational symmetry to choose the polar axis parallel to $\mathbf{h}$,&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = R^{D-1} \int_{\Omega} \int_{0}^{\pi} \mathrm{d}^{D-2} \Omega \;\mathrm{d}\theta \; \mathrm{e}^{\beta R h \cos \theta } \sin^{D-2} \theta ,
\end{equation}&lt;p&gt;where $h=\lVert\mathbf{h}\rVert$ and where $\int_{\Omega} \mathrm{d}^{D-2} \Omega$ represents the integral over all other spherical angles, which coincides with the surface area of the unit sphere in $D-1$ dimensions,&lt;/p&gt;
\begin{equation}
S_{D-1} = \frac{2\pi^{\frac{D-1}{2}}}{\Gamma\left( \frac{D-1}{2} \right)},
\end{equation}&lt;p&gt;so that&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = \frac{2 \pi^{\frac{D-1}{2}} R^{D-1}}{\Gamma\left( \frac{D-1}{2} \right)} \int_{0}^{\pi} \mathrm{d}\theta \; \mathrm{e}^{\beta R h \cos \theta } \sin^{D-2} \theta .
\end{equation}&lt;p&gt;If we now let $u = \cos \theta$, then&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = \frac{2 \pi^{\frac{D-1}{2}} R^{D-1}}{\Gamma\left( \frac{D-1}{2} \right)} \int_{-1}^{1} \mathrm{d}u \; \mathrm{e}^{\beta R h u } \left(1 - u^2\right)^{(D-3)/2} .
\end{equation}&lt;p&gt;Recognizing
,&lt;/p&gt;
\begin{equation}
I_{\nu}(z) = \frac{2^{-\nu}}{\sqrt{\pi}\, \Gamma\left(\nu+\frac{1}{2}\right)} z^{\nu} \int_{-1}^{1} \mathrm{d}t \; \mathrm{e}^{\pm zt} \left(1-t^2\right)^{\nu-\frac{1}{2}},
\end{equation}&lt;p&gt;we identify $\nu = D/2 - 1$ and $z = \beta R h$ to find&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R h) }{ \left(\beta h\right)^{D/2-1} }.
\end{equation}&lt;h2 id="a2-vector-spin-distribution-expected-value-first-moment"&gt;A.2. Vector-spin distribution: expected value (first moment)&lt;/h2&gt;
&lt;p&gt;We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }.
\end{equation}&lt;p&gt;Starting from the expression of the normalization constant Eq. \eqref{eq:partfun},&lt;/p&gt;
\begin{equation}
\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } = Z(\beta, R, \lVert \mathbf{h}\rVert) ,
\end{equation}&lt;p&gt;we write the expected value as&lt;/p&gt;
\begin{equation}
\mathbb{E}_{p} [ \mathbf{s} ] = \frac{1}{Z} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathbf{s} \, \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}
\end{equation}&lt;p&gt;so that&lt;/p&gt;
\begin{align}
\mathbb{E}_{p} [ \mathbf{s} ] = \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \left( \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert\mathbf{h} \rVert) }{ \left(\beta \lVert\mathbf{h}\rVert \right)^{D/2-1} } \right)
\end{align}&lt;p&gt;which evaluates to&lt;/p&gt;
\begin{align}
\mathbb{E}_{p} [ \mathbf{s} ] = \left( \frac{I'_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert)}{I_{D/2 - 1}(\beta R \lVert\mathbf{h}\rVert)} - \frac{ D/2-1 }{ \beta R \lVert\mathbf{h}\rVert} \right) \frac{R \mathbf{h}}{\lVert\mathbf{h}\rVert}.
\end{align}&lt;p&gt;Using the
,&lt;/p&gt;
\begin{align}
I_{\nu-1}(z) - I_{\nu+1}(z) &amp;= \frac{2\nu}{z} I_{\nu}(z), \label{eq:irecurr}\\\\
I_{\nu-1}(z) + I_{\nu+1}(z) &amp;= 2 I'_{\nu}(z), \label{eq:irecurrderiv}
\end{align}&lt;p&gt;we end up with&lt;/p&gt;
\begin{align}
\mathbb{E}_{p} [ \mathbf{s} ] = \frac{I_{D/2}(\beta R \lVert \mathbf{h}\rVert)}{I_{D/2 - 1}(\beta R \lVert\mathbf{h}\rVert)} \frac{R \mathbf{h}}{\lVert\mathbf{h}\rVert}\equiv \boldsymbol{\varphi} (\mathbf{h}). \label{eq:app:expectedvalue}
\end{align}&lt;h2 id="a3-vector-spin-distribution-variance-second-moment"&gt;A.3. Vector-spin distribution: variance (second moment)&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TODO:&lt;/strong&gt; Add variance for general case.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }.
\end{equation}&lt;p&gt;Using the expression of the normalization constant Eq. \eqref{eq:partfun},&lt;/p&gt;
\begin{equation}
\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } = Z(\beta, R, \lVert \mathbf{h}\rVert) ,
\end{equation}&lt;p&gt;we write the symmetric outer-product variance matrix as&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \frac{1}{Z} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} \, ( \mathbf{s} - \mathbb{E}_{p} [ \mathbf{s} ])( \mathbf{s} - \mathbb{E}_{p} [ \mathbf{s} ])^{T} \\\\
&amp;= \frac{1}{\beta^2 Z} \frac{ \partial^2 }{ \partial \mathbf{h} \partial \mathbf{h}^{T} } \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} - \mathbb{E}_{p} [ \mathbf{s} ] \mathbb{E}_{p} [ \mathbf{s} ]^{T},
\end{align}&lt;p&gt;so that&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \left( Z \mathbb{E}_{p} [ \mathbf{s} ]^{T} \right) - \mathbb{E}_{p} [ \mathbf{s} ] \mathbb{E}_{p} [ \mathbf{s} ]^{T}, \\\\
&amp;= \frac{1}{\beta} \frac{ \partial }{ \partial \mathbf{h} } \mathbb{E}_{p} [ \mathbf{s} ]^{T},
\end{align}&lt;p&gt;which evaluates to&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \ldots \label{eq:app:var}
\end{align}&lt;p&gt;for the general case with the expected value given by Eq. \eqref{eq:app:expectedvalue} and to&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \frac{\mathbb{1}}{1+\gamma(\mathbf{h})} - \frac{\beta^2\mathbf{h} \otimes \mathbf{h}}{R^2\gamma(\mathbf{h})\left(1+\gamma(\mathbf{h})\right)^2}\\\\
&amp;= \frac{\mathbb{1}}{1+\gamma(\mathbf{h})} - \frac{\boldsymbol{\varphi} (\mathbf{h}) \otimes \boldsymbol{\varphi}(\mathbf{h})}{R^2\gamma(\mathbf{h})}
\end{align}&lt;p&gt;for the large-$D$ limit with the expected value given by Eq. \eqref{eq:largedevmag}, where&lt;/p&gt;
\begin{align}
\gamma(\mathbf{h}) = \sqrt{1+\beta^{2}\lVert\mathbf{h}\rVert^{2}/R^2}
\end{align}&lt;h2 id="a4-ratio-of-modified-bessel-functions-of-the-first-kind"&gt;A.4. Ratio of modified Bessel functions of the first kind&lt;/h2&gt;
&lt;p&gt;To compute the ratio $I_{\nu+1}(x) / I_{\nu}(x)$ of modified Bessel functions of the first kind for $\nu \geq 0$ and $x \geq 0$, we implement a
of the algorithm described in
. A pseudocode implementation can be found in
. We compare our implementation against explicitly calculating the ratio using
across a range of orders $\nu$ for several different values of $x$ to get a feel for its behavior.&lt;/p&gt;
&lt;img src="bessel_plot_1.png" width="500px"/&gt;
&lt;p&gt;We observe a satisfying agreement between the two approaches. For $x=\sqrt{\nu}$, the ratio takes on very small values for large orders. For $x=\nu^2$, the oppositive happens and we see saturation. The case $x=\nu$ seems to sit in between, which suggests it might be opportune to fix the radius of our little spins to $R=\sqrt{D}$ so that with $\lVert\mathbf{h}\rVert \sim \mathcal{O}(\sqrt{D})$ we might maximize the &amp;ldquo;sensitivity&amp;rdquo; of the expected value. In this regime, we can get away with
for large $\nu$ given that the ratio flattens out quickly.&lt;/p&gt;
&lt;h2 id="a5-general-case-partial-derivatives-with-respect-to"&gt;A.5. General case: partial derivatives with respect to $\alpha$&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TODO:&lt;/strong&gt; Clean up and verify (haha, no).&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We are interested in computing the first-order and second-order derivative with respect to $\alpha$ of the function&lt;/p&gt;
\begin{equation}
\boldsymbol{\varphi}(\mathbf{h}(\alpha)) = \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \mathbf{h}(\alpha)}{\lVert \mathbf{h}(\alpha) \rVert},
\end{equation}&lt;p&gt;where $\mathbf{h}(\alpha) = \boldsymbol{\theta} + \alpha \Delta \mathbf{h}$. Using&lt;/p&gt;
\begin{equation}
\frac{\partial \lVert \mathbf{h}(\alpha) \rVert}{\partial\alpha} = \frac{\mathbf{h}(\alpha) \cdot \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert}
\end{equation}&lt;p&gt;and Eqs. \eqref{eq:irecurr}-\eqref{eq:irecurrderiv}, we find&lt;/p&gt;
\begin{align}
\frac{\partial \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} = \beta &amp;\lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\\\
&amp;+ \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \label{eq:generalgradalphafirstorder}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{equation}
\lambda_{D} (x) = \frac{I^2_{D/2-1}(x)}{I^2_{D/2}(x)} - \frac{D}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)} - 1. \label{eq:app:lambda}
\end{equation}&lt;p&gt;For the second-order derivative, we need to slog through even more tedious algebra,&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha^2}
= \beta &amp;\frac{\partial}{\partial\alpha}\biggl( \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \biggr) \nonumber \\\\
&amp;+ \frac{\partial}{\partial\alpha}\biggl( \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \biggr) ,
\end{align}&lt;p&gt;which eventually leads to something like&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha^2}
= -2\beta^2 &amp; \, \kappa_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right)^{2} \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\\\
&amp;+ \beta \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \frac{\partial\boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\\\
&amp;+ \beta \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \frac{\partial\boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} \nonumber \\\\
&amp;- \frac{D}{\lVert \mathbf{h}(\alpha) \rVert^2} \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \Delta \mathbf{h} , \label{eq:generalgradalphasecondorder}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\kappa_{D} (x) = \lambda^2_{D} (x) + \left( 1 + \frac{D/2 + 1}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)} \right) \lambda_{D} (x) + \frac{1}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)}.
\end{align}&lt;p&gt;Equation \eqref{eq:generalgradalphasecondorder} can be further simplified by substituting the first-order derivative Eq. \eqref{eq:generalgradalphafirstorder} and further simplifying the resulting expression. The derivation of the mean-field equations proceeds in a similar fashion as in the main text, but uses \eqref{eq:generalgradalphafirstorder} and \eqref{eq:generalgradalphasecondorder} as expressions for the partial derivatives instead of their large-$D$ approximations.&lt;/p&gt;
&lt;p&gt;Another useful derivative is that of the single-site probability distribution \eqref{eq:pcondsinglesitevector},&lt;/p&gt;
\begin{align}
\frac{\partial}{\partial\alpha} \left( \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \right) = \frac{\partial}{\partial\mathbf{h}(\alpha)} \left( \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \right) \cdot \Delta \mathbf{h},
\end{align}&lt;p&gt;which evaluates to&lt;/p&gt;
\begin{align}
\beta \left( \mathbf{s} - \boldsymbol{\varphi}\left(\mathbf{h}(\alpha)\right) \right) \cdot \Delta \mathbf{h} \frac{ \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} }{ \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} }
\end{align}&lt;p&gt;and can be used to calculate derivatives of the conditional distribution \eqref{eq:pcondaltvector}.&lt;/p&gt;
&lt;h1 id="footnotes"&gt;Footnotes&lt;/h1&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;We plot the absolute value to get rid of artificial &amp;ldquo;jumps&amp;rdquo; between the two branches. These occur because all models are simulated independently when sweeping across $\beta$ and the some combinations of initial state and model parameters might just happen to bounce to the other branch when $\beta$ changes in the $\beta &gt; \beta_c$ regime.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>