<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Transformers | mcbal</title><link>https://mcbal.github.io/tags/transformers/</link><atom:link href="https://mcbal.github.io/tags/transformers/index.xml" rel="self" type="application/rss+xml"/><description>Transformers</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-gb</language><lastBuildDate>Mon, 02 Feb 2026 09:28:17 +0100</lastBuildDate><image><url>https://mcbal.github.io/media/icon.svg</url><title>Transformers</title><link>https://mcbal.github.io/tags/transformers/</link></image><item><title>Entropy Production in Non-Equilibrium Neural Networks</title><link>https://mcbal.github.io/post/entropy-production-in-non-equilibrium-neural-networks/</link><pubDate>Mon, 02 Feb 2026 09:28:17 +0100</pubDate><guid>https://mcbal.github.io/post/entropy-production-in-non-equilibrium-neural-networks/</guid><description>&lt;p&gt;&lt;a title="Walter Baxter / A murmuration of starlings at Gretna" href="https://commons.wikimedia.org/wiki/File:Starling_murmuration.jpg"&gt;&lt;img width="512" alt="A murmuration of starlings at Gretna" src="https://upload.wikimedia.org/wikipedia/commons/8/8d/Starling_murmuration.jpg?20150218191823"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;&lt;p align="center"&gt;This project is a work in progress (open research)&lt;/p&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 id="introduction"&gt;Introduction&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Modern large-scale autoregressive language models are impressive system engineering artifacts. Yet they are frozen, with no apparent notion of dynamics unfolding over time. Surfacing in-context learning at inference time through prompt and environment engineering mitigates the fact that these models are temporal only in so far as information inside their context windows matches patterns observed during consecutive offline training stages. Time, and its dynamic memory affordances, is in a sense amortized or compressed away, incentivizing models to overrely on storing relevant patterns into parametric memory instead of sculpting latent low-dimensional shapes supporting stable dynamic computation. This has implications for online continual learning, adaptive model deployment, and real-time closed-loop interaction with live systems.&lt;/p&gt;
&lt;p&gt;In this post, we take the notion of treating neural networks as non-equilibrium thermodynamic systems seriously. We design a physics-inspired transformer module with adaptable couplings and memory parameters based on the naive mean-field dynamics of a class of vector-spin models introduced in
. The underlying mean-field spin-model interpretation enables us to write down an expression for
, a thermodynamic quantity measuring &amp;ldquo;instantaneous&amp;rdquo; irreversibility by quantifying the asymmetry between forward and backward time steps.&lt;/p&gt;
&lt;p&gt;Since every operation in our spin-transformer module is differentiable, entropy production can be made into a loss function. For example, maximizing entropy production incentivizes the system to &lt;em&gt;lean into the external drive&lt;/em&gt; by nudging its parameters to dump entropy as fast as possible in a way that maximizes uncertainty given constraints. Internally, we imagine the system reshaping itself into ordered structures to enable more efficient dissipation of the internal tension caused by the incoming data stream.&lt;/p&gt;
&lt;h2 id="background-and-intuitions"&gt;Background and intuitions&lt;/h2&gt;
&lt;p&gt;We
consider transformer modules as differentiable driven disordered vector-spin systems whose mean-field collective behavior we can control through training, and refer to
going back to
for earlier instantiations of this intuition. According to our correspondence, the forward pass of a transformer module implements a spin system&amp;rsquo;s response to getting probed, where &lt;em&gt;inputs&lt;/em&gt; map to time-varying applied external fields, &lt;em&gt;asymmetric, sparse attention matrices&lt;/em&gt; can be identified with fully-connected spin-spin interactions, and &lt;em&gt;outputs&lt;/em&gt; map to spin expectation values or magnetizations. Practically, the forward pass of a spin-transformer module can be designed to mimic that of a vanilla transformer module.&lt;/p&gt;
&lt;p&gt;In contrast to physics-oriented literature, we do not specify explicit probability distributions for the external fields and couplings of the disordered many-body system, nor are we interested in Nobel-prize-winning ways to average out the disorder. We instead focus on the very specific quenched disorder realizations induced by a dataset or environment of interest (encoded as sequences of vector embeddings), whose examples we use to drive the system. In this framing, training a transformer module corresponds to sculpting the underlying system&amp;rsquo;s collective response by tuning the parametrized distributions of its external fields and couplings.&lt;/p&gt;
&lt;img src="spin_transformer_module_fwd_bwd.png" alt="Forward and backward pass illustration" width="500px"/&gt;
&lt;p&gt;In
, we observed that these systems tend to settle into non-equilibrium steady states as dynamic sweet spots where the &amp;ldquo;continuous kicking&amp;rdquo; of the inputs (applied external fields) &amp;ldquo;sustains&amp;rdquo; the outputs (magnetizations). This negotiation process tends to happen after just a few iterations. The first iteration already gives a decent guess, which might explain why (1) transformers can get away with just stacking modules whose forward passes take just one time step, and (2) why doing a few time steps can improve performance, as done in recursive reasoning approaches. Indeed, repeating the same module can be seen as allowing the underlying non-equilibrium system to settle more snuggly into its steady state for that particular inputs/parameters configuration. However, as soon as the input sequence changes or the parameters are updated, the system has to renegotiate a different steady state compatible with what its current configuration dictates the response should be.&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="non-equilibrium-neural-networks"&gt;Non-equilibrium neural networks&lt;/h2&gt;
&lt;h3 id="example-model"&gt;Example model&lt;/h3&gt;
&lt;p&gt;When designing neural networks around mean-field vector-spin models, there is a lot of freedom. First of all, we must decide on what mean-field approximation to use for our spin system. Projecting the dynamics to different ansatz distributions leads to different mean-field equations, whick take into account more or less correlations at different time steps. In this post, we choose the simplest option: a first-order &lt;code&gt;Plefka[t-1,t]&lt;/code&gt; approximation. From
, we remember&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \rVert^2 / R^2 }},
\end{equation}&lt;p&gt;where $\mathbf{m}_{i,t} \in \mathbb{R}^{D}$ denote the magnetizations (outputs) at time $t$, $\mathbf{x}_{i,t} \in \mathbb{R}^{D}$ denote the applied external fields (inputs) at time $t$, $J_{ij}$ are the couplings, $\beta$ is an inverse temperature, and $R=\sqrt{D/2 -1}$ is a natural length scale resulting from the large-$D$ approximation we used to get rid of dealing with Bessel functions.&lt;/p&gt;
&lt;p&gt;If we now consider &lt;em&gt;parametrized input-dependent couplings&lt;/em&gt;&lt;/p&gt;
\begin{equation}
\mathbf{J} (\mathbf{x}) = \mathrm{softmax}\left( \mathbf{x} \boldsymbol{Q} \boldsymbol{K}^{T} \mathbf{x}^{T} \right), \label{eq:softmax}
\end{equation}&lt;p&gt;and augment the applied external fields with a &lt;em&gt;parametrized input-dependent memory&lt;/em&gt;,&lt;/p&gt;
\begin{equation}
\mathbf{x}_{i,t} \to \mathbf{x}_{i,t} + \mathrm{FFN}\left( \mathbf{x}_{i,t} \right),
\end{equation}&lt;p&gt;then our forward pass looks like&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \mathrm{FFN}\left( \mathbf{x}_{i,t} \right) + \sum_{j} J_{ij} (\mathbf{x}_{t}) \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \mathrm{FFN}\left( \mathbf{x}_{i,t} \right) + \sum_{j} J_{ij} (\mathbf{x}_{t}) \mathbf{m}_{j,t-1} \rVert^2 / R^2 }},
\end{equation}&lt;p&gt;which resembles a parallel transformer block as introduced in GPT-J and used in PaLM, with the notable difference that the &amp;ldquo;values&amp;rdquo; here correspond to the outputs (magnetizations) of the previous time step instead of some linear transformation applied to the inputs at the current time step. Making the applied external fields as well as the couplings input-dependent leads to a &lt;em&gt;highly adaptive system&lt;/em&gt; where the interaction landscape itself is dynamically shaped by the inputs.&lt;/p&gt;
&lt;p&gt;We can choose to have our module keep track of the previous state so that one forward pass corresponds to taking a single time step. If we care more about the steady state, we can also immediately compute the fixed point of the time evolution using a differentiable fixed-point solver, in which case one forward pass corresponds to jumping to the time-evolution fixed point. The latter approach is reminiscent of deep equilibrium models and certain recursive reasoning approaches.&lt;/p&gt;
&lt;h3 id="entropy-production"&gt;Entropy production&lt;/h3&gt;
&lt;p&gt;Following
, the entropy production for the kinetic Ising model, assuming a non-equilibrium steady state, is given by&lt;/p&gt;
\begin{equation}
\sigma_{t} = \sum_{ij} \left(J_{ij} - J_{ji}\right) D_{ij,t} \geq 0,
\end{equation}&lt;p&gt;where $J_{ij}$ corresponds to the couplings and $D_{ij,t}$ denotes the time-delayed correlations. If we write this down for the vector-spin case,&lt;/p&gt;
\begin{equation}
D_{ij,t} = \int \mathrm{d} \mathbf{s}_{t} \int \mathrm{d} \mathbf{s}_{t-1} \; \left( \mathbf{s}_{i,t} - \mathbf{m}_{i,t} \right) \cdot \left( \mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1}\right) \; P( \mathbf{s}_{t}, \mathbf{s}_{t-1} ),
\end{equation}&lt;p&gt;we can compute a first-order &lt;code&gt;Plefka[t-1,t]&lt;/code&gt; mean-field approximation for the time-delayed correlations, similar to the computations we did previously for the magnetizations in
, leading to something like&lt;/p&gt;
\begin{align}
D_{ij,t} = &amp;\frac{\beta J_{ij}}{1+\gamma_{i,t}} \left(R^2 - \mathbf{m}_{j,t-1}^2 \right) \nonumber\\\\
&amp;- \frac{\beta J_{ij}}{R^2 \gamma_{i,t} \left( 1 + \gamma_{j,t-1} \right)} \mathbf{m}_{i,t}^2 \nonumber\\\\
&amp;+ \frac{\beta J_{ij}}{R^4 \gamma_{i,t} \gamma_{j,t-1}} \left( \mathbf{m}_{i,t} \cdot \mathbf{m}_{j,t-1} \right)^2,
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\gamma_{i,t} &amp;= \sqrt{1 + \beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 } \\\\
\boldsymbol{\theta}_{i,t} &amp;= \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1}
\end{align}&lt;h3 id="vibe-check"&gt;Vibe check&lt;/h3&gt;
&lt;p&gt;Let us try to get a feel for what the entropy production looks like for vector-spin models using some rough back-of-the-envelope estimations. Assume both vectors $\mathbf{m}_{i,t}$ and $\mathbf{m}_{j,t-1}$ have a norm $\mathcal{O}(R)$, then the time-delayed correlations behave approximately like&lt;/p&gt;
\begin{align}
D_{ij,t} \sim J_{ij} \cos^2 \alpha_{(i,t)(j,t-1)},
\end{align}&lt;p&gt;where $\alpha_{(i,t)(j,t-1)}$ denotes the angle between the magnetization vectors. So the entropy production looks approximately like&lt;/p&gt;
\begin{equation}
\sigma_{t} \sim \sum_{ij} \left(J_{ij}^2 - J_{ij} J_{ji}\right) \cos^2 \alpha_{(i,t)(j,t-1)},
\end{equation}&lt;p&gt;which, in general, is minimized for symmetric coupling matrices or orthogonal embeddings and maximized for fully-asymmetric couplings or (anti-)parallel embeddings.&lt;/p&gt;
&lt;p&gt;But for the softmax attention matrix Eq. \eqref{eq:softmax}, we have additional constraints $J_{ij} \geq 0$ as well as a Frobenius norm of $\mathcal{O}(\sqrt{N})$ preventing unbounded growth under maximization. Additionally, imposing a causal mask on the couplings to do autoregressive modeling leads to even more constraints since then the upper triangular part of $J_{ij}$ is fixed to zero. So it feels like maximizing entropy production for causal softmax couplings promotes some kind of compromise between &lt;em&gt;sparse attention&lt;/em&gt; (intuitively, if the upper-triangular part is zero then it is favorable to push the lower-triangular elements close to zero as well) and &lt;em&gt;clustering of embeddings&lt;/em&gt; (weighted maximization of cosine similarity).&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="experiments"&gt;Experiments&lt;/h2&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h3 id="model-behavior-in-a-noisy-environment"&gt;Model behavior in a noisy environment&lt;/h3&gt;
&lt;p&gt;Interfaces, sensors and effectors.&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h3 id="global-coherence-from-local-backpropagation"&gt;Global coherence from local backpropagation&lt;/h3&gt;
&lt;p&gt;We test a stack of spin-transformer modules in a toy femtoscale online learning setup and try to see if we can make
when maximizing per-layer entropy-production losses &lt;em&gt;independently&lt;/em&gt;. If we detach module outputs after applying each layer, we end up with systems communicating via their input/output interfaces but without gradients backpropagating through the whole stack. (Pretty unlikely that the entropy-production losses on their own provide enough signal though.)&lt;/p&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h3 id="growing-network-topologies"&gt;Growing network topologies&lt;/h3&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="discussion-and-related-work"&gt;Discussion and related work&lt;/h2&gt;
&lt;p&gt;&amp;hellip;&lt;/p&gt;
&lt;h2 id="references"&gt;References&lt;/h2&gt;
&lt;p&gt;A non-exhaustive list of references and inspiration includes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
by
Miguel Aguilera, S. Amin Moosavi, and Hideaki Shimazaki&lt;/li&gt;
&lt;li&gt;
by Jacob Mitchell Gold&lt;/li&gt;
&lt;li&gt;
by Giovanni Pezzulo and Michael Levin&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2026,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Entropy Production in Non-Equilibrium Neural Networks},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2026},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {?},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/entropy-production-in-non-equilibrium-neural-networks/}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h1 id="footnotes"&gt;Footnotes&lt;/h1&gt;</description></item><item><title>Spin-Model Transformers</title><link>https://mcbal.github.io/post/spin-model-transformers/</link><pubDate>Sun, 19 Jun 2022 09:28:17 +0100</pubDate><guid>https://mcbal.github.io/post/spin-model-transformers/</guid><description>&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TL;DR:&lt;/strong&gt; &lt;em&gt;We interpret and implement transformer modules as driven, disordered vector-spin models whose response behavior can be shaped by learning parameterized interactions, gradually steering a cascade of near-equilibrium steady-state magnetizations towards solving a given objective. Using dynamical mean-field theory, we show that a first-order approximation of the update equations for the magnetizations reproduces residual and attention terms. Going to second-order adds explicit expressions for feed-forward-like correction terms that are fully determined by the mean-field structure of the underlying spin model. By blending ideas from deep learning and statistical mechanics, we hope our work can help open up broader interdisciplinary bridges to improve our understanding of learning and generalization in transformer neural networks.&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;In a series of previous
, we have tried to connect the forward pass of a transformer neural-network module to computing mean magnetizations in disordered Ising-like vector-spin models with parameterized couplings and external magnetic fields. According to this perspective, the forward pass of a transformer module can be understood as computing statistical observables given a specific realization of quenched couplings and external magnetic fields while the backward pass nudges the parameterized couplings and external magnetic fields. Physically, the transformer module represents an interacting many-body system modulating its behavior by learning to respond to being probed and driven in all kinds of ways.&lt;/p&gt;
&lt;p&gt;However, both the mean-field message-passing approach of
and the saddle-point free-energy approach of
inherently rely on methods that are only well-defined for spin models with symmetric coupling matrices, whose stochastic dynamics obey detailed balance and converge to a steady-state equilibrium characterized by the Boltzmann distribution. The softmax attention matrix in transformers is famously asymmetric though, so we had better come up with a more convincing approach to establish a correspondence.&lt;/p&gt;
&lt;p&gt;To capture spin models with asymmetric coupling matrices, we turn to non-equilibrium spin systems, whose dynamics can be pretty wild yet gentle enough to support regimes where relaxation to a non-equilibrium or near-equilibrium steady state can occur. In the past few decades, dynamical mean-field approaches have been developed for the binary kinetic Ising model, which exhibits non-equilibrium behavior for asymmetric couplings or when parameters are subject to rapid changes.&lt;/p&gt;
&lt;p&gt;In this post, we generalize a particular dynamical mean-field approach from binary spins to vector spins and relate the resulting mean-field update equations for the magnetizations to the forward pass of a transformer module. We find that the spin-model structure is rich enough for the update equations to yield residual connections, attention terms, and feed-forward correction terms, motivating a family of physics-inspired transformers.&lt;/p&gt;
&lt;h1 id="mean-field-theory-of-asymmetric-ising-models-with-binary-spins"&gt;Mean-field theory of asymmetric Ising models with binary spins&lt;/h1&gt;
&lt;p&gt;In this preliminary section, we review known results on mean-field theory approaches capturing the stochastic dynamics of binary kinetic Ising models. Readers familiar with this framework can skip ahead to
where we develop a generalization to vector spins. We primarily follow the discussion outlined in
. At the end of the section, we implement the mean-field update equations for the mean magnetizations in JAX and run a few numerical experiments.&lt;/p&gt;
&lt;h2 id="setting-the-scene-the-kinetic-ising-model"&gt;Setting the scene: the kinetic Ising model&lt;/h2&gt;
&lt;img src="binary_spins.png" alt="Random Ising model configuration with binary spins" width="250px"/&gt;
&lt;p&gt;We consider a kinetic Ising model describing a system made up of $N$ interacting binary spins $s_{i,t} \in \{-1, 1\}$ that evolve in discrete time steps $t$ according to synchronous dynamics, i.e. all spins get updated at the same time in parallel. Given a configuration $\mathbf{s}_{t-1} = \{ s_{1,t-1}, s_{2,t-1}, \ldots, s_{N,t-1} \}$ at time $t-1$, we consider the spins $\mathbf{s}_{t}$ at time $t$ to be conditionally independent random variables captured by a discrete-time Markov chain transition probability&lt;/p&gt;
\begin{equation}
P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}}}{\sum_{s_{i,t}} \mathrm{e}^{s_{i,t} h_{i,t}}} = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}}}{2 \cosh h_{i,t}}, \label{eq:pcond}
\end{equation}&lt;p&gt;where the effective external field is given by&lt;/p&gt;
\begin{equation}
h_{i,t} = x_{i,t} + \sum_{j=1}^{N} J_{ij} s_{j,t-1}.
\end{equation}&lt;p&gt;Here, the parameters $\mathbf{x}$ represent the (possibly time-dependent) local external fields at each site while the coupling parameters $\mathbf{J}$ are a specific realization of quenched disorder encoding the interactions between pairs of spins. Using the probability mass function of the previous state $P( \mathbf{s}_{t-1} )$ we can write the distribution of the current state as&lt;/p&gt;
\begin{equation}
P( \mathbf{s}_{t} ) = \sum_{\mathbf{s}_{t-1}} P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P( \mathbf{s}_{t-1} ), \label{eq:marginal}
\end{equation}&lt;p&gt;which, when applied recursively, traces the evolution of the system starting from some initial distribution $P( \mathbf{s}_{0} )$. Unless we turn off the couplings by setting $\mathbf{J} = \mathbf{0}$, the marginal distribution $P( \mathbf{s}_{t} )$ is not factorized and tends to be quite complicated. Our goal is to compute statistical properties of the system, such as the mean magnetizations&lt;/p&gt;
\begin{equation}
m_{i,t} = \sum_{\mathbf{s}_{t}} s_{i,t} P( \mathbf{s}_{t} ),
\end{equation}&lt;p&gt;as well as correlations&lt;/p&gt;
\begin{equation}
C_{ik,t} = \sum_{\mathbf{s}_{t}} s_{i,t} s_{k,t} P( \mathbf{s}_{t} ) - m_{i,t} m_{k,t},
\end{equation}&lt;p&gt;and delayed correlations&lt;/p&gt;
\begin{equation}
D_{il,t} = \sum_{\mathbf{s}_{t},\mathbf{s}_{t-1}} s_{i,t} s_{l,t-1} P( \mathbf{s}_{t}, \mathbf{s}_{t-1} ) - m_{i,t} m_{l,t-1}.
\end{equation}&lt;p&gt;Since the above expressions involve summing over a large amount of possible spin configurations, they are not very useful in practice. So we will try to approximate the tricky marginal distribution $P( \mathbf{s}_{t} )$ defined in Eq. \eqref{eq:marginal} using a mean-field theory approach.&lt;/p&gt;
&lt;h2 id="mean-field-theory-and-kullback-leibler-divergence"&gt;Mean-field theory and Kullback-Leibler divergence&lt;/h2&gt;
&lt;p&gt;Mean-field theory tries to approximate a complicated object ${\color{red}P}$ by wiggling around the parameters of a simple, analytically tractable parameterized ansatz ${\color{green}Q_{\theta}}$ to get as close as possible to ${\color{red}P}$. At risk of inducing headaches in mathematicians by calling everything a manifold, we can picture what is going on geometrically as trying to approximate a target probability distribution $P( \mathbf{s}_{t} \vert \mathbf{x}, \mathbf{J})$ and its statistical properties $\mathbf{m}_{t}$, $\mathbf{C}_{t}$, and $\mathbf{D}_{t}$ by restricting ourselves to a submanifold of tractable probability distributions. A particularly convenient submanifold is that of factorized models, where each point on the submanifold corresponds to a distribution parameterized by a vector $\boldsymbol{\theta}_{t}$,&lt;/p&gt;
\begin{equation}
Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} \theta_{i,t}}}{2 \cosh \theta_{i,t}}, \label{eq:q}
\end{equation}&lt;p&gt;so that the mean magnetizations are simply given by&lt;/p&gt;
\begin{equation}
m_{i,t} = \tanh \theta_{i,t} \label{eq:meanmagstanh}
\end{equation}&lt;p&gt;as there are no couplings between spins. The factorized model $Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}^{*}_{t} )$ that minimizes the Kullback-Leibler (KL) divergence&lt;/p&gt;
\begin{equation}
D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}}) = \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \log \frac{P( \mathbf{s}_{t})}{Q_{\theta}( \mathbf{s}_{t})} \label{eq:kl}
\end{equation}&lt;p&gt;has mean magnetizations $\mathbf{m}_{t}$ identical to those of the target distribution $P( \mathbf{s}_{t})$ since, for all spins $i=1,2,\ldots,N$, we find that&lt;/p&gt;
\begin{align}
\frac{\partial D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}}) }{\partial \theta_{i, t}} \Biggr\rvert_{\boldsymbol{\theta}_{t}=\boldsymbol{\theta}^{*}_{t}} &amp;= - \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \frac{\partial \log Q_{\theta}( \mathbf{s}_{t}) }{\partial \theta_{i, t}} \Biggr\rvert_{\boldsymbol{\theta}_{t}=\boldsymbol{\theta}^{*}_{t}} \\\\
&amp;= - \sum_{\mathbf{s}_{t}} s_{i,t} P( \mathbf{s}_{t}) + \tanh \theta^{*}_{i,t} \\\\
&amp;= -m^{{\color{red}P}}_{i,t} + m^{{\color{green}Q_{\theta^{*}}}}_{i,t} = 0, \label{eq:klm}
\end{align}&lt;p&gt;where $m^{{\color{red}P}}_{i,t}$ and $m^{{\color{green}Q_{\theta^{*}}}}_{i,t}$ respectively denote the expectation values of $s_{i,t}$ with respect to ${\color{red}P}$ and ${\color{green}Q_{\theta^{*}}}$. Indeed, minimizing $D_{\mathrm{KL}} ({\color{red}P}\vert\vert{\color{green}Q_{\theta}})$ tries to cover the modes of ${\color{red}P}$ by moment matching since the expectation value in Eq. \eqref{eq:kl} is calculated with respect to ${\color{red}P}$.&lt;/p&gt;
&lt;h2 id="the-plefka-expansion-interpolating-distributions"&gt;The Plefka expansion: interpolating distributions&lt;/h2&gt;
&lt;p&gt;Great, but is it even possible to find the parameters&lt;/p&gt;
\begin{equation}
\boldsymbol{\theta}^{*}_{t} = \operatorname*{arg\,min}_{\boldsymbol{\theta}_{t}} \left( - \sum_{\mathbf{s}_{t}} P( \mathbf{s}_{t}) \log Q_{\theta}( \mathbf{s}_{t}) \right)
\end{equation}&lt;p&gt;that minimize the KL divergence? Well, that&amp;rsquo;s going to be hard, unless you already know the target distribution $P( \mathbf{s}_{t})$, or you have a clever way of approximately evaluating the expectation value of $\log {\color{green}Q_{\theta}}$ with respect to ${\color{red}P}$. So let us introduce some more distributions to get around this issue. To apply the Plefka expansion to our problem, we introduce the conditional distribution&lt;/p&gt;
\begin{equation}
P_{\alpha}( \mathbf{s}_{t}\vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{s_{i,t} h_{i,t}(\alpha) }}{2 \cosh h_{i,t}(\alpha)}, \label{eq:pcondalt}
\end{equation}\begin{equation}
h_{i,t}(\alpha) = (1-\alpha) \theta_{i,t} + \alpha \left( x_{i,t} + \sum_{j=1}^{N} J_{ij} s_{j,t-1} \right), \label{eq:pcondalth}
\end{equation}&lt;p&gt;parameterized by a scalar $\alpha$ interpolating between $P_{\alpha=0}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} )$ (Eq. \eqref{eq:q}) and $P_{\alpha=1}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} )$ (Eq. \eqref{eq:pcond}). Using Eq. \eqref{eq:pcondalt}, we can construct an approximate marginal distribution $P_{\alpha}( \mathbf{s}_{t})$, leading to $\alpha$-dependent statistical properties $\mathbf{m}_{t}(\alpha)$, $\mathbf{C}_{t}(\alpha)$, and $\mathbf{D}_{t}(\alpha)$ for the approximate system. The Plefka expansion then boils down to writing these properties as Taylor series expansions around the factorized model $\alpha=0$. For the mean magnetizations, the expansion up to $n$-th order looks like&lt;/p&gt;
\begin{equation}
\mathbf{m}_{t}(\alpha) = \mathbf{m}_{t}(\alpha=0) + \sum_{k=1}^{n} \frac{\alpha^k}{k!} \frac{\partial^{k} \mathbf{m}_{t}(\alpha=0)}{\partial \alpha^{k}} + \mathcal{O}(\alpha^{n+1}), \label{eq:mtaylor}
\end{equation}&lt;p&gt;where all coefficients in the expansion are functions of $\boldsymbol{\theta}_{t}$ via Eq. \eqref{eq:pcondalth}. The mean-field approximation is computed by setting $\alpha=1$ so that the original marginal distribution is recovered and Eq. \eqref{eq:klm} holds, which implies that $\mathbf{m}_{t}(\alpha=1) = \mathbf{m}_{t}(\alpha=0)$ and thus&lt;/p&gt;
\begin{equation}
\sum_{k=1}^{n} \frac{1}{k!} \frac{\partial^{k} \mathbf{m}_{t}(\alpha=0)}{\partial \alpha^{k}} + \mathcal{O}(\alpha^{n+1}) = 0. \label{eq:mftheta}
\end{equation}&lt;p&gt;Finally, we solve Eq. \eqref{eq:mftheta} for $\boldsymbol{\theta}_{t}$ to find the mean-field values $\boldsymbol{\theta}^{*}_{t}$ of the parameters of the distribution Eq. \eqref{eq:q}. Physically, we are tuning the effective external magnetic fields of the factorized ansatz to $\boldsymbol{\theta}^{*}_{t}$ so that its approximate mean magnetizations get as close as possible to the true ones.&lt;/p&gt;
&lt;h2 id="naive-mean-field-and-thouless-anderson-palmer-approximations"&gt;Naive mean-field and Thouless-Anderson-Palmer approximations&lt;/h2&gt;
&lt;p&gt;We now consider first and second order approximations of the mean magnetizations Eq. \eqref{eq:mtaylor} to recover respectively the naive mean-field and Thouless-Anderson-Palmer (TAP) approximations for the binary kinetic Ising model. The starting point is a Plefka expansion around factorized models at times $t-1$ and $t$. From Eq. \eqref{eq:marginal} and Eq. \eqref{eq:pcondalt}, we construct a marginal probability distribution&lt;/p&gt;
\begin{equation}
P^{[t-1:t]}_{\alpha}( \mathbf{s}_{t} ) = \sum_{\mathbf{s}_{t-1},\mathbf{s}_{t-2}} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ),
\end{equation}&lt;p&gt;interpolating between $P^{[t-1:t]}_{\alpha=0}( \mathbf{s}_{t} ) = Q( \mathbf{s}_{t} )$ and $P^{[t-1:t]}_{\alpha=1}( \mathbf{s}_{t} ) = P( \mathbf{s}_{t} )$. The corresponding mean magnetizations are&lt;/p&gt;
\begin{align}
m_{i,t}(\alpha) &amp;= \sum_{\mathbf{s}_{t},\mathbf{s}_{t-1},\mathbf{s}_{t-2}} s_{i,t} \, P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \\\\
&amp;= \sum_{\mathbf{s}_{t-1},\mathbf{s}_{t-2}} \tanh h_{i,t}(\alpha) \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} )
\end{align}&lt;p&gt;Following Eq. \eqref{eq:mftheta}, the first-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\frac{\partial m_{i,t}(\alpha=0)}{\partial\alpha} = \left( 1-m^{2}_{i,t} \right) \left( -\theta_{i,t} + x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} \right) = 0,
\end{equation}&lt;p&gt;so that $\theta^{*}_{i,t} = x_{i,t} + \sum_{j} J_{ij} m_{j,t-1}$ and we end up with the naive mean-field equations:&lt;/p&gt;
\begin{equation}
\boxed{m_{i,t} = \tanh \left( x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} \right)} \label{eq:naivem}
\end{equation}&lt;p&gt;Again following Eq. \eqref{eq:mftheta}, the second-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\frac{\partial m_{i,t}(\alpha=0)}{\partial\alpha} + \frac{1}{2} \frac{\partial^{2} m_{i,t}(\alpha=0)}{\partial\alpha^2} = 0,
\end{equation}&lt;p&gt;where the second-order derivative, neglecting terms higher than $\mathcal{O}(\alpha^2)$, is&lt;/p&gt;
\begin{equation}
\frac{\partial^{2} m_{i,t}(\alpha=0)}{\partial\alpha^2} \approx -2 m_{i,t} \left( 1-m^{2}_{i,t} \right) \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right)
\end{equation}&lt;p&gt;so that&lt;/p&gt;
\begin{equation}
\theta^{*}_{i,t} = x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} - m_{i,t} \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right)
\end{equation}&lt;p&gt;and we end up with the TAP mean-field equations:&lt;/p&gt;
\begin{equation}
\boxed{m_{i,t} = \tanh \left( x_{i,t} + \sum_{j} J_{ij} m_{j,t-1} - m_{i,t} \sum_{j} J^{2}_{ij} \left( 1-m^{2}_{j,t-1} \right) \right)} \label{eq:tapm}
\end{equation}&lt;p&gt;which includes the so-called Onsager correction term. The mean-field equations obtained above can also be elegantly derived using a Legendre transformation of the generating functional of the set of trajectories of the model, as outlined in e.g.
. We can also derive second-order TAP approximations of the correlations&lt;/p&gt;
\begin{equation}
C_{ik,t} = \begin{cases}
1 - m^{2}_{i,t} &amp; i = k \\\\
\left( 1-m^{2}_{i,t} \right) \left( 1-m^{2}_{k,t} \right) \sum_{j} J_{ij} J_{kj} \left( 1-m^{2}_{j,t-1} \right) &amp; i \neq k \label{eq:tapc}
\end{cases}
\end{equation}&lt;p&gt;and delayed correlations&lt;/p&gt;
\begin{equation}
D_{il,t} = J_{il} \left( 1-m^{2}_{i,t} \right) \left( 1-m^{2}_{l,t-1} \right) \left( 1 + 2 J_{il} m_{i,t} m_{l,t-1} \right). \label{eq:tapd}
\end{equation}&lt;p&gt;We refer to
for full derivations of the above mean-field results as well as variations based on different approximations of the marginal distribution $P( \mathbf{s}_{t} )$.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;In summary, given the mean magnetizations $\mathbf{m}_{t-1}$ of the system at time $t-1$, we can use equations \eqref{eq:tapm} \eqref{eq:tapc} \eqref{eq:tapd} to compute a tuple $(\mathbf{m}_{t},\mathbf{C}_{t},\mathbf{D}_{t})$ of approximate statistical properties of the system at time $t$. The time evolution of the system can be captured at the mean-field level by recursively computing $\mathbf{m}_{t}$ starting from an initial state $\mathbf{m}_{0}$ (with approximation errors likely accumulating over the course of the time evolution).&lt;/p&gt;
&lt;h2 id="a-simple-jax-implementation"&gt;A simple JAX implementation&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;To get more insight into what is going on, let us turn the mean-field update equations \eqref{eq:naivem} and \eqref{eq:tapm} for the mean magnetizations into code. But before we show a few plots, we need to know a bit more background about the model we are about to simulate. In
, the authors derive a solution of the asymmetric version of the kinetic
using a generating functional or dynamical partition function approach to capture the distribution of trajectories. They consider the same kinetic Ising model as in Eq. \eqref{eq:pcond} but with an inverse temperature parameter $\beta$ in the exponentials:&lt;/p&gt;
\begin{equation}
P( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta s_{i,t} h_{i,t}}}{2 \cosh \beta h_{i,t}}. \label{eq:pcondwithbeta}
\end{equation}&lt;p&gt;For Gaussian couplings $J_{ij} \sim \mathcal{N}\left( J_{\mu} / N, J^{2}_{\sigma} / N\right)$ and uniformly distributed external magnetic fields $x_{i} \sim \mathcal{U}(-X_{0}, X_{0})$, they show the existence of a ferromagnetic phase transition. In particular for $X_{0}=0.5$, $J_{\mu}=1.0$, and $J_{\sigma}=0.1$, a phase transition happens when tuning $\beta$ to a critical value $\beta_{c} \approx 1.1108$.&lt;/p&gt;
&lt;h3 id="simulating-magnetization-trajectories"&gt;Simulating magnetization trajectories&lt;/h3&gt;
&lt;p&gt;We first present a JAX implementation of the mean-field time evolution of the magnetizations according to the model described above. We use &lt;code&gt;jax.lax.scan&lt;/code&gt; to implement the time evolution and &lt;code&gt;jax.vmap&lt;/code&gt; to parallelize trajectories starting from a batch of initial magnetization configurations $\mathbf{m}_{0}$. For the second-order TAP equations, &lt;code&gt;jaxopt&lt;/code&gt;&amp;rsquo;s Anderson acceleration is used to find the fixed point magnetizations $\mathbf{m}_{t}$ given $\mathbf{m}_{t-1}$.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax.numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;jnp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;jaxopt&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_naive_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (22).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_tap_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (26).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fixed_point_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scan&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;init_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,),&lt;/span&gt; &lt;span class="n"&gt;minval&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxval&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;**-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;**-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;simulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;update_tap_mf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;init_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J_sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="naive-mean-field-vs-thouless-anderson-palmer-tap"&gt;Naive mean-field vs. Thouless-Anderson-Palmer (TAP)&lt;/h3&gt;
&lt;p&gt;We fix the seed and randomly initialize model parameters $\mathbf{x}$ and $\mathbf{J}$ to simulate $N=512$ spins at the critical temperature $\beta_{c}$ for $t=128$ time steps starting from an all-ones initial state. We first consider the naive mean-field update step.&lt;/p&gt;
&lt;img src="binary_plot_1.png" width="600px"/&gt;
&lt;p&gt;The left axis shows the individual magnetization trajectories for each spin plotted horizontally while the red line associated to the right axis describes the average of the magnetizations across all spins for each time step. We observe convergence to what looks like a &lt;em&gt;non-equilibrium / near-equilibrium steady state&lt;/em&gt; (NESS).&lt;/p&gt;
&lt;img src="binary_plot_2.png" width="550px"/&gt;
&lt;p&gt;Comparing the naive first-order mean-field update equations to the second-order Thouless-Anderson-Palmer (TAP) ones, we observe lower values for the mean magnetization across all spins, which
showed to be closer to ground truth values (not shown) obtained via sampling and averaging spin configurations.&lt;/p&gt;
&lt;h3 id="sampling-trajectories"&gt;Sampling trajectories&lt;/h3&gt;
&lt;p&gt;Let us consider 100 randomly-initialized initial states and simulate their associated trajectories in three different model regimes: below the critical point ($\beta=\beta_c / 2 $), at the critical point ($\beta=\beta_c$), and above the critical point ($\beta=2 \beta_c$).&lt;/p&gt;
&lt;img src="binary_plot_3.png" width="650px"/&gt;
&lt;p&gt;We observe that the trajectories of randomly-initialized initial states converge to identical final states in each regime. These final states map to a simple ferromagnetic Ising phase diagram, where a high-temperature disordered phase $\langle m_{i,t} \rangle \to 0$ (left) is separated from a low-temperature locally-ordered phase $\langle m_{i,t} \rangle \to \pm 1$ (right) by a critical point (center). The behavior around $\beta=\beta_{c}$ is pretty interesting: &lt;em&gt;the non-trivial steady state looks like an attractor implicitly encoded in the dynamics of the model&lt;/em&gt;. If we were to parameterize the couplings, we could train the system to act as an associative memory.&lt;/p&gt;
&lt;h3 id="sampling-model-parameters"&gt;Sampling model parameters&lt;/h3&gt;
&lt;p&gt;We now go back to considering just a single trajectory since we just saw that trajectories seem to converge to the same final steady-state magnetizations for fixed model parameters. To get a feel for the variation of these values across different realizations of model parameters, we plot the absolute value&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; $\| \langle m_{i} \rangle \|$ of the final steady-state magnetizations across 100 samples of model parameters and a range of inverse temperatures. We are using JAX, so we can easily sample model parameters by &lt;code&gt;vmap&lt;/code&gt;&amp;lsquo;ing the random key fed into the &lt;code&gt;simulate&lt;/code&gt; function followed by another &lt;code&gt;vmap&lt;/code&gt; to sweep across $\beta$.&lt;/p&gt;
&lt;img src="binary_plot_4.png" width="500px"/&gt;
&lt;p&gt;Every curve in the above plot describes the final steady-state value of the &amp;ldquo;order parameter&amp;rdquo; $\| \langle m_{i} \rangle \|$ for a fixed set of model parameters sweeping across $\beta$. We observe a greater spread of values near the critical point and hence an improved capacity to map input external fields to a range of output magnetizations. If we were to let the number of spins $N \to \infty$ and average over a large number of model parameter samples, the finite-size results above would probably transform into a sharp curve with zero magnetization below the critical point and a sudden non-zero magnetization emerging at the critical point.&lt;/p&gt;
&lt;h1 id="mean-field-theory-of-asymmetric-ising-models-with-vector-spins"&gt;Mean-field theory of asymmetric Ising models with vector spins&lt;/h1&gt;
&lt;p&gt;We now transpose the binary-spin results of the previous section to a setting where local spin degrees of freedom are $D$-dimensional vector spins restricted to wiggle around on $(D-1)$-dimensional spheres. We start by generalizing the conditional distribution Eq. \eqref{eq:pcondalt} to vector spins. Next, we motivate the limit of large vector dimension and derive first-order and second-order mean-field update equations for the mean magnetizations. We finish this section with a JAX implementation and some toy numerical experiments.&lt;/p&gt;
&lt;h2 id="vector-spins-distributions-on-hyperspheres"&gt;Vector spins: distributions on hyperspheres&lt;/h2&gt;
&lt;img src="vector_spins.png" alt="Random Ising model configuration with vector spins" width="250px"/&gt;
&lt;p&gt;A vector-spin equivalent of Eq. \eqref{eq:pcondalt} looks something like&lt;/p&gt;
\begin{equation}
P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \mathbf{h}_{i,t}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s}_{i,t} \; \mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \mathbf{h}_{i,t}(\alpha)} }, \label{eq:pcondaltvector}
\end{equation}&lt;p&gt;where we immediately included an inverse temperature $\beta$ like in Eq. \eqref{eq:pcondwithbeta}. A vector-spin equivalent of Eq. \eqref{eq:pcondalth} is&lt;/p&gt;
\begin{equation}
\mathbf{h}_{i,t}(\alpha) = (1-\alpha) \boldsymbol{\theta}_{i,t} + \alpha \left( \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{s}_{j,t-1} \right) \equiv \boldsymbol{\theta}_{i,t} + \alpha \Delta \mathbf{h}_{i,t}, \label{eq:pcondalthvector}
\end{equation}&lt;p&gt;where $S_{D-1}(R) = \{ x \in \mathbb{R}^{D} : \lVert x \rVert = R \}$ denotes the $(D-1)$-dimensional sphere with radius $R$ embedded in $D$ dimensions. Let us focus on the distribution for a single site and drop all subscripts and dependencies for clarity:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }. \label{eq:pcondsinglesitevector}
\end{equation}&lt;p&gt;The normalization constant in the denominator can be shown to be (see
)&lt;/p&gt;
\begin{equation}
\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } \equiv Z(\beta, R, \lVert \mathbf{h}\rVert) \label{eq:partfun}
\end{equation}&lt;p&gt;where $I_{\nu}(z)$ denotes the modified Bessel function of the first kind and $\lVert \mathbf{h} \rVert = \sqrt{\mathbf{h} \cdot \mathbf{h}}$. Physically, we can think of this single-site distribution as measuring dot-product alignment to an effective external magnetic field $\mathbf{h}$ at inverse temperature $\beta$.&lt;/p&gt;
&lt;p&gt;If we consider spins living on the unit sphere $R=1$ as well as unit vectors $\mathbf{h}$, the distribution boils down to a
with mean direction $\boldsymbol{\mu} \equiv \mathbf{h}$ and
$\kappa \equiv \beta$. This distribution is unimodal for $\kappa &gt; 0$ and can be derived from restricting an isotropic multivariate Gaussian to the unit hypersphere. The greater the value of $\kappa$ (the inverse temperature $\beta$), the higher the concentration of the distribution around the mean direction $\boldsymbol{\mu}$ (the more the spin tends to align to the effective external field $\mathbf{h}$). Though instead of a fixed parameter $\boldsymbol{\mu}$, we have a very funky parameter Eq. \eqref{eq:pcondalthvector} that depends on all other spins to spice things up.&lt;/p&gt;
&lt;h2 id="magnetizations-and-limit-of-large-vector-dimension"&gt;Magnetizations and limit of large vector dimension&lt;/h2&gt;
&lt;p&gt;Before we derive mean-field approximations for the mean magnetizations of our vector-spin system, let us first consider the decoupled $\alpha \to 0$ limit of the distribution Eq. \eqref{eq:pcondaltvector},&lt;/p&gt;
\begin{equation}
Q( \mathbf{s}_{t} \vert \boldsymbol{\theta}_{t} ) = \prod_{i=1}^{N} \frac{\mathrm{e}^{\beta \, \mathbf{s}_{i,t} \cdot \boldsymbol{\theta}_{i,t}}}{Z_{i,t}\left(\beta, R, \lVert \boldsymbol{\theta}_{i,t} \rVert\right)},
\end{equation}&lt;p&gt;and find an expression for its mean magnetizations. For every decoupled site, the mean magnetization can be shown to be (see
)&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \frac{I_{D/2}(\beta R \lVert \boldsymbol{\theta}_{i,t} \rVert)}{I_{D/2 - 1}(\beta R \lVert \boldsymbol{\theta}_{i,t} \rVert)} \frac{R \boldsymbol{\theta}_{i,t}}{\lVert \boldsymbol{\theta}_{i,t} \rVert} \equiv \boldsymbol{\varphi} \left(\boldsymbol{\theta}_{i,t}\right), \label{eq:meanmagsbessels}
\end{equation}&lt;p&gt;which plays the role of $m_{i,t} = \tanh \theta_{i,t}$ in the binary setting, see Eq. \eqref{eq:meanmagstanh}. Looking ahead at turning the above equation into code, we note that there exist
to compute the ratio of modified Bessel functions of the first kind. We implement a fast JAX version in
and show numerically how the ratio flattens out quickly for large values of the order $\nu = D/2 -1$, motivating some kind of large-order expansion.&lt;/p&gt;
&lt;p&gt;Remember that our goal is to make a connection to transformer neural networks. Since the vector dimension in dense transformer modules tends be somewhere between $\mathcal{O}(10^2)$ and $\mathcal{O}(10^5)$, it is not nonsensical to focus on the large vector dimension limit. A relevant uniform asymptotic expansion of the ratio of modified Bessel functions of the first kind is
:&lt;/p&gt;
\begin{align}
\frac{I_{\nu+\alpha}(\nu x)}{I_{\nu}(\nu x)} = \left( \frac{x}{1+\sqrt{1+x^2}} \right)^{\alpha} \left( 1 - \frac{1+\alpha\sqrt{1+x^2}}{2(1+x^2)} \frac{\alpha}{\nu} + \mathcal{O}\left( \frac{1}{\nu^2} \right) \right)
\end{align}&lt;p&gt;Indeed, if we choose to tie the radius $R$ of our little spins to their vector dimension $D$ via&lt;/p&gt;
\begin{align}
\nu=D/2-1=R^2,
\end{align}&lt;p&gt;we can apply the leading order of the asymptotic expansion for $\alpha=1$ to \eqref{eq:meanmagsbessels} to find&lt;/p&gt;
\begin{equation}
\mathbf{m}^{D \to \infty}_{i,t} \approx \frac{\beta}{1+\gamma( \lVert \boldsymbol{\theta}_{i,t} \rVert )} \boldsymbol{\theta}_{i,t} \equiv \boldsymbol{\varphi}^{D \to \infty} \left(\boldsymbol{\theta}_{i,t}\right). \label{eq:largedevmag}
\end{equation}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\gamma \left(\lVert \boldsymbol{\theta}_{i,t} \rVert\right) = \sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 },
\end{align}&lt;p&gt;From here on, we will default to using the large-$D$ approximation because keeping track of (derivatives of) Bessel functions gets boring real quick. We refer to
for some truly outrageous expressions pertaining to the general case valid for all $D&gt;1$.&lt;/p&gt;
&lt;h2 id="first-order-naive-mean-field-approximation"&gt;First-order naive mean-field approximation&lt;/h2&gt;
&lt;p&gt;All right, let&amp;rsquo;s go. Closely mimicking the binary case, we start from the following approximated marginal probability distribution&lt;/p&gt;
\begin{equation}
P^{[t-1:t]}_{\alpha}( \mathbf{s}_{t} ) = \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ),
\end{equation}&lt;p&gt;interpolating between $P^{[t-1:t]}_{\alpha=0}( \mathbf{s}_{t} ) = Q( \mathbf{s}_{t} )$ and $P^{[t-1:t]}_{\alpha=1}( \mathbf{s}_{t} ) = P( \mathbf{s}_{t} )$. Our lazy integral notation $\int \mathrm{d} \mathbf{s}_{t}$ should be understood as $\int \prod_{i=1}^{N} \mathrm{d}^{D} \mathbf{s}_{i, t}$, i.e. integrating over all the little spins at a fixed time $t$. The estimated mean magnetizations are&lt;/p&gt;
\begin{align}
\mathbf{m}_{i,t}(\alpha) &amp;= \int \mathrm{d} \mathbf{s}_{t} \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; \mathbf{s}_{i,t} P_{\alpha}( \mathbf{s}_{t} \vert \mathbf{s}_{t-1} ) P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ) \nonumber\\\\
&amp;= \int \mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \; \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) P( \mathbf{s}_{t-2} ).
\end{align}&lt;p&gt;The first-order derivative with respect to $\alpha$ is then given by&lt;/p&gt;
\begin{align}
\frac{\partial \mathbf{m}_{i,t}(\alpha)}{\partial\alpha} = \int &amp;\mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \Biggl( \frac{\partial\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha} \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) \nonumber\\\\
&amp;+ \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, \frac{\partial P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha} \Biggr) P( \mathbf{s}_{t-2} ), \label{eq:mitfirstorderalpha}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\frac{\partial \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} = \frac{\beta}{1+\gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right)} \Delta \mathbf{h}_{i,t} - \frac{\beta}{R^2} \frac{ \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) }{ \gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right) } \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \label{eq:firstorderphialpha}
\end{align}&lt;p&gt;Evaluating \eqref{eq:mitfirstorderalpha} at $\alpha=0$, the second term drops out because the first-order derivative of $P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )$ becomes independent of $\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)$ and $\int \mathrm{d} \mathbf{s}_{t-1} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )=1$. We thus end up with&lt;/p&gt;
\begin{align}
\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} = \frac{\beta}{1+\gamma \left(\lVert \boldsymbol{\theta}_{i,t} \rVert\right)}\boldsymbol{v}_{i,t} - \frac{\beta}{R^2}\frac{\left( \mathbf{m}_{i,t} \cdot \boldsymbol{v}_{i,t} \right)}{\gamma \left(\lVert \boldsymbol{\theta}_{i,t} \rVert\right)} \mathbf{m}_{i,t} \label{eq:mfirstorderalphazero}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\boldsymbol{v}_{i,t} = -\boldsymbol{\theta}_{i,t} + \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{m}_{j,t-1} \label{eq:vmf}
\end{align}&lt;p&gt;captures the result of integrating $\Delta \mathbf{h}_{i,t}$ over the spins $\mathbf{s}_{t-1}$. Following Eq. \eqref{eq:mftheta}, the first-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\left[ \alpha \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} \right]_{\alpha=1} = \mathbf{0} + \left[ \mathcal{O}\left(\alpha^2\right)\right]_{\alpha=1},\label{eq:firstorderapproxreqs}
\end{equation}&lt;p&gt;so that we are encouraged to set $\boldsymbol{v}_{i,t}=0$ and hence $\boldsymbol{\theta}^{*}_{i,t} = \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1}$, leading to the naive mean-field equations:&lt;/p&gt;
\begin{equation}
\boxed{ \mathbf{m}_{i,t} = \frac{\beta \left( \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \right)}{1+\sqrt{1+\beta^2 \lVert \mathbf{x}_{i,t} + \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \rVert^2 / R^2 }} } \label{eq:naivemvector}
\end{equation}&lt;p&gt;Looking ahead at the transformer-module correspondence in
, we squint our eyes and recognize a scaled sum of a residual connection and an attention term. No feed-forward terms though.&lt;/p&gt;
&lt;p&gt;Before moving on to the second-order approximation, let us end this section with an interesting observation about Eq. \eqref{eq:mfirstorderalphazero}. In
, we show that the variance matrix of a single spin in the large-$D$ limit equals a rank-1 perturbation of a diagonal matrix&lt;/p&gt;
\begin{align}
\mathrm{Var} [ \mathbf{s}_{i,t} ] &amp;= \frac{\mathbb{1}}{1+\gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right)} - \frac{ \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \otimes \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) }{ R^2 \gamma \left(\lVert \mathbf{h}_{i,t}(\alpha) \rVert\right) }, \label{eq:spinvariance}
\end{align}&lt;p&gt;Taking the $\alpha \to 0$ limit of the above expressions, we can reinterpret Eq. \eqref{eq:mfirstorderalphazero} as the matrix-vector multiplication of the decoupled spin&amp;rsquo;s variance matrix with $\boldsymbol{v}_{i,t}$,&lt;/p&gt;
\begin{align}
\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} = \beta \mathrm{Var} [ \mathbf{s}_{i,t} ] \boldsymbol{v}_{i,t}.
\end{align}&lt;h2 id="second-order-thouless-anderson-palmer-approximation"&gt;Second-order Thouless-Anderson-Palmer approximation&lt;/h2&gt;
&lt;p&gt;Let us now try to find out whether going to the second-order approximation spits out additional Onsager feed-forward like correction terms in the update equations for the magnetizations.&lt;/p&gt;
&lt;p&gt;Again following Eq. \eqref{eq:mftheta}, the second-order approximation should satisfy&lt;/p&gt;
\begin{equation}
\left[ \alpha \frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} \right]_{\alpha=1} + \left[ \frac{\alpha^2}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2}\right]_{\alpha=1} = \mathbf{0} + \left[ \mathcal{O}\left(\alpha^3\right)\right]_{\alpha=1}, \label{eq:secondorderconstraint}
\end{equation}&lt;p&gt;where the second-order derivative is given by&lt;/p&gt;
\begin{align}
\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha)}{\partial\alpha^2} = \int &amp;\mathrm{d} \mathbf{s}_{t-1} \int \mathrm{d} \mathbf{s}_{t-2} \Biggl( \frac{\partial^{2}\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha^2} \, P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} ) \nonumber\\\\
&amp;+ 2\frac{\partial\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)}{\partial\alpha} \, \frac{\partial P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha} \nonumber \\\\
&amp;+ \boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right) \, \frac{\partial^{2} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )}{\partial\alpha^2} \Biggr) P( \mathbf{s}_{t-2} ). \label{eq:mhasecordder}
\end{align}&lt;p&gt;Evaluated at $\alpha=0$, the third term in the expression above will drop out because the derivative becomes independent of $\boldsymbol{\varphi} \left(\mathbf{h}_{i,t}(\alpha)\right)$ and $\int \mathrm{d} \mathbf{s}_{t-1} P_{\alpha}( \mathbf{s}_{t-1} \vert \mathbf{s}_{t-2} )=1$.&lt;/p&gt;
&lt;p&gt;The first term in Eq. \eqref{eq:mhasecordder} can be shown to look something like&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha^2} = &amp; \frac{\beta^2}{R^4} \frac{ 1+\gamma_{i,t}(\alpha) }{ \gamma_{i,t}(\alpha)^3 } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right)^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta}{R^2} \frac{1}{\gamma_{i,t}(\alpha)} \left( \frac{\partial\boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} \cdot \Delta \mathbf{h}_{i,t} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta}{R^2} \frac{1}{\gamma_{i,t}(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \frac{\partial\boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha} \nonumber \\\\
&amp;- \frac{\beta^2}{R^2} \frac{1}{\gamma_{i,t}(\alpha)^2 + \gamma_{i,t}(\alpha) } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \Delta \mathbf{h}_{i,t},
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\gamma_{i,t} (\alpha) \equiv \gamma\left( \lVert \mathbf{h}_{i,t}(\alpha) \rVert \right) = \sqrt{1+\beta^2 \lVert \mathbf{h}_{i,t}(\alpha) \rVert^2 / R^2 },
\end{align}&lt;p&gt;which, after substituting the first-order derivative Eq. \eqref{eq:firstorderphialpha}, simplifies to&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha))}{\partial\alpha^2} = &amp; \frac{\beta^2}{R^4} \frac{ 1+3\gamma_{i,t}(\alpha) }{ \gamma_{i,t}(\alpha)^3 } \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right)^2 \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta^2}{R^2} \frac{1}{\gamma_{i,t}(\alpha)^2 + \gamma_{i,t}(\alpha)} \left( \Delta \mathbf{h}_{i,t} \cdot \Delta \mathbf{h}_{i,t} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \nonumber \\\\
&amp;- \frac{\beta^2}{R^2} \frac{2}{\gamma_{i,t}(\alpha)^2 + \gamma_{i,t}(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \Delta \mathbf{h}_{i,t} \right) \Delta \mathbf{h}_{i,t} . \label{eq:secondorderphialpha}
\end{align}&lt;p&gt;The second term in Eq. \eqref{eq:mhasecordder} contains non-vanishing contributions in the $\alpha \to 0$ limit coming from the $\sum_{j=1}^{N} J_{ij} \mathbf{s}_{j, t-1}$ terms in $\Delta \mathbf{h}_{i,t}$. One can show that the surviving terms in the integrand are proportional to&lt;/p&gt;
\begin{align}
\sum_{j} J_{ij} \Biggl( &amp;\frac{2 \beta^2}{1+\gamma_{i,t}(\alpha)} \frac{\partial\mathbf{m}_{j, t-1}(\alpha)}{\partial\alpha} \nonumber \\\\
&amp;- \frac{2 \beta^2}{R^2 \gamma_{i,t}(\alpha)} \left( \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \cdot \frac{\partial\mathbf{m}_{j, t-1}(\alpha)}{\partial\alpha} \right) \boldsymbol{\varphi}(\mathbf{h}_{i,t}(\alpha)) \Biggr),
\end{align}&lt;p&gt;which we can ignore since they are $\mathcal{O}(\alpha)$ on their own, and thus $\mathcal{O}(\alpha^3)$ when multiplied with $\alpha^2$ in the second-order approximation.&lt;/p&gt;
&lt;p&gt;Before taking the $\alpha \to 0$ limit of whatever is left in Eq. \eqref{eq:mhasecordder}, we list a few useful tricks to make the evaluation easier. First of all, we use Eq. \eqref{eq:vmf} to introduce the following sneaky substitution&lt;/p&gt;
\begin{align}
\Delta \mathbf{h}_{i,t} = -\boldsymbol{\theta}_{i,t} + \mathbf{x}_{i,t} + \sum_{j=1}^{N} J_{ij} \mathbf{s}_{j,t-1} = \boldsymbol{v}_{i,t} + \sum_{j=1}^{N} J_{ij} \left( \mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1} \right),
\end{align}&lt;p&gt;which conveniently separates terms with fluctuating spin variables from magnetizations that can be pulled out of the integrals. Secondly, all terms that contain only one spin variable with a dependence looking like $\mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1}$ drop out because, schematically,&lt;/p&gt;
\begin{align}
\mathbf{s}_{j,t-1} - \mathbf{m}_{j,t-1} \overset{\int \mathrm{d} \mathbf{s}_{t-1}}{\to} \boldsymbol{\varphi}(\mathbf{h}_{j,t}(\alpha)) - \mathbf{m}_{j,t-1} \overset{\alpha \to 0}{\to} \mathbf{0}.
\end{align}&lt;p&gt;Thirdly, since the $\alpha \to 0$ limit decouples all spins $\mathbf{s}_{t-1}$, any term containing dot products $(\mathbf{s}_{j,t-1}-\mathbf{m}_{j,t-1}) \cdot (\mathbf{s}_{k,t-1}-\mathbf{m}_{k,t-1})$ of two spin variables is zero for $j \neq k$ and equal to $R^2 - \mathbf{m}^2_{j,t-1}$ for $j=k$. We will also encounter terms containing (tensor contractions with) outer products $(\mathbf{s}_{j,t-1}-\mathbf{m}_{j,t-1}) \otimes (\mathbf{s}_{k,t-1}-\mathbf{m}_{k,t-1})$, which we can think of as projection operators. For $j \neq k$, these and similar terms again evaluate to zero, while, for $j=k$, we get the variance contributions we mentioned previously in Eq. \eqref{eq:spinvariance} at the end of the previous section.&lt;/p&gt;
&lt;p&gt;Finally, we take the $\alpha \to 0$ limit of Eq. \eqref{eq:mhasecordder} only to end up with the following mess:&lt;/p&gt;
\begin{align}
&amp;\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} = \label{eq:secondordercorrections} \\\\
\end{align}&lt;p&gt;
&lt;/p&gt;
\begin{align}
&amp;\hspace{-1em}\frac{\beta^2}{R^4} \frac{1+3\gamma_{i,t}(0)}{\gamma_{i,t}(0)^3} \left( \left( \mathbf{m}_{i,t} \cdot \mathbf{v}_{i,t} \right)^2 + \sum_{j} J_{ij}^2 \left( \frac{\mathbf{m}_{i,t}^2}{1+\gamma_{i,t-1}(0)} - \frac{\left(\mathbf{m}_{i,t}\cdot\mathbf{m}_{j,t-1}\right)^2}{R^2 \gamma_{i,t-1}(0)} \right) \right) \mathbf{m}_{i,t} \nonumber \\\\
&amp;\hspace{-1em}- \frac{\beta^2}{R^2} \frac{1}{\gamma_{i,t}^2 (0) + \gamma_{i,t}(0)} \left( \mathbf{v}_{i,t}^2 + \sum_{j} J_{ij}^2 \left( R^2 - \mathbf{m}_{j,t-1}^2 \right) \right) \mathbf{m}_{i,t} \nonumber \\\\
&amp;\hspace{-1em}- \frac{\beta^2}{R^2} \frac{2}{\gamma_{i,t}^2 (0) + \gamma_{i,t}(0)} \Biggr( \mathbf{v}_{i,t} \otimes \mathbf{v}_{i,t} + \sum_{j} J_{ij}^2 \left( \frac{\mathbb{1}}{1+\gamma_{i,t-1}(0)} - \frac{\mathbf{m}_{j,t-1}\otimes\mathbf{m}_{j,t-1}}{R^2 \gamma_{i,t-1}(0)} \right) \Biggr) \mathbf{m}_{i,t} \nonumber
\end{align}&lt;p&gt;At this point, it is too late. We should have remembered that the second-order approximation lives in the neighborhood of the first-order approximation. We probably ended up doing too much work by taking terms into account that are of higher order in $\alpha$. We can always drop terms later on if it turns out they are neglible at $\mathcal{O}(\alpha^2)$.&lt;/p&gt;
&lt;p&gt;To get to the second-order mean-field equations for the magnetizations, we have to solve Eq. \eqref{eq:secondorderconstraint} for the optimal parameters $\boldsymbol{\theta}^{*}_{i,t}$, i.e.,&lt;/p&gt;
\begin{equation}
\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha} + \frac{1}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} = \mathbf{0} + \mathcal{O}\left(\alpha^3\right).
\end{equation}&lt;p&gt;Let us substitute $\frac{\partial \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha}$ from Eq. \eqref{eq:mfirstorderalphazero} but keep $\frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2}$ for generality,&lt;/p&gt;
\begin{align}
\beta \left( \frac{\mathbb{1}}{1+\gamma_{i,t}(0)} - \frac{\mathbf{m}_{i,t}\otimes\mathbf{m}_{i,t}}{R^2 \gamma_{i,t}(0)} \right) \mathbf{v}_{i,t} + \frac{1}{2} \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} = \mathbf{0} + \mathcal{O}\left(\alpha^3\right),
\end{align}&lt;p&gt;so that we can then isolate $\boldsymbol{\theta}_{i,t}$ in $\mathbf{v}_{i,t}$ to find&lt;/p&gt;
\begin{align}
\boldsymbol{\theta}_{i,t} = \mathbf{x}_{i,t} &amp;+ \sum_{j} J_{ij} \mathbf{m}_{j,t-1} \nonumber \\\\
&amp;+ \frac{1+\gamma_{i,t}(0)}{2\beta} \left( \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2} + \frac{\mathbf{m}_{i,t} \cdot \frac{\partial^{2} \mathbf{m}_{i,t}(\alpha=0)}{\partial\alpha^2}}{\frac{R^2 \gamma_{i,t}(0)}{1+\gamma_{i,t}(0)} - \mathbf{m}_{i,t}^2} \mathbf{m}_{i,t} \right),\label{eq:ftheta}
\end{align}&lt;p&gt;where we have used the
to compute the inverse of the variance matrix. Since the expression on the right-hand side &lt;em&gt;also&lt;/em&gt; depends on $\boldsymbol{\theta}_{i,t}$, we seem to have stumbled upon a set of fixed-point equations which we should solve for $\boldsymbol{\theta}^{*}_{i,t}$,&lt;/p&gt;
\begin{align}
\boldsymbol{\theta}_{i,t} = \mathbf{f} (\boldsymbol{\theta}_{i,t}, \mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1}), \label{eq:thetafp}
\end{align}&lt;p&gt;where the function $\mathbf{f}$ is given by the right-hand side of Eq. \eqref{eq:ftheta}. The second-order mean-field equations then become &lt;em&gt;yet another&lt;/em&gt; set of fixed-point equations&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left(\boldsymbol{\theta}^{*}_{i,t}(\mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1})\right)
\end{equation}&lt;p&gt;because of the dependence of $\boldsymbol{\theta}^{*}_{i,t}$ on $\mathbf{m}_{i,t}$. Similar to the binary TAP approximation Eq. \eqref{eq:tapm}, this dependency suggests that we should solve for fixed-point magnetizations $\mathbf{m}^{*}_{i,t}$. However, in contrast to the binary case, the dependence here is &lt;em&gt;implicit&lt;/em&gt; since $\boldsymbol{\theta}^{*}_{i,t}$ is itself obtained from solving fixed-point equations Eq. \eqref{eq:thetafp}, which, in turn, also depend on $\mathbf{m}_{i,t}$.&lt;/p&gt;
&lt;p&gt;The problem setup looks like a
, where the solutions to the inner-level fixed-point equations are fed as parameters to the outer-level fixed-point equations. Because of the hierarchical relationship and the implicit dependence of the outer solution on the inner problem&amp;rsquo;s parameters, bi-level optimization can be potentially computationally demanding and unstable. Let us try to sidestep this dreadfulness by writing all instances of $\boldsymbol{\theta}_{i,t}$ in Eq. \eqref{eq:ftheta} in terms of $\mathbf{m}_{i,t}$ by inverting Eq \eqref{eq:largedevmag} so that, for $\mathbf{m}^2_{i,t} &lt; R^2$,&lt;/p&gt;
\begin{equation}
\boldsymbol{\theta}_{i,t} = \frac{2 R^2}{\beta \left( R^2 - \mathbf{m}^2_{i,t} \right)} \mathbf{m}_{i,t},\label{eq:invphi}
\end{equation}&lt;p&gt;leading to a set of fixed-point equations in terms of only $\mathbf{m}_{i,t}$,&lt;/p&gt;
\begin{equation}
\boxed{\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left( \mathbf{f} (\mathbf{x}_{i,t}, \mathbf{m}_{i,t}, \mathbf{m}_{t-1})\right) } \label{eq:tapmvector}
\end{equation}&lt;p&gt;Looking ahead at the transformer-module correspondence in
, we recognize a scaled sum of a residual connection, an attention term, and a self-consistent expression in terms of magnetizations and couplings taking on the role of the feed-forward network. Interestingly, these second-order correction terms require &lt;em&gt;no additional free parameters&lt;/em&gt; since they are &lt;em&gt;fully determined by the mean-field structure&lt;/em&gt; of the underlying spin model.&lt;/p&gt;
&lt;h2 id="a-simple-jax-implementation-1"&gt;A simple JAX implementation&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We now turn to a JAX implementation of the mean-field time evolution of the magnetizations according to the vector-spin model introduced in the previous sections. Compared to the binary-spin simulations of
, we will not attempt to precisely tune the vector-spin model since computing its critical temperature and quirky phase-diagram properties is well beyond the scope of this work. We will instead take an empirical approach and play around with a numerical implementation to figure out what works. Along the way, we provide some physical intuition.&lt;/p&gt;
&lt;h3 id="simulating-magnetization-trajectories-1"&gt;Simulating magnetization trajectories&lt;/h3&gt;
&lt;p&gt;The JAX reference implementation looks very similar to the binary-spin case. Essentially, we have to keep track of an additional vector dimension and replace the update equations with the vector equivalents introduced in the previous sections. We deliberately do not fiddle with the hyperparameters of the fixed-point solver &lt;code&gt;AndersonAcceleration&lt;/code&gt; to ensure robustness of exploratory results.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;jax.numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;jnp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;jaxopt&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (39).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (38).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_naive_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (47).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (64).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_d2_m_d_alpha_2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (58).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;g0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;g1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i d, i d -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, i d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, i d, j d, i e, j e -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, j -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i d, i d, i f -&amp;gt; i f&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, i d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;i j, i d, j d, j f -&amp;gt; i f&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (61).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;g1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_inv_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;d2_m_d_alpha_2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_d2_m_d_alpha_2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ff&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;d2_m_d_alpha_2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i d, i d -&amp;gt; i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d2_m_d_alpha_2&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;g1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;i j, j d -&amp;gt; i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;ff&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_tap_mf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;See Eq. (65).&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fixed_point_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tap&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;m1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scan&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;simulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;update_tap_mf&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;time_evolution&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_fun&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wrapped_time_evolution&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_carry&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stacked_outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="playing-with-parameter-scales-an-exploration"&gt;Playing with parameter scales: an exploration&lt;/h3&gt;
&lt;p&gt;To get a feel for the complexity, let us visualize a $N=64$ sample of a coupling matrix $\mathbf{J} \in \mathbb{R}^{N \times N}$ drawn from $\mathcal{N}\left( 0, 1/N \right)$ using a visually appealing yet utterly pointless ball-of-yarn plot:&lt;/p&gt;
&lt;img src="vector_plot_1.png" width="500px"/&gt;
&lt;p&gt;We randomly initialize the external magnetic fields $\mathbf{x} \in \mathbb{R}^{N \times D}$ and coupling matrix $\mathbf{J} \in \mathbb{R}^{N \times N}$ by drawing from respectively $\mathcal{N}\left( 0, 1\right)$ and $\mathcal{N}\left( 0, 1/N \right)$ and simulate $N=1024$ $(D=512)-$dimensional vector spins at inverse temperature $\beta=1.0$ for $t=20$ time steps starting from an intial state $\mathbf{m}_{0} \in \mathbb{R}^{N \times D}$ of all-ones vectors. We choose to normalize all $\mathbf{x}$ vectors to lie on the spherical shell at radius $R$, so that $\mathbf{x}_{i} \to R \mathbf{x}_{i} / \lVert\mathbf{x}_{i}\rVert$. We apply the same external magnetic fields at all time steps ($\mathbf{x}_{t} \equiv \mathbf{x}$, $\forall t \geq 0$) so that the probing of the system is time-independent and relentless.&lt;/p&gt;
&lt;p&gt;We first consider the first-order naive mean-field update equations. To visualize a set of vectors evolving in time, we track their directionalities with respect to reference states using cosine similarities and their magnitudes using Euclidean norms.&lt;/p&gt;
&lt;img src="vector_plot_2.png" width="300px"/&gt;
&lt;p&gt;The top plot shows the cosine-similarity alignments of individual magnetization trajectories $\mathbf{m}_{i,t}$ compared to respectively $\mathbf{m}_{i,t-1}$ (green, magnetizations at previous time step to track convergence), $\mathbf{m}_{0}$ (yellow, magnetizations at initial time step to track drift from initial conditions), and $\mathbf{x}_{i}$ (blue, time-independent external magnetic fields to track alignment with the &amp;ldquo;residual stream&amp;rdquo;). The bottom plot tracks the evolution of the norms of $\mathbf{m}_{i,t}$ during time evolution. From the tracked metrics, we observe convergence to what looks like a &lt;em&gt;non-equilibrium / near-equilibrium steady state&lt;/em&gt; (NESS) with magnetizations remaining dynamically stable at the mean-field level.&lt;/p&gt;
&lt;p&gt;To compare the naive first-order mean-field update equations to the second-order Thouless-Anderson-Palmer (TAP) ones, we plot the mean magnetization trajectories across all sites and add shading to denote the spread of maximum and minimum values.&lt;/p&gt;
&lt;img src="vector_plot_3.png" width="500px"/&gt;
&lt;p&gt;We observe that the final TAP magnetizations are slightly different for our particular choice of parameters. The Onsager correction term seems to account for at least some correlations, lowering the local effective mean field and hence the magnitude of the magnetizations. If we lower the temperature to $\beta = 2.0$ while keeping all other parameters fixed, the difference becomes more pronounced:&lt;/p&gt;
&lt;img src="vector_plot_4.png" width="500px"/&gt;
&lt;p&gt;Lowering the temperature further while keeping all other parameters fixed starts leading to convergence issues for the TAP equations. If we go back to $\beta=1.0$ but (1) increase the random interaction strengths by doubling the elements of the coupling matrix and (2) reduce the influence of the random external magnetic fields by normalizing all $\mathbf{x}$ vectors to lie on the unit sphere, we end up in a regime where we observe that the naive mean-field equations have trouble converging whereas the TAP magnetizations quickly settle into a small-norm fixed-point solution:&lt;/p&gt;
&lt;img src="vector_plot_5.png" width="500px"/&gt;
&lt;h3 id="playing-with-parameter-scales-an-explanation"&gt;Playing with parameter scales: an explanation&lt;/h3&gt;
&lt;p&gt;To better understand the behavior of the system, we focus on the inverse temperature $\beta$, the magnitudes $\lVert\mathbf{x}_{i,t}\rVert$ of the external magnetic fields, the scale of the coupling matrix elements $J_{ij}$, and the vector-spin radius $R=\sqrt{D/2-1}$. The latter is fixed for fixed dimension $D$ and provides a natural length scale. In spin-glas mean-field theory, the random coupling matrix is usually chosen to have a variance of $1/N$ to ensure the existence of a proper thermodynamic limit. The magnitudes of the external magnetic fields determine to what extent the vector spins will try to align with their imposed external environment or yield to the influence of their neighbours. The relation between the scales of the couplings and the fields should be such that meaningful competition between the external magnetic fields and the intrinsic spin-spin interactions can occur. Finally, the system&amp;rsquo;s overall behavior is further governed by the thermal noise introduced via the inverse temperature $\beta$.&lt;/p&gt;
&lt;p&gt;Revisiting the magnetization equation Eq. \eqref{eq:largedevmag},&lt;/p&gt;
\begin{equation}
\mathbf{m}_{i,t} = \boldsymbol{\varphi} \left(\boldsymbol{\theta}_{i,t}\right) = \frac{\beta}{1+\sqrt{1+\beta^2 \lVert \boldsymbol{\theta}_{i,t} \rVert^2 / R^2 }} \boldsymbol{\theta}_{i,t},
\end{equation}&lt;p&gt;we observe that the infinite-temperature limit $\beta \to 0$ pushes the magnitude of the magnitization to $0$ whereas the zero-temperature limit $\beta \to \infty$ snaps to the spherical shell at radius $R$. We plot the norm of this equation for different values of $\beta$ as a function of $\lVert\boldsymbol{\theta}\rVert$ in $D=512$ dimensions below. The dashed horizontal and vertical lines indicate the value of $R=\sqrt{D/2-1}\approx 15.9687$.&lt;/p&gt;
&lt;img src="vector_plot_6.png" width="450px"/&gt;
&lt;p&gt;This plot partly explains why the TAP equations start showing convergence issues at lower temperatures. Large values of $\beta$ push the norm of the magnetizations towards $R$, but that in turn leads to $\boldsymbol{\theta}$ blowing up because of the $R^2-\mathbf{m}^2_{i,t}$ factors in the denominators of Eq. \eqref{eq:ftheta} and Eq. \eqref{eq:invphi}. This is no surprise since the Plefka expansion is in fact a high-temperature expansion. Indeed, if we write out the mean-field update equations, we find that the first-order terms scale as $\beta$ and the second-order terms as $\beta^2$. Additionally, we know from mean-field theory of binary spin glasses that the TAP equations break down when crossing the so-called de Almeida-Thouless line (AT line) in the $(\beta, x)$ phase diagram. Assuming an
, it might be worth rederiving the Onsager term like was done for binary spins in
to make sure its time indices are more geared towards convergence. But even then we would still not be able to cross the AT line and find mean-field solutions at lower temperatures.&lt;/p&gt;
&lt;p&gt;But we have to ask ourselves whether we actually care about this low-temperature failure mode for our purposes. Do we want a spin-transformer module to inhabit a complex spin-glass phase full of local minima containing frozen disordered spins that cannot respond to external magnetic fields? No. We would like our system to be able to fluidly and adaptively respond to its environment.&lt;/p&gt;
&lt;h1 id="spin-transformer-modules-a-family-of-transformer-like-modules"&gt;Spin-transformer modules: a family of transformer-like modules&lt;/h1&gt;
&lt;p&gt;In this final section, we propose a physics-inspired class of transformer modules based on the mean-field update equations for the vector-spin magnetizations derived in the previous section. We highlight conceptual similarities, physical interpretations, and potential benefits of exploiting spin-model structure to reduce parameter count.&lt;/p&gt;
&lt;h2 id="connecting-the-dots"&gt;Connecting the dots&lt;/h2&gt;
&lt;p&gt;Following
and
, we interpret a transformer module as a differentiable vector-spin system that is driven by data and whose collective behavior can be shaped through training. Intuitively, there is little difference here compared to the work mentioned above: we still probe a spin system and observe its response. But, technically and conceptually, the shift to dynamical mean-field expressions enables us to solidify the correspondence by moving past symmetric coupling matrices and equilibrium free energies.&lt;/p&gt;
&lt;p&gt;We define a &lt;em&gt;spin-transformer module&lt;/em&gt; as a wrapper around a vector-spin model where module inputs $\mathbf{x} \in \mathbb{R}^{N \times D}$ get routed to external magnetic fields. Inside the module, we evolve a set of initial magnetizations in time using either the first-order (Eq. \eqref{eq:naivemvector}) or the second-order (Eq. \eqref{eq:tapmvector}) mean-field update equations. Only the second-order update equations exhibit feed-forward-like corrections. We choose to relentlessly apply the same external magnetic fields at all time steps ($\mathbf{x}_{t} \equiv \mathbf{x}$, $\forall t \geq 0$) and construct input-dependent couplings using the row-stochastic attention matrix,&lt;/p&gt;
\begin{equation}
\mathbf{J}(\mathbf{x}) = \mathrm{softmax}\left( \frac{\boldsymbol{x} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{x}^{T}}{\sqrt{D}} \right). \label{eq:softmaxcouplings}
\end{equation}&lt;p&gt;where $\boldsymbol{W}_{\boldsymbol{Q}}$ and $\boldsymbol{W}_{\boldsymbol{K}}$ denote linear query- and key-mappings. Adding bias terms to these linear transformations would introduce intrinsic interactions between the spins that persist even in the absence of the external magnetic fields. Essentially, we recognize the softmax attention matrix as a parametrized flavor of the (asymmetric) coupling matrix of a vector-spin model. The external magnetic fields thus not only affect the vector spins directly, but also indirectly by altering the interaction strengths between them. This setup leads to a highly adaptive system where the interaction landscape itself is dynamically shaped by the inputs.&lt;/p&gt;
&lt;img src="arch_comparison.png" alt="Comparison between vanilla transformer module and spin-transformer module" width="550px"/&gt;
&lt;p&gt;What does the spin-transformer module return? The within-module time evolution is said to converge when the mean magnetizations collectively reach some kind of &lt;em&gt;non-equilibrium / near-equilibrium steady-state&lt;/em&gt; (NESS), which is not guaranteed a priori and requires us to make sure the couplings, inverse temperature, and normalizations are sensibly chosen. In fact, it might very well be the case that, for the parameter regimes we would want to consider, the behavior of the vector-spin model is quite equilibrium-like, and this is probably what we want to aim for anyway given that oscillations, instabilities, and divergences are always lurking close by in the perilous phase spaces of these systems. If the within-module time evolution converges, we return the magnetizations $\mathbf{m}_{\mathrm{NESS}} \in \mathbb{R}^{N \times D}$ as module outputs. Instead of time evolving for a number of steps until convergence, we could also try hunting for the NESS directly by assuming it exists and solving for it as if it were a fixed point of the time evolution.&lt;/p&gt;
&lt;p&gt;To wrap up this section, we list a few conceptual similarities and features below to close the gap between vector-spin models and transformer modules:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Attention heads:&lt;/strong&gt; Multiple attention heads can be implemented by embedding $N_{h}$ coupling matrices into a head-block-diagonal coupling tensor. Effectively, this operation stacks $N_{h}$ smaller-dimensional spin models where each submodel processes a disjoint $D_{h}-$dimensional piece of the full $D-$dimensional vector space. Mixing between subspaces can occur because (1) each individual coupling matrix is still constructed from query and key mappings $\mathbb{R}^{D} \to \mathbb{R}^{N_{h} \times D_{h}}$ acting on the full input space, and (2) the dot products in the second-order correction terms Eq. \eqref{eq:secondordercorrections} naturally mix channels.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Causal masks:&lt;/strong&gt; Since we identify the attention matrix with the spin model&amp;rsquo;s couplings, autoregressive modeling can be done by applying the appropriate triangular mask to the coupling matrix instead. The causal structure is preserved during the within-module time evolution. More generally, we expect any kind of masking that can be done on the level of the attention matrix to transfer to the coupling matrix.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Cross-attention:&lt;/strong&gt; The framework described above implements self-attention by constructing both queries and keys from the inputs $x$ according to Eq. \eqref{eq:softmaxcouplings}. Decoder layers in encoder-decoder models, however, rely on cross-attention, where keys (and values) from the encoder output are sent to the decoder input as context. We can accommodate this scenario by feeding the spin-transformer module an additional set of context vectors $\mathbf{c}$ to build the coupling matrix, i.e.,&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
\begin{equation}
\mathbf{J}(\mathbf{x}, \mathbf{c}) = \mathrm{softmax}\left( \frac{\boldsymbol{x} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{c}^{T}}{\sqrt{D}} \right). \label{eq:crosssoftmaxcouplings}
\end{equation}&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Normalization:&lt;/strong&gt; A flavor of
(RMSNorm) naturally appears in expression Eq. \eqref{eq:largedevmag} for the magnetization in the limit of large vector dimension as well as in all the mean-field update equations derived from it.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Queries, keys, and values:&lt;/strong&gt; The &lt;em&gt;queries&lt;/em&gt; and &lt;em&gt;keys&lt;/em&gt; are used to define the interactions between the spins from the external magnetic fields via Eq. \eqref{eq:softmaxcouplings}. In a sense, these linear transformations remain quite arbitrary since our framework is agnostic to the nature of the coupling matrix. But the &lt;em&gt;values&lt;/em&gt; do have an interpretation as the magnetizations $\mathbf{m}_{t-1}$ at the previous time step, or, in case of convergence, the steady-state magnetizations $\mathbf{m}^{\mathrm{NESS}}$.&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="fast--and-slow-moving-parameters"&gt;Fast- and slow-moving parameters&lt;/h2&gt;
&lt;p&gt;We now provide some additional physical intuition. As mentioned ad nauseam in
and
, each example in a batch of sequential data can be thought of as probing a spin-transformer module in a particular way. The response of the many-body system depends on the context provided by the applied external fields. We can tune the collective response behavior by parametrizing the couplings and making sure the whole probe-response stack is differentiable.&lt;/p&gt;
&lt;img src="spin_model_transformer_module.png" alt="Focus on a spin-transformer module in a stack of layers" width="450px"/&gt;
&lt;p&gt;Physically, the &lt;em&gt;fast-moving&lt;/em&gt; parameterized couplings $\mathbf{J}(\mathbf{x})$ are determined by the &lt;em&gt;fast-moving&lt;/em&gt; parameterized external fields $\mathbf{x}$, which, in a stack of transformer modules, depend on the magnetizations of the previous layer and ultimately on the input data. The external fields act as an environment of contextual patterns that gets transformed instantly into the values of the coupling matrix, effectively inducing some kind of state of quenched disorder. The &lt;em&gt;slow-moving&lt;/em&gt; parameters are those receiving gradient updates during training, e.g., the query-key matrices in the softmax couplings. On the level of a spin-transformer module, training can be understood as &lt;em&gt;shaping the input-dependent distribution of coupling parameters&lt;/em&gt; by amassing information from a huge amount of quenched disorder realizations, sculpting a spin glass with data.&lt;/p&gt;
&lt;h2 id="a-simple-jax-implementation-2"&gt;A simple JAX implementation&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;GitHub repository:
&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Let us wrap up this post with some code showing how one could implement a spin-transformer module based on the recipe described above. We choose to normalize input vectors to have norm $R$, and, because of this choice, we set the softmax temperature in the couplings Eq. \eqref{eq:softmaxcouplings} to $1$ instead of $\sqrt{D}$ to make sure the scale of the matrix elements is similar as in scaled dot-product attention. As we have seen in
, lowering the norm of the input vectors decreases the strength the applied magnetic fields and increases the influence of the spin-spin interactions. Other normalization conventions might turn out to work better in actual training scenarios. Additionally, since different flavors of mean-field approximations lead to different update equations for the magnetizations, we want to stress that the approach we took in this post is just one possible option, which might not be the most useful one in practice.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;We use
to implement our neural network modules. We could replace the fixed-step &lt;code&gt;lax.scan&lt;/code&gt; time evolution of
with an &lt;code&gt;equinox.internal.while_loop&lt;/code&gt; to implement early-stopping when convergence occurs in a way that supports reverse-mode autodifferentiation. But then we would have to make sure to stop gradients so that only the values of the final iteration, corresponding to the steady-state magnetizations $\mathbf{m}^{\mathrm{NESS}}$, contribute to the gradient computation. To make things easier in the implementation below, we are going to assume the NESS exists and solve for it as if it were a fixed point of the time evolution. Implicit differentation of the fixed-point solver then takes care of the (near-)equilibrium gradients. So we only need the following function:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Find fixed-point vector magnetizations of second-order mean-field update equations.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_m_ness&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;AndersonAcceleration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;fixed_point_fun&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_m_ness&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;maxiter&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_phi&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We implement a spin-transformer module by wrapping a little boilerplate around the &lt;code&gt;vector_tap_fp&lt;/code&gt; function. We construct the spin-model couplings from the input vectors and mimic multi-head attention by &lt;code&gt;vmap&lt;/code&gt;&amp;lsquo;ing the magnetizations&amp;rsquo; fixed-point solving across &lt;code&gt;num_heads&lt;/code&gt; spin models where each one acts on an equal-size subspace of the full vector dimension.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TODO:&lt;/strong&gt; Fix multi-head case (it&amp;rsquo;s not just &lt;code&gt;vmap&lt;/code&gt;&amp;lsquo;ing the full thing).&lt;/p&gt;
&lt;/blockquote&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;typing&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Callable&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;equinox&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;eqx&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;einops&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SpinTransformerModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;to_qk&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Callable&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_qk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vector_tap_fp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;... h n d -&amp;gt; ... n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_qk&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;... n (h d) -&amp;gt; ... h n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;... i d, ... j d -&amp;gt; ... i j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;finfo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;min&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;... n (h d) -&amp;gt; ... h n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;m0&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vector_tap_fp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;in_axes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;m0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_J&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;... h n d -&amp;gt; ... n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Let&amp;rsquo;s run a forward pass of the spin-transformer module&amp;hellip;&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PRNGKey&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2666&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mod_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;transformer_module&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;SpinTransformerModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mod_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;transformer_module&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;[[[&lt;/span&gt; &lt;span class="mf"&gt;0.46483648&lt;/span&gt; &lt;span class="mf"&gt;0.3805422&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.44913006&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.02650307&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.36570293&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.23443604&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.37061682&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.42315483&lt;/span&gt; &lt;span class="mf"&gt;0.1197958&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.6265602&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.61598897&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.5583689&lt;/span&gt; &lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.21803643&lt;/span&gt; &lt;span class="mf"&gt;0.17418407&lt;/span&gt; &lt;span class="mf"&gt;0.22512378&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.82831764&lt;/span&gt; &lt;span class="mf"&gt;0.13957487&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.17361565&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.03738704&lt;/span&gt; &lt;span class="mf"&gt;0.10310851&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.12114237&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.17507279&lt;/span&gt; &lt;span class="mf"&gt;0.30361462&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.09653477&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.4211655&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.20545821&lt;/span&gt; &lt;span class="mf"&gt;0.12954816&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.74708706&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.35752055&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5818469&lt;/span&gt; &lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;1.149747&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.6245326&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.28383803&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.31866318&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.13622926&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.52548647&lt;/span&gt;&lt;span class="p"&gt;]]]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;&amp;hellip; and a backward pass.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nd"&gt;@eqx.filter_jit&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;filter_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;transformer_module&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_qk&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;weight&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;[[&lt;/span&gt; &lt;span class="mf"&gt;6.84143470e-06&lt;/span&gt; &lt;span class="mf"&gt;1.26781670e-04&lt;/span&gt; &lt;span class="mf"&gt;3.00350985e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;2.42774186e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;6.56897682e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.09572255e-04&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;2.77053477e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.62737968e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;9.00395680e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;8.95370322e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.99462512e-05&lt;/span&gt; &lt;span class="mf"&gt;5.35702784e-05&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.52689070e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.44067290e-05&lt;/span&gt; &lt;span class="mf"&gt;1.77498405e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.35530383e-04&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;7.19401141e-05&lt;/span&gt; &lt;span class="mf"&gt;1.22722937e-04&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.90037055e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.04181963e-04&lt;/span&gt; &lt;span class="mf"&gt;4.73747787e-06&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;8.87275892e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;5.93782897e-06&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.02471051e-05&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;4.34355170e-05&lt;/span&gt; &lt;span class="mf"&gt;3.30054972e-05&lt;/span&gt; &lt;span class="mf"&gt;1.77152877e-04&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.20974844e-04&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.17946729e-04&lt;/span&gt; &lt;span class="mf"&gt;4.90189996e-06&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;3.79099110e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.06873820e-04&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;8.71618904e-05&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;4.89293416e-05&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;8.51267905e-05&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.46996666e-04&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Going beyond a single spin-transformer module, we can stack modules sequentially to create a spin-transformer model using the
:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SpinTransformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;modules&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;SpinTransformerModule&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;depth&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;depth&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;make_modules&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;SpinTransformerModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;modules&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;filter_vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;make_modules&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dynamic_modules&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;static_modules&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;partition&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;modules&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_array&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;f&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_dynamic_module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;module&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eqx&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;combine&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_dynamic_module&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;static_modules&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;module&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scan&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dynamic_modules&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;transformer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;SpinTransformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;depth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mod_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;transformer&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;[[[&lt;/span&gt; &lt;span class="mf"&gt;0.20396525&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.06002701&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.24426042&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.25347382&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.01503923&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.15146086&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.3552067&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.4154298&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.2159235&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.68296695&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.18692644&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.20893992&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.03525298&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.11836862&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.13671912&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.22646151&lt;/span&gt; &lt;span class="mf"&gt;0.18905625&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.05829766&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.11216182&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.26305646&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.31211302&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.27817503&lt;/span&gt; &lt;span class="mf"&gt;0.25123474&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.11120855&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.17170963&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.33360714&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.12762357&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.70538384&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.04229175&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5447842&lt;/span&gt; &lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="mf"&gt;0.5191558&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5662918&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.33646253&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="mf"&gt;0.4568781&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.04439414&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;0.18843232&lt;/span&gt;&lt;span class="p"&gt;]]]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;In this post, we have shown how
can be generalized to capture asymmetric coupling matrices like softmax attention. We observed that dynamical mean-field descriptions of vector-spin models exhibit structure capable of yielding residual connections, attention terms, and feed-forward-like correction terms, motivating a physics-inspired class of spin-transformer modules. By blending ideas from deep learning and statistical mechanics, we hope our work can help open up broader interdisciplinary bridges to improve our understanding of learning and generalization in transformer neural networks.&lt;/p&gt;
&lt;p&gt;From a theoretical point of view, it would be interesting to further explore and develop connections to the physics of vector spin glasses and properly study transformers as statistical-mechanical systems. Computationally, we look forward to experiments at scale to get more insight into potential benefits and bottlenecks of spin-transformer models in terms of
, representational power, and scaling behavior. In any case, it is fun to think about transformers as a collective of driven, disordered vector-spin models whose response behavior can be shaped by learning parameterized interactions, gradually steering a cascade of near-equilibrium steady-state magnetizations towards solving a given objective.&lt;/p&gt;
&lt;h1 id="references"&gt;References&lt;/h1&gt;
&lt;p&gt;A non-exhaustive list of references and inspiration includes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;F. Nicoletti, Low energy excitations of vector spin glasses, PhD thesis (2023)
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;M. Aguilera, S.A. Moosavi, and H. Shimazaki, A unifying framework for mean-field theories of asymmetric kinetic Ising systems, &lt;em&gt;Nat Commun&lt;/em&gt; &lt;strong&gt;12&lt;/strong&gt;, 1197 (2021)
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Y. Roudi and J. Hertz, Dynamical TAP equations for non-equilibrium Ising spin glasses, &lt;em&gt;J. Stat. Mech.&lt;/em&gt;, P03031 (2011)
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;H.J. Kappen and J.J. Spanjers, Mean field theory for asymmetric neural networks, &lt;em&gt;Phys. Rev. E&lt;/em&gt; &lt;strong&gt;61&lt;/strong&gt;, 5658 (2000)&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;G. Parisi, Asymmetric neural networks and the process of learning, &lt;em&gt;J. Phys. A: Math. Gen.&lt;/em&gt; &lt;strong&gt;19&lt;/strong&gt; L675 (1986)&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2023spinmodeltransformers,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Spin-Model Transformers},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2023},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {December},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/spin-model-transformers}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h1 id="appendices"&gt;Appendices&lt;/h1&gt;
&lt;h2 id="a1-vector-spin-distribution-normalization-constant"&gt;A.1. Vector-spin distribution: normalization constant&lt;/h2&gt;
&lt;p&gt;We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }.
\end{equation}&lt;p&gt;Let $Z(\beta, R, \mathbf{h})=\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}$. We switch to $D$-dimensional spherical coordinates to make our life easier and use rotational symmetry to choose the polar axis parallel to $\mathbf{h}$,&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = R^{D-1} \int_{\Omega} \int_{0}^{\pi} \mathrm{d}^{D-2} \Omega \;\mathrm{d}\theta \; \mathrm{e}^{\beta R h \cos \theta } \sin^{D-2} \theta ,
\end{equation}&lt;p&gt;where $h=\lVert\mathbf{h}\rVert$ and where $\int_{\Omega} \mathrm{d}^{D-2} \Omega$ represents the integral over all other spherical angles, which coincides with the surface area of the unit sphere in $D-1$ dimensions,&lt;/p&gt;
\begin{equation}
S_{D-1} = \frac{2\pi^{\frac{D-1}{2}}}{\Gamma\left( \frac{D-1}{2} \right)},
\end{equation}&lt;p&gt;so that&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = \frac{2 \pi^{\frac{D-1}{2}} R^{D-1}}{\Gamma\left( \frac{D-1}{2} \right)} \int_{0}^{\pi} \mathrm{d}\theta \; \mathrm{e}^{\beta R h \cos \theta } \sin^{D-2} \theta .
\end{equation}&lt;p&gt;If we now let $u = \cos \theta$, then&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = \frac{2 \pi^{\frac{D-1}{2}} R^{D-1}}{\Gamma\left( \frac{D-1}{2} \right)} \int_{-1}^{1} \mathrm{d}u \; \mathrm{e}^{\beta R h u } \left(1 - u^2\right)^{(D-3)/2} .
\end{equation}&lt;p&gt;Recognizing
,&lt;/p&gt;
\begin{equation}
I_{\nu}(z) = \frac{2^{-\nu}}{\sqrt{\pi}\, \Gamma\left(\nu+\frac{1}{2}\right)} z^{\nu} \int_{-1}^{1} \mathrm{d}t \; \mathrm{e}^{\pm zt} \left(1-t^2\right)^{\nu-\frac{1}{2}},
\end{equation}&lt;p&gt;we identify $\nu = D/2 - 1$ and $z = \beta R h$ to find&lt;/p&gt;
\begin{equation}
Z(\beta, R, h) = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R h) }{ \left(\beta h\right)^{D/2-1} }.
\end{equation}&lt;h2 id="a2-vector-spin-distribution-expected-value-first-moment"&gt;A.2. Vector-spin distribution: expected value (first moment)&lt;/h2&gt;
&lt;p&gt;We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }.
\end{equation}&lt;p&gt;Starting from the expression of the normalization constant Eq. \eqref{eq:partfun},&lt;/p&gt;
\begin{equation}
\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } = Z(\beta, R, \lVert \mathbf{h}\rVert) ,
\end{equation}&lt;p&gt;we write the expected value as&lt;/p&gt;
\begin{equation}
\mathbb{E}_{p} [ \mathbf{s} ] = \frac{1}{Z} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathbf{s} \, \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}
\end{equation}&lt;p&gt;so that&lt;/p&gt;
\begin{align}
\mathbb{E}_{p} [ \mathbf{s} ] = \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \left( \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert\mathbf{h} \rVert) }{ \left(\beta \lVert\mathbf{h}\rVert \right)^{D/2-1} } \right)
\end{align}&lt;p&gt;which evaluates to&lt;/p&gt;
\begin{align}
\mathbb{E}_{p} [ \mathbf{s} ] = \left( \frac{I'_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert)}{I_{D/2 - 1}(\beta R \lVert\mathbf{h}\rVert)} - \frac{ D/2-1 }{ \beta R \lVert\mathbf{h}\rVert} \right) \frac{R \mathbf{h}}{\lVert\mathbf{h}\rVert}.
\end{align}&lt;p&gt;Using the
,&lt;/p&gt;
\begin{align}
I_{\nu-1}(z) - I_{\nu+1}(z) &amp;= \frac{2\nu}{z} I_{\nu}(z), \label{eq:irecurr}\\\\
I_{\nu-1}(z) + I_{\nu+1}(z) &amp;= 2 I'_{\nu}(z), \label{eq:irecurrderiv}
\end{align}&lt;p&gt;we end up with&lt;/p&gt;
\begin{align}
\mathbb{E}_{p} [ \mathbf{s} ] = \frac{I_{D/2}(\beta R \lVert \mathbf{h}\rVert)}{I_{D/2 - 1}(\beta R \lVert\mathbf{h}\rVert)} \frac{R \mathbf{h}}{\lVert\mathbf{h}\rVert}\equiv \boldsymbol{\varphi} (\mathbf{h}). \label{eq:app:expectedvalue}
\end{align}&lt;h2 id="a3-vector-spin-distribution-variance-second-moment"&gt;A.3. Vector-spin distribution: variance (second moment)&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TODO:&lt;/strong&gt; Add variance for general case.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We consider the single-site vector-spin distribution Eq. \eqref{eq:pcondsinglesitevector}:&lt;/p&gt;
\begin{equation}
p ( \mathbf{s} ; \beta, \mathbf{h}) = \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} }.
\end{equation}&lt;p&gt;Using the expression of the normalization constant Eq. \eqref{eq:partfun},&lt;/p&gt;
\begin{equation}
\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} = \frac{ \left( 2 \pi R \right)^{D/2} I_{D/2 - 1}(\beta R \lVert \mathbf{h}\rVert) }{ \left(\beta \lVert \mathbf{h}\rVert\right)^{D/2-1} } = Z(\beta, R, \lVert \mathbf{h}\rVert) ,
\end{equation}&lt;p&gt;we write the symmetric outer-product variance matrix as&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \frac{1}{Z} \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} \, ( \mathbf{s} - \mathbb{E}_{p} [ \mathbf{s} ])( \mathbf{s} - \mathbb{E}_{p} [ \mathbf{s} ])^{T} \\\\
&amp;= \frac{1}{\beta^2 Z} \frac{ \partial^2 }{ \partial \mathbf{h} \partial \mathbf{h}^{T} } \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}} - \mathbb{E}_{p} [ \mathbf{s} ] \mathbb{E}_{p} [ \mathbf{s} ]^{T},
\end{align}&lt;p&gt;so that&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \frac{1}{\beta Z} \frac{ \partial }{ \partial \mathbf{h} } \left( Z \mathbb{E}_{p} [ \mathbf{s} ]^{T} \right) - \mathbb{E}_{p} [ \mathbf{s} ] \mathbb{E}_{p} [ \mathbf{s} ]^{T}, \\\\
&amp;= \frac{1}{\beta} \frac{ \partial }{ \partial \mathbf{h} } \mathbb{E}_{p} [ \mathbf{s} ]^{T},
\end{align}&lt;p&gt;which evaluates to&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \ldots \label{eq:app:var}
\end{align}&lt;p&gt;for the general case with the expected value given by Eq. \eqref{eq:app:expectedvalue} and to&lt;/p&gt;
\begin{align}
\mathrm{Var}_{p} [ \mathbf{s} ] &amp;= \frac{\mathbb{1}}{1+\gamma(\mathbf{h})} - \frac{\beta^2\mathbf{h} \otimes \mathbf{h}}{R^2\gamma(\mathbf{h})\left(1+\gamma(\mathbf{h})\right)^2}\\\\
&amp;= \frac{\mathbb{1}}{1+\gamma(\mathbf{h})} - \frac{\boldsymbol{\varphi} (\mathbf{h}) \otimes \boldsymbol{\varphi}(\mathbf{h})}{R^2\gamma(\mathbf{h})}
\end{align}&lt;p&gt;for the large-$D$ limit with the expected value given by Eq. \eqref{eq:largedevmag}, where&lt;/p&gt;
\begin{align}
\gamma(\mathbf{h}) = \sqrt{1+\beta^{2}\lVert\mathbf{h}\rVert^{2}/R^2}
\end{align}&lt;h2 id="a4-ratio-of-modified-bessel-functions-of-the-first-kind"&gt;A.4. Ratio of modified Bessel functions of the first kind&lt;/h2&gt;
&lt;p&gt;To compute the ratio $I_{\nu+1}(x) / I_{\nu}(x)$ of modified Bessel functions of the first kind for $\nu \geq 0$ and $x \geq 0$, we implement a
of the algorithm described in
. A pseudocode implementation can be found in
. We compare our implementation against explicitly calculating the ratio using
across a range of orders $\nu$ for several different values of $x$ to get a feel for its behavior.&lt;/p&gt;
&lt;img src="bessel_plot_1.png" width="500px"/&gt;
&lt;p&gt;We observe a satisfying agreement between the two approaches. For $x=\sqrt{\nu}$, the ratio takes on very small values for large orders. For $x=\nu^2$, the oppositive happens and we see saturation. The case $x=\nu$ seems to sit in between, which suggests it might be opportune to fix the radius of our little spins to $R=\sqrt{D}$ so that with $\lVert\mathbf{h}\rVert \sim \mathcal{O}(\sqrt{D})$ we might maximize the &amp;ldquo;sensitivity&amp;rdquo; of the expected value. In this regime, we can get away with
for large $\nu$ given that the ratio flattens out quickly.&lt;/p&gt;
&lt;h2 id="a5-general-case-partial-derivatives-with-respect-to"&gt;A.5. General case: partial derivatives with respect to $\alpha$&lt;/h2&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;✨ &lt;strong&gt;TODO:&lt;/strong&gt; Clean up and verify (haha, no).&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We are interested in computing the first-order and second-order derivative with respect to $\alpha$ of the function&lt;/p&gt;
\begin{equation}
\boldsymbol{\varphi}(\mathbf{h}(\alpha)) = \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \mathbf{h}(\alpha)}{\lVert \mathbf{h}(\alpha) \rVert},
\end{equation}&lt;p&gt;where $\mathbf{h}(\alpha) = \boldsymbol{\theta} + \alpha \Delta \mathbf{h}$. Using&lt;/p&gt;
\begin{equation}
\frac{\partial \lVert \mathbf{h}(\alpha) \rVert}{\partial\alpha} = \frac{\mathbf{h}(\alpha) \cdot \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert}
\end{equation}&lt;p&gt;and Eqs. \eqref{eq:irecurr}-\eqref{eq:irecurrderiv}, we find&lt;/p&gt;
\begin{align}
\frac{\partial \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} = \beta &amp;\lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\\\
&amp;+ \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \label{eq:generalgradalphafirstorder}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{equation}
\lambda_{D} (x) = \frac{I^2_{D/2-1}(x)}{I^2_{D/2}(x)} - \frac{D}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)} - 1. \label{eq:app:lambda}
\end{equation}&lt;p&gt;For the second-order derivative, we need to slog through even more tedious algebra,&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha^2}
= \beta &amp;\frac{\partial}{\partial\alpha}\biggl( \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \biggr) \nonumber \\\\
&amp;+ \frac{\partial}{\partial\alpha}\biggl( \frac{I_{D/2}(\beta R \lVert \mathbf{h}(\alpha) \rVert)}{I_{D/2 - 1}(\beta R \lVert \mathbf{h}(\alpha) \rVert)} \frac{R \Delta \mathbf{h}}{\lVert \mathbf{h}(\alpha) \rVert} \biggr) ,
\end{align}&lt;p&gt;which eventually leads to something like&lt;/p&gt;
\begin{align}
\frac{\partial^2 \boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha^2}
= -2\beta^2 &amp; \, \kappa_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right)^{2} \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\\\
&amp;+ \beta \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \frac{\partial\boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} \cdot \Delta \mathbf{h} \right) \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \nonumber \\\\
&amp;+ \beta \lambda_{D} (\beta R \lVert \mathbf{h}(\alpha) \rVert) \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \frac{\partial\boldsymbol{\varphi}(\mathbf{h}(\alpha))}{\partial\alpha} \nonumber \\\\
&amp;- \frac{D}{\lVert \mathbf{h}(\alpha) \rVert^2} \left( \boldsymbol{\varphi}(\mathbf{h}(\alpha)) \cdot \Delta \mathbf{h} \right) \Delta \mathbf{h} , \label{eq:generalgradalphasecondorder}
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
\kappa_{D} (x) = \lambda^2_{D} (x) + \left( 1 + \frac{D/2 + 1}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)} \right) \lambda_{D} (x) + \frac{1}{x} \frac{I_{D/2-1}(x)}{I_{D/2}(x)}.
\end{align}&lt;p&gt;Equation \eqref{eq:generalgradalphasecondorder} can be further simplified by substituting the first-order derivative Eq. \eqref{eq:generalgradalphafirstorder} and further simplifying the resulting expression. The derivation of the mean-field equations proceeds in a similar fashion as in the main text, but uses \eqref{eq:generalgradalphafirstorder} and \eqref{eq:generalgradalphasecondorder} as expressions for the partial derivatives instead of their large-$D$ approximations.&lt;/p&gt;
&lt;p&gt;Another useful derivative is that of the single-site probability distribution \eqref{eq:pcondsinglesitevector},&lt;/p&gt;
\begin{align}
\frac{\partial}{\partial\alpha} \left( \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \right) = \frac{\partial}{\partial\mathbf{h}(\alpha)} \left( \frac{\mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)}}{\int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} } \right) \cdot \Delta \mathbf{h},
\end{align}&lt;p&gt;which evaluates to&lt;/p&gt;
\begin{align}
\beta \left( \mathbf{s} - \boldsymbol{\varphi}\left(\mathbf{h}(\alpha)\right) \right) \cdot \Delta \mathbf{h} \frac{ \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} }{ \int_{S_{D-1}} \mathrm{d}^{D} \mathbf{s} \; \mathrm{e}^{\beta \, \mathbf{s} \cdot \mathbf{h}(\alpha)} }
\end{align}&lt;p&gt;and can be used to calculate derivatives of the conditional distribution \eqref{eq:pcondaltvector}.&lt;/p&gt;
&lt;h1 id="footnotes"&gt;Footnotes&lt;/h1&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;We plot the absolute value to get rid of artificial &amp;ldquo;jumps&amp;rdquo; between the two branches. These occur because all models are simulated independently when sweeping across $\beta$ and the some combinations of initial state and model parameters might just happen to bounce to the other branch when $\beta$ changes in the $\beta &gt; \beta_c$ regime.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Transformers Are Secretly Collectives of Spin Systems</title><link>https://mcbal.github.io/post/transformers-are-secretly-collectives-of-spin-systems/</link><pubDate>Tue, 23 Nov 2021 12:17:17 +0100</pubDate><guid>https://mcbal.github.io/post/transformers-are-secretly-collectives-of-spin-systems/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (April 2023):&lt;/strong&gt; Consider reading
where we continue building on the intuition of probing a spin system to engineer its collective response but get rid of the assumption of symmetric coupling matrices by shifting focus from equilibrium free energies to dynamical mean-field approximations of non-equilibrium vector-spin models.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;In this post, we try to distill a unifying perspective out of ideas developed in a series of longer posts on understanding transformers as physical systems:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;We argue that a blueprint of the neural-network architecture of the archetypical transformer module can be derived from the structure of physical spin systems familiar from classical statistical mechanics. More specifically, we claim that the forward pass of transformer modules maps onto computing magnetizations in vector-spin models in response to incoming data. We imagine transformers as collectives of differentiable spin systems whose behavior can be shaped through training.&lt;/p&gt;
&lt;h1 id="where-does-the-transformer-module-architecture-come-from"&gt;Where does the transformer module architecture come from?&lt;/h1&gt;
&lt;p&gt;Taking a bird&amp;rsquo;s eye view of the evergrowing zoo of transformer architectures in natural language processing and computer vision suggests that the design pattern introduced in &lt;em&gt;
&lt;/em&gt;&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; is still dominant. Almost all architectural variations of transformer modules published in the last four years have stuck to a successful combination of residual connections, an attention-like operation (token-mixing), normalization layers, and a feed-forward-like operation (channel-mixing).&lt;/p&gt;
&lt;p&gt;&lt;a href="https://arxiv.org/abs/2111.11418" target=_blank&gt;&lt;img src="metaformer.png" alt="MetaFormer architecture comparison" width="400px"/&gt;&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;Recent work like &lt;em&gt;
&lt;/em&gt;&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; appropriately shifts focus to the high-level architecture of the transformer module and argues that its full structure, rather than just the token-mixing attention operation, is essential for transformers to achieve competitive performance.&lt;/p&gt;
&lt;p&gt;So where does this archetypical design pattern come from? Why does it seem to stick around? Is there any physical intuition behind its structure?&lt;/p&gt;
&lt;h1 id="deriving-attention-from-energy-functions-only-gets-you-so-far"&gt;Deriving attention from energy functions only gets you so far&lt;/h1&gt;
&lt;p&gt;Recent papers like &lt;em&gt;
&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;&lt;/em&gt; and &lt;em&gt;
&lt;/em&gt;&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt; have looked for physical intuition behind attention mechanisms using an
phrased in terms of modern continuous Hopfield networks. The main idea is to derive the softmax-attention update rule&lt;/p&gt;
\begin{equation}
\boldsymbol{Q}' = \text{softmax}\left( \frac{\boldsymbol{Q} \boldsymbol{K}^T}{\sqrt{d}} \right) \boldsymbol{K}
\end{equation}&lt;p&gt;by taking a large gradient descent update step using the derivative with respect to input queries $\boldsymbol{Q}$ of some judiciously chosen energy function&lt;/p&gt;
\begin{equation}
E = \frac{1}{2} \boldsymbol{Q} \boldsymbol{Q}^T -\mathrm{logsumexp} \left( \frac{\boldsymbol{Q} \boldsymbol{K}^T}{\sqrt{d}} \right). \label{eq:logsumexp}
\end{equation}&lt;p&gt;In this way, vanilla softmax attention can be recast as taking a
. The energy landscape defined by Eq. \eqref{eq:logsumexp} implements an associative memory system for storing and retrieving vector patterns where queries flow towards valleys associated with their nearest keys (see
):&lt;/p&gt;
&lt;p&gt;&lt;a href="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/" target=_blank&gt;&lt;img src="landscape.png" alt="Logsumexp energy function landscape" width="300px"/&gt;&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;But there is more to transformer modules than just attention. In practice, we know that residual connections, normalization layers, and feed-forward layers are all essential to achieve good empirical performance.&lt;/p&gt;
&lt;p&gt;Can we generalize this physical intuition of taking derivatives with respect to an energy function to recover the full transformer module? Yes, we can. But we have to take a step back from energy functions and focus on their underlying physical systems instead.&lt;/p&gt;
&lt;h1 id="back-to-the-roots-physical-spin-systems-and-vector-spin-models"&gt;Back to the roots: physical spin systems and vector-spin models&lt;/h1&gt;
&lt;p&gt;Energy functions in classical statistical mechanics are succinct descriptions encoding interactions and constraints in physical systems. Spin systems are prototypical physical systems which often serve as toy models for all kinds of phenomena&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;The
is a simple toy model describing a classical binary spin system with local spin degrees of freedom at every site pointing either up or down. The energy function of the binary random Ising model for $N$ spins in the presence of a site-dependent external magnetic field is given by&lt;/p&gt;
\begin{equation}
E = - \sum_{i,j=1}^{N} J_{ij} \sigma_{i} \sigma_{j} - \sum_{i=1}^{N} h_{i} \sigma_{i}, \label{eq:binaryrandomising}
\end{equation}&lt;p&gt;where the $J_{ij}$ encode coupling strengths between all pairs of spins and the external magnetic fields $h_{i}$ act as biases by providing a preferential value of alignment at every site. The model defined by \eqref{eq:binaryrandomising} is also known as a
or
. A cartoon of this model looks like a graph of little arrows that are pairwise coupled&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;img src="binary_ising.png" alt="Random Ising model configuration with binary spins" width="200px"/&gt;
&lt;p&gt;At thermal equilibrium, the Boltzmann probability distribution $e^{-\beta E\left( \sigma \right)} / Z$ reflects what patterns of up-down spins, or &lt;em&gt;spin configurations&lt;/em&gt;, are preferred. The partition function $Z = \sum_{\sigma} e^{-\beta E\left( \sigma \right)}$ of a spin system is not only a normalization constant but also a magical object relating the microscopic world of fluctuating spins to thermodynamic, observable quantities via the free energy $F = - \beta^{-1} \log Z$. Even for simple spin systems, computing partition functions by summing over all possible configurations is a shockingly hard thing to do in most scenarios.&lt;/p&gt;
&lt;p&gt;Binary spin models are nice but rarely excite machine learning practitioners anymore nowadays. Modern neural networks like transformers act on sequences of vectors like token embeddings or image patches. Instead of abandoning spin models altogether, we could consider &lt;em&gt;vector-spin models&lt;/em&gt;. Replacing binary degrees of freedom with $d$-dimensional vector degrees of freedom, we can define a spin-model energy function&lt;/p&gt;
\begin{align}
E = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i}, \label{eq:vectorrandomising}
\end{align}&lt;p&gt;where the scalar products have turned into dot products. Models of this form first popped up in 1960s statistical mechanics literature as
. They also appear in recent studies on higher-dimensional generalizations of spin glass models&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;img src="vector_ising.png" alt="Random Ising model configuration with vector spins" width="200px"/&gt;
&lt;p&gt;Now how can we relate vector-spin systems like Eq. \eqref{eq:vectorrandomising} to modern neural networks?&lt;/p&gt;
&lt;h1 id="why-dont-we-just-probe-a-vector-spin-system-with-data"&gt;Why don’t we just probe a vector-spin system with data?&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s pursue an intuitive idea. Imagine we want to expose our vector-spin system Eq. \eqref{eq:vectorrandomising} to a sequence of vector data. We can do this by having the sequence act as the spin system&amp;rsquo;s external magnetic field $(\boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N})$. We would then like to observe how the spin system responds to this particular environment of patterns.&lt;/p&gt;
&lt;p&gt;If all of the steps in the computation of the spin system&amp;rsquo;s responses can be implemented in a differentiable way, we should be able to engineer its collective behavior by optimizing the coupling parameters to better respond to future incoming data. We propose to observe spin-system responses in terms of &lt;em&gt;magnetizations computed from free energies&lt;/em&gt;.&lt;/p&gt;
&lt;h1 id="a-slice-of-statistical-mechanics-magnetizations-and-free-energies"&gt;A slice of statistical mechanics: magnetizations and free energies&lt;/h1&gt;
&lt;p&gt;For ease of notation, let&amp;rsquo;s call the model parameters $\theta \equiv \{ J_{ij} \}$, the spins $\sigma \equiv \{ \boldsymbol{\sigma}_{i} \}$, and the external magnetic fields $h \equiv (\boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N})$. We can then schematically write our spin system&amp;rsquo;s partition function as&lt;/p&gt;
\begin{align}
Z_{\theta} \left( h \right) = \int \mathrm{d} \sigma \ \mathrm{e}^{ - \beta E_{\theta}\left( \sigma, h \right) } \label{eq:partfun}
\end{align}&lt;p&gt;and the corresponding free energy as $F_{\theta} \left( h \right) = - \beta^{-1} \log Z_{\theta} \left( h \right)$.&lt;/p&gt;
&lt;p&gt;Magnetizations are responses of our spin system to the external magnetic field imposed by $(\boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N})$. From standard thermodynamics, we know that we can calculate magnetizations from the free energy by differentiating with respect to the external field&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
\begin{align}
\boldsymbol{m}_{i} = - \frac{\mathrm{d} F_{\theta} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right)}{\mathrm{d} \boldsymbol{h}_{i}} = \langle \boldsymbol{\sigma}_{i} \rangle , \label{eq:sigma}
\end{align}&lt;p&gt;which, in this case, boils down to calculating spin expectation values. The magnetization for every site depends on the couplings and, through the couplings between spins, on the values of the external field at all sites. Magnetizations reveal how spins will collectively tend to align themselves when we place the spin system in an environment of patterns.&lt;/p&gt;
&lt;p&gt;Before we move on, we have to account for one more complication. If we want to draw a correspondence between transformer modules and vector-spin systems, we will have to allow for couplings that depend on the external magnetic field. For example, the attention matrix in vanilla transformers looks something like&lt;/p&gt;
\begin{equation}
J_{ij} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right) = \left[\mathrm{softmax}\left( \frac{\boldsymbol{H} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{H}^{T}}{\sqrt{d}} \right)\right]_{ij}, \label{eq:softmaxcouplings}
\end{equation}&lt;p&gt;where the matrix $\boldsymbol{H}$ denotes the stack of external magnetic field vectors. The interactions between spins are determined dynamically based on the inputs. From a physics perspective, these &amp;ldquo;amortized&amp;rdquo; couplings are very weird and highly unusual, but such is the transformer.&lt;/p&gt;
&lt;p&gt;The potential dependency of the couplings on the external field changes the magnetization of Eq. \eqref{eq:sigma} to an expression of the form&lt;/p&gt;
\begin{align}
\boldsymbol{m}_{i} &amp;= - \frac{\mathrm{d} F_{\theta} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right)}{\mathrm{d} \boldsymbol{h}_{i}} \nonumber \\\\ &amp;= \langle \boldsymbol{\sigma}_{i} \rangle + \sum_{m,n} \langle \boldsymbol{\sigma}_{m} \cdot \boldsymbol{\sigma}_{n} \rangle \frac{\partial J_{mn} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right) }{ \partial \boldsymbol{h}_{i} } , \label{eq:sigmaweird}
\end{align}&lt;p&gt;where two-point correlation functions are seen to act as weights for the coupling contributions&lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;. In practice, we should of course let an automatic differentiation framework keep track of dependencies so that we can get away with simply computing&lt;/p&gt;
\begin{align}
\boldsymbol{m}_{i} = - \frac{\mathrm{d} F_{\theta} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right)}{\mathrm{d} \boldsymbol{h}_{i}}, \label{eq:magnetization}
\end{align}&lt;p&gt;assuming we have a differentiable expression for the (approximate) free energy available.&lt;/p&gt;
&lt;h1 id="turning-a-differentiable-spin-system-into-a-neural-network"&gt;Turning a differentiable spin system into a neural network&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s now use the ingredients introduced above to construct a neural network module which wraps around a vector-spin system. Given the energy function Eq. \eqref{eq:vectorrandomising} and the free energy $F_{\theta} \left( h \right) = - \beta^{-1} \log \int \mathrm{d} \sigma \ \mathrm{e}^{ - \beta E_{\theta}\left( \sigma, h \right) }$, we let incoming data play the role of the external magnetic field and return magnetizations in response.&lt;/p&gt;
&lt;img src="spinmodule_new.png" alt="Spin system as a neural network" width="600px"/&gt;
&lt;p&gt;Nice. But didn&amp;rsquo;t we mention before that partition functions (and hence free energies and thus magnetizations) are shockingly hard to compute? Why introduce all these formal expressions if we cannot compute anything?&lt;/p&gt;
&lt;p&gt;Looking back at statistical mechanics papers from the 1950s-1970s, it turns out that physicists have already developed several tricks and approximation methods that can be applied to deal with vector-spin systems. Computational evidence that the partition function approach outlined above &lt;em&gt;is&lt;/em&gt; possible for vector-spin systems can be found in
(below, left) and
(below, right).&lt;/p&gt;
&lt;img src="arch_dia_afem_new.png" alt="Deep implicit attention and approximate free-energy minimization" width="600px"/&gt;
&lt;p&gt;In these examples, approximations of the partition function Eq. \eqref{eq:partfun} were obtained following respectively a mean-field theory and a steepest-descent approach. Our
of both approaches rely internally on
to ensure that fixed-point calculations and root-solving steps are efficiently differentiable.&lt;/p&gt;
&lt;h1 id="an-exercise-in-squinting-recognizing-the-transformer-module"&gt;An exercise in squinting: recognizing the transformer module&lt;/h1&gt;
&lt;p&gt;Computing magnetizations according to Eq. \eqref{eq:magnetization} from the (approximate) free energies obtained in
and
reveals a high-level structure that is surprisingly familiar: a pattern of residual connections, token-mixing, normalization, and channel-mixing. Approaching the crux from the other direction, we argue that transformer modules react to inputs by implementing particular approximations to the general magnetization response Eq. \eqref{eq:sigmaweird}.&lt;/p&gt;
&lt;p&gt;Residual connections are proportional to the inputs and arise from the presence of the external magnetic field. Token-mixing contributions emerge from the coupling terms in the energy function and mix inputs without acting on the local vector-spin dimension. Normalization follows from requiring that the energy of the spin system remain linearly proportional to the number of lattice sites and from normalizing the external magnetic field vectors. Channel-mixing contributions include terms in the magnetization that can be applied locally, like Onsager self-correction terms in mean-field approaches or (approximations to) contributions coming from input-dependent couplings in Eq. \eqref{eq:sigmaweird}.&lt;/p&gt;
&lt;p&gt;Taken together, these observations suggest that we can picture the forward pass of a transformer module as a wrapper around a vector-spin system: module inputs are routed to the external magnetic field (and, optionally, to a parametrized couplings function) after which magnetizations are returned as outputs. The transformer module bears an uncanny resemblance to a differentiable physical system whose collective behavior we can control through training.&lt;/p&gt;
&lt;h1 id="training-transformer-modules-shapes-collective-behavior"&gt;Training transformer modules shapes collective behavior&lt;/h1&gt;
&lt;p&gt;Now that we can picture transformer modules as physical spin systems responding to getting probed with data, let&amp;rsquo;s imagine what training them looks like.&lt;/p&gt;
&lt;p&gt;On the level of the energy function of our spin system Eq. \eqref{eq:vectorrandomising}, we can model the training process of a transformer module by introducing a (discrete) time dimension and making the external magnetic field time-dependent, leading to&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
\begin{equation}
E(t) = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i}(t) \cdot \boldsymbol{\sigma}_{i} \label{eq:sloppyenergy}
\end{equation}&lt;p&gt;At every training step $t$, a sequence of incoming data $\{ \boldsymbol{h}_{1}(t), \boldsymbol{h}_{2}(t), \ldots, \boldsymbol{h}_{N}(t) \}$ takes on the role of external magnetic field. During the forward pass, magnetizations $\boldsymbol{m}_{i}$ are computed in a differentiable way according to the current model parameters and in the presence of the current external magnetic field. Physically, we consider &amp;ldquo;quenched&amp;rdquo; systems with &amp;ldquo;frozen&amp;rdquo; couplings at every training step. During the backward pass, the module&amp;rsquo;s coupling parameters $J_{ij}$ get updated, nudging the interactions in the spin system so as to influence its magnetization responses to similar data in future iterations.&lt;/p&gt;
&lt;img src="spinmoduletraining_new.png" alt="Training a spin system as a neural network" width="600px"/&gt;
&lt;p&gt;We can think about this training process as gradually shaping the collective behavior of a differentiable vector-spin system that is driven by data. If the couplings depend on the inputs, like in Eq. \eqref{eq:softmaxcouplings}, we should make the couplings time-dependent as well in Eq. \eqref{eq:sloppyenergy}. In that case, the external magnetic fields as well as the parametrized couplings change instantaneously at every training step.&lt;/p&gt;
&lt;h1 id="training-deep-transformers-orchestrates-spin-system-collectives"&gt;Training deep transformers orchestrates spin-system collectives&lt;/h1&gt;
&lt;p&gt;Training a deep transformer model corresponds to orchestrating a stack of transformer modules by building up a differentiable structure of correlations where the magnetizations of one spin system drive the next one. Wiggling (billions of) parameters during training nudges the cascading response behavior of the collective of spin systems to better adapt to the collective&amp;rsquo;s (meta-)tasks as specified by the data and the loss function.&lt;/p&gt;
&lt;img src="transformertraining_new.png" alt="Training a transformer" width="500px"/&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;In this post, we argued that the forward pass of a transformer module maps onto computing magnetizations in a vector-spin model responding to data. Generalizing previous work on understanding softmax attention modules in terms of modern continuous Hopfield networks by taking derivatives of a judiciously chosen &lt;em&gt;energy&lt;/em&gt; function, we propose to take derivatives of the &lt;em&gt;free energy&lt;/em&gt; of a general vector-spin system to get to a blueprint of the architecture of a full transformer module.&lt;/p&gt;
&lt;p&gt;By zooming out and approaching transformers from a tangential, statistical-mechanical point of view, we arrived at a physical intuition of transformers that seems hard to obtain when restricting oneself to perpetually perturbing explicit neural network architectures. Recognizing transformer modules as spin models in disguise might not only unify architectural variations as different ways to approximately compute magnetizations but also elucidate the empirical success of transformers in deep learning.&lt;/p&gt;
&lt;h1 id="acknowledgements"&gt;Acknowledgements&lt;/h1&gt;
&lt;p&gt;We would like to thank
for hosting its research jams and providing a friendly environment to present ideas.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021isingisallyouneed,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Transformers Are Secretly Collectives of Spin Systems},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {November},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/transformers-are-secretly-collectives-of-spin-systems/}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;em&gt;Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin,
(2017)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;em&gt;Weihao Yu, Mi Luo, Pan Zhou, Chenyang Si, Yichen Zhou, Xinchao Wang, Jiashi Feng, and Shuicheng Yan,
(2021)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;em&gt;Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Lukas Gruber, Markus Holzleitner, Milena Pavlović, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Consider reading the Physics Today article on
for an introduction to disordered systems, spin glasses, Ising spin systems, emergent collective computational abilities, associative memories, Hopfield models, and the idea of learning patterns as shaping the behavior of systems. Essentially, what we&amp;rsquo;re trying to do in this post is figuring out a way to relate modern transformer models back to these old ideas.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;We plot spin sites at random positions to emphasize that there is no spatial notion of &amp;ldquo;closeness&amp;rdquo; in a fully-connected system: every site is just a hop away. To not overload the graph, we only draw connections strongest in absolute value.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;For example, see &lt;em&gt;
&lt;/em&gt; and &lt;em&gt;
&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;For example, see the content of Chapter 2 in the
by Thierry Giamarchi.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;In the absence of an explicit expression for the free energy, one of the feed-forward network&amp;rsquo;s roles might be to try to approximate the complicated dependencies in the magnetization expression Eq. \eqref{eq:sigmaweird}, at the cost of introducing a large amount of additional free parameters beyond just the coupling parameters. It would be interesting to look into this numerically at scale using the free energy expression obtained in
.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;The time-dependence in Eq. \eqref{eq:sloppyenergy} smells of non-equilibrium statistical mechanics. Incoming data might be considered as time-dependent &amp;ldquo;probes&amp;rdquo; which inject energy (and useful information if its content is low-entropy enough) into a non-equilibrium system. By nudging its dynamical response behavior across spatiotemporal scales, the system could potentially learn how to deal with being driven by all kinds of patterns in incoming data. For an interesting toy example of such behavior, see
by Jeremy England on &lt;em&gt;Low rattling: a principle for understanding driven many-body self-organization&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Transformers from Spin Models: Approximate Free Energy Minimization</title><link>https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/</link><pubDate>Tue, 12 Oct 2021 18:40:17 +0100</pubDate><guid>https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (November 2021):&lt;/strong&gt; Consider reading
for a high-level overview of some of the ideas outlined in this post.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ TL;DR:&lt;/strong&gt; &lt;em&gt;We consider transformer modules as wrappers around a differentiable steepest-descent approximation of simple Ising-like vector-spin models familiar from statistical mechanics. We observe that a blueprint of the successful transformer-like architectural pattern of token-mixing (attention) and channel-mixing (feed-forward) naturally emerges when computing spin expectation values in vector-spin models with input-dependent couplings. Feel free to skip to the
for a visual comparison of this work to vanilla transformers, deep equilibrium transformers, and deep implicit attention.&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Code: A PyTorch implementation of the ideas outlined in this blog post is available in the GitHub repository
.&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;In
, we introduced a mean-field theory perspective on transformer modules. We showed how their outputs can be understood as mean-field spin expectation values of simple Ising-like vector-spin systems. Physically, the process of training a transformer module can be understood as driving a classical many-body system with data and iteratively shaping its collective response behaviour through coupling-weight parameter updates. Stacking transformer modules corresponds to building up a differentiable structure of correlations by using the spin expectation values of one physical system to drive the next one.&lt;/p&gt;
&lt;p&gt;In this post, we flesh out the idea of looking at transformer modules as physical systems. Having identified vector spin systems as plausible physical models underlying transformers, we turn to 1960s statistical-mechanics literature to look for inspiration on how to deal with their partition functions&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;. We rediscover that the partition function of a particular class of vector-spin models can be approximated in the limit of large local spin dimension using steepest descent, leading to approximate yet tractable expressions for the free energy and other derived quantities.&lt;/p&gt;
&lt;p&gt;Combining these canonical results from statistical mechanics with modern differentiable programming, we implement a differentiable vector-spin model based on an approximate free-energy minimization algorithm. Internally, the model uses an implicit layer to solve for the stationary point of the partition function in a differentiable way. We then construct a transformer-like attention module which encapsulates the spin model by routing inputs to applied magnetic fields and spin expectation values to outputs. The latter are obtained by following the familiar recipe of statistical mechanics: differentiating the spin model&amp;rsquo;s $\log Z$ with respect to conjugate input variables. Finally, we contextualize our approach by comparing it to vanilla transformers, deep equilibrium transformers, and deep implicit attention.&lt;/p&gt;
&lt;h1 id="massaging-partition-functions"&gt;Massaging partition functions&lt;/h1&gt;
&lt;p&gt;In this section, we set out to derive an approximate, analytical expression for the free energy of a classical disordered vector-spin system exposed to a site-dependent external magnetic field. In deriving the results below, we found inspiration in H. E. Stanley&amp;rsquo;s
and Chapter 5 of R. J. Baxter&amp;rsquo;s bible on
.&lt;/p&gt;
&lt;h2 id="a-vector-spin-model-and-its-partition-function"&gt;A vector-spin model and its partition function&lt;/h2&gt;
&lt;p&gt;We start from the following Hamiltonian (or energy function) of a classical vector spin system of $N$ spins in a site-dependent external magnetic field,&lt;/p&gt;
\begin{equation}
E = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i}, \label{eq:vectrandomising}
\end{equation}&lt;p&gt;where both $\boldsymbol{\sigma}_{i} = \left[ \sigma_{1}(i), \sigma_{2}(i), \ldots, \sigma_{D}(i) \right]$ and $\boldsymbol{h}_{i} = \left[ h_{1}(i), h_{2}(i), \ldots, h_{D}(i) \right]$ are vectors of dimension $D$. The coupling matrix $\boldsymbol{J}$ is assumed to be traceless and symmetric but can otherwise have real elements with both negative and positive signs. We take the vector degrees of freedom $\boldsymbol{\sigma}_{i}$ to be constrained by a set of $N$ constraints&lt;/p&gt;
\begin{equation}
\lVert \boldsymbol{\sigma}_{i} \rVert _{2}^{2} = \sum_{a=1}^{D} \sigma_{a}^{2}(i) = D, \quad i = 1,2,\ldots,N,
\end{equation}&lt;p&gt;so that their magnitudes equal $\sqrt{D}$. One can picture the classical spin degrees of freedom as arrows rotating along the surface of $(D-1)$-dimensional spheres at every site.&lt;/p&gt;
&lt;figure class="spin-system-figure"&gt;
&lt;img src="spin_system.png" alt="Cartoon of vector-spin system"&gt;
&lt;figcaption&gt;Cartoon of vector-spin system&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;In statistical mechanics, the model Eq. \eqref{eq:vectrandomising} is known as a
whose familiar small-$D$ cases include the
($D=1$), the
($D=2$), and the
($D=3$). For infinite-dimensional spins $D \to \infty$, one can show that the system approaches the
. The model defined by \eqref{eq:vectrandomising} can also be regarded as a vector generalization of
or
or disordered
(but with just a single sample of non-local couplings instead of an underlying probability distribution). Similar models also appear in recent studies on higher-dimensional generalizations of spin glass models&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;The partition function for our spin system looks like:&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} &amp;\left( \beta, J_{ij}, \{ \boldsymbol{h}_{i} \} \right) \nonumber \\
&amp;= \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{1}(1) \cdots \mathrm{d}\sigma_{D}(N) \nonumber \\
&amp; \qquad \times \prod_{j=1}^{N} \delta \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) \nonumber \\
&amp; \qquad \times \exp \left[ \beta \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} + \beta \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i} \right] \label{eq:fullpartfun}
\end{align}&lt;p&gt;where we have made all dependencies explicit. This looks absolutely mental. We somehow need to find a way to do $N \times D$ integrals while taking into account all the constraints and interactions.&lt;/p&gt;
&lt;h2 id="peeking-into-a-physicists-bag-of-tricks"&gt;Peeking into a physicist&amp;rsquo;s bag of tricks&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s first of all get rid of the explicit Dirac delta functions by substituting their complex integral representations&lt;/p&gt;
\begin{align}
\delta \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) = \frac{\beta}{2 \pi i} \int_{-i\infty}^{i\infty} \mathrm{d} t_{j} \exp \left[ \beta t_{j} \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) \right]
\end{align}&lt;p&gt;so that&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} &amp;= \left(\frac{\beta}{2 \pi i}\right)^{N} \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{1}(1) \cdots \mathrm{d}\sigma_{D}(N) \nonumber \\
&amp; \times \int_{-i\infty}^{i\infty} \cdots \int_{-i\infty}^{i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \, \exp \left( \beta D \sum_{j=1}^{N} t_{j} \right)\nonumber \\
&amp; \times \prod_{\alpha=1}^{D} \exp \left[ -\beta \sum_{i,j=1}^{N} \left(t_{j}\delta_{ij}-J_{ij}\right) \; \sigma_{\alpha}(i) \sigma_{\alpha}(j) + \beta \sum_{i=1}^{N} h_{\alpha}(i) \sigma_{\alpha}(i) \right] \nonumber
\end{align}&lt;p&gt;Great, even more integrals. The next frustrating trick involves writing the number 1 as a judiciously chosen exponential,&lt;/p&gt;
\begin{align}
\exp \left( \beta \sum_{j=1}^{N} a \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) \right) = 1,
\end{align}&lt;p&gt;for some arbitrary constant $a$, which, inside the integral, indeed evaluates to $\exp (0) = 1$ because of the constraints. Inserting this expression gives&lt;/p&gt;
\begin{align}
&amp;Z_{N}^{(D)} = \left(\frac{\beta}{2 \pi i}\right)^{N} \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{1}(1) \cdots \mathrm{d}\sigma_{D}(N) \nonumber \\
&amp; \times \int_{-i\infty}^{i\infty} \cdots \int_{-i\infty}^{i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \, \exp \left( \beta D \sum_{j=1}^{N} \left( t_{j} + a\right) \right)\nonumber \\
&amp; \times \prod_{\alpha=1}^{D} \exp \left[ -\beta \sum_{i,j=1}^{N} \left( \left( t_{j} + a \right) \delta_{ij}-J_{ij}\right) \; \sigma_{\alpha}(i) \sigma_{\alpha}(j) + \beta \sum_{i=1}^{N} h_{\alpha}(i) \sigma_{\alpha}(i) \right] \nonumber
\end{align}&lt;p&gt;Next, we&amp;rsquo;d like to swap the order of the $\mathrm{d}\sigma_{a}(j)$ and $\mathrm{d}t_{j}$ integrations to start integrating. But we are only allowed to do this if we assume $a$ to be a sufficiently large positive real number. Why? Essentially, we are deforming the contours of the complex integrals sufficiently far to the right such that the real part the quadratic form appearing in the exponential is positive definite, see e.g.
.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s go ahead and assume that everything is fine. We swap integrals and do a change of variables $t_j \to t_j + a$ so that&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} &amp;= \left(\frac{\beta}{2 \pi i}\right)^{N} \int_{a-i\infty}^{a+i\infty} \cdots \int_{a-i\infty}^{a+i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \\
&amp; \times \exp \left( \beta D \sum_{j=1}^{N} t_{j} \right)\nonumber \prod_{\alpha=1}^{D} I_{\alpha} \left( \beta, \{ t_{j} \}, \{ h_{\alpha}(i) \} \right)\nonumber
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
I_{\alpha} &amp;\left( \beta, \{ t_{j} \}, \{ h_{\alpha}(i) \} \right) = \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{\alpha}(1) \cdots \mathrm{d}\sigma_{\alpha}(N) \nonumber \\
&amp; \times \exp \left[ -\beta \sum_{i,j=1}^{N} \left( t_{j} \delta_{ij}-J_{ij}\right) \; \sigma_{\alpha}(i) \sigma_{\alpha}(j) + \beta \sum_{i=1}^{N} h_{\alpha}(i) \sigma_{\alpha}(i) \right]\nonumber
\end{align}&lt;p&gt;Notice how the integrals have kind of factorized over the vector dimension: for every $\alpha$-component we can evaluate an $N$-dimensional Gaussian integral with a linear term. The $I_{\alpha}$ functions depend on the sources $\{ \boldsymbol{h}_{i} \}$ indexed along local dimension instead of spin. Introducing the symmetric $N \times N$ matrix $V_{ij} = t_{j} \delta_{ij}-J_{ij}$, we can evaluate the Gaussian integrals and find&lt;/p&gt;
\begin{align}
I_{\alpha} &amp;\left( \beta, \{ t_{j} \}, \{ h_{\alpha}(i) \} \right) = \left( \frac{\pi}{\beta} \right)^{N/2} \left[ \det \left( \boldsymbol{V} \right) \right]^{-1/2} \exp \left(\frac{\beta}{4} \boldsymbol{h}_{\alpha}^{T} \boldsymbol{V}^{-1} \boldsymbol{h}_{\alpha} \right) \nonumber
\end{align}&lt;p&gt;where $\boldsymbol{h}_{\alpha} = \left[ h_{\alpha}(1), h_{\alpha}(2), \ldots, h_{\alpha}(N) \right]$ denote $N$-dimensional vectors. The expression for the partition function becomes&lt;/p&gt;
\begin{align}
&amp;Z_{N}^{(D)} = \left(\frac{\beta}{2 \pi i}\right)^{N} \left( \frac{\pi}{\beta} \right)^{DN/2} \int_{a-i\infty}^{a+i\infty} \cdots \int_{a-i\infty}^{a+i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \nonumber \\
&amp; \times \exp \left( D \left( \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) \right) \right) \exp \left( \frac{\beta}{4} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \right) \nonumber
\end{align}&lt;p&gt;where we have introduced the matrix notation $\boldsymbol{H} \in \mathbb{R}^{N \times D}$ to group the vectors $\{ \boldsymbol{h}_{i} \}$.&lt;/p&gt;
&lt;h2 id="steepest-descent-hunting-for-the-saddle"&gt;Steepest descent: hunting for the saddle&lt;/h2&gt;
&lt;p&gt;But there&amp;rsquo;s still $N$ complex integrals over the auxiliary variables $\{ t_{j} \}$ left to do. Can we avoid doing them? Maybe. Let&amp;rsquo;s rewrite our partition function as&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} = \left(\frac{\beta}{2 \pi i}\right)^{N} &amp;\left( \frac{\pi}{\beta} \right)^{DN/2} \int_{a-i\infty}^{a+i\infty} \cdots \int_{a-i\infty}^{a+i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \, \mathrm{e}^{D \varphi \left(\boldsymbol{t} \right) } \label{eq:partfunsteep}
\end{align}&lt;p&gt;with&lt;/p&gt;
\begin{align}
\varphi \left(\boldsymbol{t}; \beta, J_{ij} \right) = \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) + \frac{\beta}{4D} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \label{eq:varphi}
\end{align}&lt;p&gt;As $D \to \infty$, the
suggests that the partition function will be dominated by its largest contribution, i.e. in the neigbourhood of the maximum $\varphi(\boldsymbol{t^{*}})$ along the integration paths.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Hmm, this doesn&amp;rsquo;t quite seem right #1:&lt;/strong&gt; What does $D \to \infty$ even look like for the last term in Eq. \eqref{eq:varphi}? What does it mean for the input vectors $\{ \boldsymbol{h}_{i} \}$ to become infinite-dimensional? Good points, but let&amp;rsquo;s carry on.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;The saddle-point values $\boldsymbol{t^{*}}$ are obtained from the set of stationary conditions&lt;/p&gt;
\begin{align}
\frac{\partial \varphi \left( \boldsymbol{t} \right)}{\partial t_j} \Biggr\rvert_{t_j = t^{*}_{j}} = 0, \qquad j=1,\ldots,N \label{eq:statcond}
\end{align}
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Hmm, this doesn&amp;rsquo;t quite seem right #2:&lt;/strong&gt; In the single-variable case,
argues that $\varphi (t)$ is analytic for $\mathrm{Re}(t)&gt;0$ and that we should consider $\varphi (t)$ first for $t$ real and positive. For positive $\beta$ and non-zero magnetic field, the function tends to plus infinity as $t$ tends to either zero or infinity. Thus in between $\varphi(t)$ must have a &lt;em&gt;minimum&lt;/em&gt; at some positive value $t^{*}$ of $t$. Since $\varphi''(t) &gt; 0$ there is also only one such minimum. If we take the constant $a$ in the integral limits to be $t^{*}$, then along the (imaginary) integration path $\varphi (t)$ has a &lt;em&gt;maximum&lt;/em&gt; at $t=t^{*}$. We naively assume that this kind of saddle-point reasoning transfers to our case in several complex variables with $\varphi : \mathbb{C}^{N} \to \mathbb{C}$ where the equivalent of $\mathrm{Re}(t)&gt;0$ is to try to steer clear of the singularity at $\det \left( \boldsymbol{V} \right)=0$. We will check the numerical behaviour of our $\varphi$-function in
.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Expanding $\varphi$ around $\boldsymbol{t^{*}}$ and then taking the logarithm of Eq. \eqref{eq:partfunsteep} leads to&lt;/p&gt;
\begin{align}
\ln Z_{N}^{(D)} = \frac{DN}{2} \ln \left( \frac{\pi}{\beta} \right) + D \varphi \left( \boldsymbol{t^{*}} \right) + \ln R \nonumber
\end{align}&lt;p&gt;where we have collected all higher-order contributions and remaining nastiness in $R$. Following
, the free energy in the limit of large local dimension $D \to \infty$ then becomes&lt;/p&gt;
\begin{align}
-\beta f_{N}^{(\infty)} = \lim_{D \to \infty} D^{-1} \ln \left( Z_{N}^{(D)} / Z_{N}^{(D)}(0) \right) \nonumber
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
Z_{N}^{(D)}(0) = \left( \left(\pi\right)^{D/2} D^{(D-1)/2} / \Gamma \left(D/2\right) \right)^{N} \nonumber
\end{align}&lt;p&gt;is a normalization factor&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; accounting for the surface areas of the $(D-1)$-dimensional spheres with radius $\sqrt{D}$ associated to each and every spin degree of freedom. After applying
to the $\Gamma$-function in the normalization factor and doing some algebra, we end up with&lt;/p&gt;
\begin{align}
\boxed{-\beta f_{N}^{(\infty)} = - \frac{N}{2} - \frac{N}{2} \ln \left( 2\beta \right) + \varphi \left( \boldsymbol{t^{*}} \right)} \label{eq:afe}
\end{align}&lt;p&gt;where we have dropped the last term $\lim_{D \to \infty} D^{-1} \ln R$ assuming it tends to zero. Since $\varphi \left( \boldsymbol{t^{*}} \right) \propto N$, the last term actually also survives the limit $N \to \infty$.&lt;/p&gt;
&lt;h2 id="taking-stock-of-what-we-have-done"&gt;Taking stock of what we have done&lt;/h2&gt;
&lt;p&gt;We have derived a closed-form expression Eq. \eqref{eq:afe} for the approximate free energy of a vector-spin model in the limit of large local spin dimension. Let us take a brief moment to reflect on what we have done and touch on some tangential points.&lt;/p&gt;
&lt;h4 id="questioning-steepest-descent-and-the-large--limit"&gt;Questioning steepest descent and the large-$D$ limit&lt;/h4&gt;
&lt;p&gt;The result \eqref{eq:afe} is only sensible if steepest descent is a valid thing to do, which depends on how outrageous the landscape defined by the $\varphi$-function \eqref{eq:varphi} really is. More practically, we will also never &lt;em&gt;really&lt;/em&gt; let the vector-spin dimension $D$ tend towards infinity since our goal is to implement a numerical attention-like neural network module. So large but finite vector dimensions better behave as if they were sufficiently close to infinity. We will find out in
to what extent these assumptions are valid in practice.&lt;/p&gt;
&lt;h4 id="energy-based-models-and-effective-energy-functions"&gt;Energy-based models and effective energy functions&lt;/h4&gt;
&lt;p&gt;Let us take another look at our model&amp;rsquo;s partition function \eqref{eq:fullpartfun} from an energy-based perspective. For ease of notation, let us call the model parameters $\theta \equiv \{ J_{ij} \}$, the spins $\sigma \equiv \{ \boldsymbol{\sigma}_{i} \}$, and the external magnetic fields $h \equiv \{ \boldsymbol{h}_{i} \}$. We can schematically write our model&amp;rsquo;s partition function as&lt;/p&gt;
\begin{align}
Z_{\theta} \left( h \right) = \int \mathrm{d} \sigma \ \mathrm{e}^{ - E_{\theta}\left( \sigma, h \right) }
\end{align}&lt;p&gt;where $E_{\theta}\left( \sigma, h \right)$ denotes the energy function Eq. \eqref{eq:vectrandomising}. If we now introduce an energy-based model $p_{\theta} \left( \sigma, h \right) = \mathrm{e}^{-E_{\theta}\left( \sigma, h \right)} / Z_{\theta}$, we can define the marginal distribution&lt;/p&gt;
\begin{align}
p_{\theta} \left( h \right) = \frac{\int \mathrm{d} \sigma \ \mathrm{e}^{-E_{\theta}\left( \sigma, h \right)}}{Z_{\theta}} = \frac{\mathrm{e}^{-E_{\theta}\left( h \right)}}{Z_{\theta}} \label{eq:ph}
\end{align}&lt;p&gt;where the applied magnetic fields act as observables and the spins as latent variables. The effective energy $E_{\theta}\left( h \right)$ equals $E_{\theta}\left( h \right) = - \log \int \mathrm{d} \sigma \ \mathrm{e}^{-E_{\theta}\left( \sigma, h \right)} \approx - \log Z^{\ast}_{\theta} \left( h \right)$, where we have used the steepest-descent approximation for the integral. Taking the logarithm of Eq. \eqref{eq:ph}, we find that $\log p_{\theta} \left( h \right) \approx \log Z^{\ast}_{\theta} \left( h \right) - \log \int \mathrm{d} h \ Z^{\ast}_{\theta} \left( h \right)$.&lt;/p&gt;
&lt;h4 id="spin-glasses-and-mean-field-approximation"&gt;Spin glasses and mean-field approximation&lt;/h4&gt;
&lt;p&gt;Ordered systems have a long history in statistical mechanics. Couplings in these models often encode a translation-invariant lattice geometry, e.g. nearest-neighbour interactions between spins living on a $d$-dimensional hypercubic lattice. One reason for this focus is practical: the regularity in these systems enables mathematical physicists to deploy all kinds of tricks and make progress towards some kind of understanding. In contrast, disordered systems, like spin glasses, are a mess and studying them is all about
. From the perspective of spin glasses, we can summarize our approach as follows: we want to arrive at an approximate yet tractable mean-field spin-glass model where its couplings are treated as parameters learned from data&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Fully-connected models like Sherrington-Kirkpatrick spin-glass models (or Eq. \eqref{eq:vectrandomising}) naturally lead to mean-field theory because the couplings $J_{ij}$ encode long-range interactions where every other spin is just a hop away, see e.g.
. Intuitively, all-to-all interactions correspond to the mean-field limit of infinite spatial dimension. To see this, consider a spin in a local nearest-neighbour lattice model getting ever more neighbours as the spatial dimension grows: the notion of nearest neighbours melts away and all spins effectively become connected to each other&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;. Fully-connected non-local couplings and the limit of infinite spatial dimension are two sides of the same mean-field coin.&lt;/p&gt;
&lt;h1 id="implementing-approximate-free-energy-minimization"&gt;Implementing approximate free-energy minimization&lt;/h1&gt;
&lt;p&gt;In this section, we turn the equations of the previous section into the algorithmic backbone of a differentiable vector-spin model. We begin by sketching an approximate free-energy minimization algorithm. We then show how to wrap around the spin model to turn it into an attention module.&lt;/p&gt;
&lt;h2 id="the-algorithm-bold-moves-on-a-tricky-landscape"&gt;The algorithm: bold moves on a tricky landscape&lt;/h2&gt;
&lt;p&gt;Our goal is to compute the steepest-descent approximation of our model&amp;rsquo;s partition function in a differentiable way. Essentially, we need to solve the set of equations&lt;/p&gt;
\begin{align}
\frac{\partial \varphi \left( \boldsymbol{t} \right)}{\partial t_j} \Biggr\rvert_{t_j = t^{*}_{j}} = 0, \qquad j=1,\ldots,N
\end{align}&lt;p&gt;which corresponds to finding a value $\boldsymbol{t^{*}} = \mathrm{argmin}_{\boldsymbol{t}} \varphi \left( \boldsymbol{t} \right)$ for which the scalar function&lt;/p&gt;
\begin{align}
\varphi \left(\boldsymbol{t}; \beta, J_{ij} \right) = \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) + \frac{\beta}{4D} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \nonumber
\end{align}&lt;p&gt;attains its minimum, or, equivalently, we need to solve for the root of $\nabla \varphi \left( \boldsymbol{t} \right)$.&lt;/p&gt;
&lt;h4 id="initialization-and-normalization"&gt;Initialization and normalization&lt;/h4&gt;
&lt;p&gt;Until now we have not been explicit about the values of the couplings $\boldsymbol{J}$ and inputs $\boldsymbol{H}$. If we want to implement any of this, we have to be more careful. Recall that the energy function of our model looks like&lt;/p&gt;
\begin{equation}
E = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i}
\end{equation}&lt;p&gt;where all spins $\boldsymbol{\sigma}_{i}$ are fixed to norm $\sqrt{D}$. We&amp;rsquo;d like this energy to remain linearly proportional to the the number of lattice sites. Numerically, we observe that stable root-finding is possible when initializing the couplings according to
&lt;/p&gt;
\begin{equation}
J_{ij} \sim \mathcal{N} (0, 1/\sqrt{ND} )
\end{equation}&lt;p&gt;
The factor $1/\sqrt{N}$ can be explained from spin-glass mean-field theory&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt; whereas the $1/\sqrt{D}$ factor follows from additionally normalizing with respect to the vector dimension to ensure $\sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} \sim \mathcal{O}(N)$. One strategy to normalize the inputs $\boldsymbol{H}$ is to feed them into a layer normalization layer so that $\left\lVert \boldsymbol{h}_{i} \right\rVert \sim \mathcal{O}(\sqrt{D})$ and then explicitly dividing by $\sqrt{D}$ to make them $\mathcal{O}(1)$. A practical consequence of these initialization and normalization choices at the level of the energy function is that the $\varphi$-function changes to&lt;/p&gt;
\begin{align}
\varphi \left(\boldsymbol{t}; \beta, J_{ij} \right) = \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) + \frac{\beta}{4} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \label{eq:varphinorm}
\end{align}&lt;p&gt;where the prefactor in the last term changed since we decided on explicitly dividing the layer-normalized $\boldsymbol{H}$ by $1/\sqrt{D}$.&lt;/p&gt;
&lt;h4 id="implicit-layers-for-steepest-descent-root-finding"&gt;Implicit layers for steepest-descent root-finding&lt;/h4&gt;
&lt;p&gt;Let&amp;rsquo;s now find the root of the gradient of $\varphi$ in a differentiable way by combining
with a black-box root-finding algorithm like
, which requires access to both a function (the gradient of $\varphi$) and its gradient (the Jacobian of the gradient of $\varphi$). We could rely on automatic differentiation to calculate these gradients, but we just as well exploit the fact that we have an analytical expression Eq. \eqref{eq:varphinorm}. Grabbing a coffee and peeking at the
, we can figure out what happens&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;when we wiggle around $t_{i}$ (the gradient vector at $\boldsymbol{t}$)&lt;/li&gt;
&lt;/ul&gt;
\begin{align}
\left[ \nabla \varphi \left( \boldsymbol{t} \right) \right]_{i} = \beta - \frac{1}{2} \left[ \boldsymbol{V}^{-1} \right]_{ii} - \frac{\beta}{4} \left[ \boldsymbol{V}^{-T} \boldsymbol{H} \boldsymbol{H}^{T} \boldsymbol{V}^{-T} \right]_{ii} \nonumber
\end{align}&lt;ul&gt;
&lt;li&gt;when we wiggle around both $t_{i}$ and $t_{j}$ (the symmetric Hessian matrix at $\boldsymbol{t}$)&lt;/li&gt;
&lt;/ul&gt;
\begin{align}
\left[ \boldsymbol{J}(\nabla \varphi \left( \boldsymbol{t} \right)) \right]_{ij} = \frac{1}{2} &amp;\left[ \boldsymbol{V}^{-1} \odot \boldsymbol{V}^{-T} \right]_{ij} \nonumber \\
&amp;+ \frac{\beta}{4} \left[ \boldsymbol{V}^{-T} \boldsymbol{H} \boldsymbol{H}^{T} \boldsymbol{V}^{-T} \boldsymbol{V}^{-T} \odot \boldsymbol{I} \right]_{ij} \nonumber \\
&amp;+ \frac{\beta}{4} \left[ \boldsymbol{V}^{-T} \boldsymbol{V}^{-T} \boldsymbol{H} \boldsymbol{H}^{T} \boldsymbol{V}^{-T} \odot \boldsymbol{I} \right]_{ij} \nonumber
\end{align}&lt;p&gt;Given an initial guess $\boldsymbol{t_{0}} \in \mathbb{R}^{N}_{&gt;0}$ and input data $\boldsymbol{H} \in \mathbb{R}^{N \times D}$, we can now construct a differentiable root-solver which returns $\boldsymbol{t^{*}}$. It is important to keep in mind that the stationary value $\boldsymbol{t^{*}}$ actually depends on $\left(\beta, \boldsymbol{J}, \boldsymbol{H} \right)$ implicitly. Since we make use of implicit layers within an automatic differentation framework, these dependencies are kept track of and are included in the computational graph.&lt;/p&gt;
&lt;h4 id="fun-with-free-energies"&gt;Fun with free energies&lt;/h4&gt;
&lt;p&gt;Let&amp;rsquo;s test the algorithm by initializing a random vector-spin model and applying a random magnetic field at every site. For visualization purposes, we restrict the auxiliary variables to be effectively one-dimensional by defining $\boldsymbol{t} = t \boldsymbol{1}_{N}$ with just a single scalar parameter $t \in \mathbb{R}_{&gt;0}$. We can probe a &lt;code&gt;VectorSpinModel&lt;/code&gt; and get the approximate free energy for a given set of parameters and inputs by running the following script:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;afem.models&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;VectorSpinModel&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VectorSpinModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;t0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;afe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;t0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_afe&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;afe&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Inside the forward pass, the root $\boldsymbol{t^{*}}$ is computed and then fed into Eq. \eqref{eq:afe} to calculate the approximate free energy. We can verify that our algorithm is doing something sensible by sweeping across the auxiliary $t$-values and plotting $\varphi$ and its derivatives:&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-sweep-across-auxiliary-variable"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_4aa84911571ee37e.webp 320w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_f6f0b5f47ed137ab.webp 480w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_f752654317dd2531.webp 640w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_4aa84911571ee37e.webp"
width="640"
height="480"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Sweep across auxiliary variable
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;The region close to $t=0$ looks terrifying. In this regime, $t$ is likely not large enough to overshadow the largest eigenvalue of the couplings so we lose positive definiteness and its nice properties. Let&amp;rsquo;s try to stay away from that region by always initializing $\boldsymbol{t}_{0}$ sufficiently far from it. Depending on the parameters and initial guess provided to the solver, one can of course end up in less favourable landscapes where root-solving can become difficult due to zero gradients or extreme sensitivity to initial conditions. Fortunately, when the root-solving step fails, it tends to fail spectacularly.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s now sweep across inverse temperature $\beta$ to get some intuition. From the analytical expression of the free energy, we can deduce that for small $\beta$ (high temperature) the entropy term reigns while for large $\beta$ (low temperature) the energy terms take over.&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-sweep-across-inverse-temperature"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="alt text"
src="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/beta_sweep.gif"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Sweep across inverse temperature
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Finally, let&amp;rsquo;s lift the one-dimensional restriction on $\boldsymbol{t}$ and plot $\varphi (\boldsymbol{t})$ for two spins. In that case, $\boldsymbol{t}$ is also just two-dimensional so we can still visualize the optimization landscape.&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-two-dimensional-auxiliary-variables"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_7b965a283d029a52.webp 320w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_af3a40a7eaf8a93b.webp 480w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_cae17aa20904ce69.webp 545w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_7b965a283d029a52.webp"
width="545"
height="459"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Two-dimensional auxiliary variables
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h2 id="the-attention-module-probing-spins-with-data"&gt;The attention module: probing spins with data&lt;/h2&gt;
&lt;p&gt;In the previous section, we showed how to numerically compute the steepest-descent approximation of a vector-spin model&amp;rsquo;s partition function and hence its free energy. Since this approximation is fully differentiable, we can also take derivatives with respect to conjugate variables. Let&amp;rsquo;s use this observation to construct an attention module.&lt;/p&gt;
&lt;h4 id="spin-expectation-values"&gt;Spin expectation values&lt;/h4&gt;
&lt;p&gt;We can calculate spin expectation values or magnetizations from our partition function approximation by differentiating with respect to the applied magnetic fields:&lt;/p&gt;
\begin{align}
\langle \boldsymbol{\sigma}_{i} \rangle = \frac{\mathrm{d} \log Z \left( \boldsymbol{t}, \boldsymbol{H} \right)}{\mathrm{d} \boldsymbol{h}_{i}} = \frac{\partial \varphi}{\partial \boldsymbol{t}} \frac{\partial \boldsymbol{t}}{\partial \boldsymbol{h}_{i}} + \frac{\partial \varphi}{\partial \boldsymbol{h}_{i}} \label{eq:spinevgeneral}
\end{align}&lt;p&gt;If we evaluate the partition function approximation at the stationary point $\boldsymbol{t^{\ast}}$, the first term drops out because $\partial_{\boldsymbol{t}} \varphi \rvert_{\boldsymbol{t}=\boldsymbol{t^{\ast}}} = 0$. Assuming that the matrix $\boldsymbol{V}$ (and hence the couplings $\boldsymbol{J}$) do not depend on the inputs $\boldsymbol{H}$, the spin expectation value boils down to&lt;/p&gt;
\begin{align}
\langle \boldsymbol{\sigma}_{i} \rangle = \frac{\partial \varphi}{\partial \boldsymbol{h}_{i}} = \frac{\beta}{2} \sum_{j} \boldsymbol{V}^{-1}_{ij} \boldsymbol{h}_{j} \label{eq:spinev}
\end{align}&lt;p&gt;which, for every site, is just a weighted sum of inputs. In the language of transformers, Eq. \eqref{eq:spinev} resembles an update step where $\boldsymbol{V}^{-1}$ can be interpreted as a symmetric attention matrix. Expanding the matrix inverse reveals a residual connection as the zero-th order contribution&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Since the couplings are scalars at the level of the energy function Eq. \eqref{eq:vectrandomising}, getting terms to act on the hidden dimension seems to be impossible. But by considering couplings $\boldsymbol{J}(\boldsymbol{H})$ which do depend on inputs, additional terms can appear in Eq. \eqref{eq:spinev} propagating via dependencies in $\boldsymbol{V}$. Instead of calculating these gradients analytically, we should of course just let our automatic differentiation framework compute them for us.&lt;/p&gt;
&lt;h4 id="wrapping-around-the-spin-model"&gt;Wrapping around the spin model&lt;/h4&gt;
&lt;p&gt;At this point, we have done all the heavy lifting. All that remains is to write a wrapper so that we can use our module just like any other explicit attention module:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;afem.attention&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;VectorSpinAttention&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VectorSpinAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (1, 32, 128)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Inside the forward pass of &lt;code&gt;VectorSpinAttention&lt;/code&gt;, (normalized) inputs are sent to an internal &lt;code&gt;VectorSpinModel&lt;/code&gt; which solves for the saddle point $\boldsymbol{t^{*}}$ and then feeds it into the steepest descent partition function to calculate magnetizations according to Eq. \eqref{eq:spinevgeneral}.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s finish this section by discussing some of the peculiarities of our approach:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Stability and symmetry:&lt;/strong&gt; The root-finding is stable as long as $\det \boldsymbol{V} &gt; 0$, which ensures that $\boldsymbol{V}$ is nonsingular and which is garantueed as long as the quadratic form is positive definite. A quadratic form involving a general $\boldsymbol{V}$ (i.e. with nonsymmetric couplings $\boldsymbol{J}$) is positive definite iff its symmetric part has all positive eigenvalues. When this is no longer the case, things tend to blow up.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Scaling:&lt;/strong&gt; Our approach is kind of slow because calculating inverses scales as $\mathcal{O}\left(N^3\right)$. Yet there might be ways to approximate the slow parts of the algorithm similar to how vanilla transformers can be understood to approximate mean-field fixed-point equations&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Lack of permutation invariance:&lt;/strong&gt; Our model is not permutation invariant with the default choice of input-independent couplings: every spin has a role to play.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Input-dependent couplings:&lt;/strong&gt; Because our default model assumes coupling-independent couplings $\boldsymbol{J}$, Eq. \eqref{eq:spinev} features just a &amp;ldquo;token-mixing&amp;rdquo; attention operation. Channel-mixing terms can appear when we consider the physically very weird setup where the couplings are made dependent on the applied magnetic fields. One possible choice could be:
\begin{align}
\boldsymbol{J}(\boldsymbol{H}) = \frac{\tanh \left( \boldsymbol{H} \boldsymbol{Q} \boldsymbol{K}^T \boldsymbol{H}^T \cdot \sqrt{D} \right)}{\sqrt{ND}} \nonumber
\end{align}
where $\boldsymbol{Q}$ and $\boldsymbol{K}$ are linear transformations acting on the hidden dimension and where the scaling factors have been inserted because of the normalization conventions we discussed in
. We hypothesize that additional terms in the spin expectation value Eq. \eqref{eq:spinev} arising from input-dependent couplings might be related to channel-mixing feed-forward networks in transformer modules.&lt;/li&gt;
&lt;/ul&gt;
&lt;h4 id="comparison-with-vanilla-transformers"&gt;Comparison with vanilla transformers&lt;/h4&gt;
&lt;p&gt;In this final section, let&amp;rsquo;s summarize our approach on a high level by visually comparing it to vanilla transformers and deep equilibrium approaches.&lt;/p&gt;
&lt;img src="arch_vanilla_deq.png" alt="Vanilla transformer and deep equilibrium transformer" width="500px"/&gt;
&lt;p&gt;The vanilla transformer
(left above) is an explicit architecture which processes input sequences sequentially through a stack of transformer modules. Deep equilibrium transformers
(right above) compute the output of a transformer module by implicitly solving for the fixed point of $f(z, x) = z$ where $f$ denotes the explicit transformer module. Data is repeatedly inserted by adding it to the current iteration of $z$ inside the module until fixed-point convergence. The converged fixed point is considered the output of the module. Backpropagation through the iterations of the solver is avoided by using the implicit function theorem to calculate gradients directly at the equilibrium point. Instead of a stack of layers, there&amp;rsquo;s just a single layer.&lt;/p&gt;
&lt;p&gt;But deep equilibrium transformers still treat the transformer module as a black box. In
we looked for a physical spin-model interpretation of the deep equilibrium fixed-point procedure (left below). We argued how the update step of a vanilla transformer module resembled mean-field fixed-point equations of a vector-spin model, explaining the successful pattern of token-mixing, residual connections, normalization layers, and feed-forward or channel-mixing modules from a physical spin systems&amp;rsquo; perspective.&lt;/p&gt;
&lt;img src="arch_dia_afem.png" alt="Deep implicit attention and approximate free-energy minimization" width="600px"/&gt;
&lt;p&gt;In this work (right above), we continued on the path of spin expectation values but replaced solving mean-field fixed-point equations with directly taking derivatives of the steepest-descent partition function of a particular class of vector-spin models. The fixed-point procedure is replaced with a root-solving step to determine the steepest-descent partition function. The structure of our module&amp;rsquo;s output reveals the same successful transformer-like pattern of token-mixing (attention) and channel-mixing (feed-forward) interspersed with normalization layers and residual connections.&lt;/p&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;In this post, we introduced transformer modules as wrappers around statistical-mechanical vector-spin models. We used implicit layers to construct a class of approximate yet tractable vector-spin models whose couplings act as parameters that can be learned from data. We showed how these models can act as transformer-like attention modules by routing inputs to applied magnetic fields and returning spin expectation values derived from their steepest-descent partition function.&lt;/p&gt;
&lt;p&gt;By zooming out and approaching transformers from a tangential, statistical-mechanical point of view, we were able to develop a physical intuition of transformers that seems hard to arrive at when restricting oneself to perturbing explicit neural network architectures. Recognizing transformer modules as spin models in disguise might not only unify architectural variations but also elucidate the high-level architectural convergence and empirical success of transformers in deep learning.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021afem,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Transformers from Spin Models: Approximate Free Energy Minimization},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {October},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;We could have turned to the mean-field free energies associated with the adaptive TAP equations discussed in
, but we decided on attacking the problem from the steepest-descent angle on the full partition function.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;For example, see &lt;em&gt;
&lt;/em&gt; and &lt;em&gt;
&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;The original 1968 paper has a small typo here: the $\nu$ in the paper&amp;rsquo;s Eq. (23) should be $\nu^{1/2}$ for the surface area of a $\nu-1$-dimensional sphere with radius $R=\nu^{1/2}$ embedded in $\nu$ dimensions. Using the paper&amp;rsquo;s formula, an annoying $\ln \nu$ term won&amp;rsquo;t cancel out in the limiting free energy calculation.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;In contrast to spin glasses however, we do not (yet want to go full Bayesian and) treat the couplings as drawn from some kind of probability distribution. For now, we settle for obtaining point estimates of model parameters.&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;By promoting sparseness in the couplings, a model might become less mean-field-y, which might be one of the reasons behind the sucess of scaled &lt;code&gt;softmax&lt;/code&gt; attention in vanilla transformers.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;From
: &lt;em&gt;The mean-field limit to infinite dimensions or long-range interaction introduces a new large scale. To make the thermodynamic limit meaningful the dependence of the energy on this new large scale must be compensated by rescaling the non-local spin exchange so that the energy remains linearly proportional to the volume or the number of lattice sites (spins).&lt;/em&gt;&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;We can expand the right-hand side using a
to find
&lt;/p&gt;
\begin{align}
\boldsymbol{V}^{-1} &amp;= \left( \mathrm{diag} ( \boldsymbol{t} ) - \boldsymbol{J} \right)^{-1} = \sum_{k=0}^{\infty} \left( \mathrm{diag} \left( \boldsymbol{t}^{-1} \right) \boldsymbol{J} \right)^{k} \mathrm{diag} \left( \boldsymbol{t}^{-1} \right) \nonumber
\end{align}&lt;p&gt;
which converges if the largest absolute value of the eigenvalues of the matrix inside the power-brackets is less than 1. So the spin expectation value looks like a sum of contributions that mix and weigh inputs of different sites.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;As discussed previously in
. In that setting, calculating inverses was sidestepped by approximating part of the solution with a feed-forward neural network.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms</title><link>https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/</link><pubDate>Wed, 07 Apr 2021 15:17:17 +0100</pubDate><guid>https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (November 2021):&lt;/strong&gt; Consider reading
for a high-level overview of some of the ideas outlined in this post.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Code: A reference PyTorch implementation of the ideas outlined in this blog post is available in the repository
. Comments welcome.&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;To explore progress beyond the cage of softmax attention, we have previously looked at energy-based perspectives on attention mechanisms:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;The main take-away so far has been that you can think of softmax attention as implementing a single, big gradient step of some energy function and that training transformers is akin to meta-learning how to best tune a stack of attention and feed-forward modules to perform well on some auxiliary (meta-)task(s). But what can an energy-based perspective actually provide beyond quaint and hand-wavy statements like &lt;em&gt;implicit energy landscapes are sculpted every time you train a transformer&lt;/em&gt;?&lt;/p&gt;
&lt;p&gt;In this post, we approach attention in terms of the &lt;em&gt;collective response of a statistical-mechanical system&lt;/em&gt;. Attention is interpreted as an inner-loop fixed-point optimization step which returns the approximate response of a system being probed by data. This response is a differentiable compromise between the system&amp;rsquo;s internal dynamics and the data it&amp;rsquo;s being exposed to. To better respond to incoming data, outer-loop optimization steps can nudge the interactions and the self-organizing behaviour of the system.&lt;/p&gt;
&lt;p&gt;To implement our proposal, we combine old ideas and new technology to construct a family of attention mechanisms based on fixed points. We use
to solve a set of self-consistent mean-field equations of a vector generalization of the random Ising spin-model. By approximating these equations, we arrive at simplified update steps which mirror the vanilla transformer architecture. We conclude by showing how transformers can be understood from a mean-field theory perspective.&lt;/p&gt;
&lt;h1 id="mean-field-theory-for-disordered-systems"&gt;Mean-field theory for disordered systems&lt;/h1&gt;
&lt;p&gt;In physics,
is an approximation method to study models made up of many individual degrees of freedom that interact with each other. Mean-field theory approximates the effect of the environment on any given individual degree of freedom by a single, averaged effect, and thus reduces a many-body problem to an (effective) one-body problem. This is a drastic approximation. Whether mean-field theory a sensible thing to do depends on the problem and the properties of your variational ansatz.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Mean-field theory &amp;amp; variational methods:&lt;/strong&gt; From the point of view of variational methods, mean-field theory tries to approximate a complicated object (like a partition function of a statistical-mechanical system) by wiggling around the parameters of a tractable variational ansatz to get as close as possible to the real thing. You can picture this process as projecting down a complicated object living in a high-dimensional space to its shadow in an easier-to-handle subspace (&lt;em&gt;I can hear a mathematician fainting in the background&lt;/em&gt;). This effectively reduces the problem to optimizing for the best possible approximation within your variational class. A lot of mean-field machinery also shows up in probability theory, statistics, and machine learning where it appears in belief propagation, approximate variational inference, expectation propagation, etc.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;In the next two subsections, we introduce random Ising models and sketch a physics-inspired approach to deal with disordered models using mean-field theory. In
we will then generalize these results to vector spin degrees of freedom and propose two flavours of attention models.&lt;/p&gt;
&lt;h2 id="random-ising-models-or-boltzmann-machines-or-"&gt;Random Ising models (or Boltzmann machines or &amp;hellip;)&lt;/h2&gt;
&lt;p&gt;The random Ising model is a prototypical model in the study of spin glasses and disordered random systems, where it is often referred to as the
, famous for its replica-method solution by Giorgio Parisi in 1979. Its energy function with external field for $N$ classical, binary spin variables looks like&lt;/p&gt;
\begin{equation}
E = \sum_{i,j} J_{ij} S_{i} S_{j} + \sum_{i} x_{i} S_{i}, \label{eq:randomising}
\end{equation}&lt;p&gt;where the couplings $J_{ij}$ between degrees of freedom are randomly distributed according to some probability distribution and self-interactions are absent ($J_{ii} = 0$). The external magnetic fields $x_{i}$ provide a preferential direction of alignment at every local site. Since the elements in the coupling matrix can have both negative and positive signs, the system is said to have both frustrated ferro- as well as antiferromagnetic couplings. The model defined by \eqref{eq:randomising} is also known as a
or a
.&lt;/p&gt;
&lt;p&gt;In contrast with disordered systems, we expect the couplings in the context of artificial neural networks to no longer be randomly drawn from a distribution but to reflect structure and organization between spins after being exposed to data. The system should self-organize in order to better respond to incoming data.&lt;/p&gt;
&lt;p&gt;A cartoon of a spin configuration of a 7-spin system looks something like
&lt;img src="binary_ising.png" alt="Random Ising model configuration with binary spins" width="250px"/&gt;
where we have only drawn the connections strongest in absolute value. It&amp;rsquo;s helpful to think of classical spin degrees of freedom as arrows. For vector spins, we can imagine lifting the up/down restriction and letting the arrows rotate freely.&lt;/p&gt;
&lt;h2 id="adaptive-thoulessandersonpalmer-mean-field-theory"&gt;Adaptive Thouless&amp;ndash;Anderson&amp;ndash;Palmer mean-field theory&lt;/h2&gt;
&lt;p&gt;One of the approaches physicists have come up with to tackle disordered random systems with pairwise interactions like those in Eq. \eqref{eq:randomising} is
. The TAP equations improve mean-field theory results by adding a so-called &lt;em&gt;Onsager self-correction term&lt;/em&gt; calculated from the couplings&amp;rsquo; distribution.&lt;/p&gt;
&lt;p&gt;
adapted this method to probabilisic modeling to be able to deal with scenarios where the distribution of the couplings between spins is not known a priori. To compensate for the lack of knowledge of the couplings distribution, they introduced a self-consistent computation to adapt the Onsager correction to the &lt;em&gt;actual&lt;/em&gt; couplings using the cavity method and linear response relations. We will sketch the adaptive TAP approach below but refer to
and
for more details and derivations.&lt;/p&gt;
&lt;h3 id="single-site-partition-function-from-cavity-method"&gt;Single-site partition function from cavity method&lt;/h3&gt;
&lt;p&gt;The adaptive TAP equations can be derived using the cavity method, where a cavity field distribution is introduced to rewrite the marginal distributions of the spins. The cavity corresponds to the &amp;ldquo;hole&amp;rdquo; left by removing a single spin. By assuming a Gaussian cavity distribution in the large connectivity limit, one can show that the single-site partition function looks like&lt;/p&gt;
\begin{equation}
Z_{0}^{(i)} = \int \mathrm{d} S \ \rho_{i}\left(S\right) \exp \left[ S \left( a_{i} + x_{i} \right) + \frac{V_{i} S^2}{2} \right]
\end{equation}&lt;p&gt;where the $a_i$ denote &lt;em&gt;cavity means&lt;/em&gt; and the $V_i$ &lt;em&gt;cavity variances&lt;/em&gt;. The single-site partition function can be integrated to yield an explicit expression after choosing well-behaved priors $\rho_{i}(S)$ for the spins. For binary spins $S=\pm 1$, we can pick $\rho_{i}(S)=\frac{1}{2}\left( \delta(S-1) + \delta(S+1) \right)$ to find&lt;/p&gt;
\begin{equation}
Z_{0}^{(i)} = \cosh \left( a_{i} + x_{i} \right). \label{eq:partfunbinaryspins}
\end{equation}&lt;h3 id="cavity-means-and-onsager-correction-term"&gt;Cavity means and Onsager correction term&lt;/h3&gt;
&lt;p&gt;The cavity means can be shown to be given by
&lt;/p&gt;
\begin{equation}
a_{i} = \sum_{j} J_{ij} \langle S_{j} \rangle - V_{i} \langle S_{i} \rangle. \label{eq:cavitymean}
\end{equation}&lt;p&gt;where the last term is the &lt;em&gt;Onsager correction term&lt;/em&gt;, a self-correction term for every spin which depends on the cavity variances.&lt;/p&gt;
&lt;h3 id="cavity-variances-and-linear-response"&gt;Cavity variances and linear response&lt;/h3&gt;
&lt;p&gt;The cavity variances are determined self-consistently, i.e. by calculating the same quantity in two different ways and demanding the obtained expressions to be equal. To do this, we introduce the matrix of susceptibilities&lt;/p&gt;
\begin{equation}
\chi_{ij} = \langle S_{i} S_{j} \rangle - \langle S_{i} \rangle \langle S_{j} \rangle = \frac{\partial^2}{\partial x_{i}\partial x_{j}} \log Z_{0}^{(i)}
\end{equation}
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;The susceptibility matrix $\chi_{ij}$ is a covariance matrix and should thus be positive semi-definite, which is criterion for the mean-field solution be consistent. As soon this property is lost, the fixed-point procedure will no longer be stable.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Its diagonal elements $\chi_{ii}$ can be obtained both from the explicit calculation of the spin variances from the partition function&lt;/p&gt;
\begin{equation}
\chi_{ii} = \langle S_{i}^2 \rangle - \langle S_{i} \rangle^2 = \frac{\partial^2}{\partial x_{i}^2} \log Z_{0}^{(i)} \label{eq:chiii}
\end{equation}&lt;p&gt;but also from a linear response calculation assuming fixed $V_i$,&lt;/p&gt;
\begin{align}
\chi_{ij} = \frac{\partial \langle S_{i} \rangle}{\partial x_{j}} = \frac{\partial \langle S_{i} \rangle}{\partial x_{i}} \left( \delta_{ij} + \sum_{k} \left( J_{ik} - V_{k} \delta_{ik} \right) \chi_{kj} \right) \label{eq:chiijlinrespexp}
\end{align}&lt;p&gt;which can be solved for $\chi_{ij}$ to yield
&lt;/p&gt;
\begin{equation}
\chi_{ij} = \left[ \left( \boldsymbol{\Lambda} - \boldsymbol{J} \right)^{-1} \right]_{ij} \label{eq:chiijlinresp}
\end{equation}&lt;p&gt;
where
&lt;/p&gt;
\begin{align}
\boldsymbol{\Lambda} = \mathrm{diag} \left( \Lambda_1, \ldots, \Lambda_{N} \right),\\\\
\Lambda_i = V_i + \left( \frac{\partial \langle S_{i} \rangle}{\partial x_{i}} \right)^{-1}.
\end{align}&lt;p&gt;The cavity variances $V_i$ are then determined by equating \eqref{eq:chiii} to the diagonal elements of \eqref{eq:chiijlinresp} and solving the following consistency condition for $V_i$
&lt;/p&gt;
\begin{equation}
\frac{1}{\Lambda_i - V_i} = \left[ \left( \boldsymbol{\Lambda} - \boldsymbol{J} \right)^{-1} \right]_{ii}. \label{eq:viselfcons}
\end{equation}&lt;p&gt;Given updated values for the cavity means $a_i$ and the cavity variances $V_i$, spin means and spin variances can then be updated as follows:&lt;/p&gt;
\begin{align}
\langle S_{i} \rangle &amp;= \frac{\partial}{\partial x_{i}} \log Z_{0}^{(i)} (x_{i}, a_{i}, V_{i}),\\\\
\langle S_{i}^2 \rangle - \langle S_{i} \rangle^2 &amp;= \frac{\partial^2}{\partial x_{i}^2} \log Z_{0}^{(i)} (x_{i}, a_{i}, V_{i}),
\end{align}&lt;p&gt;These equations reduce to explicit expressions given an explicit expression for $Z_{0}^{(i)}$. For the binary-spin partition function \eqref{eq:partfunbinaryspins} where $S=\pm 1$, we get a set of fixed-point equations for the spin means that look like&lt;/p&gt;
\begin{equation}
\langle S_{i} \rangle = \tanh \left( \sum_{j} J_{ij} \langle S_{j} \rangle - V_{i} \langle S_{i} \rangle + x_{i} \right)
\end{equation}&lt;p&gt;with spin variances $\chi_{ii} = 1 - \langle S_{i} \rangle^2$.&lt;/p&gt;
&lt;h1 id="attention-as-a-fixed-point-method"&gt;Attention as a fixed-point method&lt;/h1&gt;
&lt;p&gt;In this section, we attempt to generalize the mean-field equations obtained in the previous section to random Ising-like models with vector spin degrees of freedom. We then recognize the physical system as an attention model and provide both a slow, explicit implementation and a faster, neural one.&lt;/p&gt;
&lt;h2 id="generalizing-spin-models-to-vector-degrees-of-freedom"&gt;Generalizing spin models to vector degrees of freedom&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s return to our Ising model cartoon and replace the scalar spin degrees of freedom $S_i$ at every site with vectors $\boldsymbol{S}_i \in \mathbb{R}^d$, which we visualize using arrows below&lt;/p&gt;
&lt;img src="featured.png" alt="Random Ising model configuration with vector spins" width="250px"/&gt;
&lt;p&gt;Let&amp;rsquo;s consider a system of $N$ $d$-dimensional spins and let&amp;rsquo;s label site indices with $i,j,\ldots$ and internal vector-space indices with Greek letters $\alpha,\beta,\ldots$. We let the coupling weight matrix become a tensor $\boldsymbol{J}_{ij} = J_{ij}^{\alpha\beta}$ (matrices coupling every pair of sites) and remove self-couplings by enforcing the couplings&amp;rsquo; block-diagonal to be zero. Additionally, we can symmetrize both the internal dimension and the sites to end up with $N(N-1)/2$ times $d(d+1)/2$ effective free parameters for the couplings. If we also turn the external fields into vectors, we obtain a vector generalization of Eq. \eqref{eq:randomising}:&lt;/p&gt;
\begin{equation}
E = \sum_{i,j} \boldsymbol{S}_{i}^{T} \boldsymbol{J}_{ij} \boldsymbol{S}_{j} + \sum_{i} \boldsymbol{X}_{i} \cdot \boldsymbol{S}_{i}. \label{eq:vectrandomising}
\end{equation}&lt;h2 id="deep-implicit-attention-attention-as-a-collective-response"&gt;Deep implicit attention: attention as a collective response&lt;/h2&gt;
&lt;p&gt;Remember that our goal is to understand attention as the collective response of a statistical-mechanical system. Let&amp;rsquo;s now relate vector models like Eq. \eqref{eq:vectrandomising} to attention models by treating the external magnetic fields $\boldsymbol{X}_{i}$ as input data. Batches of sequences applied to every site act as probes for the system, pushing its behaviour into a certain direction. The system&amp;rsquo;s mean-field average magnetizations $\langle \boldsymbol{S}_{i} \rangle$ are an approximation of the collective response at every site: what is the expected value of this particular vector spin? We interpret solving mean-field equations for $\langle \boldsymbol{S}_{i} \rangle$ in the presence of input injections $\boldsymbol{X}_{i}$ as an attention operation. If the whole system is differentiable, we can tune the couplings $\boldsymbol{J}_{ij}$ in an outer-loop optimization to steer the system&amp;rsquo;s behaviour to better&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; respond to future incoming data.&lt;/p&gt;
&lt;h2 id="slow-and-explicit-solving-the-adaptive-tap-equations"&gt;Slow and explicit: solving the adaptive TAP equations&lt;/h2&gt;
&lt;p&gt;What changes do we have to make to the adaptive TAP mean-field equations to turn them into a vector-based attention module and how can we implement them? Let&amp;rsquo;s explicitly enumerate the objects introduced in
together with their (generalized) tensor shapes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;Iteratively determined fixed-point variables&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Spin means $\langle \boldsymbol{S}_{i} \rangle = \left[ \langle \boldsymbol{S}_{i} \rangle \right]^{\alpha}$ &lt;code&gt;(batch_size, N, d)&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;Cavity variances $\boldsymbol{V}_{i} = V_{i}^{\alpha\beta}$ &lt;code&gt;(N, d, d)&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Other variables calculated during fixed-point iteration&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Cavity means $\boldsymbol{a}_{i} = a_{i}^{\alpha}$ &lt;code&gt;(batch_size, N, d)&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;Spin variances $\langle \boldsymbol{S}_{i}^2 \rangle - \langle \boldsymbol{S}_{i} \rangle^2 = \boldsymbol{\chi}_{ii} = \chi_{ii}^{\alpha\beta}$ &lt;code&gt;(N, d, d)&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;For every site, the scalar spin and cavity variances have turned into $d \times d$ (inverse) covariance matrices on the level of the local dimension. Note that the &amp;ldquo;system properties&amp;rdquo; in the above list have no batch size: their values are identical across all examples and capture the properties of the system irrespective of the input injections $\boldsymbol{X}_i$.&lt;/p&gt;
&lt;p&gt;The vector translation of the single-site partition function looks like&lt;/p&gt;
\begin{equation}
Z_{0}^{(i)} = \int \mathrm{d}^{d} \boldsymbol{S} \ \rho_{i}\left(\boldsymbol{S}\right) \exp \left[ \boldsymbol{S} \cdot \left( \boldsymbol{a}_{i} + \boldsymbol{X}_{i} \right) + \frac{1}{2} \boldsymbol{S}^T \boldsymbol{V}_{i} \boldsymbol{S} \right]
\end{equation}&lt;p&gt;where&lt;/p&gt;
\begin{equation}
\boldsymbol{a}_{i} = \sum_{j} \boldsymbol{J}_{ij} \langle \boldsymbol{S}_{j} \rangle - \boldsymbol{V}_{i}\langle \boldsymbol{S}_{i} \rangle. \label{eq:veccavmeans}
\end{equation}&lt;p&gt;Spin means and variances are then computed from&lt;/p&gt;
\begin{equation}
\langle \boldsymbol{S}_{i} \rangle = \frac{\partial}{\partial\boldsymbol{X}_{i}} \log Z_{0}^{(i)} (\boldsymbol{X}_{i}, \boldsymbol{a}_{i}, \boldsymbol{V}_{i})
\end{equation}\begin{equation}
\langle \boldsymbol{S}_{i}^2 \rangle - \langle \boldsymbol{S}_{i} \rangle^2 = \frac{\partial^2}{\partial\boldsymbol{X}_{i}^2} \log Z_{0}^{(i)} (\boldsymbol{X}_{i}, \boldsymbol{a}_{i}, \boldsymbol{V}_{i})
\end{equation}&lt;p&gt;As a spin prior $\rho_{i}\left(\boldsymbol{S}\right)$, we pick a simple diagonal multivariate Gaussian $\mathcal{N} \left( \boldsymbol{\mu} = \boldsymbol{0}_{d}, \boldsymbol{\Sigma}= \boldsymbol{1}_{d \times d} \right)$ at every site, leading to the explicit equations:&lt;/p&gt;
\begin{equation}
\langle \boldsymbol{S}_{i} \rangle = \left( \boldsymbol{\Sigma}^{-1} - \boldsymbol{V}_{i} \right)^{-1} \left( \boldsymbol{a}_{i} + \boldsymbol{X}_{i} \right)
\end{equation}\begin{equation}
\langle \boldsymbol{S}_{i}^2 \rangle - \langle \boldsymbol{S}_{i} \rangle^2 = \left( \boldsymbol{\Sigma}^{-1} - \boldsymbol{V}_{i} \right)^{-1}
\end{equation}&lt;h3 id="generalizing-the-cavity-variance-calculation"&gt;Generalizing the cavity variance calculation&lt;/h3&gt;
&lt;p&gt;The cavity variance computation can be done by generalizing Eqs. \eqref{eq:chiijlinrespexp}&amp;ndash;\eqref{eq:chiijlinresp} and solving the following system of equations for $\boldsymbol{\chi}_{ij}$,&lt;/p&gt;
\begin{equation}
\left( \delta_{ik} \otimes \boldsymbol{1}_{d} - \boldsymbol{\Sigma}_{i} \boldsymbol{J}_{ik} + \boldsymbol{\Sigma}_{i} \boldsymbol{V}_{i} \delta_{ik} \right)\boldsymbol{\chi}_{kj} = \boldsymbol{\Sigma}_{i} \delta_{ij}
\end{equation}&lt;p&gt;The generalization of the self-consistency condition Eq \eqref{eq:viselfcons} is then obtained by solving $\boldsymbol{\chi}_{ii} \boldsymbol{V}_{i} = \boldsymbol{\chi}_{ii} \boldsymbol{\Lambda}_{i} - \boldsymbol{1}_{N \times d \times d}$ for $\boldsymbol{V}_{i}$, where $ \boldsymbol{\Lambda}_{i} = \boldsymbol{V}_{i} + \boldsymbol{\Sigma}^{-1}$ is computed using the current values of $\boldsymbol{V}_{i}$. The price to pay for this added complexity is a computational cost of $O(N^3d^3)$ and an excruciatingly slow backward pass. The algorithm works, but it ain&amp;rsquo;t pretty.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Implementation:&lt;/strong&gt; To avoid &lt;code&gt;torch.solve&lt;/code&gt; crashing on singular matrices during the fixed-point calculation, we found it crucial for stability and learning behaviour to initialize the couplings $J_{ij}^{\alpha\beta} \sim \mathcal{N}(0, \sigma^2)$ with small values $\sigma^2 = 1 / (N*d^2)$ to ensure $|J| \sim \mathcal{O}(1)$. It&amp;rsquo;s also beneficial if the sources satisfy $|\boldsymbol{X}_{i}| \sim \mathcal{O}(1)$ so that terms are balanced in the update step, all together adding up to $\mathcal{O}(1)$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 id="fast-and-neural-parametrizing-the-onsager-self-correction-term"&gt;Fast and neural: parametrizing the Onsager self-correction term&lt;/h2&gt;
&lt;p&gt;Can we somehow approximate the slow and explicit calculation of the cavity variances? Since $\boldsymbol{z}^{*} = \left( \langle \boldsymbol{S}_{i}^{*} \rangle, \boldsymbol{V}_{i}^{*} \right)$ at the fixed point, the Onsager self-correction term in Eq. \eqref{eq:veccavmeans} converges to a constant vector $\boldsymbol{V}_{i}^{*}\langle \boldsymbol{S}_{i}^{*} \rangle$ for every site. We propose to make a bold move by getting rid of the cavity variables altogether and reducing the equations for the fixed-point update step to&lt;/p&gt;
\begin{equation}
\langle \boldsymbol{S}_{i} \rangle = \sum_{j} \boldsymbol{J}_{ij} \langle \boldsymbol{S}_{j} \rangle - f_{\theta} \left( \langle \boldsymbol{S}_{i} \rangle \right) + \boldsymbol{X}_{i}, \label{eq:diaupdate}
\end{equation}&lt;p&gt;where $f_{\theta}$ is a neural network parametrizing the action of the cavity variances on the spin means. Since the parameters $\theta$ stay fixed during the inner-loop fixed-point calculation, we have effectively lifted the optimization of the self-correction term to the outer-loop, which also optimizes the weights $\boldsymbol{J}_{ij}$.&lt;/p&gt;
&lt;p&gt;All of this starts to look an awful lot like a transformer module. Before discussing an explicit comparison in
, let&amp;rsquo;s finish this section with a simple example model.&lt;/p&gt;
&lt;h3 id="simple-example-mnist"&gt;Simple example: MNIST&lt;/h3&gt;
&lt;p&gt;A simple image classification model for MNIST using a convolutional feature extractor and a deep implicit attention layer could look something like&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MNISTNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MNISTNet&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_patch_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Conv2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 26 x 26&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MaxPool2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 12 x 12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Conv2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 10 x 10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MaxPool2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 4 x 4&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s1"&gt;&amp;#39;b c h w -&amp;gt; b (h w) c&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;deq_atn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;DEQFixedPoint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;DEQMeanFieldAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;weight_sym_internal&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;weight_sym_sites&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;lin_response&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;anderson&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_fwd_max_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;40&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_fwd_tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_bwd_max_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;40&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_bwd_tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_patch_embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;cls_tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls_token&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;cls_tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;deq_atn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The ViT-style classification token is interpreted as an additional site in the system, which is probed with a learnable input injection that is shared across examples. The model uses the classification token&amp;rsquo;s output response to do the final classification. The system has to self-organize its behaviour so that the classification token gets all the information it needs.&lt;/p&gt;
&lt;img src="vit_mnist.gif" alt="ViT-style model with deep implicit attention layer on MNIST" width="500px"/&gt;
&lt;p&gt;You can
this small model (26k parameters) on MNIST to find a test set accuracy hovering around 99.1%. The animation above shows a graph reflecting the (directed) connection strengths between spins during training as measured by the Frobenius norms of the matrices $\boldsymbol{J}_{ij}$. Almost all major organization of connections is seen to happen in the first few iterations. One imagines the model getting frustrated at zeros which &lt;em&gt;really&lt;/em&gt; look like nines and just flat-out refusing to remember edge cases out of spite.&lt;/p&gt;
&lt;h1 id="a-mean-field-theory-perspective-on-transformers"&gt;A mean-field theory perspective on transformers&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s conclude this post by applying the mean-field theory perspective on attention to the transformer architecture. Schematically, a vanilla transformer module looks like&lt;/p&gt;
&lt;img src="vanilla_transformer_module.png" alt="Vanilla transformer module" width="200px"/&gt;
&lt;p&gt;which consists of an attention module acting on all vectors in the sequence input followed by a feed-forward layer acting &amp;ldquo;locally&amp;rdquo; across individual vectors in the sequence, mixed with some residual connections and layer normalizations.&lt;/p&gt;
&lt;h2 id="parametrizing-the-couplings-sparse-graph-structure-from-inputs"&gt;Parametrizing the couplings: sparse graph structure from inputs&lt;/h2&gt;
&lt;p&gt;Transformers can be interpreted as fully-connected graph neural networks acting on sets of vectors. Inside an attention module, the row-stochastic attention matrix corresponds to a particular parametrization of the couplings&lt;/p&gt;
\begin{equation}
J_{ij} = \left[\mathrm{softmax}\left( \frac{\boldsymbol{X} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{X}^{T}}{\sqrt{d}} \right)\right]_{ij}. \label{eq:softmaxcouplings}
\end{equation}&lt;p&gt;which swaps storing explicit coupling weights for parameters of linear query-key transformations. By dynamically determining the connectivity of the sites based on the inputs $\boldsymbol{X}$ according to Eq. \eqref{eq:softmaxcouplings}, the coupling weights are no longer completely free parameters. The introduction of queries and keys can be seen as a neural network approach to &amp;ldquo;amortizing&amp;rdquo; the coupling tensor while the softmax temperature promotes sparsity. Multiple attention heads correspond to imposing a block-diagonal structure in the hidden dimensions of the couplings: the dot product gets cut into disjoint pieces, one for each attention head.&lt;/p&gt;
&lt;h2 id="softmax-attention-does-a-single-naive-mean-field-update-step"&gt;Softmax attention does a single, naive mean-field update step&lt;/h2&gt;
&lt;p&gt;Looking at the update step \eqref{eq:diaupdate} and the softmax couplings \eqref{eq:softmaxcouplings}, we observe that the softmax attention module does a single, naive mean-field update step without a self-correction term. Ignoring layer normalizations, the attention update step for every input vector looks like&lt;/p&gt;
\begin{equation}
\boldsymbol{X}'_{i} = \sum_{j} \left[ \mathrm{softmax} \left( \frac{\boldsymbol{X} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{X}^{T}}{\sqrt{d}} \right) \right]_{ij} \left[ \boldsymbol{X} \boldsymbol{W}_{\boldsymbol{V}} \right]_{j} + \boldsymbol{X}_{i}, \nonumber
\label{eq:vanilla-attention}
\end{equation}&lt;p&gt;where, crucially, the residual connection is responsible for adding the source term to the update step. Without a residual connection, the applied magnetic field is effectively turned off and the signal would only be able to propagate via the coupling term.&lt;/p&gt;
&lt;h2 id="feed-forward-layer-corrects-naive-mean-field-update"&gt;Feed-forward layer corrects naive mean-field update&lt;/h2&gt;
&lt;p&gt;Looking at the Onsager self-correction term $f_{\theta} \left( \langle \boldsymbol{S}_{i} \rangle \right)$ in the update step \eqref{eq:diaupdate}, we observe that the full transformer attention module emerges when we substitute $\langle \boldsymbol{S}_{i} \rangle$ for its naive mean-field value, leading to&lt;/p&gt;
\begin{equation}
\mathrm{Attention}(\boldsymbol{X})_{i} = \boldsymbol{X}'_{i} + \mathrm{FeedForward}\left( \boldsymbol{X}'_{i} \right),
\end{equation}&lt;p&gt;with $\boldsymbol{X}'_{i}$ defined above. Again, the residual connection appears to be crucial for the structure of the mean-field theory equations to match the vanilla transformer module&amp;rsquo;s architecture. As previously discussed in
, we hypothesize that feed-forward networks in transformer modules &amp;ldquo;amortize&amp;rdquo; the linear response self-corrections.&lt;/p&gt;
&lt;h2 id="mean-field-theory-framework-for-transformer-architectures"&gt;Mean-field theory framework for transformer architectures&lt;/h2&gt;
&lt;p&gt;Within the general mean-field (or
) structure outlined above, there is considerable freedom in parametrizing the interaction and self-correction terms. Most transformer papers parametrize the self-correction terms with a feed-forward layer, i.e. some variation of an MLP. In
the authors went even further and dropped the softmax parametrization of the interaction term to approximate the full action of summing over couplings with an MLP as well. Related papers like
,
, and
can all be considered as explorations of different parametrizations of the mean-field interaction terms. In the large-scale regime, it seems like the softmax attention module can be swapped for just about any function which mixes tokens as long as the structure of residual connections and self-correction terms is preserved.&lt;/p&gt;
&lt;h2 id="comparison-with-energy-based-perspective"&gt;Comparison with energy-based perspective&lt;/h2&gt;
&lt;p&gt;In a previous post on
, we introduced a picture of attention modules in transformers as stacks of energy functions which are defined dynamically at every layer depending on the outputs of the previous layer (so ultimately on the inputs of the first layer). Looking back, this interpretation feels kind of forced and is also unable to explain the presence of skip connections and fully-connected layers surrounding the attention modules. The mean-field perspective seems more interesting since it (1) relies on just one layer (one energy function) whose fixed-point (an infinite amount of &amp;ldquo;layers&amp;rdquo;) gets calculated, and (2) explains the presence of skip connections (source terms) and fully-connected layers (amortized self-correction terms).&lt;/p&gt;
&lt;h1 id="conclusion-and-outlook"&gt;Conclusion and outlook&lt;/h1&gt;
&lt;p&gt;We have shown how attention can be understood as the mean-field response of Ising-like spin systems being probed by data. By thinking of incoming data as applied magnetic fields and the output of attention modules as spin expectation values, attention can be interpreted as a fixed-point optimization process solving for a compromise between a system&amp;rsquo;s internal dynamics and the data it&amp;rsquo;s being exposed to. Since the whole system is differentiable, we can optimize the interaction weights in an outer loop to nudge the system&amp;rsquo;s behaviour.&lt;/p&gt;
&lt;p&gt;We have also seen how transformers fit into the mean-field theory framework. For scalability, transformers introduce two additional constraints/approximations on top of the mean-field approximation: (1) replacing explicit couplings with parametrized couplings that are dynamically computed from the input via linear transformations (softmax query-key-value attention), and (2) replacing the expensive self-consistent computation of Onsager self-correction terms with a neural network (feed-forward layer).&lt;/p&gt;
&lt;p&gt;Looking ahead, the methods introduced in this post could provide ways to implicitly train mean-field approximations of Boltzmann machines and have them serve as distributed attention modules in larger interconnected systems. To go beyond mean-field approaches, it could be interesting to look at tensor network approaches. Conceptually, the physical interpretation of attention as an interacting many-body system modulating its behaviour by &lt;em&gt;learning to respond to being driven in particular ways&lt;/em&gt; is fun to think about.&lt;/p&gt;
&lt;h1 id="related-work"&gt;Related work&lt;/h1&gt;
&lt;p&gt;A non-exhaustive list of references and inspiration includes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;On deep equilibrium models:
(2019) by Shaojie Bai, Zico Kolter, Vladlen Koltun and
of the
by Zico Kolter, David Duvenaud, and Matt Johnson&lt;/li&gt;
&lt;li&gt;On the adaptive Thouless-Anderson-Palmer (TAP) mean-field approach in disorder physics:
(2001) by Manfred Opper and Ole Winther&lt;/li&gt;
&lt;li&gt;On variational inference, iterative approximation algorithms, expectation propagation, mean-field methods and belief propagation:
(2014) by Jack Raymond, Andre Manoel, Manfred Opper&lt;/li&gt;
&lt;li&gt;On Boltzmann machines and mean-field theory:
(1998) by H. J. Kappen and
F. B. Rodríguez and
(1998) by Toshiyuki Tanaka&lt;/li&gt;
&lt;li&gt;On approximate message passing (AMP) methods in statistics:
(2021) by Oliver Y. Feng, Ramji Venkataramanan, Cynthia Rush, Richard J. Samworth: the example on page 2 basically describes how transformers implement approximate message passing: an iterative algorithm with a &amp;ldquo;denoising&amp;rdquo; step (attention) followed by a &amp;ldquo;memory term&amp;rdquo; or Onsager correction term (feed-forward layer)&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021deepimplicitattention,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {May},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Whatever &amp;ldquo;better&amp;rdquo; means depends on the system&amp;rsquo;s (meta-)loss function, e.g. predicting corrupted tokens BERT-style or aligning representations to a teacher BYOL/DINO-style.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Attention as Energy Minimization: Visualizing Energy Landscapes</title><link>https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/</link><pubDate>Wed, 17 Mar 2021 22:36:17 +0100</pubDate><guid>https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/</guid><description>&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;📓 Colab notebook available
. Comments welcome.&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Recent work &lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; has shown that the softmax-attention update step in transformer models can be intepreted as a one-step gradient update or &amp;ldquo;inference&amp;rdquo; step of a judiciously chosen energy function. An overview of these ideas can be found in previous blog posts:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;The goal of this educational blog post is to explicitly show how vanilla softmax attention is related to energy minimization approaches and how the former can be substituted for the latter. For pedagogical purposes, we will focus purely on the attention operation. However, for transformer models to perform well in practice, it is necessary to wrap attention in residual connections and point-wise feedforward processing layers, see e.g.
.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Summary:&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;We provide a pedagogical energy-based attention module that stays as close as possible to vanilla softmax attention for ease of comparison.&lt;/li&gt;
&lt;li&gt;We walk through the correspondence between modern Hopfield networks and vanilla softmax attention by gradually adding complexity.&lt;/li&gt;
&lt;li&gt;We present visualizations of energy landscapes and trajectories associated to attention update steps for two-dimensional toy patterns.&lt;/li&gt;
&lt;/ol&gt;
&lt;h1 id="prelude-pattern-terminology"&gt;Prelude: pattern terminology&lt;/h1&gt;
&lt;p&gt;Transformer literature almost exclusively talks about queries, keys, and values. For self-attention, these are all obtained from different linear transformations acting on the same set of &lt;em&gt;input patterns&lt;/em&gt;. For cross-attention, only the queries derive from the &lt;em&gt;input patterns&lt;/em&gt;; the keys and values are obtained from a different set of &lt;em&gt;context patterns&lt;/em&gt;: think of a decoder architecture attending to encoded translations or the
model attending to multimodal input.&lt;/p&gt;
&lt;p&gt;Hopfield networks literature starts from the idea of trying to implement an associative memory system for storing and retrieving patterns. Patterns stored in memory are called &lt;em&gt;stored patterns&lt;/em&gt;. A &lt;em&gt;state pattern&lt;/em&gt; is an input prompt for the associative memory system: what patterns stored in memory are closest to this particular prompt?&lt;/p&gt;
&lt;p&gt;Depending on the context (heh), we can refer to input patterns as state patterns or queries and to context patterns as stored patterns or memory or keys.&lt;/p&gt;
&lt;h1 id="attention-modules"&gt;Attention modules&lt;/h1&gt;
&lt;h2 id="explicit-vanilla-softmax-attention"&gt;Explicit vanilla softmax attention&lt;/h2&gt;
&lt;p&gt;To compare the behavior of explicit attention modules to that of energy-based attention modules, we need to first of all define a vanilla softmax attention module. The annotated implementation below features a &lt;code&gt;bare_attn&lt;/code&gt; toggle in the forward pass for ease of comparison with the &amp;ldquo;bare&amp;rdquo; modern continuous Hopfield energy function we will discuss later on. The flag essentially disables all linear mappings so input and context patterns are processed &amp;ldquo;raw&amp;rdquo;.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;VanillaSoftmaxAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Vanilla softmax attention.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Adapted from https://github.com/lucidrains/perceiver-pytorch (commit 37e2eb6).
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Inner dimension is expressed in terms of head count and dimensionality&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# and thus decoupled from query_dim/context_dim (heads always &amp;#34;fit&amp;#34;).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Linear transformations (queries, keys, values, head-mixing).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# To facilitate comparison with modern Hopfield networks, setting `bare_attn`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# to `True` disables all linear mappings, assures there&amp;#39;s only a single head and&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# reduces the module to a barebone attention which takes in &amp;#34;raw&amp;#34; queries or state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# patterns and attends to a &amp;#34;raw&amp;#34; context/memory of stored patterns.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;only a single head when bare attention&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;query_dim/context_dim must match&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Adaptive scale.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Take context either from elsewhere of from self (attention vs. self-attention).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Map x to queries and context to keys and values.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Split up latent dimension into subspaces for heads to act on.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Head dimension becomes part of batch dimension (=&amp;gt; parallel processing of heads).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b n (h d) -&amp;gt; (b h) n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Scaled dot product of all queries against all keys (sum over `inner_dim`).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;b i d, b j d -&amp;gt; b i j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Optional masking.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;finfo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b j -&amp;gt; (b h) () j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Softmax operation across &amp;#34;keys&amp;#34; sequence dimension.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Contract attention matrix with values.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;b i j, b j d -&amp;gt; b i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Move head dimension out of batch again.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;(b h) n d -&amp;gt; b n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Mix all the heads&amp;#39; outputs; stir well and serve immediately.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="implicit-energy-based-attention"&gt;Implicit energy-based attention&lt;/h2&gt;
&lt;p&gt;Next, we define our energy-based attention module. Its forward pass will make use of the simple gradient descent function defined below to do energy minimization and update queries accordingly.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;minimize_energy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Minimize energy function with respect to queries.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Keeps track of energies and trajectories for logging and plotting.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;defaultdict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queries&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;grad_queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;autograd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energies&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;grad_outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energies&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;create_graph&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# enables double backprop for optimization&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;step_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;grad_queries&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queries&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;energies&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;energies&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queries&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The &lt;code&gt;EnergyBasedAttention&lt;/code&gt; module below has been structured to look as similar as possible to the the &lt;code&gt;VanillaSoftmaxAttention&lt;/code&gt; module defined above. The main difference is the appearance of an energy function and the energy minimization call in the forward pass where the softmax attention used to be. Other differences include the absence of a linear map to &amp;ldquo;values&amp;rdquo; and masking being pushed into the energy function.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;EnergyBasedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Linear transformations (queries, keys, output).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;energy_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Bare checks.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;only a single head when bare attention&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;query_dim/context_dim must match&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b n (h d) -&amp;gt; (b h) n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b j -&amp;gt; (b h) () j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Minimize energy with respect to queries.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;minimize_energy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;(b h) n d -&amp;gt; b n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h1 id="from-modern-hopfield-networks-to-multi-head-attention"&gt;From modern Hopfield networks to multi-head attention&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s start with the simplest possible case: bare attention. We disable all linear mappings to queries/keys/values/output to make sure input and context patterns are processed &amp;ldquo;raw&amp;rdquo; and restrict ourselves to a single attention head. We numerically verify that a &amp;ldquo;bare&amp;rdquo; explicit attention module indeed returns the same result as doing a single, big step of energy minimization with respect to input state patterns. Put differently and more to the point, we merely show that automatic differentiation works.&lt;/p&gt;
&lt;h2 id="energy-function"&gt;Energy function&lt;/h2&gt;
&lt;p&gt;Consider the energy function of a modern continuous Hopfield network for a set of state patterns $\boldsymbol{\Xi}$ and stored patterns $\boldsymbol{X}$:&lt;/p&gt;
\begin{equation}
E(\boldsymbol{\Xi}; \boldsymbol{X}) = \frac{1}{2} \boldsymbol{\Xi}^T \boldsymbol{\Xi} -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\Xi} \right),\label{eq:energy}
\end{equation}&lt;p&gt;Think of this model as the scoring function of an associative memory system. For now, we&amp;rsquo;d like to keep the stored patterns fixed as memory slots and wiggle around the state patterns. We can translate this energy function into the following (batched) function:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;kinetic&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;b i d, b i d -&amp;gt; b i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scaled_dot_product&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;b i d, b j d -&amp;gt; b i j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stored_patterns&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;finfo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scaled_dot_product&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scaled_dot_product&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;potential&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logsumexp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scaled_dot_product&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;kinetic&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;potential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="verifying-the-update-rule"&gt;Verifying the update rule&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s sample some state patterns and stored patterns and enable gradient tracking for the state patterns since we want to take derivatives with respect to these parameters later on.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="cross-attention"&gt;Cross-attention&lt;/h3&gt;
&lt;p&gt;First up is cross-attention. We feed state patterns as input and stored patterns as context into a vanilla softmax attention module.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;softmax_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VanillaSoftmaxAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_softmax_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now we do the same for an energy-based attention module and tell it to take a single, big gradient update step.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;EnergyBasedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now let&amp;rsquo;s compare the outputs of the two methods:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output_bare_softmax_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_bare_energy_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;True
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;Both tensors are approximately equal: bare softmax attention corresponds to taking a single gradient step of &lt;code&gt;step_size=1.0&lt;/code&gt; with respect to the state patterns using the energy function of modern Hopfield networks as a loss. For more details on this correspondence, we refer to
.&lt;/p&gt;
&lt;h3 id="self-attention"&gt;Self-attention&lt;/h3&gt;
&lt;p&gt;Let&amp;rsquo;s do the same check for self-attention, which boils down to only inputting state patterns. Internally, the modules will consider the state patterns as stored patterns and effectively make the patterns pay attention to themselves.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_softmax_self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_energy_self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_bare_softmax_self_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_bare_energy_self_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Norm between input state patterns and energy-minimized patterns: &amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;output_bare_energy_self_attn&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;True
Norm between input state patterns and energy-minimized patterns: 5.553587470785715e-06
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;The pattern update step looks almost like an an identity operation, which is to be expected for &amp;ldquo;bare&amp;rdquo; self-attention&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;. Without any linear transformations to map state patterns to queries and keys, every state pattern starts off already close to a local minimum since it coincides with itself as a stored pattern. The query starts off close to the key since the query-key mappings are identities. We will visualize this behavior in
for two-dimensional patterns.&lt;/p&gt;
&lt;h2 id="adding-queries-keys-and-values"&gt;Adding queries, keys, and values&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s now move closer to proper vanilla softmax attention by enabling linear transformations which map state patterns to queries and stored patterns to keys (and values). These parameters are able to move patterns around on the energy landscape before (queries, keys) and after (values) paying attention.&lt;/p&gt;
&lt;p&gt;We recycle the previously instantiated patterns and modules and compare outputs again, making sure the parameters are equal and omitting the &lt;code&gt;bare_attn&lt;/code&gt; flag:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_softmax_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;strict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output_softmax_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_energy_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;False
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;Why don&amp;rsquo;t the outputs match? We have to make sure we compare apples to apples and be mindful of the fact that the energy minimization step only knows about keys. Indeed, as shown previously in
, the one-step energy minimization, expressed in terms of queries and keys, effectively implements&lt;/p&gt;
\begin{equation}
\boldsymbol{Q}^{\text{new}} = \text{softmax}\left( \frac{1}{\sqrt{d_k}} \boldsymbol{Q} \boldsymbol{K}^T \right) \boldsymbol{K}
\end{equation}&lt;p&gt;instead of the vanilla softmax attention step&lt;/p&gt;
\begin{equation}
\boldsymbol{Q}^{\text{new}} = \text{softmax}\left( \frac{1}{\sqrt{d_k}} \boldsymbol{Q} \boldsymbol{K}^T \right) \boldsymbol{V}
\end{equation}&lt;p&gt;We can approximately undo this mapping to make a forced comparison for fixed parameters:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_energy_attn_transformed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_energy_attn&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pinverse&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;weight&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output_softmax_attn&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;output_energy_attn_transformed&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;tensor(0.0005, grad_fn=&amp;lt;CopyBackwards&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;Yet since all these parameters would be optimized in a real-world scenario, we should only care about whether the representational power of the modules is similar. To make the single-head energy-based attention module more expressive, we can always add an output layer, parametrized by weights $W_{O}$, to the module. As long as the composition of linear transformations $W_{K}W_{O}$ doesn&amp;rsquo;t collapse and its rank does not fall below that of the softmax attention&amp;rsquo;s $W_{V}$, things should be okay.&lt;/p&gt;
&lt;h2 id="adding-masking-and-multiple-attention-heads"&gt;Adding masking and multiple attention heads&lt;/h2&gt;
&lt;p&gt;Finally, let us tie up some loose ends and complete the correspondence between vanilla softmax attention and energy-based minimization.&lt;/p&gt;
&lt;h3 id="masking"&gt;Masking&lt;/h3&gt;
&lt;p&gt;Since masking boils down to putting restrictions on what patterns in the inputs are allowed to talk to each other, it can just as well be done at the level of the energy function. By filling the tensor inside the &lt;code&gt;logsumexp&lt;/code&gt; operator in &lt;code&gt;hopfield_energy&lt;/code&gt; with $-\infty$ values at to-be-masked-out positions, we get the same effect as the masking operation in the forward pass of &lt;code&gt;VanillaSoftmaxAttention&lt;/code&gt;. Boolean masks can be passed to the &lt;code&gt;EnergyBasedAttention&lt;/code&gt;&amp;rsquo;s forward function and propagate to the energy function.&lt;/p&gt;
&lt;h3 id="multi-head-attention"&gt;Multi-head attention&lt;/h3&gt;
&lt;p&gt;Up to now, we have only considered a single attention head. Essentially, multiple attention heads subdivide the latent space into equal parts and process these subproblems in parallel. The head dimension becomes part of the batch dimension. This translates to having parallel energy minimizations going on for different heads, each acting on their own subspace. Since our &lt;code&gt;hopfield_energy&lt;/code&gt; function is already batched, we can use the same machinery of the previous sections, as shown below.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mha_energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;EnergyBasedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mha_energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;tensor([[[-0.0514, -0.0353, 0.0243, ..., -0.0335, -0.0060, 0.0243],
[-0.1004, -0.0136, -0.0297, ..., 0.0079, 0.0083, 0.0336],
[-0.0507, -0.0369, -0.0219, ..., -0.0022, -0.0246, -0.0223],
...,
[-0.0388, -0.0217, -0.0470, ..., -0.0067, 0.0020, -0.0139],
[-0.0283, -0.0699, -0.0205, ..., -0.0261, -0.0667, 0.0052],
[-0.0262, -0.0360, -0.0139, ..., -0.0011, -0.0199, -0.0004]]],
grad_fn=&amp;lt;AddBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;It is hard to compare with the exact output of the equivalent &lt;code&gt;VanillaSoftmaxAttention&lt;/code&gt; module for fixed module parameters. For multi-head attention, the updated queries coming out of the separate energy minimization steps will have summed over each heads&amp;rsquo; keys instead of its values. For a single attention head we could undo the keys&amp;rsquo; transformation by acting with the inverse of the keys&amp;rsquo; weights. For multiple attention heads, that is no longer possible.&lt;/p&gt;
&lt;p&gt;Again, since all these parameters would be optimized in a real-world scenario, we should only care about whether the representational power of the modules is similar. One approach would be to add parameters inside the energy function that take care of mapping to &amp;ldquo;values&amp;rdquo; on the level of the heads.&lt;/p&gt;
&lt;h1 id="attention-in-flatland-visualizing-energy-landscapes"&gt;Attention in flatland: visualizing energy landscapes&lt;/h1&gt;
&lt;p&gt;We now leave the world of high-dimensional latent spaces behind us and focus on the toy model scenario of just two latent space dimensions. We only consider a single attention head because having just two heads, each with dimension one, is just silly. For every two-dimensional token pattern vector, a third dimension will be provided by the value of the scalar energy function at that point.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s sample some tiny toy patterns to play around with.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h4 id="bare-cross-attention"&gt;Bare cross-attention&lt;/h4&gt;
&lt;p&gt;Let&amp;rsquo;s plot our tiny toy patterns taking a big gradient step!&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-bare-cross-attention"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_16e739ac33b2bfa8.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_89bbd1b398e61db4.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_c884d6bda828a1f1.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_16e739ac33b2bfa8.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Bare cross-attention
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;In the figure above, the blue open circles correspond to the stored patterns (memory, context, keys, &amp;hellip;), the red circles denote the initial state patterns (inputs, queries, probes, &amp;hellip;) and the red crosses the updated queries obtained after &lt;code&gt;n_steps&lt;/code&gt; of energy minimization. The red arrows denote the trajectory in the energy landscape.&lt;/p&gt;
&lt;p&gt;We will now illustrate some example scenarios.&lt;/p&gt;
&lt;h4 id="small-steps-go-nowhere"&gt;Small steps go nowhere&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-small-steps"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_913d624b8622617a.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_ac2668d10a8d96ce.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_20adc42ede72878c.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_913d624b8622617a.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Small steps
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="lots-of-big-steps-converge-near-global-minimum-or-repeated-softmax-iterations-make-all-token-representations-identical"&gt;Lots of (big) steps converge near (global) minimum or repeated softmax iterations make all token representations identical&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-big-steps"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_d991410d3d62e314.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_4174b72e3e8a46a4.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_68a97e48bda407ba.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_d991410d3d62e314.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Big steps
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="decreasing-the-scale-increasing-the-temperature-makes-the-landscape-smoother-and-encourages-convergence-to-same-global-minimum"&gt;Decreasing the scale (increasing the temperature) makes the landscape smoother and encourages convergence to same (global) minimum&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-decrease-scale"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_ed4cc3ea5763a50b.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_b9ea06f031681b75.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_65e05b2705206111.webp 574w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_ed4cc3ea5763a50b.webp"
width="574"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Decrease scale
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="increasing-the-scale-lowering-the-temperature-creates-disconnected-valleys-in-the-energy-landscape-inhabited-by-stored-patterns-which-act-as-attractors-for-any-query-that-happens-to-be-in-its-basin-of-attraction"&gt;Increasing the scale (lowering the temperature) creates &amp;ldquo;disconnected&amp;rdquo; valleys in the energy landscape inhabited by stored patterns which act as attractors for any query that happens to be in its basin of attraction&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-increase-scale"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_11c5d579723c22a8.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_b9fdc7b06e7f4bd4.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_36776070bfc9a5b8.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_11c5d579723c22a8.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Increase scale
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="adding-linear-query-key-value-transformations"&gt;Adding linear query-key-value transformations&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# As commented on before, the value transformation is applied&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# after the update step so that effectively the product&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# W_K x W_V is applied to the updated state patterns.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;to_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;to_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;to_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;to_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;values_post_processing_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_grid_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-adding-query-key-value-mappings"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_7f82d31be0acfa17.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_1c2fe38a2f7233eb.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_39d9fe4f4110b9e2.webp 587w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_7f82d31be0acfa17.webp"
width="587"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Adding query-key-value mappings
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;The yellow arrows point from the final, energy-minimized, query updates to the &amp;ldquo;value-transformed&amp;rdquo; output queries, which are denoted with yellow crosses. Running this cell again in the colab notebook will give different landscapes and trajectories every time since the queries and keys depend on the random linear layers. The differences are more pronounced when increasing the scale (lowering the temperature).&lt;/p&gt;
&lt;p&gt;Since the value transformation is done after the energy minimization, it can and does undo some of the influence of the keys&amp;rsquo; attractors, e.g. sending updated queries to &amp;ldquo;uphill&amp;rdquo; regions in the energy landscape defined at that that layer. This suggests that the value transformation should not be seen as part of the core attention mechanism but that its role is rather to learn during training how to best hop to different regions in preparation for whatever the next layer needs.&lt;/p&gt;
&lt;h4 id="bare-self-attention-on-the-importance-of-scale-and-why-multiple-heads"&gt;Bare self-attention: on the importance of scale and why multiple heads&lt;/h4&gt;
&lt;p&gt;Since all of the flatland examples so far have been for cross-attention, let&amp;rsquo;s also visualize a self-attention update below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-bare-self-attention-visualization"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_78f0dd4fca34764a.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_28e0072d72ffd56f.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_df1d0c6c073c1e65.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_78f0dd4fca34764a.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Bare self-attention visualization
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Wait, what? Why did the updated state patterns move from their initialization? Didn&amp;rsquo;t we see before that the norm between inputs and outputs hardly changed at all for bare self-attention?&lt;/p&gt;
&lt;p&gt;To look into this, let&amp;rsquo;s plot the norm between inputs and outputs in function of the latent dimension, while scaling the scale or inverse temperature relative to the transformer default $\beta = 1/\sqrt{\mathrm{d_k}}$. We sample toy patterns repeatedly for every dimension/scale combination to get an idea of the statistical behavior.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;dims&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;int32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;norms&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bare_attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VanillaSoftmaxAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta_scale&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bare_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta_scale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norms&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# Suppresses a warning.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;norms&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ma&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;norms&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;norms&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# Plot data.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gca&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;meshgrid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;contourplot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contourf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norms&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;colors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LogNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vmin&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vmax&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;levels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;d_k&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;scale / sqrt(d_k)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;colorbar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;contourplot&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;format&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;%.e&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ticks&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;ticker&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LogLocator&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;base&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;transformer_default_scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-bare-self-attention-experiment"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_c6264e1df494a014.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_d419bf703bcb545.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_bef84266f747834.webp 588w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_c6264e1df494a014.webp"
width="588"
height="485"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Bare self-attention experiment
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;In this contour plot, we plot the norm differences between inputs and outputs of a bare self-attention step for a sweep across latent dimensions and inverse temperature scale factors. The horizontal red line corresponds to the scale factor used by default in most transformer implementations. Some comments:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;For a fixed latent dimension, we see that increasing the scale factor corresponds to smaller norm differences, i.e. more pronounced valleys where it&amp;rsquo;s much harder to get out of, especially if you start at the bottom and there is no query-key-value mapping taking you elsewhere.&lt;/li&gt;
&lt;li&gt;The vertical red line corresponds to the earlier bare self-attention result using a latent dimension of 512. The intersection point indeed corresponds a norm difference of the order we saw previously. The value for a latent dimension of 2 (left border of plot) suggests that patterns do move around quite a bit, confirming our visualization above.&lt;/li&gt;
&lt;li&gt;Setting the scale for bare multi-head attention proportionally to the (smaller) head dimension instead of the full latent dimension corresponds to moving leftwards along the horizontal red line. The norm difference increases so that, for bare multi-head self-attention, patterns in multiple small heads tend to bounce around more than they would in a single big head. This might be one of the reasons why multiple heads help with training transformers: since the effective temperature is lower in the smaller latent spaces, the topography of the lower-dimensional energy landscapes is more pronounced and individual heads can go explore a bit to find their niche valley.&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;Using the tools presented in this blog post, we have shown that it is possible to swap the explicit attention module in a transformer for an implicit energy minimization method. What happens when we start playing around with different energy functions? Can we make patterns interact? Can we make the energy minimization step more efficient by treating it as a fixed-point problem? It remains to be seen whether all of this is a useful thing to do.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021visualizingattention,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Attention as Energy Minimization: Visualizing Energy Landscapes},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {March},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;em&gt;Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Lukas Gruber, Markus Holzleitner, Milena Pavlović, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;strong&gt;Caveat:&lt;/strong&gt; For the special case of bare energy-based self-attention, state patterns actually appear quadratically in the argument of the &lt;code&gt;logsumexp&lt;/code&gt; part of the energy function. Taking the derivative using &lt;code&gt;minimize_energy(..)&lt;/code&gt; however assumes the context is a different node in the computational graph, which, in this case, where we &lt;em&gt;should&lt;/em&gt; be taking the derivative of &lt;code&gt;energy(x, x)&lt;/code&gt; instead of &lt;code&gt;energy(x, context)&lt;/code&gt;, yields a gradient that misses a factor of 2. But ensuring the gradient is &amp;ldquo;correct&amp;rdquo; for this special case would of course screw up the cancellation of the state pattern with itself for &lt;code&gt;step_size=1.0&lt;/code&gt; and &lt;code&gt;num_steps=1&lt;/code&gt; so that the updated query would no longer match the output of bare vanilla softmax attention. Proper treatment of doing multiple steps of bare energy-based self-attention should also include manually setting the context to the updated queries (since the queries themselves change every update step). Luckily no one would seriously consider using bare energy-based self-attention.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Transformer Attention as an Implicit Mixture of Effective Energy-Based Models</title><link>https://mcbal.github.io/post/transformer-attention-as-an-implicit-mixture-of-effective-energy-based-models/</link><pubDate>Tue, 22 Dec 2020 10:03:17 +0100</pubDate><guid>https://mcbal.github.io/post/transformer-attention-as-an-implicit-mixture-of-effective-energy-based-models/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (November 2021):&lt;/strong&gt; Please consider reading
for an arguably more comprehensive approach towards understanding transformers from a physics perspective.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;!-- In this post, I will try to partly address the concerns of the following critic:
&gt; _In your [previous post](https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/), you introduced the energy function of modern Hopfield networks without explanation. Where does it come from? What's up with the logarithm? Is there actually any other interpretation then it being reverse-engineered from the Transformers' attention step? Is this all a desperate attempt to make Hopfield networks cool again? Also, I cannot see the value of looking at attention from an energy-based perspective if it doesn't help me achieve SOTA. Weak reject._ --&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;In a
, I provided an overview of attention in Transformer models and summarized its connections to modern Hopfield networks. We saw that the energy-based model
&lt;/p&gt;
\begin{equation}
E(\boldsymbol{\Xi}; \boldsymbol{X}) = \frac{1}{2} \boldsymbol{\Xi}^T \boldsymbol{\Xi} -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\Xi} \right).
\label{eq:mhnenergy}
\end{equation}&lt;p&gt;
enables fast pattern storage and retrieval through its simple and robust dynamics, leading to rapid convergence
&lt;/p&gt;
\begin{align}
\boldsymbol{\Xi}_{n+1} = \boldsymbol{X} \ \mathrm{softmax} \left( \boldsymbol{X}^T \boldsymbol{\Xi}_{n}\right)
\label{eq:mhnupdate}
\end{align}&lt;p&gt;
of input queries $\boldsymbol{\Xi}_{n}$ to updated queries $\boldsymbol{\Xi}_{n+1}$ lying in the convex hull of stored patterns $\boldsymbol{X}$. I also argued by means of handwaving that optimizing a Transformer looks like meta-learning from the point of view of its attention modules, sculpting energy landscapes to accommodate statistical patterns found in data.&lt;/p&gt;
&lt;p&gt;The main goal of this post is to build on these insights and highlight how an energy-based perspective can be a useful, complementary approach towards improving attention-based neural network modules. Parallel to scaling compute and making (self-)attention more efficient, it might be worthwhile to try to scale learning itself by experimenting with radically different attention mechanisms.&lt;/p&gt;
&lt;p&gt;To this end, we will first revisit ancient ideas at the boundary of statistical physics and machine learning and show how vanilla attention looks like a mixture of simple energy-based models. We will then argue how going beyond these simple models could benefit from thinking in terms of implicit instead of explicit attention modules, suggesting opportunities to put ideas from
to work.&lt;/p&gt;
&lt;h1 id="attention-from-effective-energy-based-models"&gt;Attention from effective energy-based models&lt;/h1&gt;
&lt;p&gt;In this section, we will introduce
as a particular class of energy-based models, focusing on their capacity to capture effective correlations. After identifying classical discrete Hopfield networks and modern discrete Hopfield networks, we will demonstrate a naive way to fit modern continuous Hopfield networks into this framework. Throughout this section, we will rely heavily on the wonderful review
by
&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h2 id="restricted-boltzmann-machines"&gt;Restricted Boltzmann Machines&lt;/h2&gt;
&lt;p&gt;A
(RBM) is an
with a bipartite structure imposed on visible and hidden degrees of freedom: visible and hidden degrees of freedom interact with each other but do not interact among themselves (this is the &amp;ldquo;restriction&amp;rdquo;). The energy function looks like&lt;/p&gt;
\begin{equation}
E \left( \boldsymbol{v}, \boldsymbol{h} \right) = - \sum_{i} a_{i} (v_{i}) - \sum_{\mu} b_{\mu} (h_{\mu}) - \sum_{i \mu} W_{i \mu} v_{i} h_{\mu},
\end{equation}&lt;p&gt;where the matrix $W_{i \mu}$ encodes the coupling between hidden and visible units and where $a_{i} (\cdot)$ and $b_{\mu} (\cdot)$ are functions that can be chosen at will. Popular options are:&lt;/p&gt;
\begin{align}
a_{i} (\cdot) =
\begin{cases}
a_{i} v_{i} &amp; \text{if $v_{i} \in \{0,1\}$ is binary (Bernouilli)}\\\\
\frac{v_{i}^2}{2\sigma_{i}^{2}} &amp; \text{if $v_{i} \in \mathbb{R}$ is continuous (Gaussian)}\\
\end{cases}
\end{align}&lt;p&gt;and similar for $b_{\mu} (\cdot)$.&lt;/p&gt;
&lt;p&gt;
&lt;/p&gt;
&lt;h2 id="why-hidden-units"&gt;Why hidden units?&lt;/h2&gt;
&lt;p&gt;Introducing hidden or latent variables is a powerful technique to encode interactions between visible units. Complex correlations between visible units can be captured at the cost of introducing new degrees of freedom and letting them interact with visible units in a simpler way. Since this trick often relies on exploiting
and physicists like their Gaussians, it shows up in several places across physics, e.g. in the
.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Renormalization group&lt;/strong&gt;: Rather than trying to fix the interactions in the &amp;ldquo;microscopic theory&amp;rdquo; like is done in the modeling scenario above, physicists are more familiar with the &amp;ldquo;reverse&amp;rdquo; procedure of deducing what effective theory emerges at large scales from a given microscopic theory. Indeed, integrating out degrees of freedom in physical theories can lead to complex, effective interactions between remaining degrees of freedom. This insight crystallized in the development of
theory in the early 1970s. By focusing on theories defined at different length scales,
and his contemporaries introduced and unified the notions of flows, fixed points, and universality in theory space to understand the behavior of physical systems under a change of scale.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;As we will see in the next sections, the bipartite structure of RBMs enables pairwise and higher-order correlations to emerge between visible units after integrating out hidden units. Additionally, the conditional independence of visible and hidden units enables tractable training methods like (block) Gibbs sampling and contrastive divergence&lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;. We will not consider explicitly training RBMs in this post but will instead reflect on the idea of implicitly training these models, which is what seems to be happening inside Transformers.&lt;/p&gt;
&lt;h2 id="effective-energies-and-correlations"&gt;Effective energies and correlations&lt;/h2&gt;
&lt;p&gt;Let us now consider what kind of correlations between visible degrees of freedom are supported by RBMs. The distribution of the visible degrees of freedom can be obtained by marginalizing over the hidden degrees of freedom:&lt;/p&gt;
\begin{equation}
p \left( \boldsymbol{v} \right) = \int \mathrm{d} \boldsymbol{h} \ p \left( \boldsymbol{v}, \boldsymbol{h} \right) = \int \mathrm{d} \boldsymbol{h} \ \frac{\mathrm{e}^{- E \left( \boldsymbol{v}, \boldsymbol{h} \right)}}{Z}
\end{equation}&lt;p&gt;We try to find an expression for the marginalized energy $E (\boldsymbol{v})$ by defining&lt;/p&gt;
\begin{equation}
p \left( \boldsymbol{v} \right) = \frac{\mathrm{e}^{- E (\boldsymbol{v})}}{Z}
\end{equation}&lt;p&gt;so that we can identify&lt;/p&gt;
\begin{align}
E \left( \boldsymbol{v} \right) &amp;= - \mathrm{log} \int \mathrm{d} \boldsymbol{h} \ \mathrm{e}^{- E \left( \boldsymbol{v}, \boldsymbol{h} \right)} \\\\
&amp;= - \sum_{i} a_{i} (v_{i}) - \sum_{\mu} \log \int \mathrm{d} h_{\mu}\ \mathrm{e}^{b_{\mu}(h_{\mu}) + \sum_{i} W_{i\mu} v_{i} h_{\mu}} \label{eq:effvisenergy}
\end{align}&lt;p&gt;Following
, we can try to better understand the correlations in $p(\boldsymbol{v})$ by introducing the (prior) distribution&lt;/p&gt;
\begin{equation}
q_{\mu} \left( h_{\mu} \right) = \frac{\mathrm{e}^{b_{\mu} (h_{\mu})}}{Z}
\end{equation}&lt;p&gt;for the hidden units $h_{\mu}$, ignoring the interactions between $\boldsymbol{v}$ and $\boldsymbol{h}$. Additionally, we can introduce the hidden unit&amp;rsquo;s distribution&amp;rsquo;s
&lt;/p&gt;
\begin{align}
K_{\mu} (t) &amp;= \mathrm{log}\ \mathbb{E} \left[ \mathrm{e}^{t h_{\mu}} \right] \\\\
&amp;= \mathrm{log} \int \mathrm{d} h_{\mu} \ q_{\mu} \left( h_{\mu} \right) \mathrm{e}^{t h_{\mu}}\\\\
&amp;= \sum_{n=1}^{\infty} \kappa_{\mu}^{(n)} \frac{t^{n}}{n!},
\end{align}&lt;p&gt;which is defined such that the $n^{\mathrm{th}}$ cumulant $\kappa_{\mu}^{(n)}$ of $q_{\mu} \left( h_{\mu} \right)$ can be obtained by taking derivatives $\kappa_{\mu}^{(n)} = \partial_{t}^{n} K_{\mu} \rvert_{t=0}$.&lt;/p&gt;
&lt;p&gt;Looking back at the effective energy function \eqref{eq:effvisenergy} for the visible units, we find that the effective energy can be expressed in terms of cumulants:&lt;/p&gt;
\begin{align}
E \left( \boldsymbol{v} \right) &amp;= - \sum_{i} a_{i} \left(v_{i}\right) - \sum_{\mu} K_{\mu} \left( \sum_{i} W_{i\mu} v_{i} \right) \\\\
&amp;= - \sum_{i} a_{i} \left(v_{i}\right) - \sum_{\mu} \sum_{n=1}^{\infty} \kappa_{\mu}^{(n)} \frac{\left( \sum_{i} W_{i\mu} v_{i} \right)^{n}}{n!} \\\\
&amp;= - \sum_{i} a_{i} \left(v_{i}\right) - \sum_{i} \left( \sum_{\mu} \kappa_{\mu}^{(1)} W_{i\mu} \right) v_{i} \\\\
&amp;\ \ \ \ \ - \frac{1}{2} \sum_{ij} \left( \sum_{\mu} \kappa_{\mu}^{(2)} W_{i\mu} W_{j\mu} \right) v_{i} v_{j} + \ldots \label{eq:effectivenergy}
\end{align}&lt;p&gt;We see that the auxiliary, hidden degrees of freedom induce effective pairwise and higher-order correlations among visible degrees of freedom. Each hidden unit $h_{\mu}$ can encode interactions of arbitrarily high order, with the $n$-th order cumulants of $q_{\mu} \left( h_{\mu} \right)$ weighting the $n$-th order interactions. By combining many hidden units and/or stacking layers, RBMs can in principle encode complex interactions at all orders and learn them from data.&lt;/p&gt;
&lt;p&gt;Let us now recover some known models by picking a suitable prior distribution for the hidden units:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Classical discrete Hopfield networks&lt;/strong&gt;: Consider a Bernouilli distribution for the visible units and a standard Gaussian distribution for the hidden units. For a standard Gaussian, the mean $\kappa_{\mu}^{(1)} = 0$, the variance $\kappa_{\mu}^{(2)} = 1$, and $\kappa_{\mu}^{(n)} = 0$, $\forall n\geq 3$, leading to the quadratic energy function of Hopfield networks:
&lt;/p&gt;
\begin{align}
E \left( \boldsymbol{v} \right) = - \sum_{i} a_{i} v_{i} - \frac{1}{2} \sum_{ij} \left( \sum_{\mu} W_{i\mu} W_{j\mu} \right) v_{i} v_{j}
\end{align}&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Modern discrete Hopfield networks&lt;/strong&gt;: Consider a Bernouilli distribution for the visible units. Since it can be shown that the normal distribution is the only distribution whose cumulant generating function is a polynomial, i.e. the only distribution having a finite number of non-zero cumulants&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, it looks like we cannot model a finite amount of polynomial interactions in this framework. But we can model an exponential interaction by considering a Poisson distribution $\mathrm{Pois}(\lambda)$ with rate $\lambda=1$ for the hidden units, whose cumulants are all equal to the rate, i.e. $\kappa_{\mu}^{(n)} = 1$, $\forall n\geq 1$. Up to a constant, we then obtain an exponential interaction
&lt;/p&gt;
\begin{align}
E \left( \boldsymbol{v} \right) = - \sum_{i} a_{i} v_{i} - \sum_{\mu} \exp \left( \sum_{i} W_{i\mu} v_{i} \right)
\end{align}&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Other kinds of effective interactions can be obtained by substituting the cumulants of your favorite probability distribution. The
induce interactions of all orders. Considering exponential or Laplacian distributions where $\kappa^{(n)} \sim (n-1)!$ seems to lead to funky logarithmic interactions.&lt;/p&gt;
&lt;h2 id="modern-hopfield-networks-as-mixtures-of-effective-rbms"&gt;Modern Hopfield networks as mixtures of effective RBMs&lt;/h2&gt;
&lt;p&gt;Let us now turn to the energy function of modern Hopfield networks for a single query $\boldsymbol{\xi} \in \mathbb{R}^{d}$ and $N$ stored patterns encoded by $\boldsymbol{X} \in \mathbb{R}^{d \times N}$,
&lt;/p&gt;
\begin{equation}
E(\boldsymbol{\xi}; \boldsymbol{X}) = \frac{1}{2} \boldsymbol{\xi}^T \boldsymbol{\xi} -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\xi} \right),
\end{equation}&lt;p&gt;
which we can transform into the RBM notation of the previous section by changing the names of variables and transposing the stored pattern matrix,
&lt;/p&gt;
\begin{equation}
E(\boldsymbol{v}; W) = \frac{1}{2} \sum_{i} v_{i}^{2} -\log \left( \sum_{\mu} \exp \left( \sum_{i} W_{\mu i} v_{i} \right) \right).
\end{equation}&lt;p&gt;Is there a simple way to interpret this energy function in terms of (effective) RBMs? Let&amp;rsquo;s imagine this energy to be an effective energy $E(\boldsymbol{v})$ for the visible units with probability distribution
&lt;/p&gt;
\begin{equation}
p(\boldsymbol{v}) = \frac{\mathrm{e}^{-E(\boldsymbol{v})}}{Z} = \frac{1}{Z} \sum_{\mu} \mathrm{e}^{-\frac{1}{2} \sum_{i} v_{i}^{2} + \sum_{i} W_{\mu i} v_{i}},
\end{equation}&lt;p&gt;
where the partition function $Z$ follows from doing a
&lt;/p&gt;
\begin{equation}
Z = (2\pi)^{n/2} \sum_{\mu} Z_{\mu} = (2\pi)^{n/2} \sum_{\mu} \mathrm{e}^{\frac{1}{2} \sum_{i} W_{\mu i} W_{i\mu}}
\end{equation}&lt;p&gt;We can then identify the probability distribution $p(\boldsymbol{v})$ with a mixture of effective energy-based models&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;
&lt;/p&gt;
\begin{equation}
p(\boldsymbol{v}) = \sum_{\mu} w_{\mu} \frac{\mathrm{e}^{-\frac{1}{2} \sum_{i} v_{i}^{2} + \sum_{i} \mathbf{W}_{\mu i} v_{i}}}{Z_{\mu}} = \sum_{\mu} w_{\mu} \frac{ \mathrm{e}^{ -E_{\mu}(\boldsymbol{v}) }}{Z_{\mu}}
\end{equation}&lt;p&gt;
where $w_{\mu} = Z_{\mu} / Z$ so that $\sum_{\mu} w_{\mu} = 1$. During training, the model can control prior weights $w_{\mu}$ by adjusting relative norms of patterns. If the difference in norms between the stored patterns is not too wild, $w_{\mu} \approx 1/N$.&lt;/p&gt;
&lt;p&gt;A single model in the mixture has an effective energy function derived from a joint energy function with just a single hidden unit,&lt;/p&gt;
\begin{equation}
E_{\mu} \left( \boldsymbol{v}, h_{\mu} \right) = - \sum_{i} a_{i} (v_{i}) - b_{\mu} (h_{\mu}) - \sum_{i} W_{i \mu} v_{i} h_{\mu}
\end{equation}&lt;p&gt;Looking back at \eqref{eq:effectivenergy}, we see that we can recover $E_{\mu}(\boldsymbol{v})$ by picking a hidden prior distribution that is a constant random variable so that $\kappa_{\mu}^{(1)}=1$ is the only non-zero cumulant. This frozen property of hidden units seems to agree with the fast dynamics of memory neurons in the dynamical systems model proposed in
&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;In conclusion, the energy-based model underlying vanilla Transformer attention is not terribly exciting.&lt;/p&gt;
&lt;h1 id="attention-as-implicit-energy-minimization"&gt;Attention as implicit energy minimization&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s finish this post with some comments on how one could leverage the idea of implicit energy minimization to develop novel attention mechanisms.&lt;/p&gt;
&lt;h2 id="bending-the-explicit-architecture"&gt;Bending the explicit architecture&lt;/h2&gt;
&lt;p&gt;A lot of work on post-vanilla Transformer architectures tries to improve softmax-attention by making it more efficient through approximations and/or modifications at the level of the architecture. Kernel-based approaches like
have shown not only that softmax attention can be efficiently approximated by a generalized attention mechanism but also that generalized ReLU-based attention performed better in practice. Papers like
show how we can replace the softmax non-linearity in \eqref{eq:mhnupdate} with pure normalization and still end up with a competitive algorithm, noting that the updated query being restricted to lie in the convex hull of the stored patterns is a bias we might want to question.&lt;/p&gt;
&lt;p&gt;From the above examples, it seems like at least a part of current research on attention is trying to break away from the confines of existing, explicit attention architectures but doesn&amp;rsquo;t quite know how to do so in a principled way. Does an energy-based perspective help to understand these developments?&lt;/p&gt;
&lt;h2 id="from-explicit-architectures-to-implicit-energy-minimization"&gt;From explicit architectures to implicit energy minimization&lt;/h2&gt;
&lt;p&gt;We have seen in this post that the energy function behind the &lt;code&gt;softmax&lt;/code&gt; attention mechanism can be understood as a mixture of simple energy-based models. But what can we actually do with this information? Especially since we know from language modeling experiments that &amp;ldquo;just scaling&amp;rdquo; these simple models to billions of parameters enables them to store enough patterns to be useful. Despite huge progress, there however remain important challenges in terms of efficiency and generalizability. Considering slightly less trivial energy-based models might address both by adding interactions in such a way that attention modules are able to return a &lt;em&gt;collective response&lt;/em&gt; rather than a sum of decoupled contributions.&lt;/p&gt;
&lt;p&gt;To some extent, the additional linear transformations on the input patterns in the query-key-value formulation of Transformer self-attention already try to address this:
&lt;/p&gt;
\begin{equation}
\mathrm{Attention}\left( \mathbf{Q}, \mathbf{K}, \mathbf{V} \right) = \mathrm{softmax} \left( \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d}} \right) \mathbf{V}
\label{eq:vanilla-attention}
\end{equation}&lt;p&gt;
These linear transformations slightly generalize the &amp;ldquo;naked&amp;rdquo; explicit gradient step of \eqref{eq:mhnupdate} and can in principle learn to cluster and direct patterns to neighborhoods in the energy landscape, parametrizing the energy function. But why stop there?&lt;/p&gt;
&lt;h2 id="deep-implicit-layers-for-attention-dynamics"&gt;Deep implicit layers for attention dynamics&lt;/h2&gt;
&lt;p&gt;An interesting way forward might be to integrate attention with &lt;em&gt;deep implicit layers&lt;/em&gt;. Funnily enough, the authors of the NeurIPS 2020 tutorial on
list self-attention as a prime example of an explicit layer in their
. Approaches like
implicitly train DEQ-Transformers but still consider the attention module itself an explicit function.&lt;/p&gt;
&lt;p&gt;Yet we have seen in a
that self-attention can &amp;mdash; and perhaps should &amp;mdash; actually be considered an implicit layer solving for a fixed point query. Because of the lack of dynamics of the current generation of attention mechanisms, this can be done in a single big gradient step, removing the need to iterate. Attention models with more complicated dynamics might benefit from a differentiable solver to find a fixed point and return the most appropriate result in a given context.&lt;/p&gt;
&lt;p&gt;Compared to modifying explicit architectures, the implicit-layer perspective seems to act on a different &amp;ldquo;conceptual level&amp;rdquo; of neural network architecture design. This raises a lot of questions. Which families of attention architectures can be expressed in terms of implicit energy functions like softmax-attention? How many of these have efficient minimization properties with closed-form gradients? Beyond closed-form gradients, how far can we go in parametrizing more general energy-based attention models and still end up with an efficient algorithm? What does the trade-off look like between an attention model&amp;rsquo;s complexity and it still being implicitly trainable?&lt;/p&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;Looking back and reversing causation, one could argue that the now-famous dot-product attention module introduced in
&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt; could only have been arrived at because of the properties of its implicit energy function \eqref{eq:mhnenergy}. Indeed, it is only because of the associative memory&amp;rsquo;s decoupled and rather crude way of storing patterns in isolated, high-dimensional valleys that expensive, implicit energy minimization steps can be traded for a cheap, explicit one-step gradient update like \eqref{eq:mhnupdate}.&lt;/p&gt;
&lt;p&gt;The obvious pitfall of continuing to hold on to the conceptual framework introduced by this shortcut is that a potentially far richer picture of (sparse) attention dynamics remains obscured. Rather than perpetually rethinking what is all you &lt;em&gt;really&lt;/em&gt; need within the confines of existing, explicit attention modules, why not opt for implicit modules built on top of an energy-based perspective to try to push things forward?&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2020attentionrbms,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Transformer Attention as an Implicit Mixture of Effective Energy-Based Models},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2020},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {December},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/transformer-attention-as-an-implicit-mixture-of-effective-energy-based-models/},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;em&gt;Pankaj Mehta, Marin Bukov, Ching-Hao Wang, Alexandre G.R. Day, Clint Richardson, Charles K. Fisher, David J. Schwab,
(2019)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Proof by Marcinkiewicz (1935) according to
.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;We are aware that this identification might be tremendously trivial when considering prior work on
or, more generally, mixture models in the context of
.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;&lt;em&gt;Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin,
(2017)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>An Energy-Based Perspective on Attention Mechanisms in Transformers</title><link>https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/</link><pubDate>Sat, 28 Nov 2020 10:54:21 +0100</pubDate><guid>https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/</guid><description>&lt;p align="center"&gt;
&lt;a href="https://xkcd.com/793/"&gt;XKCD 793: A physicist encountering machine learning for the first time&lt;/a&gt;
&lt;/p&gt;
&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (November 2021):&lt;/strong&gt; Please consider reading
for an arguably more comprehensive approach towards understanding transformers from a physics perspective.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;In 2017,
&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; demonstrated state-of-the-art performance in neural machine translation by stacking only (self-)attention layers. Compared to recurrent neural networks, Transformer models exhibit efficient parallel processing of tokens, leading to better modeling of long-range correlations and, most importantly,
. Since then, Transformers seem to have taken over natural language processing. Widespread adoption of attention-based architectures seems likely given recent work like
and the flurry of developments addressing the architecture&amp;rsquo;s quadratic scaling bottlenecks.&lt;/p&gt;
&lt;p&gt;Recently, the papers
&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt; and
&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt; provided complementary post-facto explanations of some of the success of Transformers from the perspective of energy-based models. In this post, I provide a biased overview of (self-)attention in Transformers and summarize its connections to modern Hopfield networks. Along the way, I look for intuition from physics and indulge in hand-wavy arguments on how an energy-based perspective can shed light on training and improving Transformer models.&lt;/p&gt;
&lt;h1 id="a-growing-zoo-of-transformers"&gt;A growing zoo of Transformers&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s start off with an overview of the components in a vanilla Transformer model. Since our focus is on (self-)attention, I am going to assume some prior knowledge&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt; and skip comprehensive architecture descriptions and experimental results. In
, we will start from scratch and use Hopfield networks to build back up to the attention module described below.&lt;/p&gt;
&lt;h2 id="vanilla-transformers"&gt;Vanilla Transformers&lt;/h2&gt;
&lt;p&gt;The proto-Transformer was introduced in an encoder-decoder context for machine translation in
. The original motivation seems to have been mostly driven by engineering efforts to model long-range correlations in sequence data and the recent successes of attention mechanisms stacked on top of recurrent neural networks. The main contribution and selling point of the paper was making an attention-only approach to sequence modeling work.&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-vanilla-transformers-encoder-decoder-architecture"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/vanilla_transformer_hu_21b41c96d8b8878c.webp 320w, https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/vanilla_transformer_hu_a8681765fafad7cf.webp 480w, https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/vanilla_transformer_hu_abaa71d70b75783b.webp 516w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/vanilla_transformer_hu_21b41c96d8b8878c.webp"
width="516"
height="760"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Vanilla Transformers encoder-decoder architecture
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s focus on the encoder on the left and ignore the decoder on the right. Transformer models accept (batches of) sets of vectors, which covers most inputs people care about in machine learning. Text can be modelled as a sequence of embedded tokens. Images can be viewed as a snaky sequence of embedded pixels or embedded patches of pixels. Since sets have no notion of ordering, learned or fixed positional information needs to be explicitly added to the input vectors.&lt;/p&gt;
&lt;p&gt;The main module in the Transformer encoder block is the multi-head &lt;em&gt;self-attention&lt;/em&gt;, which is based on a (scaled) dot-product attention mechanism acting on a set of $d$-dimensional vectors:&lt;/p&gt;
\begin{equation}
\mathrm{Attention}\left( \mathbf{Q}, \mathbf{K}, \mathbf{V} \right) = \mathrm{softmax} \left( \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d}} \right) \mathbf{V}
\label{eq:vanilla-attention}
\end{equation}&lt;p&gt;Here, queries $\mathbf{Q}$, keys $\mathbf{K}$, and values $\mathbf{V}$ are matrices obtained from acting with different linear transformations &amp;mdash; parametrized respectively by weights $\mathbf{W}_{\mathbf{Q}}$, $\mathbf{W}_{\mathbf{K}}$, and $\mathbf{W}_{\mathbf{V}}$ &amp;mdash; on the same set of $d$-dimensional inputs. &lt;em&gt;Cross-attention&lt;/em&gt; takes the inputs for its queries from a different source than for its keys and values, as can be glimpsed from the decoder part of the architecture on the right.&lt;/p&gt;
&lt;p&gt;For every input query, the updated output query of \eqref{eq:vanilla-attention} is a linear combination of values weighted by an attention vector quantifying the overlap of the input query with the keys corresponding to these values. Stacking input query attention vectors leads to an attention matrix. Since all objects are vectors and the attention mechanism is just a dot product between vectors, we can think of the attention module as matching query vectors to their &amp;ldquo;closest&amp;rdquo; key vectors in latent space and summing up contributions from value vectors, weighted by the &amp;ldquo;closeness&amp;rdquo; of their keys to the queries.&lt;/p&gt;
&lt;p&gt;The remaining components of the Transformer encoder block are needed to make the module work properly in practice:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;The &lt;em&gt;multi-headedness&lt;/em&gt; of the attention module refers to chunking up the dimension of the vector space and having multiple attention operations running in parallel in the same module, yet with each acting on a lower-dimensional segment of the full space. This is a trick to (1) get around the fact that every input vector only couples to one query at a time to calculate its attention coefficient, and (2) provide multiple starting points in the subspaces for the queries, which might help to avoid bad local minima in parameter space during optimization.&lt;/li&gt;
&lt;li&gt;A positional feed-forward network, made up of two linear layers with a non-linearity in between, is inserted at the end of the module. Folklore wisdom tells us that the feed-forward layer needs to blow up the dimension of the latent space by a factor of four for it to be able to &amp;ldquo;disentangle&amp;rdquo; the represention. More likely though, it&amp;rsquo;s a way to increase model capacity and warp latent spaces since the attention modules on their own are pretty much linear apart from the $\mathrm{softmax}$-operator used to obtain the normalized attention coefficients.&lt;/li&gt;
&lt;li&gt;Residual connections are added to control the flow of gradients.&lt;/li&gt;
&lt;li&gt;Layer normalisation is used to control learning dynamics and keep vector norms from exploding.&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="beyond-vanilla-confronting-quadratic-scaling"&gt;Beyond vanilla: confronting quadratic scaling&lt;/h2&gt;
&lt;p&gt;Most architectural variations of the vanilla Transformer are targeted at the attention module, which scales poorly with respect to the input sequence length $N$. Since the overlap of all queries with all keys is required, calculating a dense attention matrix scales like $\mathcal{O}(N^2)$ in time and space. Limits on the context window of the attention mechanism during training prevent the model from learning how to deal with long sequences and long-range correlations. The majority of post-vanilla Transformer species can be classified into one of the following buckets&lt;sup id="fnref1:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Low-rank approximations: truncate the matrix product $\mathbf{Q} \mathbf{K}^T$ since it&amp;rsquo;s likely not full rank for structured data&lt;/li&gt;
&lt;li&gt;Sparsification: reduce the attention calculation from all query-key pairs to a subset because not all of them feel the need to talk to each other&lt;/li&gt;
&lt;li&gt;Recurrence: keep track of a (compressed) history of context&lt;/li&gt;
&lt;li&gt;Kernels: approximate the attention operation with kernel methods&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;For the remainder of our discussion, we will focus on vanilla Transformers. One of the goals of this blog post is to explore how a different perspective on the &lt;em&gt;function&lt;/em&gt; of attention-based algorithms might lead to qualitatively different improvements beyond what is possible by relying on scaling and reducing computational complexity alone.&lt;/p&gt;
&lt;h1 id="from-hopfield-networks-to-transformers"&gt;From Hopfield networks to Transformers&lt;/h1&gt;
&lt;p&gt;In this section, we provide a short history of Hopfield networks and gradually build up intuition until we can recognize the Transformer self-attention mechanism for what it really is. We refer to the
accompanying
for more details and insightful visualizations of pattern storage and retrieval.&lt;/p&gt;
&lt;h2 id="classical-discrete-hopfield-networks"&gt;Classical discrete Hopfield networks&lt;/h2&gt;
&lt;p&gt;A
is a simple model for associative memory popularized by John Hopfield in his 1982 paper
&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;. The task of an associative memory is to store and retrieve patterns, preferably in a way that allows one to recover stored patterns quickly with a low error rate.&lt;/p&gt;
&lt;p&gt;The basic idea of the Hopfield network &amp;mdash; and other energy-based models like
&amp;mdash; is to construct an &lt;em&gt;energy function&lt;/em&gt; which defines an &lt;em&gt;energy landscape&lt;/em&gt; containing basins of attraction around patterns we want to store. Starting at any pattern, we want to have an update rule pointing towards the closest stored pattern, guided by a scalar &amp;ldquo;closeness&amp;rdquo; score provided by the energy function.&lt;/p&gt;
&lt;p&gt;
&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s make this a bit more formal but not too formal. Consider trying to store a set of $N$ binary patterns $\{\boldsymbol{x}_{i}\}_{i=1}^{N}$ where each pattern $\boldsymbol{x}_{i}$ is a $d$-dimensional vector whose entries are either $-1$ or $1$. For example, in the case of storing black-and-white images, every image would correspond to a string of pixel values, a binary pattern $\boldsymbol{x}_{i}$.&lt;/p&gt;
&lt;p&gt;For any query $\boldsymbol{\xi} \in \mathbb{R}^{d}$, or &lt;em&gt;state pattern&lt;/em&gt;, we want to find a way to retrieve the closest &lt;em&gt;stored pattern&lt;/em&gt;. In his paper, Hopfield considered the energy function&lt;/p&gt;
\begin{equation}
E = - \frac{1}{2} \boldsymbol{\xi}^{T} \boldsymbol{W} \boldsymbol{\xi} + \boldsymbol{\xi}^{T} \boldsymbol{b} = - \frac{1}{2} \sum_{i=1}^{d} \sum_{j=1}^{d} w_{ij} \xi_{i} \xi_{j} + \sum_{i=1}^{d} b_{i} \xi_{i} ,
\label{eq:ising}
\end{equation}&lt;p&gt;where $\boldsymbol{b} \in \mathbb{R}^{d}$ denotes a bias vector and the weights $\boldsymbol{W} \in \mathbb{R}^{d \times d}$ are set to the sum of the outer products of the patterns we want to store&lt;/p&gt;
\begin{equation}
\boldsymbol{W} = \sum_{i=1}^{N} \boldsymbol{x}_{i} \otimes \boldsymbol{x}_{i}^{T}.
\end{equation}&lt;p&gt;The state pattern update rule is given by the sign of the gradient of \eqref{eq:ising} with respect to $\boldsymbol{\xi}$ and can be done in one step (synchronously) or separately for every component of the vector (asynchronously):&lt;/p&gt;
\begin{equation}
\boldsymbol{\xi}_{n+1} = \mathrm{sgn} \left( \boldsymbol{W} \boldsymbol{\xi}_{n} - \boldsymbol{b} \right).
\end{equation}&lt;p&gt;The storage capacity of this system for retrieval of patterns with a small amount of errors can be shown to be $C \cong 0.14 d$, scaling linearly with the dimension of the pattern vector.&lt;/p&gt;
&lt;h3 id="physical-intuition"&gt;Physical intuition&lt;/h3&gt;
&lt;p&gt;Physicists immediately recognize the energy function \eqref{eq:ising} as an incarnation of the
. Spin degree of freedoms $\xi_{i}$ are grouped into patterns $\boldsymbol{\xi}$ that are equivalent to &lt;em&gt;spin configurations&lt;/em&gt; of $d$ spins. The weight matrix is a sum of stored-pattern spin configurations, serving as attractors for the state-pattern spin configuration. The couplings $w_{ij}$ can be regarded a sum of samples of an underlying pattern data distribution. They are not restricted to (nearest-)neighbors and their values are neither uniform like in exactly solvable models nor totally random like in spin glass models.&lt;/p&gt;
&lt;!-- After identifying relevant degrees of freedom, physicists combine appropriate conceptual structures with arguments based on locality, symmetry, and physical and mathematical intuition to write down a model description, usually after a lot of hard work and trial-and-error.--&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Neural networks and spin glasses&lt;/strong&gt;: There is some literature on connections between
and
. Spin glasses are phases of matter describing disordered magnetic systems exhibiting both
and frustratation. Spin glasses were a major inspiration for Hopfield networks, as beautifully explained by the condensed matter physicist
in a
(1988-1990). However, apart from
&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;, I could not find any recent papers that point to a productive research direction beyond qualitative statements like &amp;ldquo;here&amp;rsquo;s two hard problems where symmetry and order will not help you solve them&amp;rdquo;.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 id="modern-discrete-hopfield-networks"&gt;Modern discrete Hopfield networks&lt;/h2&gt;
&lt;p&gt;Modern discrete Hopfield networks (or &lt;em&gt;dense&lt;/em&gt; associative memories) introduced the following family of energy functions to improve pattern storage capacity and pattern separation capabilities &lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
\begin{equation}
E = - \sum_{i=1}^{N} F \left( \boldsymbol{x}_{i}^{T} \cdot \boldsymbol{\xi} \right)
\end{equation}&lt;p&gt;Compared to the classical discrete Hopfield network energy function \eqref{eq:ising}, the explicit weight matrix is gone and the energy has been reduced to a sum of a function of dot products between the state pattern $\boldsymbol{\xi}$ and every stored pattern $\boldsymbol{x}_i$. For a polynomial interaction function $F(x) = x^{a}$, low-error storage capacity is $C \cong d^{a-1}$. The quadratic, classical discrete Hopfield network is recovered by setting $a=2$.&lt;/p&gt;
&lt;p&gt;Essentially, the role of $F(x)$ is to separate close patterns by blowing up differences in dot product values. Few things blow up better than exponentials, so
we can generalize the energy to&lt;/p&gt;
\begin{equation}
E = - \sum_{i=1}^{N} \exp \left( \boldsymbol{x}_{i}^{T} \cdot \boldsymbol{\xi} \right)
\end{equation}&lt;p&gt;with storage capacity $C \cong 2^{d/2}$. The corresponding update rules for modern discrete Hopfield networks can be shown to converge quickly with high probability&lt;sup id="fnref1:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h2 id="modern-continuous-hopfield-networks"&gt;Modern continuous Hopfield networks&lt;/h2&gt;
&lt;p&gt;Most machine learning applications are tailored to work with continuous embeddings (vector representations) rather than discrete patterns. Is there a way to generalize modern Hopfield networks to continuous data? Recently,
proposed the following energy function to deal with continuous $d$-dimensional patterns&lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
\begin{equation}
E(\boldsymbol{\xi}; \boldsymbol{X}) = \frac{1}{2} \boldsymbol{\xi}^T \boldsymbol{\xi} -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\xi} \right),
\label{eq:energyfunc}
\end{equation}&lt;p&gt;which we consider to be a function of the state pattern $\boldsymbol{\xi} \in \mathbb{R}^{d}$ and parametrized by $N$ stored patterns $\boldsymbol{X} = (\mathbf{x}_{1}, \ldots, \mathbf{x}_{N}) \in \mathbb{R}^{d \times N}$. From the point of view of
, the stored patterns $\boldsymbol{X}^T$ can also be interpreted as weights mapping $\boldsymbol{\xi}$ to hidden units&lt;sup id="fnref1:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Smoothly taking a maximum&lt;/strong&gt;: The $\mathrm{logsumexp}$ operator is defined for vectors $\mathbf{x}$ as
&lt;/p&gt;
\begin{equation}
\mathrm{logsumexp} \left( \mathbf{x} \right) = \log \left( \sum_{i=1}^{N} \mathrm{e}^{x_i} \right)
\end{equation}&lt;p&gt;
while for matrix arguments (like a batch of vectors), the $\mathrm{sumexp}$ is understood to apply to just one dimension after which the $\log$ acts element-wise on the resulting vector.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="physical-intuition-1"&gt;Physical intuition&lt;/h3&gt;
&lt;p&gt;We assume that the stored patterns equilibrate much quicker than those of the state pattern so that the former can effectively be considered &amp;ldquo;frozen&amp;rdquo;. The energy function \eqref{eq:energyfunc} looks deceptively simple: there is a single state pattern and there are no interactions among stored patterns. The first term takes care of making sure the norm of the input state pattern is finite, while the second term scores the query&amp;rsquo;s overlap based on its individual alignment with every stored pattern. The exponential function in the term&lt;/p&gt;
\begin{equation}
\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\xi} \right) = \log \left( \sum_{i=1}^{N} \mathrm{e}^{\mathbf{x}_i \cdot \boldsymbol{\xi}} \right)
\end{equation}&lt;p&gt;is used to pull apart close patterns by blowing up differences in the dot product between state pattern and stored patterns. From the perspective of the query, it is not so much an interaction term but rather a measure of the alignment of the query to external &amp;ldquo;magnetic fields&amp;rdquo; generated by the stored patterns.&lt;/p&gt;
&lt;h3 id="deriving-the-update-rule"&gt;Deriving the update rule&lt;/h3&gt;
&lt;p&gt;In the spirit of hand-waving, let us refuse to resort to of the dynamical systems machinery used in the original references &lt;sup id="fnref1:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref2:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt; and rather derive the update rule for the state pattern $\boldsymbol{\xi}$ by taking the derivative of the energy function \eqref{eq:energyfunc} with respect to $\boldsymbol{\xi}$&lt;/p&gt;
\begin{equation}
\nabla_{\boldsymbol{\xi}} E(\boldsymbol{\xi}; \boldsymbol{X}) = \boldsymbol{\xi} - \boldsymbol{X} \ \mathrm{softmax} \left( \boldsymbol{X}^T \boldsymbol{\xi} \right).
\end{equation}&lt;p&gt;A gradient descent update with step size $\gamma$ looks like&lt;/p&gt;
\begin{equation}
\boldsymbol{\xi}_{n+1} = \boldsymbol{\xi}_{n} - \gamma \left( \boldsymbol{\xi}_{n} - \boldsymbol{X} \ \mathrm{softmax} \left( \boldsymbol{X}^T \boldsymbol{\xi}_{n}\right) \right).
\label{eq:conthopfupdate}
\end{equation}&lt;p&gt;We are very confident that the topography of the energy landscape allows us to take big steps and boldly set $\gamma = 1$ to recover the familiar update rule&lt;/p&gt;
\begin{align}
\boldsymbol{\xi}_{n+1} = \boldsymbol{X} \ \mathrm{softmax} \left( \boldsymbol{X}^T \boldsymbol{\xi}_{n}\right) .
\end{align}&lt;p&gt;The updated vector is a linear combination of all stored patterns, weighted by an attention vector quantifying the overlap with the input pattern.&lt;/p&gt;
&lt;h2 id="modern-continuous-hopfield-networks-as-energy-based-models"&gt;Modern continuous Hopfield Networks as energy-based models&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s now try to connect the system defined by the energy function \eqref{eq:energyfunc} to the statistical mechanics framework of energy-based models &lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h3 id="energy-based-models-a-gentle-introduction"&gt;Energy-based models: a gentle introduction&lt;/h3&gt;
&lt;p&gt;Energy-based models learn a parametrized energy function $E_{\theta}$ which maps data points $\boldsymbol{x}$ to real, scalar energy values $E_{\theta}(\boldsymbol{x})$. The data distribution is modeled by the
,
&lt;/p&gt;
\begin{equation}
p_{\theta}(\boldsymbol{x}) = \frac{\mathrm{e}^{ - E_{\theta}(\boldsymbol{x}) }}{Z(\theta)},
\label{eq:boltzmann}
\end{equation}&lt;p&gt;
where $Z(\theta) = \int \mathrm{d} \boldsymbol{x} \ \mathrm{e}^{-E(\boldsymbol{x})}$ denotes the system&amp;rsquo;s partition function. Configurations $\boldsymbol{x}$ with low energies $E_{\theta}(\boldsymbol{x})$ are considered more likely and their weight contributes more strongly to the partition function.&lt;/p&gt;
&lt;p&gt;To steer the model distribution $p_{\theta}$ towards a target data distribution $p_{\mathrm{data}}$, we can try to minimize the likelihood loss function&lt;/p&gt;
\begin{equation}
\mathcal{L}_{\mathrm{ML}} (\theta) = \mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}} \left[ -\log p_{\theta} (\boldsymbol{x}) \right],
\label{eq:nll}
\end{equation}&lt;p&gt;where the negative log-likelihood equals&lt;/p&gt;
\begin{equation}
-\log p_{\theta} (\boldsymbol{x}) = E_{\theta} (\boldsymbol{x}) + \log Z (\theta).
\end{equation}&lt;p&gt;This is a hard optimization problem because calculating $\log Z (\theta)$ is hard for the vast majority of high-dimensional data distributions we care about. In practice, people resort to approximations like contrastive divergence to push the energy down on &amp;ldquo;positive examples&amp;rdquo; drawn from the data distribution while pushing up on &amp;ldquo;negative examples&amp;rdquo; obtained from sampling the model distribution. Even though sampling from \eqref{eq:boltzmann} can be done with methods like Markov Chain Monte Carlo, it is computationally expensive to do so, especially as part of an inner-loop optimization step&lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h3 id="exactly-optimizing-modern-continuous-hopfield-networks"&gt;Exactly optimizing modern continuous Hopfield networks&lt;/h3&gt;
&lt;p&gt;So what about the system defined by the energy function \eqref{eq:energyfunc}? Let&amp;rsquo;s consider the stored patterns $\mathbf{X} \in \mathbb{R}^{d \times N}$ as the model parameters we want to optimise. The task for the model is then to try to memorise incoming state patterns $\boldsymbol{\xi} \in \mathbb{R}^{d}$ drawn from some data distribution $p_{\mathrm{data}}$ by deciding what kind of patterns to store. The partition function looks like&lt;/p&gt;
\begin{equation}
Z = \int \mathrm{d} \boldsymbol{\xi} \ \mathrm{e}^{-E(\boldsymbol{\xi})} = \int \mathrm{d} \boldsymbol{\xi} \ \mathrm{e}^{-\frac{1}{2} \boldsymbol{\xi}^T \boldsymbol{\xi}} \left( \sum_{i=1}^{N} \mathrm{e}^{ \boldsymbol{x}^{T}_{i} \cdot \boldsymbol{\xi} } \right)
\label{eq:zforcontinuoushopfield}
\end{equation}&lt;p&gt;which, because of the $\log$ in the &amp;ldquo;interaction term&amp;rdquo;, boils down to a sum of
&lt;/p&gt;
\begin{equation}
\begin{aligned}
Z = (2\pi)^{n/2} \sum_{i=1}^{N} \mathrm{e}^{ \frac{1}{2} \boldsymbol{x}_{i}^{T} \cdot \boldsymbol{x}_{i} }
\end{aligned}
\end{equation}&lt;p&gt;After taking the logarithm, we end up with the $\mathrm{logsumexp}$ operator:&lt;/p&gt;
\begin{equation}
\log Z = \frac{n}{2} \log \left( 2\pi \right) + \mathrm{logsumexp} \left( \frac{1}{2} \mathrm{diag} \left( \boldsymbol{X}^{T} \boldsymbol{X} \right) \right)
\end{equation}&lt;p&gt;where the $\mathrm{diag}$ operator is understood to turn the diagonal of its matrix argument into a vector. Plugging this expression into \eqref{eq:nll} leads to the following loss function for the matrix of stored patterns&lt;/p&gt;
\begin{align}
\mathcal{L}_{\mathrm{ML}} (\mathbf{X}) = &amp; \mathbb{E}_{\boldsymbol{\xi} \sim p_{\mathrm{data}}} \left[ \frac{1}{2} \boldsymbol{\xi}^T \boldsymbol{\xi} -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\xi} \right) \right] \nonumber \\\\
&amp; + \mathrm{logsumexp} \left( \frac{1}{2} \mathrm{diag} \left( \boldsymbol{X}^{T} \boldsymbol{X} \right) \right) + \frac{n}{2} \log \left( 2\pi \right)
\end{align}&lt;p&gt;and a gradient&lt;/p&gt;
\begin{align}
\nabla_{\mathbf{X}} \mathcal{L}_{\mathrm{ML}} (\mathbf{X}) = &amp; - \mathbb{E}_{\boldsymbol{\xi} \sim p_{\mathrm{data}}} \left[ \boldsymbol{\xi} \otimes \mathrm{softmax} \left( \boldsymbol{X}^T \boldsymbol{\xi} \right) \right] \nonumber \\\\
&amp; + \boldsymbol{X} \ \mathrm{softmax} \left( \frac{1}{2} \mathrm{diag} \left( \boldsymbol{X}^{T} \boldsymbol{X} \right) \right)
\end{align}&lt;p&gt;and an update with step size $\gamma$&lt;/p&gt;
\begin{align}
\mathbf{X}_{n+1} = \ \mathbf{X}_{n} &amp;+ \gamma \ \mathbb{E}_{\boldsymbol{\xi} \sim p_{\mathrm{data}}} \left[ \boldsymbol{\xi} \otimes \mathrm{softmax} \left( \boldsymbol{X}^T_{n} \boldsymbol{\xi} \right) \right] \nonumber \\\\
&amp; - \gamma \ \mathbf{X}_{n} \ \mathrm{softmax} \left( \frac{1}{2} \mathrm{diag} \left( \boldsymbol{X}^{T}_{n} \boldsymbol{X}_{n} \right) \right)
\end{align}&lt;p&gt;Let&amp;rsquo;s try to guess what this means for a single input state pattern. The first gradient term pushes all stored patterns towards the sample but weighted by a dot-product attention vector quantifying their overlap with the input pattern, similar to \eqref{eq:conthopfupdate} but in the other direction. The second gradient term comes from the partition function and acts as a regularizer by keeping the norms of the stored patterns in check. Regularization keeps pattern values within a reasonable range and pushes the system towards regions in parameter space with non-trivial small dot-product values.&lt;/p&gt;
&lt;h2 id="transformers-store-and-retrieve-context-dependent-patterns"&gt;Transformers store and retrieve context-dependent patterns&lt;/h2&gt;
&lt;p&gt;Making the leap from modern continous Hopfield networks to the vanilla Transformer (self-)attention mechanism we encountered in
requires a few additional steps, as explained in detail in the
accompanying
.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;We want to act on multipe $d$-dimensional state patterns at the same time in order to retrieve multiple updated patterns in parallel:
\begin{align}
\boldsymbol{\xi} \in \mathbb{R}^{d} \to \boldsymbol{\Xi} = (\boldsymbol{\xi}_{1}, \ldots, \boldsymbol{\xi}_{S}) \in \mathbb{R}^{d \times S}
\end{align}
so that
\begin{align}
\boldsymbol{\Xi}_{n+1} = \boldsymbol{X} \ \mathrm{softmax} \left( \boldsymbol{X}^T \boldsymbol{\Xi}_{n}\right) .
\end{align}
In practice, the number of state patterns $S$ is often taken to be equal to the number of stored patterns $N$.&lt;/li&gt;
&lt;li&gt;We want to map stored patterns $\mathbf{X}$ and state patterns $\boldsymbol{\Xi}$ respectively to &lt;em&gt;keys&lt;/em&gt; $\mathbf{K} \in \mathbb{R}^{N \times d}$ and &lt;em&gt;queries&lt;/em&gt; $\mathbf{Q} \in \mathbb{R}^{S \times d}$ in a common feature space using linear transformations $\mathbf{W_{K}}$ and $\mathbf{W_{Q}}$.&lt;/li&gt;
&lt;li&gt;We want introduce another linear transformation $\mathbf{W_{V}}$ on stored patterns to transform them into &lt;em&gt;values&lt;/em&gt; $\mathbf{V} \in \mathbb{R}^{N \times d}$ appropriate for the keys&amp;rsquo; content.&lt;/li&gt;
&lt;li&gt;We want to modify the learning dynamics by decreasing the inverse temperature to $\beta = 1 / \sqrt{d}$, effectively making the $\mathrm{softmax}$ softer by increasing the temperature of the system&lt;sup id="fnref:15"&gt;&lt;a href="#fn:15" class="footnote-ref" role="doc-noteref"&gt;15&lt;/a&gt;&lt;/sup&gt;. Physically, this might correspond to warming up the system just enough to get out of the spin-glass phase while not introducing too much thermal noise&lt;sup id="fnref1:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;The result is the update rule we stated without explanation in
:
&lt;/p&gt;
\begin{equation}
\mathbf{Q}^{\mathrm{updated}} = \mathrm{Attention}\left( \mathbf{Q}, \mathbf{K}, \mathbf{V} \right) = \mathrm{softmax} \left( \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d}} \right) \mathbf{V},
\label{eq:transformerattnupdate}
\end{equation}&lt;p&gt;
where the $\mathrm{softmax}$ acts row-wise. In practice, the vanilla Transformer module additionally wraps the above attention module in (1) residual connections to control the flow of gradients, (2) layer norms to control pattern normalisations and learning dynamics, and (3) a positional feed-forward network for additional model capacity.&lt;/p&gt;
&lt;h2 id="where-are-patterns-stored-in-a-transformer"&gt;Where are patterns stored in a Transformer?&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s try to digest the implications of these quite substantial changes. It&amp;rsquo;s useful to think of Transformer (self-)attention modules as dynamic pattern storage and retrieval systems. In modern continuous Hopfield networks, stored patterns are considered a given. However, in the Transformer (self-)attenton module, patterns to be matched and retrieved are &lt;em&gt;dependent on inputs&lt;/em&gt; and &lt;em&gt;implicitly stored in the weights&lt;/em&gt; $\mathbf{W_{Q}}$, $\mathbf{W_{K}}$, and $\mathbf{W_{V}}$ of the linear transformations. In every layer, the module needs to learn how to map a set of inputs to patterns it wants to store (keys and values) as well as how to best retrieve them (queries). Within the same layer, dynamically generated queries are matched to keys within the same latent space. Between attention modules of neighboring layers, the non-linear activation function in the positional feed-forward network warps latent spaces.&lt;/p&gt;
&lt;!-- ## Transformer self-attention as energy-based models
For completeness, we can try to write down the energy function of the Transformer self-attention module. Starting from \eqref{eq:energyfunc}
Instead of stored patterns $X$ we considered fixed, the energy function
that is implicitly being optimised for by making the necessary substitutions in Eq [] :
\begin{equation}
E(\boldsymbol{\xi}; \boldsymbol{X}) = \frac{1}{2} \mathrm{diag} \left( \boldsymbol{\Xi}^T \boldsymbol{\Xi} \right) -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\Xi} \right),
\label{eq:transformerenergy}
\end{equation}
\begin{equation}
E(\boldsymbol{\xi}; \boldsymbol{X}) = \frac{1}{2} \mathrm{diag} \left( \boldsymbol{\Xi}^T \boldsymbol{\Xi} \right) -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\Xi} \right),
\label{eq:transformerenergy}
\end{equation}
Checking whether this transformed energy function still leads to a tractable Gaussian partition function (possibly involving the determinant of a sum of products of linear transformation matrices), is left as an exercise for the reader. --&gt;
&lt;h1 id="training-transformers"&gt;Training Transformers&lt;/h1&gt;
&lt;p&gt;Now that we are aware of an energy-based interpretation of dot-product (self-)attention, we can start hand-waving about what could be going on during the supervised training procedure of Transformer models and how energy-based models suggest a qualitatively different approach to improving attention mechanisms.&lt;/p&gt;
&lt;h2 id="pretraining-loss-functions"&gt;Pretraining loss functions&lt;/h2&gt;
&lt;p&gt;The goal of pretraining loss functions is to induce &lt;em&gt;useful&lt;/em&gt; data-dependent pattern storage and retrieval behavior. Pretraining strategies for Transformer-based language models rely on loss functions derived from auxiliary tasks to learn statistical patterns in natural language. Starting from almost identical model architectures, autoregressive models like GPT-3 leverage all their parameters to predict the next token in a sequence given previous tokens while autoencoding models like BERT try to reconstruct corrupted tokens. In both cases, the loss function is a cross-entropy loss involving predictions in the space of the model&amp;rsquo;s token vocabulary.&lt;/p&gt;
&lt;h2 id="stepping-through-the-transformer-implicit-energy-minimization"&gt;Stepping through the Transformer: implicit energy minimization&lt;/h2&gt;
&lt;p&gt;Although no energy function is &lt;em&gt;explicitly&lt;/em&gt; optimized during training&lt;sup id="fnref:16"&gt;&lt;a href="#fn:16" class="footnote-ref" role="doc-noteref"&gt;16&lt;/a&gt;&lt;/sup&gt;, let&amp;rsquo;s see how far we can push hand-wavy energy-based arguments by stepping through the forward and backward pass of a Transformer model. We have learned that the attention update \eqref{eq:transformerattnupdate} in every Transformer layer is actually a hidden gradient step. This trivial insight leads to a trio of trivial observations.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Trivial Observation #1:&lt;/strong&gt; &lt;em&gt;During training, the update step \eqref{eq:transformerattnupdate} of the attention mechanism in a Transformer layer acts as an inner-loop optimization step, minimizing an implicit energy function determined by the queries, keys, and values constructed from the output of the previous layer.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Trivial Observation #2:&lt;/strong&gt; &lt;em&gt;During the forward pass of a deep Transformer model, a nested hierarchy of energy functions is minimized.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Trivial Observation #3:&lt;/strong&gt; &lt;em&gt;During the backward pass of a deep Transformer model, the parameters of its attention modules get updated such that the inner-loop optimization steps conspire to pattern match queries to keys in such a way that the sequentially-updated final latent representations are useful for improving the loss.&lt;/em&gt;&lt;/p&gt;
&lt;h2 id="meta-learning-and-few-shot-inference"&gt;Meta-learning and few-shot inference&lt;/h2&gt;
&lt;p&gt;Squinting our eyes, we can see traces of a &lt;em&gt;meta-learning&lt;/em&gt; problem: how to tune model parameters &amp;mdash; in particular the attention mechanisms&amp;rsquo; linear transformation matrices &amp;mdash; such that applying a sequence of one-step attention updates to sets of input patterns converges to representations useful for minimizing the (meta-)loss function. Learnable modules of a differentiable program can of course often be considered part of a larger meta-learning setup. But what this point of view suggests is that confining the one-step inner-loop update to a simple associative memory pattern lookup might be quite restrictive.&lt;/p&gt;
&lt;p&gt;Yet even with with a simple dense associative memory, OpenAI&amp;rsquo;s paper
showed that large-capacity models like GPT-3 already exhibit quite impressive meta-learning capabilities. The energy-based perspective provides a naive yet attractive explanation for this phenomenon. At inference time, the few-shot demonstrations, which make up the initial part of a few-shot learning query, condition the sequential generation process by providing basins of attraction in the energy landscape for other energy minimization steps to be pulled towards. &lt;em&gt;The GPT-3 model is memorizing to the extent the demonstrations match patterns seen during training and generalizing within the possibilities of the rudimentary attention dynamics of the simple underlying energy functions.&lt;/em&gt;&lt;/p&gt;
&lt;h1 id="beyond-dot-product-attention"&gt;Beyond dot-product attention&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s conclude this post with two related thoughts inspired by an energy-based perspective on current attention architectures: attention dynamics and modeling very long sequences.&lt;/p&gt;
&lt;h2 id="attention-dynamics-embracing-collective-phenomena"&gt;Attention dynamics: embracing collective phenomena&lt;/h2&gt;
&lt;p&gt;We have seen that the energy function of a modern continuous Hopfield network \eqref{eq:energyfunc} is rather uninspiring from a physics perspective. Theoretically, the exponential storage and efficient retrieval of patterns is obtained by burning deep valleys into the energy landscape around stored patterns (keys) for neighbouring state patterns (queries) to quickly roll into. In practice, the authors of
observed three kinds of fixed-point behavior in a pretrained BERT model: (1) global fixed points averaging over all stored patterns, (2) metastable states averaging over a subset of stored patterns, and (3) fixed points returning a single, well-separated stored pattern.&lt;/p&gt;
&lt;p&gt;What does this tell us? Assuming the attention updates converge faithfully during training, the linear maps turning input vectors into queries, keys, and values can become bottlenecks in terms of being able to separate patterns and organise the energy landscape. Additionally, the lack of interactions among patterns and the decoupled dot-product overlap between queries and keys puts considerable limits on how the network can process information. In practice, this is being partially addressed by using multiple attention heads (see
), but this solution does not feel satisfactory.&lt;/p&gt;
&lt;h2 id="why-very-long-sequences-should-not-be-needed"&gt;Why very long sequences should not be needed&lt;/h2&gt;
&lt;p&gt;Recurrent neural networks try to compress patterns in a single hidden state via sequential propagation but often fail to do so and forget stuff along the way. Transformers bake patterns into a hierarchical energy landscape but focus on a fixed-length context window to store and retrieve patterns. As we&amp;rsquo;ve seen in
, a lot of research on improving Transformers focuses on alleviating the $\mathcal{O}(N^2)$ bottleneck of the attention computation with the implicit goal of scaling to longer sequences and enabling larger context windows.&lt;/p&gt;
&lt;p&gt;But very long sequences should not be needed if patterns are allowed to talk to each other. A model should not need all of the world as context if patterns and emergent concepts can be connected. It&amp;rsquo;s definitely worthwhile to try to reduce the computational complexity of current attention architectures, but it might be far more valuable to swap the simple energy-based model \eqref{eq:energyfunc} for more interesting energy-based models. Why not dust off the old unrestricted Boltzmann machine once again? Or experiment with any one of a century&amp;rsquo;s worth of physics models? Not to train them explicitly, but have them serve as implicit models underlying more intricate attention mechanisms, mediated by (local) interactions among patterns. Naturally, after so much hand-waving, our journey has to end here.&lt;/p&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;Even if attention turns out to &lt;em&gt;not&lt;/em&gt; be all we need, (self-)attention modules have established themselves as highly parallelizable neural network building blocks capable of dynamically routing information based on context. We have seen that dot-product attention modules in Transformer models work by encoding high-dimensional patterns into the landscapes of simple energy functions, enabling fast pattern storage and retrieval. During training, these landscapes are sculpted to accommodate statistical patterns found in data by hierarchically matching and combining latent pattern representations through a sequence of implicit energy function minimizations.&lt;/p&gt;
&lt;p&gt;We argued that an energy-based perspective on attention provides an intuitive explanation of meta-learning capabilities of large-capacity language models and encourages the exploration of qualitatively different attention mechanisms for pattern storage and retrievel. Rather than naively scaling the current generation of Transformers, it might be more rewarding to scale learning itself by exploring more powerful, expressive, and computationally efficient attention mechanisms, guided by energy-based models. Perhaps we should consider looking at neural networks again like John Hopfield already did in 1982: &lt;em&gt;physical systems with emergent collective computational abilities&lt;/em&gt;.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2020energyattention,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {An Energy-Based Perspective on Attention Mechanisms in Transformers},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2020},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {December},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/an-energy-based-perspective-on-attention-mechanisms-in-transformers/},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;em&gt;Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin,
(2017)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;em&gt;Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Lukas Gruber, Markus Holzleitner, Milena Pavlović, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;em&gt;Johannes Brandstetter,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;em&gt;Johannes Brandstetter and Hubert Ramsauer,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;If you have only just joined the attention revolution, there are a lot of great resources out there to get you started. Yannic Kilcher provides a great introduction in his
. The
presented at
contain a thorough and visually appealing introduction to attention-based models. Because code is usually more to the point than papers that need to sell themselves, I highly recommend Phil Wang&amp;rsquo;s
showcasing some of the latest models and techniques.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;&lt;em&gt;John Hopfield,
(1982)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;&lt;em&gt;Alejandro Pozas-Kerstjens, Gorka Muñoz-Gil, Miguel Ángel García-March, Antonio Acín, Maciej Lewenstein, Przemysław R. Grzybowski,
(2019)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2016)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;&lt;em&gt;Mete Demircigil, Judith Heusel, Matthias Löwe, Sven Upgang, and Franck Vermet,
(2017)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;A physicist might consider these continuous patterns spin configurations of the degrees of freedom in a vector spin model where the internal dimension $D \sim 10^2-10^4$ is much bigger than familiar small-$D$ cases like the
or the
but much smaller than infinity.&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;&lt;em&gt;Yann LeCun, Sumit Chopra, Raia Hadsell, Marc&amp;rsquo;Aurelio Ranzato, and Fu Jie Huang,
(2006)&lt;/em&gt; and &lt;em&gt;Yann LeCun and Alfredo Canziani,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;&lt;em&gt;Pankaj Mehta, Marin Bukov, Ching-Hao Wang, Alexandre G.R. Day, Clint Richardson, Charles K. Fisher, and David J. Schwab,
(2019)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;The generator in a Generative Adverserial Network (GAN) setup can be considered a clever way to generate negative samples for the implicit energy function optimization taking place in the discriminator.&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:15"&gt;
&lt;p&gt;As we have seen in
, the naive interpretation of $\beta$ as &lt;em&gt;the&lt;/em&gt; effective inverse temperature is tenuous in practice given the influence of the surrounding layer normalisation modules.&amp;#160;&lt;a href="#fnref:15" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:16"&gt;
&lt;p&gt;The implicitly defined energy functions in Tranformer layers are not optimized directly because they arguably do not provide a meaningful training signal on their own. Verifying whether this is true or not could make for an interesting experiment.&amp;#160;&lt;a href="#fnref:16" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>