<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Boltzmann Machine | mcbal</title><link>https://mcbal.github.io/tags/boltzmann-machine/</link><atom:link href="https://mcbal.github.io/tags/boltzmann-machine/index.xml" rel="self" type="application/rss+xml"/><description>Boltzmann Machine</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-gb</language><lastBuildDate>Tue, 23 Nov 2021 12:17:17 +0100</lastBuildDate><image><url>https://mcbal.github.io/media/icon.svg</url><title>Boltzmann Machine</title><link>https://mcbal.github.io/tags/boltzmann-machine/</link></image><item><title>Transformers Are Secretly Collectives of Spin Systems</title><link>https://mcbal.github.io/post/transformers-are-secretly-collectives-of-spin-systems/</link><pubDate>Tue, 23 Nov 2021 12:17:17 +0100</pubDate><guid>https://mcbal.github.io/post/transformers-are-secretly-collectives-of-spin-systems/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (April 2023):&lt;/strong&gt; Consider reading
where we continue building on the intuition of probing a spin system to engineer its collective response but get rid of the assumption of symmetric coupling matrices by shifting focus from equilibrium free energies to dynamical mean-field approximations of non-equilibrium vector-spin models.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;In this post, we try to distill a unifying perspective out of ideas developed in a series of longer posts on understanding transformers as physical systems:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;We argue that a blueprint of the neural-network architecture of the archetypical transformer module can be derived from the structure of physical spin systems familiar from classical statistical mechanics. More specifically, we claim that the forward pass of transformer modules maps onto computing magnetizations in vector-spin models in response to incoming data. We imagine transformers as collectives of differentiable spin systems whose behavior can be shaped through training.&lt;/p&gt;
&lt;h1 id="where-does-the-transformer-module-architecture-come-from"&gt;Where does the transformer module architecture come from?&lt;/h1&gt;
&lt;p&gt;Taking a bird&amp;rsquo;s eye view of the evergrowing zoo of transformer architectures in natural language processing and computer vision suggests that the design pattern introduced in &lt;em&gt;
&lt;/em&gt;&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; is still dominant. Almost all architectural variations of transformer modules published in the last four years have stuck to a successful combination of residual connections, an attention-like operation (token-mixing), normalization layers, and a feed-forward-like operation (channel-mixing).&lt;/p&gt;
&lt;p&gt;&lt;a href="https://arxiv.org/abs/2111.11418" target=_blank&gt;&lt;img src="metaformer.png" alt="MetaFormer architecture comparison" width="400px"/&gt;&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;Recent work like &lt;em&gt;
&lt;/em&gt;&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; appropriately shifts focus to the high-level architecture of the transformer module and argues that its full structure, rather than just the token-mixing attention operation, is essential for transformers to achieve competitive performance.&lt;/p&gt;
&lt;p&gt;So where does this archetypical design pattern come from? Why does it seem to stick around? Is there any physical intuition behind its structure?&lt;/p&gt;
&lt;h1 id="deriving-attention-from-energy-functions-only-gets-you-so-far"&gt;Deriving attention from energy functions only gets you so far&lt;/h1&gt;
&lt;p&gt;Recent papers like &lt;em&gt;
&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;&lt;/em&gt; and &lt;em&gt;
&lt;/em&gt;&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt; have looked for physical intuition behind attention mechanisms using an
phrased in terms of modern continuous Hopfield networks. The main idea is to derive the softmax-attention update rule&lt;/p&gt;
\begin{equation}
\boldsymbol{Q}' = \text{softmax}\left( \frac{\boldsymbol{Q} \boldsymbol{K}^T}{\sqrt{d}} \right) \boldsymbol{K}
\end{equation}&lt;p&gt;by taking a large gradient descent update step using the derivative with respect to input queries $\boldsymbol{Q}$ of some judiciously chosen energy function&lt;/p&gt;
\begin{equation}
E = \frac{1}{2} \boldsymbol{Q} \boldsymbol{Q}^T -\mathrm{logsumexp} \left( \frac{\boldsymbol{Q} \boldsymbol{K}^T}{\sqrt{d}} \right). \label{eq:logsumexp}
\end{equation}&lt;p&gt;In this way, vanilla softmax attention can be recast as taking a
. The energy landscape defined by Eq. \eqref{eq:logsumexp} implements an associative memory system for storing and retrieving vector patterns where queries flow towards valleys associated with their nearest keys (see
):&lt;/p&gt;
&lt;p&gt;&lt;a href="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/" target=_blank&gt;&lt;img src="landscape.png" alt="Logsumexp energy function landscape" width="300px"/&gt;&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;But there is more to transformer modules than just attention. In practice, we know that residual connections, normalization layers, and feed-forward layers are all essential to achieve good empirical performance.&lt;/p&gt;
&lt;p&gt;Can we generalize this physical intuition of taking derivatives with respect to an energy function to recover the full transformer module? Yes, we can. But we have to take a step back from energy functions and focus on their underlying physical systems instead.&lt;/p&gt;
&lt;h1 id="back-to-the-roots-physical-spin-systems-and-vector-spin-models"&gt;Back to the roots: physical spin systems and vector-spin models&lt;/h1&gt;
&lt;p&gt;Energy functions in classical statistical mechanics are succinct descriptions encoding interactions and constraints in physical systems. Spin systems are prototypical physical systems which often serve as toy models for all kinds of phenomena&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;The
is a simple toy model describing a classical binary spin system with local spin degrees of freedom at every site pointing either up or down. The energy function of the binary random Ising model for $N$ spins in the presence of a site-dependent external magnetic field is given by&lt;/p&gt;
\begin{equation}
E = - \sum_{i,j=1}^{N} J_{ij} \sigma_{i} \sigma_{j} - \sum_{i=1}^{N} h_{i} \sigma_{i}, \label{eq:binaryrandomising}
\end{equation}&lt;p&gt;where the $J_{ij}$ encode coupling strengths between all pairs of spins and the external magnetic fields $h_{i}$ act as biases by providing a preferential value of alignment at every site. The model defined by \eqref{eq:binaryrandomising} is also known as a
or
. A cartoon of this model looks like a graph of little arrows that are pairwise coupled&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;img src="binary_ising.png" alt="Random Ising model configuration with binary spins" width="200px"/&gt;
&lt;p&gt;At thermal equilibrium, the Boltzmann probability distribution $e^{-\beta E\left( \sigma \right)} / Z$ reflects what patterns of up-down spins, or &lt;em&gt;spin configurations&lt;/em&gt;, are preferred. The partition function $Z = \sum_{\sigma} e^{-\beta E\left( \sigma \right)}$ of a spin system is not only a normalization constant but also a magical object relating the microscopic world of fluctuating spins to thermodynamic, observable quantities via the free energy $F = - \beta^{-1} \log Z$. Even for simple spin systems, computing partition functions by summing over all possible configurations is a shockingly hard thing to do in most scenarios.&lt;/p&gt;
&lt;p&gt;Binary spin models are nice but rarely excite machine learning practitioners anymore nowadays. Modern neural networks like transformers act on sequences of vectors like token embeddings or image patches. Instead of abandoning spin models altogether, we could consider &lt;em&gt;vector-spin models&lt;/em&gt;. Replacing binary degrees of freedom with $d$-dimensional vector degrees of freedom, we can define a spin-model energy function&lt;/p&gt;
\begin{align}
E = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i}, \label{eq:vectorrandomising}
\end{align}&lt;p&gt;where the scalar products have turned into dot products. Models of this form first popped up in 1960s statistical mechanics literature as
. They also appear in recent studies on higher-dimensional generalizations of spin glass models&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;img src="vector_ising.png" alt="Random Ising model configuration with vector spins" width="200px"/&gt;
&lt;p&gt;Now how can we relate vector-spin systems like Eq. \eqref{eq:vectorrandomising} to modern neural networks?&lt;/p&gt;
&lt;h1 id="why-dont-we-just-probe-a-vector-spin-system-with-data"&gt;Why don’t we just probe a vector-spin system with data?&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s pursue an intuitive idea. Imagine we want to expose our vector-spin system Eq. \eqref{eq:vectorrandomising} to a sequence of vector data. We can do this by having the sequence act as the spin system&amp;rsquo;s external magnetic field $(\boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N})$. We would then like to observe how the spin system responds to this particular environment of patterns.&lt;/p&gt;
&lt;p&gt;If all of the steps in the computation of the spin system&amp;rsquo;s responses can be implemented in a differentiable way, we should be able to engineer its collective behavior by optimizing the coupling parameters to better respond to future incoming data. We propose to observe spin-system responses in terms of &lt;em&gt;magnetizations computed from free energies&lt;/em&gt;.&lt;/p&gt;
&lt;h1 id="a-slice-of-statistical-mechanics-magnetizations-and-free-energies"&gt;A slice of statistical mechanics: magnetizations and free energies&lt;/h1&gt;
&lt;p&gt;For ease of notation, let&amp;rsquo;s call the model parameters $\theta \equiv \{ J_{ij} \}$, the spins $\sigma \equiv \{ \boldsymbol{\sigma}_{i} \}$, and the external magnetic fields $h \equiv (\boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N})$. We can then schematically write our spin system&amp;rsquo;s partition function as&lt;/p&gt;
\begin{align}
Z_{\theta} \left( h \right) = \int \mathrm{d} \sigma \ \mathrm{e}^{ - \beta E_{\theta}\left( \sigma, h \right) } \label{eq:partfun}
\end{align}&lt;p&gt;and the corresponding free energy as $F_{\theta} \left( h \right) = - \beta^{-1} \log Z_{\theta} \left( h \right)$.&lt;/p&gt;
&lt;p&gt;Magnetizations are responses of our spin system to the external magnetic field imposed by $(\boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N})$. From standard thermodynamics, we know that we can calculate magnetizations from the free energy by differentiating with respect to the external field&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
\begin{align}
\boldsymbol{m}_{i} = - \frac{\mathrm{d} F_{\theta} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right)}{\mathrm{d} \boldsymbol{h}_{i}} = \langle \boldsymbol{\sigma}_{i} \rangle , \label{eq:sigma}
\end{align}&lt;p&gt;which, in this case, boils down to calculating spin expectation values. The magnetization for every site depends on the couplings and, through the couplings between spins, on the values of the external field at all sites. Magnetizations reveal how spins will collectively tend to align themselves when we place the spin system in an environment of patterns.&lt;/p&gt;
&lt;p&gt;Before we move on, we have to account for one more complication. If we want to draw a correspondence between transformer modules and vector-spin systems, we will have to allow for couplings that depend on the external magnetic field. For example, the attention matrix in vanilla transformers looks something like&lt;/p&gt;
\begin{equation}
J_{ij} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right) = \left[\mathrm{softmax}\left( \frac{\boldsymbol{H} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{H}^{T}}{\sqrt{d}} \right)\right]_{ij}, \label{eq:softmaxcouplings}
\end{equation}&lt;p&gt;where the matrix $\boldsymbol{H}$ denotes the stack of external magnetic field vectors. The interactions between spins are determined dynamically based on the inputs. From a physics perspective, these &amp;ldquo;amortized&amp;rdquo; couplings are very weird and highly unusual, but such is the transformer.&lt;/p&gt;
&lt;p&gt;The potential dependency of the couplings on the external field changes the magnetization of Eq. \eqref{eq:sigma} to an expression of the form&lt;/p&gt;
\begin{align}
\boldsymbol{m}_{i} &amp;= - \frac{\mathrm{d} F_{\theta} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right)}{\mathrm{d} \boldsymbol{h}_{i}} \nonumber \\\\ &amp;= \langle \boldsymbol{\sigma}_{i} \rangle + \sum_{m,n} \langle \boldsymbol{\sigma}_{m} \cdot \boldsymbol{\sigma}_{n} \rangle \frac{\partial J_{mn} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right) }{ \partial \boldsymbol{h}_{i} } , \label{eq:sigmaweird}
\end{align}&lt;p&gt;where two-point correlation functions are seen to act as weights for the coupling contributions&lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;. In practice, we should of course let an automatic differentiation framework keep track of dependencies so that we can get away with simply computing&lt;/p&gt;
\begin{align}
\boldsymbol{m}_{i} = - \frac{\mathrm{d} F_{\theta} \left( \boldsymbol{h}_{1}, \boldsymbol{h}_{2}, \ldots, \boldsymbol{h}_{N} \right)}{\mathrm{d} \boldsymbol{h}_{i}}, \label{eq:magnetization}
\end{align}&lt;p&gt;assuming we have a differentiable expression for the (approximate) free energy available.&lt;/p&gt;
&lt;h1 id="turning-a-differentiable-spin-system-into-a-neural-network"&gt;Turning a differentiable spin system into a neural network&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s now use the ingredients introduced above to construct a neural network module which wraps around a vector-spin system. Given the energy function Eq. \eqref{eq:vectorrandomising} and the free energy $F_{\theta} \left( h \right) = - \beta^{-1} \log \int \mathrm{d} \sigma \ \mathrm{e}^{ - \beta E_{\theta}\left( \sigma, h \right) }$, we let incoming data play the role of the external magnetic field and return magnetizations in response.&lt;/p&gt;
&lt;img src="spinmodule_new.png" alt="Spin system as a neural network" width="600px"/&gt;
&lt;p&gt;Nice. But didn&amp;rsquo;t we mention before that partition functions (and hence free energies and thus magnetizations) are shockingly hard to compute? Why introduce all these formal expressions if we cannot compute anything?&lt;/p&gt;
&lt;p&gt;Looking back at statistical mechanics papers from the 1950s-1970s, it turns out that physicists have already developed several tricks and approximation methods that can be applied to deal with vector-spin systems. Computational evidence that the partition function approach outlined above &lt;em&gt;is&lt;/em&gt; possible for vector-spin systems can be found in
(below, left) and
(below, right).&lt;/p&gt;
&lt;img src="arch_dia_afem_new.png" alt="Deep implicit attention and approximate free-energy minimization" width="600px"/&gt;
&lt;p&gt;In these examples, approximations of the partition function Eq. \eqref{eq:partfun} were obtained following respectively a mean-field theory and a steepest-descent approach. Our
of both approaches rely internally on
to ensure that fixed-point calculations and root-solving steps are efficiently differentiable.&lt;/p&gt;
&lt;h1 id="an-exercise-in-squinting-recognizing-the-transformer-module"&gt;An exercise in squinting: recognizing the transformer module&lt;/h1&gt;
&lt;p&gt;Computing magnetizations according to Eq. \eqref{eq:magnetization} from the (approximate) free energies obtained in
and
reveals a high-level structure that is surprisingly familiar: a pattern of residual connections, token-mixing, normalization, and channel-mixing. Approaching the crux from the other direction, we argue that transformer modules react to inputs by implementing particular approximations to the general magnetization response Eq. \eqref{eq:sigmaweird}.&lt;/p&gt;
&lt;p&gt;Residual connections are proportional to the inputs and arise from the presence of the external magnetic field. Token-mixing contributions emerge from the coupling terms in the energy function and mix inputs without acting on the local vector-spin dimension. Normalization follows from requiring that the energy of the spin system remain linearly proportional to the number of lattice sites and from normalizing the external magnetic field vectors. Channel-mixing contributions include terms in the magnetization that can be applied locally, like Onsager self-correction terms in mean-field approaches or (approximations to) contributions coming from input-dependent couplings in Eq. \eqref{eq:sigmaweird}.&lt;/p&gt;
&lt;p&gt;Taken together, these observations suggest that we can picture the forward pass of a transformer module as a wrapper around a vector-spin system: module inputs are routed to the external magnetic field (and, optionally, to a parametrized couplings function) after which magnetizations are returned as outputs. The transformer module bears an uncanny resemblance to a differentiable physical system whose collective behavior we can control through training.&lt;/p&gt;
&lt;h1 id="training-transformer-modules-shapes-collective-behavior"&gt;Training transformer modules shapes collective behavior&lt;/h1&gt;
&lt;p&gt;Now that we can picture transformer modules as physical spin systems responding to getting probed with data, let&amp;rsquo;s imagine what training them looks like.&lt;/p&gt;
&lt;p&gt;On the level of the energy function of our spin system Eq. \eqref{eq:vectorrandomising}, we can model the training process of a transformer module by introducing a (discrete) time dimension and making the external magnetic field time-dependent, leading to&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
\begin{equation}
E(t) = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i}(t) \cdot \boldsymbol{\sigma}_{i} \label{eq:sloppyenergy}
\end{equation}&lt;p&gt;At every training step $t$, a sequence of incoming data $\{ \boldsymbol{h}_{1}(t), \boldsymbol{h}_{2}(t), \ldots, \boldsymbol{h}_{N}(t) \}$ takes on the role of external magnetic field. During the forward pass, magnetizations $\boldsymbol{m}_{i}$ are computed in a differentiable way according to the current model parameters and in the presence of the current external magnetic field. Physically, we consider &amp;ldquo;quenched&amp;rdquo; systems with &amp;ldquo;frozen&amp;rdquo; couplings at every training step. During the backward pass, the module&amp;rsquo;s coupling parameters $J_{ij}$ get updated, nudging the interactions in the spin system so as to influence its magnetization responses to similar data in future iterations.&lt;/p&gt;
&lt;img src="spinmoduletraining_new.png" alt="Training a spin system as a neural network" width="600px"/&gt;
&lt;p&gt;We can think about this training process as gradually shaping the collective behavior of a differentiable vector-spin system that is driven by data. If the couplings depend on the inputs, like in Eq. \eqref{eq:softmaxcouplings}, we should make the couplings time-dependent as well in Eq. \eqref{eq:sloppyenergy}. In that case, the external magnetic fields as well as the parametrized couplings change instantaneously at every training step.&lt;/p&gt;
&lt;h1 id="training-deep-transformers-orchestrates-spin-system-collectives"&gt;Training deep transformers orchestrates spin-system collectives&lt;/h1&gt;
&lt;p&gt;Training a deep transformer model corresponds to orchestrating a stack of transformer modules by building up a differentiable structure of correlations where the magnetizations of one spin system drive the next one. Wiggling (billions of) parameters during training nudges the cascading response behavior of the collective of spin systems to better adapt to the collective&amp;rsquo;s (meta-)tasks as specified by the data and the loss function.&lt;/p&gt;
&lt;img src="transformertraining_new.png" alt="Training a transformer" width="500px"/&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;In this post, we argued that the forward pass of a transformer module maps onto computing magnetizations in a vector-spin model responding to data. Generalizing previous work on understanding softmax attention modules in terms of modern continuous Hopfield networks by taking derivatives of a judiciously chosen &lt;em&gt;energy&lt;/em&gt; function, we propose to take derivatives of the &lt;em&gt;free energy&lt;/em&gt; of a general vector-spin system to get to a blueprint of the architecture of a full transformer module.&lt;/p&gt;
&lt;p&gt;By zooming out and approaching transformers from a tangential, statistical-mechanical point of view, we arrived at a physical intuition of transformers that seems hard to obtain when restricting oneself to perpetually perturbing explicit neural network architectures. Recognizing transformer modules as spin models in disguise might not only unify architectural variations as different ways to approximately compute magnetizations but also elucidate the empirical success of transformers in deep learning.&lt;/p&gt;
&lt;h1 id="acknowledgements"&gt;Acknowledgements&lt;/h1&gt;
&lt;p&gt;We would like to thank
for hosting its research jams and providing a friendly environment to present ideas.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021isingisallyouneed,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Transformers Are Secretly Collectives of Spin Systems},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {November},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/transformers-are-secretly-collectives-of-spin-systems/}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;em&gt;Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin,
(2017)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;em&gt;Weihao Yu, Mi Luo, Pan Zhou, Chenyang Si, Yichen Zhou, Xinchao Wang, Jiashi Feng, and Shuicheng Yan,
(2021)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;em&gt;Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Lukas Gruber, Markus Holzleitner, Milena Pavlović, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Consider reading the Physics Today article on
for an introduction to disordered systems, spin glasses, Ising spin systems, emergent collective computational abilities, associative memories, Hopfield models, and the idea of learning patterns as shaping the behavior of systems. Essentially, what we&amp;rsquo;re trying to do in this post is figuring out a way to relate modern transformer models back to these old ideas.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;We plot spin sites at random positions to emphasize that there is no spatial notion of &amp;ldquo;closeness&amp;rdquo; in a fully-connected system: every site is just a hop away. To not overload the graph, we only draw connections strongest in absolute value.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;For example, see &lt;em&gt;
&lt;/em&gt; and &lt;em&gt;
&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;For example, see the content of Chapter 2 in the
by Thierry Giamarchi.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;In the absence of an explicit expression for the free energy, one of the feed-forward network&amp;rsquo;s roles might be to try to approximate the complicated dependencies in the magnetization expression Eq. \eqref{eq:sigmaweird}, at the cost of introducing a large amount of additional free parameters beyond just the coupling parameters. It would be interesting to look into this numerically at scale using the free energy expression obtained in
.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;The time-dependence in Eq. \eqref{eq:sloppyenergy} smells of non-equilibrium statistical mechanics. Incoming data might be considered as time-dependent &amp;ldquo;probes&amp;rdquo; which inject energy (and useful information if its content is low-entropy enough) into a non-equilibrium system. By nudging its dynamical response behavior across spatiotemporal scales, the system could potentially learn how to deal with being driven by all kinds of patterns in incoming data. For an interesting toy example of such behavior, see
by Jeremy England on &lt;em&gt;Low rattling: a principle for understanding driven many-body self-organization&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Transformers from Spin Models: Approximate Free Energy Minimization</title><link>https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/</link><pubDate>Tue, 12 Oct 2021 18:40:17 +0100</pubDate><guid>https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (November 2021):&lt;/strong&gt; Consider reading
for a high-level overview of some of the ideas outlined in this post.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ TL;DR:&lt;/strong&gt; &lt;em&gt;We consider transformer modules as wrappers around a differentiable steepest-descent approximation of simple Ising-like vector-spin models familiar from statistical mechanics. We observe that a blueprint of the successful transformer-like architectural pattern of token-mixing (attention) and channel-mixing (feed-forward) naturally emerges when computing spin expectation values in vector-spin models with input-dependent couplings. Feel free to skip to the
for a visual comparison of this work to vanilla transformers, deep equilibrium transformers, and deep implicit attention.&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Code: A PyTorch implementation of the ideas outlined in this blog post is available in the GitHub repository
.&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;In
, we introduced a mean-field theory perspective on transformer modules. We showed how their outputs can be understood as mean-field spin expectation values of simple Ising-like vector-spin systems. Physically, the process of training a transformer module can be understood as driving a classical many-body system with data and iteratively shaping its collective response behaviour through coupling-weight parameter updates. Stacking transformer modules corresponds to building up a differentiable structure of correlations by using the spin expectation values of one physical system to drive the next one.&lt;/p&gt;
&lt;p&gt;In this post, we flesh out the idea of looking at transformer modules as physical systems. Having identified vector spin systems as plausible physical models underlying transformers, we turn to 1960s statistical-mechanics literature to look for inspiration on how to deal with their partition functions&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;. We rediscover that the partition function of a particular class of vector-spin models can be approximated in the limit of large local spin dimension using steepest descent, leading to approximate yet tractable expressions for the free energy and other derived quantities.&lt;/p&gt;
&lt;p&gt;Combining these canonical results from statistical mechanics with modern differentiable programming, we implement a differentiable vector-spin model based on an approximate free-energy minimization algorithm. Internally, the model uses an implicit layer to solve for the stationary point of the partition function in a differentiable way. We then construct a transformer-like attention module which encapsulates the spin model by routing inputs to applied magnetic fields and spin expectation values to outputs. The latter are obtained by following the familiar recipe of statistical mechanics: differentiating the spin model&amp;rsquo;s $\log Z$ with respect to conjugate input variables. Finally, we contextualize our approach by comparing it to vanilla transformers, deep equilibrium transformers, and deep implicit attention.&lt;/p&gt;
&lt;h1 id="massaging-partition-functions"&gt;Massaging partition functions&lt;/h1&gt;
&lt;p&gt;In this section, we set out to derive an approximate, analytical expression for the free energy of a classical disordered vector-spin system exposed to a site-dependent external magnetic field. In deriving the results below, we found inspiration in H. E. Stanley&amp;rsquo;s
and Chapter 5 of R. J. Baxter&amp;rsquo;s bible on
.&lt;/p&gt;
&lt;h2 id="a-vector-spin-model-and-its-partition-function"&gt;A vector-spin model and its partition function&lt;/h2&gt;
&lt;p&gt;We start from the following Hamiltonian (or energy function) of a classical vector spin system of $N$ spins in a site-dependent external magnetic field,&lt;/p&gt;
\begin{equation}
E = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i}, \label{eq:vectrandomising}
\end{equation}&lt;p&gt;where both $\boldsymbol{\sigma}_{i} = \left[ \sigma_{1}(i), \sigma_{2}(i), \ldots, \sigma_{D}(i) \right]$ and $\boldsymbol{h}_{i} = \left[ h_{1}(i), h_{2}(i), \ldots, h_{D}(i) \right]$ are vectors of dimension $D$. The coupling matrix $\boldsymbol{J}$ is assumed to be traceless and symmetric but can otherwise have real elements with both negative and positive signs. We take the vector degrees of freedom $\boldsymbol{\sigma}_{i}$ to be constrained by a set of $N$ constraints&lt;/p&gt;
\begin{equation}
\lVert \boldsymbol{\sigma}_{i} \rVert _{2}^{2} = \sum_{a=1}^{D} \sigma_{a}^{2}(i) = D, \quad i = 1,2,\ldots,N,
\end{equation}&lt;p&gt;so that their magnitudes equal $\sqrt{D}$. One can picture the classical spin degrees of freedom as arrows rotating along the surface of $(D-1)$-dimensional spheres at every site.&lt;/p&gt;
&lt;figure class="spin-system-figure"&gt;
&lt;img src="spin_system.png" alt="Cartoon of vector-spin system"&gt;
&lt;figcaption&gt;Cartoon of vector-spin system&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;In statistical mechanics, the model Eq. \eqref{eq:vectrandomising} is known as a
whose familiar small-$D$ cases include the
($D=1$), the
($D=2$), and the
($D=3$). For infinite-dimensional spins $D \to \infty$, one can show that the system approaches the
. The model defined by \eqref{eq:vectrandomising} can also be regarded as a vector generalization of
or
or disordered
(but with just a single sample of non-local couplings instead of an underlying probability distribution). Similar models also appear in recent studies on higher-dimensional generalizations of spin glass models&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;The partition function for our spin system looks like:&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} &amp;\left( \beta, J_{ij}, \{ \boldsymbol{h}_{i} \} \right) \nonumber \\
&amp;= \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{1}(1) \cdots \mathrm{d}\sigma_{D}(N) \nonumber \\
&amp; \qquad \times \prod_{j=1}^{N} \delta \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) \nonumber \\
&amp; \qquad \times \exp \left[ \beta \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} + \beta \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i} \right] \label{eq:fullpartfun}
\end{align}&lt;p&gt;where we have made all dependencies explicit. This looks absolutely mental. We somehow need to find a way to do $N \times D$ integrals while taking into account all the constraints and interactions.&lt;/p&gt;
&lt;h2 id="peeking-into-a-physicists-bag-of-tricks"&gt;Peeking into a physicist&amp;rsquo;s bag of tricks&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s first of all get rid of the explicit Dirac delta functions by substituting their complex integral representations&lt;/p&gt;
\begin{align}
\delta \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) = \frac{\beta}{2 \pi i} \int_{-i\infty}^{i\infty} \mathrm{d} t_{j} \exp \left[ \beta t_{j} \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) \right]
\end{align}&lt;p&gt;so that&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} &amp;= \left(\frac{\beta}{2 \pi i}\right)^{N} \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{1}(1) \cdots \mathrm{d}\sigma_{D}(N) \nonumber \\
&amp; \times \int_{-i\infty}^{i\infty} \cdots \int_{-i\infty}^{i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \, \exp \left( \beta D \sum_{j=1}^{N} t_{j} \right)\nonumber \\
&amp; \times \prod_{\alpha=1}^{D} \exp \left[ -\beta \sum_{i,j=1}^{N} \left(t_{j}\delta_{ij}-J_{ij}\right) \; \sigma_{\alpha}(i) \sigma_{\alpha}(j) + \beta \sum_{i=1}^{N} h_{\alpha}(i) \sigma_{\alpha}(i) \right] \nonumber
\end{align}&lt;p&gt;Great, even more integrals. The next frustrating trick involves writing the number 1 as a judiciously chosen exponential,&lt;/p&gt;
\begin{align}
\exp \left( \beta \sum_{j=1}^{N} a \left( D - \lVert \boldsymbol{\sigma}_{j} \rVert _{2}^{2} \right) \right) = 1,
\end{align}&lt;p&gt;for some arbitrary constant $a$, which, inside the integral, indeed evaluates to $\exp (0) = 1$ because of the constraints. Inserting this expression gives&lt;/p&gt;
\begin{align}
&amp;Z_{N}^{(D)} = \left(\frac{\beta}{2 \pi i}\right)^{N} \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{1}(1) \cdots \mathrm{d}\sigma_{D}(N) \nonumber \\
&amp; \times \int_{-i\infty}^{i\infty} \cdots \int_{-i\infty}^{i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \, \exp \left( \beta D \sum_{j=1}^{N} \left( t_{j} + a\right) \right)\nonumber \\
&amp; \times \prod_{\alpha=1}^{D} \exp \left[ -\beta \sum_{i,j=1}^{N} \left( \left( t_{j} + a \right) \delta_{ij}-J_{ij}\right) \; \sigma_{\alpha}(i) \sigma_{\alpha}(j) + \beta \sum_{i=1}^{N} h_{\alpha}(i) \sigma_{\alpha}(i) \right] \nonumber
\end{align}&lt;p&gt;Next, we&amp;rsquo;d like to swap the order of the $\mathrm{d}\sigma_{a}(j)$ and $\mathrm{d}t_{j}$ integrations to start integrating. But we are only allowed to do this if we assume $a$ to be a sufficiently large positive real number. Why? Essentially, we are deforming the contours of the complex integrals sufficiently far to the right such that the real part the quadratic form appearing in the exponential is positive definite, see e.g.
.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s go ahead and assume that everything is fine. We swap integrals and do a change of variables $t_j \to t_j + a$ so that&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} &amp;= \left(\frac{\beta}{2 \pi i}\right)^{N} \int_{a-i\infty}^{a+i\infty} \cdots \int_{a-i\infty}^{a+i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \\
&amp; \times \exp \left( \beta D \sum_{j=1}^{N} t_{j} \right)\nonumber \prod_{\alpha=1}^{D} I_{\alpha} \left( \beta, \{ t_{j} \}, \{ h_{\alpha}(i) \} \right)\nonumber
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
I_{\alpha} &amp;\left( \beta, \{ t_{j} \}, \{ h_{\alpha}(i) \} \right) = \int_{-\infty}^{\infty} \cdots \int_{-\infty}^{\infty} \, \mathrm{d}\sigma_{\alpha}(1) \cdots \mathrm{d}\sigma_{\alpha}(N) \nonumber \\
&amp; \times \exp \left[ -\beta \sum_{i,j=1}^{N} \left( t_{j} \delta_{ij}-J_{ij}\right) \; \sigma_{\alpha}(i) \sigma_{\alpha}(j) + \beta \sum_{i=1}^{N} h_{\alpha}(i) \sigma_{\alpha}(i) \right]\nonumber
\end{align}&lt;p&gt;Notice how the integrals have kind of factorized over the vector dimension: for every $\alpha$-component we can evaluate an $N$-dimensional Gaussian integral with a linear term. The $I_{\alpha}$ functions depend on the sources $\{ \boldsymbol{h}_{i} \}$ indexed along local dimension instead of spin. Introducing the symmetric $N \times N$ matrix $V_{ij} = t_{j} \delta_{ij}-J_{ij}$, we can evaluate the Gaussian integrals and find&lt;/p&gt;
\begin{align}
I_{\alpha} &amp;\left( \beta, \{ t_{j} \}, \{ h_{\alpha}(i) \} \right) = \left( \frac{\pi}{\beta} \right)^{N/2} \left[ \det \left( \boldsymbol{V} \right) \right]^{-1/2} \exp \left(\frac{\beta}{4} \boldsymbol{h}_{\alpha}^{T} \boldsymbol{V}^{-1} \boldsymbol{h}_{\alpha} \right) \nonumber
\end{align}&lt;p&gt;where $\boldsymbol{h}_{\alpha} = \left[ h_{\alpha}(1), h_{\alpha}(2), \ldots, h_{\alpha}(N) \right]$ denote $N$-dimensional vectors. The expression for the partition function becomes&lt;/p&gt;
\begin{align}
&amp;Z_{N}^{(D)} = \left(\frac{\beta}{2 \pi i}\right)^{N} \left( \frac{\pi}{\beta} \right)^{DN/2} \int_{a-i\infty}^{a+i\infty} \cdots \int_{a-i\infty}^{a+i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \nonumber \\
&amp; \times \exp \left( D \left( \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) \right) \right) \exp \left( \frac{\beta}{4} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \right) \nonumber
\end{align}&lt;p&gt;where we have introduced the matrix notation $\boldsymbol{H} \in \mathbb{R}^{N \times D}$ to group the vectors $\{ \boldsymbol{h}_{i} \}$.&lt;/p&gt;
&lt;h2 id="steepest-descent-hunting-for-the-saddle"&gt;Steepest descent: hunting for the saddle&lt;/h2&gt;
&lt;p&gt;But there&amp;rsquo;s still $N$ complex integrals over the auxiliary variables $\{ t_{j} \}$ left to do. Can we avoid doing them? Maybe. Let&amp;rsquo;s rewrite our partition function as&lt;/p&gt;
\begin{align}
Z_{N}^{(D)} = \left(\frac{\beta}{2 \pi i}\right)^{N} &amp;\left( \frac{\pi}{\beta} \right)^{DN/2} \int_{a-i\infty}^{a+i\infty} \cdots \int_{a-i\infty}^{a+i\infty} \, \mathrm{d}t_{1} \cdots \mathrm{d}t_{N} \, \mathrm{e}^{D \varphi \left(\boldsymbol{t} \right) } \label{eq:partfunsteep}
\end{align}&lt;p&gt;with&lt;/p&gt;
\begin{align}
\varphi \left(\boldsymbol{t}; \beta, J_{ij} \right) = \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) + \frac{\beta}{4D} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \label{eq:varphi}
\end{align}&lt;p&gt;As $D \to \infty$, the
suggests that the partition function will be dominated by its largest contribution, i.e. in the neigbourhood of the maximum $\varphi(\boldsymbol{t^{*}})$ along the integration paths.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Hmm, this doesn&amp;rsquo;t quite seem right #1:&lt;/strong&gt; What does $D \to \infty$ even look like for the last term in Eq. \eqref{eq:varphi}? What does it mean for the input vectors $\{ \boldsymbol{h}_{i} \}$ to become infinite-dimensional? Good points, but let&amp;rsquo;s carry on.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;The saddle-point values $\boldsymbol{t^{*}}$ are obtained from the set of stationary conditions&lt;/p&gt;
\begin{align}
\frac{\partial \varphi \left( \boldsymbol{t} \right)}{\partial t_j} \Biggr\rvert_{t_j = t^{*}_{j}} = 0, \qquad j=1,\ldots,N \label{eq:statcond}
\end{align}
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Hmm, this doesn&amp;rsquo;t quite seem right #2:&lt;/strong&gt; In the single-variable case,
argues that $\varphi (t)$ is analytic for $\mathrm{Re}(t)&gt;0$ and that we should consider $\varphi (t)$ first for $t$ real and positive. For positive $\beta$ and non-zero magnetic field, the function tends to plus infinity as $t$ tends to either zero or infinity. Thus in between $\varphi(t)$ must have a &lt;em&gt;minimum&lt;/em&gt; at some positive value $t^{*}$ of $t$. Since $\varphi''(t) &gt; 0$ there is also only one such minimum. If we take the constant $a$ in the integral limits to be $t^{*}$, then along the (imaginary) integration path $\varphi (t)$ has a &lt;em&gt;maximum&lt;/em&gt; at $t=t^{*}$. We naively assume that this kind of saddle-point reasoning transfers to our case in several complex variables with $\varphi : \mathbb{C}^{N} \to \mathbb{C}$ where the equivalent of $\mathrm{Re}(t)&gt;0$ is to try to steer clear of the singularity at $\det \left( \boldsymbol{V} \right)=0$. We will check the numerical behaviour of our $\varphi$-function in
.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Expanding $\varphi$ around $\boldsymbol{t^{*}}$ and then taking the logarithm of Eq. \eqref{eq:partfunsteep} leads to&lt;/p&gt;
\begin{align}
\ln Z_{N}^{(D)} = \frac{DN}{2} \ln \left( \frac{\pi}{\beta} \right) + D \varphi \left( \boldsymbol{t^{*}} \right) + \ln R \nonumber
\end{align}&lt;p&gt;where we have collected all higher-order contributions and remaining nastiness in $R$. Following
, the free energy in the limit of large local dimension $D \to \infty$ then becomes&lt;/p&gt;
\begin{align}
-\beta f_{N}^{(\infty)} = \lim_{D \to \infty} D^{-1} \ln \left( Z_{N}^{(D)} / Z_{N}^{(D)}(0) \right) \nonumber
\end{align}&lt;p&gt;where&lt;/p&gt;
\begin{align}
Z_{N}^{(D)}(0) = \left( \left(\pi\right)^{D/2} D^{(D-1)/2} / \Gamma \left(D/2\right) \right)^{N} \nonumber
\end{align}&lt;p&gt;is a normalization factor&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; accounting for the surface areas of the $(D-1)$-dimensional spheres with radius $\sqrt{D}$ associated to each and every spin degree of freedom. After applying
to the $\Gamma$-function in the normalization factor and doing some algebra, we end up with&lt;/p&gt;
\begin{align}
\boxed{-\beta f_{N}^{(\infty)} = - \frac{N}{2} - \frac{N}{2} \ln \left( 2\beta \right) + \varphi \left( \boldsymbol{t^{*}} \right)} \label{eq:afe}
\end{align}&lt;p&gt;where we have dropped the last term $\lim_{D \to \infty} D^{-1} \ln R$ assuming it tends to zero. Since $\varphi \left( \boldsymbol{t^{*}} \right) \propto N$, the last term actually also survives the limit $N \to \infty$.&lt;/p&gt;
&lt;h2 id="taking-stock-of-what-we-have-done"&gt;Taking stock of what we have done&lt;/h2&gt;
&lt;p&gt;We have derived a closed-form expression Eq. \eqref{eq:afe} for the approximate free energy of a vector-spin model in the limit of large local spin dimension. Let us take a brief moment to reflect on what we have done and touch on some tangential points.&lt;/p&gt;
&lt;h4 id="questioning-steepest-descent-and-the-large--limit"&gt;Questioning steepest descent and the large-$D$ limit&lt;/h4&gt;
&lt;p&gt;The result \eqref{eq:afe} is only sensible if steepest descent is a valid thing to do, which depends on how outrageous the landscape defined by the $\varphi$-function \eqref{eq:varphi} really is. More practically, we will also never &lt;em&gt;really&lt;/em&gt; let the vector-spin dimension $D$ tend towards infinity since our goal is to implement a numerical attention-like neural network module. So large but finite vector dimensions better behave as if they were sufficiently close to infinity. We will find out in
to what extent these assumptions are valid in practice.&lt;/p&gt;
&lt;h4 id="energy-based-models-and-effective-energy-functions"&gt;Energy-based models and effective energy functions&lt;/h4&gt;
&lt;p&gt;Let us take another look at our model&amp;rsquo;s partition function \eqref{eq:fullpartfun} from an energy-based perspective. For ease of notation, let us call the model parameters $\theta \equiv \{ J_{ij} \}$, the spins $\sigma \equiv \{ \boldsymbol{\sigma}_{i} \}$, and the external magnetic fields $h \equiv \{ \boldsymbol{h}_{i} \}$. We can schematically write our model&amp;rsquo;s partition function as&lt;/p&gt;
\begin{align}
Z_{\theta} \left( h \right) = \int \mathrm{d} \sigma \ \mathrm{e}^{ - E_{\theta}\left( \sigma, h \right) }
\end{align}&lt;p&gt;where $E_{\theta}\left( \sigma, h \right)$ denotes the energy function Eq. \eqref{eq:vectrandomising}. If we now introduce an energy-based model $p_{\theta} \left( \sigma, h \right) = \mathrm{e}^{-E_{\theta}\left( \sigma, h \right)} / Z_{\theta}$, we can define the marginal distribution&lt;/p&gt;
\begin{align}
p_{\theta} \left( h \right) = \frac{\int \mathrm{d} \sigma \ \mathrm{e}^{-E_{\theta}\left( \sigma, h \right)}}{Z_{\theta}} = \frac{\mathrm{e}^{-E_{\theta}\left( h \right)}}{Z_{\theta}} \label{eq:ph}
\end{align}&lt;p&gt;where the applied magnetic fields act as observables and the spins as latent variables. The effective energy $E_{\theta}\left( h \right)$ equals $E_{\theta}\left( h \right) = - \log \int \mathrm{d} \sigma \ \mathrm{e}^{-E_{\theta}\left( \sigma, h \right)} \approx - \log Z^{\ast}_{\theta} \left( h \right)$, where we have used the steepest-descent approximation for the integral. Taking the logarithm of Eq. \eqref{eq:ph}, we find that $\log p_{\theta} \left( h \right) \approx \log Z^{\ast}_{\theta} \left( h \right) - \log \int \mathrm{d} h \ Z^{\ast}_{\theta} \left( h \right)$.&lt;/p&gt;
&lt;h4 id="spin-glasses-and-mean-field-approximation"&gt;Spin glasses and mean-field approximation&lt;/h4&gt;
&lt;p&gt;Ordered systems have a long history in statistical mechanics. Couplings in these models often encode a translation-invariant lattice geometry, e.g. nearest-neighbour interactions between spins living on a $d$-dimensional hypercubic lattice. One reason for this focus is practical: the regularity in these systems enables mathematical physicists to deploy all kinds of tricks and make progress towards some kind of understanding. In contrast, disordered systems, like spin glasses, are a mess and studying them is all about
. From the perspective of spin glasses, we can summarize our approach as follows: we want to arrive at an approximate yet tractable mean-field spin-glass model where its couplings are treated as parameters learned from data&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Fully-connected models like Sherrington-Kirkpatrick spin-glass models (or Eq. \eqref{eq:vectrandomising}) naturally lead to mean-field theory because the couplings $J_{ij}$ encode long-range interactions where every other spin is just a hop away, see e.g.
. Intuitively, all-to-all interactions correspond to the mean-field limit of infinite spatial dimension. To see this, consider a spin in a local nearest-neighbour lattice model getting ever more neighbours as the spatial dimension grows: the notion of nearest neighbours melts away and all spins effectively become connected to each other&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;. Fully-connected non-local couplings and the limit of infinite spatial dimension are two sides of the same mean-field coin.&lt;/p&gt;
&lt;h1 id="implementing-approximate-free-energy-minimization"&gt;Implementing approximate free-energy minimization&lt;/h1&gt;
&lt;p&gt;In this section, we turn the equations of the previous section into the algorithmic backbone of a differentiable vector-spin model. We begin by sketching an approximate free-energy minimization algorithm. We then show how to wrap around the spin model to turn it into an attention module.&lt;/p&gt;
&lt;h2 id="the-algorithm-bold-moves-on-a-tricky-landscape"&gt;The algorithm: bold moves on a tricky landscape&lt;/h2&gt;
&lt;p&gt;Our goal is to compute the steepest-descent approximation of our model&amp;rsquo;s partition function in a differentiable way. Essentially, we need to solve the set of equations&lt;/p&gt;
\begin{align}
\frac{\partial \varphi \left( \boldsymbol{t} \right)}{\partial t_j} \Biggr\rvert_{t_j = t^{*}_{j}} = 0, \qquad j=1,\ldots,N
\end{align}&lt;p&gt;which corresponds to finding a value $\boldsymbol{t^{*}} = \mathrm{argmin}_{\boldsymbol{t}} \varphi \left( \boldsymbol{t} \right)$ for which the scalar function&lt;/p&gt;
\begin{align}
\varphi \left(\boldsymbol{t}; \beta, J_{ij} \right) = \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) + \frac{\beta}{4D} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \nonumber
\end{align}&lt;p&gt;attains its minimum, or, equivalently, we need to solve for the root of $\nabla \varphi \left( \boldsymbol{t} \right)$.&lt;/p&gt;
&lt;h4 id="initialization-and-normalization"&gt;Initialization and normalization&lt;/h4&gt;
&lt;p&gt;Until now we have not been explicit about the values of the couplings $\boldsymbol{J}$ and inputs $\boldsymbol{H}$. If we want to implement any of this, we have to be more careful. Recall that the energy function of our model looks like&lt;/p&gt;
\begin{equation}
E = - \sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} - \sum_{i=1}^{N} \boldsymbol{h}_{i} \cdot \boldsymbol{\sigma}_{i}
\end{equation}&lt;p&gt;where all spins $\boldsymbol{\sigma}_{i}$ are fixed to norm $\sqrt{D}$. We&amp;rsquo;d like this energy to remain linearly proportional to the the number of lattice sites. Numerically, we observe that stable root-finding is possible when initializing the couplings according to
&lt;/p&gt;
\begin{equation}
J_{ij} \sim \mathcal{N} (0, 1/\sqrt{ND} )
\end{equation}&lt;p&gt;
The factor $1/\sqrt{N}$ can be explained from spin-glass mean-field theory&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt; whereas the $1/\sqrt{D}$ factor follows from additionally normalizing with respect to the vector dimension to ensure $\sum_{i,j=1}^{N} J_{ij} \; \boldsymbol{\sigma}_{i} \cdot \boldsymbol{\sigma}_{j} \sim \mathcal{O}(N)$. One strategy to normalize the inputs $\boldsymbol{H}$ is to feed them into a layer normalization layer so that $\left\lVert \boldsymbol{h}_{i} \right\rVert \sim \mathcal{O}(\sqrt{D})$ and then explicitly dividing by $\sqrt{D}$ to make them $\mathcal{O}(1)$. A practical consequence of these initialization and normalization choices at the level of the energy function is that the $\varphi$-function changes to&lt;/p&gt;
\begin{align}
\varphi \left(\boldsymbol{t}; \beta, J_{ij} \right) = \beta \sum_{j=1}^{N} t_{j} - \frac{1}{2} \log \det \left( \boldsymbol{V} \right) + \frac{\beta}{4} \mathrm{Tr} \left( \boldsymbol{H}^{T} \boldsymbol{V}^{-1} \boldsymbol{H} \right) \label{eq:varphinorm}
\end{align}&lt;p&gt;where the prefactor in the last term changed since we decided on explicitly dividing the layer-normalized $\boldsymbol{H}$ by $1/\sqrt{D}$.&lt;/p&gt;
&lt;h4 id="implicit-layers-for-steepest-descent-root-finding"&gt;Implicit layers for steepest-descent root-finding&lt;/h4&gt;
&lt;p&gt;Let&amp;rsquo;s now find the root of the gradient of $\varphi$ in a differentiable way by combining
with a black-box root-finding algorithm like
, which requires access to both a function (the gradient of $\varphi$) and its gradient (the Jacobian of the gradient of $\varphi$). We could rely on automatic differentiation to calculate these gradients, but we just as well exploit the fact that we have an analytical expression Eq. \eqref{eq:varphinorm}. Grabbing a coffee and peeking at the
, we can figure out what happens&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;when we wiggle around $t_{i}$ (the gradient vector at $\boldsymbol{t}$)&lt;/li&gt;
&lt;/ul&gt;
\begin{align}
\left[ \nabla \varphi \left( \boldsymbol{t} \right) \right]_{i} = \beta - \frac{1}{2} \left[ \boldsymbol{V}^{-1} \right]_{ii} - \frac{\beta}{4} \left[ \boldsymbol{V}^{-T} \boldsymbol{H} \boldsymbol{H}^{T} \boldsymbol{V}^{-T} \right]_{ii} \nonumber
\end{align}&lt;ul&gt;
&lt;li&gt;when we wiggle around both $t_{i}$ and $t_{j}$ (the symmetric Hessian matrix at $\boldsymbol{t}$)&lt;/li&gt;
&lt;/ul&gt;
\begin{align}
\left[ \boldsymbol{J}(\nabla \varphi \left( \boldsymbol{t} \right)) \right]_{ij} = \frac{1}{2} &amp;\left[ \boldsymbol{V}^{-1} \odot \boldsymbol{V}^{-T} \right]_{ij} \nonumber \\
&amp;+ \frac{\beta}{4} \left[ \boldsymbol{V}^{-T} \boldsymbol{H} \boldsymbol{H}^{T} \boldsymbol{V}^{-T} \boldsymbol{V}^{-T} \odot \boldsymbol{I} \right]_{ij} \nonumber \\
&amp;+ \frac{\beta}{4} \left[ \boldsymbol{V}^{-T} \boldsymbol{V}^{-T} \boldsymbol{H} \boldsymbol{H}^{T} \boldsymbol{V}^{-T} \odot \boldsymbol{I} \right]_{ij} \nonumber
\end{align}&lt;p&gt;Given an initial guess $\boldsymbol{t_{0}} \in \mathbb{R}^{N}_{&gt;0}$ and input data $\boldsymbol{H} \in \mathbb{R}^{N \times D}$, we can now construct a differentiable root-solver which returns $\boldsymbol{t^{*}}$. It is important to keep in mind that the stationary value $\boldsymbol{t^{*}}$ actually depends on $\left(\beta, \boldsymbol{J}, \boldsymbol{H} \right)$ implicitly. Since we make use of implicit layers within an automatic differentation framework, these dependencies are kept track of and are included in the computational graph.&lt;/p&gt;
&lt;h4 id="fun-with-free-energies"&gt;Fun with free energies&lt;/h4&gt;
&lt;p&gt;Let&amp;rsquo;s test the algorithm by initializing a random vector-spin model and applying a random magnetic field at every site. For visualization purposes, we restrict the auxiliary variables to be effectively one-dimensional by defining $\boldsymbol{t} = t \boldsymbol{1}_{N}$ with just a single scalar parameter $t \in \mathbb{R}_{&gt;0}$. We can probe a &lt;code&gt;VectorSpinModel&lt;/code&gt; and get the approximate free energy for a given set of parameters and inputs by running the following script:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;afem.models&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;VectorSpinModel&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VectorSpinModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;t0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;afe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;t0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_afe&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;afe&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Inside the forward pass, the root $\boldsymbol{t^{*}}$ is computed and then fed into Eq. \eqref{eq:afe} to calculate the approximate free energy. We can verify that our algorithm is doing something sensible by sweeping across the auxiliary $t$-values and plotting $\varphi$ and its derivatives:&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-sweep-across-auxiliary-variable"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_4aa84911571ee37e.webp 320w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_f6f0b5f47ed137ab.webp 480w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_f752654317dd2531.webp 640w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_1d_plot_hu_4aa84911571ee37e.webp"
width="640"
height="480"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Sweep across auxiliary variable
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;The region close to $t=0$ looks terrifying. In this regime, $t$ is likely not large enough to overshadow the largest eigenvalue of the couplings so we lose positive definiteness and its nice properties. Let&amp;rsquo;s try to stay away from that region by always initializing $\boldsymbol{t}_{0}$ sufficiently far from it. Depending on the parameters and initial guess provided to the solver, one can of course end up in less favourable landscapes where root-solving can become difficult due to zero gradients or extreme sensitivity to initial conditions. Fortunately, when the root-solving step fails, it tends to fail spectacularly.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s now sweep across inverse temperature $\beta$ to get some intuition. From the analytical expression of the free energy, we can deduce that for small $\beta$ (high temperature) the entropy term reigns while for large $\beta$ (low temperature) the energy terms take over.&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-sweep-across-inverse-temperature"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="alt text"
src="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/beta_sweep.gif"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Sweep across inverse temperature
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Finally, let&amp;rsquo;s lift the one-dimensional restriction on $\boldsymbol{t}$ and plot $\varphi (\boldsymbol{t})$ for two spins. In that case, $\boldsymbol{t}$ is also just two-dimensional so we can still visualize the optimization landscape.&lt;/p&gt;
&lt;p&gt;
&lt;figure id="figure-two-dimensional-auxiliary-variables"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_7b965a283d029a52.webp 320w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_af3a40a7eaf8a93b.webp 480w, https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_cae17aa20904ce69.webp 545w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/phi_2d_plot_hu_7b965a283d029a52.webp"
width="545"
height="459"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Two-dimensional auxiliary variables
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h2 id="the-attention-module-probing-spins-with-data"&gt;The attention module: probing spins with data&lt;/h2&gt;
&lt;p&gt;In the previous section, we showed how to numerically compute the steepest-descent approximation of a vector-spin model&amp;rsquo;s partition function and hence its free energy. Since this approximation is fully differentiable, we can also take derivatives with respect to conjugate variables. Let&amp;rsquo;s use this observation to construct an attention module.&lt;/p&gt;
&lt;h4 id="spin-expectation-values"&gt;Spin expectation values&lt;/h4&gt;
&lt;p&gt;We can calculate spin expectation values or magnetizations from our partition function approximation by differentiating with respect to the applied magnetic fields:&lt;/p&gt;
\begin{align}
\langle \boldsymbol{\sigma}_{i} \rangle = \frac{\mathrm{d} \log Z \left( \boldsymbol{t}, \boldsymbol{H} \right)}{\mathrm{d} \boldsymbol{h}_{i}} = \frac{\partial \varphi}{\partial \boldsymbol{t}} \frac{\partial \boldsymbol{t}}{\partial \boldsymbol{h}_{i}} + \frac{\partial \varphi}{\partial \boldsymbol{h}_{i}} \label{eq:spinevgeneral}
\end{align}&lt;p&gt;If we evaluate the partition function approximation at the stationary point $\boldsymbol{t^{\ast}}$, the first term drops out because $\partial_{\boldsymbol{t}} \varphi \rvert_{\boldsymbol{t}=\boldsymbol{t^{\ast}}} = 0$. Assuming that the matrix $\boldsymbol{V}$ (and hence the couplings $\boldsymbol{J}$) do not depend on the inputs $\boldsymbol{H}$, the spin expectation value boils down to&lt;/p&gt;
\begin{align}
\langle \boldsymbol{\sigma}_{i} \rangle = \frac{\partial \varphi}{\partial \boldsymbol{h}_{i}} = \frac{\beta}{2} \sum_{j} \boldsymbol{V}^{-1}_{ij} \boldsymbol{h}_{j} \label{eq:spinev}
\end{align}&lt;p&gt;which, for every site, is just a weighted sum of inputs. In the language of transformers, Eq. \eqref{eq:spinev} resembles an update step where $\boldsymbol{V}^{-1}$ can be interpreted as a symmetric attention matrix. Expanding the matrix inverse reveals a residual connection as the zero-th order contribution&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Since the couplings are scalars at the level of the energy function Eq. \eqref{eq:vectrandomising}, getting terms to act on the hidden dimension seems to be impossible. But by considering couplings $\boldsymbol{J}(\boldsymbol{H})$ which do depend on inputs, additional terms can appear in Eq. \eqref{eq:spinev} propagating via dependencies in $\boldsymbol{V}$. Instead of calculating these gradients analytically, we should of course just let our automatic differentiation framework compute them for us.&lt;/p&gt;
&lt;h4 id="wrapping-around-the-spin-model"&gt;Wrapping around the spin model&lt;/h4&gt;
&lt;p&gt;At this point, we have done all the heavy lifting. All that remains is to write a wrapper so that we can use our module just like any other explicit attention module:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;afem.attention&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;VectorSpinAttention&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VectorSpinAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (1, 32, 128)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Inside the forward pass of &lt;code&gt;VectorSpinAttention&lt;/code&gt;, (normalized) inputs are sent to an internal &lt;code&gt;VectorSpinModel&lt;/code&gt; which solves for the saddle point $\boldsymbol{t^{*}}$ and then feeds it into the steepest descent partition function to calculate magnetizations according to Eq. \eqref{eq:spinevgeneral}.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s finish this section by discussing some of the peculiarities of our approach:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Stability and symmetry:&lt;/strong&gt; The root-finding is stable as long as $\det \boldsymbol{V} &gt; 0$, which ensures that $\boldsymbol{V}$ is nonsingular and which is garantueed as long as the quadratic form is positive definite. A quadratic form involving a general $\boldsymbol{V}$ (i.e. with nonsymmetric couplings $\boldsymbol{J}$) is positive definite iff its symmetric part has all positive eigenvalues. When this is no longer the case, things tend to blow up.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Scaling:&lt;/strong&gt; Our approach is kind of slow because calculating inverses scales as $\mathcal{O}\left(N^3\right)$. Yet there might be ways to approximate the slow parts of the algorithm similar to how vanilla transformers can be understood to approximate mean-field fixed-point equations&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Lack of permutation invariance:&lt;/strong&gt; Our model is not permutation invariant with the default choice of input-independent couplings: every spin has a role to play.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Input-dependent couplings:&lt;/strong&gt; Because our default model assumes coupling-independent couplings $\boldsymbol{J}$, Eq. \eqref{eq:spinev} features just a &amp;ldquo;token-mixing&amp;rdquo; attention operation. Channel-mixing terms can appear when we consider the physically very weird setup where the couplings are made dependent on the applied magnetic fields. One possible choice could be:
\begin{align}
\boldsymbol{J}(\boldsymbol{H}) = \frac{\tanh \left( \boldsymbol{H} \boldsymbol{Q} \boldsymbol{K}^T \boldsymbol{H}^T \cdot \sqrt{D} \right)}{\sqrt{ND}} \nonumber
\end{align}
where $\boldsymbol{Q}$ and $\boldsymbol{K}$ are linear transformations acting on the hidden dimension and where the scaling factors have been inserted because of the normalization conventions we discussed in
. We hypothesize that additional terms in the spin expectation value Eq. \eqref{eq:spinev} arising from input-dependent couplings might be related to channel-mixing feed-forward networks in transformer modules.&lt;/li&gt;
&lt;/ul&gt;
&lt;h4 id="comparison-with-vanilla-transformers"&gt;Comparison with vanilla transformers&lt;/h4&gt;
&lt;p&gt;In this final section, let&amp;rsquo;s summarize our approach on a high level by visually comparing it to vanilla transformers and deep equilibrium approaches.&lt;/p&gt;
&lt;img src="arch_vanilla_deq.png" alt="Vanilla transformer and deep equilibrium transformer" width="500px"/&gt;
&lt;p&gt;The vanilla transformer
(left above) is an explicit architecture which processes input sequences sequentially through a stack of transformer modules. Deep equilibrium transformers
(right above) compute the output of a transformer module by implicitly solving for the fixed point of $f(z, x) = z$ where $f$ denotes the explicit transformer module. Data is repeatedly inserted by adding it to the current iteration of $z$ inside the module until fixed-point convergence. The converged fixed point is considered the output of the module. Backpropagation through the iterations of the solver is avoided by using the implicit function theorem to calculate gradients directly at the equilibrium point. Instead of a stack of layers, there&amp;rsquo;s just a single layer.&lt;/p&gt;
&lt;p&gt;But deep equilibrium transformers still treat the transformer module as a black box. In
we looked for a physical spin-model interpretation of the deep equilibrium fixed-point procedure (left below). We argued how the update step of a vanilla transformer module resembled mean-field fixed-point equations of a vector-spin model, explaining the successful pattern of token-mixing, residual connections, normalization layers, and feed-forward or channel-mixing modules from a physical spin systems&amp;rsquo; perspective.&lt;/p&gt;
&lt;img src="arch_dia_afem.png" alt="Deep implicit attention and approximate free-energy minimization" width="600px"/&gt;
&lt;p&gt;In this work (right above), we continued on the path of spin expectation values but replaced solving mean-field fixed-point equations with directly taking derivatives of the steepest-descent partition function of a particular class of vector-spin models. The fixed-point procedure is replaced with a root-solving step to determine the steepest-descent partition function. The structure of our module&amp;rsquo;s output reveals the same successful transformer-like pattern of token-mixing (attention) and channel-mixing (feed-forward) interspersed with normalization layers and residual connections.&lt;/p&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;In this post, we introduced transformer modules as wrappers around statistical-mechanical vector-spin models. We used implicit layers to construct a class of approximate yet tractable vector-spin models whose couplings act as parameters that can be learned from data. We showed how these models can act as transformer-like attention modules by routing inputs to applied magnetic fields and returning spin expectation values derived from their steepest-descent partition function.&lt;/p&gt;
&lt;p&gt;By zooming out and approaching transformers from a tangential, statistical-mechanical point of view, we were able to develop a physical intuition of transformers that seems hard to arrive at when restricting oneself to perturbing explicit neural network architectures. Recognizing transformer modules as spin models in disguise might not only unify architectural variations but also elucidate the high-level architectural convergence and empirical success of transformers in deep learning.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021afem,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Transformers from Spin Models: Approximate Free Energy Minimization},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {October},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;We could have turned to the mean-field free energies associated with the adaptive TAP equations discussed in
, but we decided on attacking the problem from the steepest-descent angle on the full partition function.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;For example, see &lt;em&gt;
&lt;/em&gt; and &lt;em&gt;
&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;The original 1968 paper has a small typo here: the $\nu$ in the paper&amp;rsquo;s Eq. (23) should be $\nu^{1/2}$ for the surface area of a $\nu-1$-dimensional sphere with radius $R=\nu^{1/2}$ embedded in $\nu$ dimensions. Using the paper&amp;rsquo;s formula, an annoying $\ln \nu$ term won&amp;rsquo;t cancel out in the limiting free energy calculation.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;In contrast to spin glasses however, we do not (yet want to go full Bayesian and) treat the couplings as drawn from some kind of probability distribution. For now, we settle for obtaining point estimates of model parameters.&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;By promoting sparseness in the couplings, a model might become less mean-field-y, which might be one of the reasons behind the sucess of scaled &lt;code&gt;softmax&lt;/code&gt; attention in vanilla transformers.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;From
: &lt;em&gt;The mean-field limit to infinite dimensions or long-range interaction introduces a new large scale. To make the thermodynamic limit meaningful the dependence of the energy on this new large scale must be compensated by rescaling the non-local spin exchange so that the energy remains linearly proportional to the volume or the number of lattice sites (spins).&lt;/em&gt;&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;We can expand the right-hand side using a
to find
&lt;/p&gt;
\begin{align}
\boldsymbol{V}^{-1} &amp;= \left( \mathrm{diag} ( \boldsymbol{t} ) - \boldsymbol{J} \right)^{-1} = \sum_{k=0}^{\infty} \left( \mathrm{diag} \left( \boldsymbol{t}^{-1} \right) \boldsymbol{J} \right)^{k} \mathrm{diag} \left( \boldsymbol{t}^{-1} \right) \nonumber
\end{align}&lt;p&gt;
which converges if the largest absolute value of the eigenvalues of the matrix inside the power-brackets is less than 1. So the spin expectation value looks like a sum of contributions that mix and weigh inputs of different sites.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;As discussed previously in
. In that setting, calculating inverses was sidestepped by approximating part of the solution with a feed-forward neural network.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms</title><link>https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/</link><pubDate>Wed, 07 Apr 2021 15:17:17 +0100</pubDate><guid>https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/</guid><description>&lt;hr&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Update (November 2021):&lt;/strong&gt; Consider reading
for a high-level overview of some of the ideas outlined in this post.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;✨ Code: A reference PyTorch implementation of the ideas outlined in this blog post is available in the repository
. Comments welcome.&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;To explore progress beyond the cage of softmax attention, we have previously looked at energy-based perspectives on attention mechanisms:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;The main take-away so far has been that you can think of softmax attention as implementing a single, big gradient step of some energy function and that training transformers is akin to meta-learning how to best tune a stack of attention and feed-forward modules to perform well on some auxiliary (meta-)task(s). But what can an energy-based perspective actually provide beyond quaint and hand-wavy statements like &lt;em&gt;implicit energy landscapes are sculpted every time you train a transformer&lt;/em&gt;?&lt;/p&gt;
&lt;p&gt;In this post, we approach attention in terms of the &lt;em&gt;collective response of a statistical-mechanical system&lt;/em&gt;. Attention is interpreted as an inner-loop fixed-point optimization step which returns the approximate response of a system being probed by data. This response is a differentiable compromise between the system&amp;rsquo;s internal dynamics and the data it&amp;rsquo;s being exposed to. To better respond to incoming data, outer-loop optimization steps can nudge the interactions and the self-organizing behaviour of the system.&lt;/p&gt;
&lt;p&gt;To implement our proposal, we combine old ideas and new technology to construct a family of attention mechanisms based on fixed points. We use
to solve a set of self-consistent mean-field equations of a vector generalization of the random Ising spin-model. By approximating these equations, we arrive at simplified update steps which mirror the vanilla transformer architecture. We conclude by showing how transformers can be understood from a mean-field theory perspective.&lt;/p&gt;
&lt;h1 id="mean-field-theory-for-disordered-systems"&gt;Mean-field theory for disordered systems&lt;/h1&gt;
&lt;p&gt;In physics,
is an approximation method to study models made up of many individual degrees of freedom that interact with each other. Mean-field theory approximates the effect of the environment on any given individual degree of freedom by a single, averaged effect, and thus reduces a many-body problem to an (effective) one-body problem. This is a drastic approximation. Whether mean-field theory a sensible thing to do depends on the problem and the properties of your variational ansatz.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Mean-field theory &amp;amp; variational methods:&lt;/strong&gt; From the point of view of variational methods, mean-field theory tries to approximate a complicated object (like a partition function of a statistical-mechanical system) by wiggling around the parameters of a tractable variational ansatz to get as close as possible to the real thing. You can picture this process as projecting down a complicated object living in a high-dimensional space to its shadow in an easier-to-handle subspace (&lt;em&gt;I can hear a mathematician fainting in the background&lt;/em&gt;). This effectively reduces the problem to optimizing for the best possible approximation within your variational class. A lot of mean-field machinery also shows up in probability theory, statistics, and machine learning where it appears in belief propagation, approximate variational inference, expectation propagation, etc.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;In the next two subsections, we introduce random Ising models and sketch a physics-inspired approach to deal with disordered models using mean-field theory. In
we will then generalize these results to vector spin degrees of freedom and propose two flavours of attention models.&lt;/p&gt;
&lt;h2 id="random-ising-models-or-boltzmann-machines-or-"&gt;Random Ising models (or Boltzmann machines or &amp;hellip;)&lt;/h2&gt;
&lt;p&gt;The random Ising model is a prototypical model in the study of spin glasses and disordered random systems, where it is often referred to as the
, famous for its replica-method solution by Giorgio Parisi in 1979. Its energy function with external field for $N$ classical, binary spin variables looks like&lt;/p&gt;
\begin{equation}
E = \sum_{i,j} J_{ij} S_{i} S_{j} + \sum_{i} x_{i} S_{i}, \label{eq:randomising}
\end{equation}&lt;p&gt;where the couplings $J_{ij}$ between degrees of freedom are randomly distributed according to some probability distribution and self-interactions are absent ($J_{ii} = 0$). The external magnetic fields $x_{i}$ provide a preferential direction of alignment at every local site. Since the elements in the coupling matrix can have both negative and positive signs, the system is said to have both frustrated ferro- as well as antiferromagnetic couplings. The model defined by \eqref{eq:randomising} is also known as a
or a
.&lt;/p&gt;
&lt;p&gt;In contrast with disordered systems, we expect the couplings in the context of artificial neural networks to no longer be randomly drawn from a distribution but to reflect structure and organization between spins after being exposed to data. The system should self-organize in order to better respond to incoming data.&lt;/p&gt;
&lt;p&gt;A cartoon of a spin configuration of a 7-spin system looks something like
&lt;img src="binary_ising.png" alt="Random Ising model configuration with binary spins" width="250px"/&gt;
where we have only drawn the connections strongest in absolute value. It&amp;rsquo;s helpful to think of classical spin degrees of freedom as arrows. For vector spins, we can imagine lifting the up/down restriction and letting the arrows rotate freely.&lt;/p&gt;
&lt;h2 id="adaptive-thoulessandersonpalmer-mean-field-theory"&gt;Adaptive Thouless&amp;ndash;Anderson&amp;ndash;Palmer mean-field theory&lt;/h2&gt;
&lt;p&gt;One of the approaches physicists have come up with to tackle disordered random systems with pairwise interactions like those in Eq. \eqref{eq:randomising} is
. The TAP equations improve mean-field theory results by adding a so-called &lt;em&gt;Onsager self-correction term&lt;/em&gt; calculated from the couplings&amp;rsquo; distribution.&lt;/p&gt;
&lt;p&gt;
adapted this method to probabilisic modeling to be able to deal with scenarios where the distribution of the couplings between spins is not known a priori. To compensate for the lack of knowledge of the couplings distribution, they introduced a self-consistent computation to adapt the Onsager correction to the &lt;em&gt;actual&lt;/em&gt; couplings using the cavity method and linear response relations. We will sketch the adaptive TAP approach below but refer to
and
for more details and derivations.&lt;/p&gt;
&lt;h3 id="single-site-partition-function-from-cavity-method"&gt;Single-site partition function from cavity method&lt;/h3&gt;
&lt;p&gt;The adaptive TAP equations can be derived using the cavity method, where a cavity field distribution is introduced to rewrite the marginal distributions of the spins. The cavity corresponds to the &amp;ldquo;hole&amp;rdquo; left by removing a single spin. By assuming a Gaussian cavity distribution in the large connectivity limit, one can show that the single-site partition function looks like&lt;/p&gt;
\begin{equation}
Z_{0}^{(i)} = \int \mathrm{d} S \ \rho_{i}\left(S\right) \exp \left[ S \left( a_{i} + x_{i} \right) + \frac{V_{i} S^2}{2} \right]
\end{equation}&lt;p&gt;where the $a_i$ denote &lt;em&gt;cavity means&lt;/em&gt; and the $V_i$ &lt;em&gt;cavity variances&lt;/em&gt;. The single-site partition function can be integrated to yield an explicit expression after choosing well-behaved priors $\rho_{i}(S)$ for the spins. For binary spins $S=\pm 1$, we can pick $\rho_{i}(S)=\frac{1}{2}\left( \delta(S-1) + \delta(S+1) \right)$ to find&lt;/p&gt;
\begin{equation}
Z_{0}^{(i)} = \cosh \left( a_{i} + x_{i} \right). \label{eq:partfunbinaryspins}
\end{equation}&lt;h3 id="cavity-means-and-onsager-correction-term"&gt;Cavity means and Onsager correction term&lt;/h3&gt;
&lt;p&gt;The cavity means can be shown to be given by
&lt;/p&gt;
\begin{equation}
a_{i} = \sum_{j} J_{ij} \langle S_{j} \rangle - V_{i} \langle S_{i} \rangle. \label{eq:cavitymean}
\end{equation}&lt;p&gt;where the last term is the &lt;em&gt;Onsager correction term&lt;/em&gt;, a self-correction term for every spin which depends on the cavity variances.&lt;/p&gt;
&lt;h3 id="cavity-variances-and-linear-response"&gt;Cavity variances and linear response&lt;/h3&gt;
&lt;p&gt;The cavity variances are determined self-consistently, i.e. by calculating the same quantity in two different ways and demanding the obtained expressions to be equal. To do this, we introduce the matrix of susceptibilities&lt;/p&gt;
\begin{equation}
\chi_{ij} = \langle S_{i} S_{j} \rangle - \langle S_{i} \rangle \langle S_{j} \rangle = \frac{\partial^2}{\partial x_{i}\partial x_{j}} \log Z_{0}^{(i)}
\end{equation}
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;The susceptibility matrix $\chi_{ij}$ is a covariance matrix and should thus be positive semi-definite, which is criterion for the mean-field solution be consistent. As soon this property is lost, the fixed-point procedure will no longer be stable.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Its diagonal elements $\chi_{ii}$ can be obtained both from the explicit calculation of the spin variances from the partition function&lt;/p&gt;
\begin{equation}
\chi_{ii} = \langle S_{i}^2 \rangle - \langle S_{i} \rangle^2 = \frac{\partial^2}{\partial x_{i}^2} \log Z_{0}^{(i)} \label{eq:chiii}
\end{equation}&lt;p&gt;but also from a linear response calculation assuming fixed $V_i$,&lt;/p&gt;
\begin{align}
\chi_{ij} = \frac{\partial \langle S_{i} \rangle}{\partial x_{j}} = \frac{\partial \langle S_{i} \rangle}{\partial x_{i}} \left( \delta_{ij} + \sum_{k} \left( J_{ik} - V_{k} \delta_{ik} \right) \chi_{kj} \right) \label{eq:chiijlinrespexp}
\end{align}&lt;p&gt;which can be solved for $\chi_{ij}$ to yield
&lt;/p&gt;
\begin{equation}
\chi_{ij} = \left[ \left( \boldsymbol{\Lambda} - \boldsymbol{J} \right)^{-1} \right]_{ij} \label{eq:chiijlinresp}
\end{equation}&lt;p&gt;
where
&lt;/p&gt;
\begin{align}
\boldsymbol{\Lambda} = \mathrm{diag} \left( \Lambda_1, \ldots, \Lambda_{N} \right),\\\\
\Lambda_i = V_i + \left( \frac{\partial \langle S_{i} \rangle}{\partial x_{i}} \right)^{-1}.
\end{align}&lt;p&gt;The cavity variances $V_i$ are then determined by equating \eqref{eq:chiii} to the diagonal elements of \eqref{eq:chiijlinresp} and solving the following consistency condition for $V_i$
&lt;/p&gt;
\begin{equation}
\frac{1}{\Lambda_i - V_i} = \left[ \left( \boldsymbol{\Lambda} - \boldsymbol{J} \right)^{-1} \right]_{ii}. \label{eq:viselfcons}
\end{equation}&lt;p&gt;Given updated values for the cavity means $a_i$ and the cavity variances $V_i$, spin means and spin variances can then be updated as follows:&lt;/p&gt;
\begin{align}
\langle S_{i} \rangle &amp;= \frac{\partial}{\partial x_{i}} \log Z_{0}^{(i)} (x_{i}, a_{i}, V_{i}),\\\\
\langle S_{i}^2 \rangle - \langle S_{i} \rangle^2 &amp;= \frac{\partial^2}{\partial x_{i}^2} \log Z_{0}^{(i)} (x_{i}, a_{i}, V_{i}),
\end{align}&lt;p&gt;These equations reduce to explicit expressions given an explicit expression for $Z_{0}^{(i)}$. For the binary-spin partition function \eqref{eq:partfunbinaryspins} where $S=\pm 1$, we get a set of fixed-point equations for the spin means that look like&lt;/p&gt;
\begin{equation}
\langle S_{i} \rangle = \tanh \left( \sum_{j} J_{ij} \langle S_{j} \rangle - V_{i} \langle S_{i} \rangle + x_{i} \right)
\end{equation}&lt;p&gt;with spin variances $\chi_{ii} = 1 - \langle S_{i} \rangle^2$.&lt;/p&gt;
&lt;h1 id="attention-as-a-fixed-point-method"&gt;Attention as a fixed-point method&lt;/h1&gt;
&lt;p&gt;In this section, we attempt to generalize the mean-field equations obtained in the previous section to random Ising-like models with vector spin degrees of freedom. We then recognize the physical system as an attention model and provide both a slow, explicit implementation and a faster, neural one.&lt;/p&gt;
&lt;h2 id="generalizing-spin-models-to-vector-degrees-of-freedom"&gt;Generalizing spin models to vector degrees of freedom&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s return to our Ising model cartoon and replace the scalar spin degrees of freedom $S_i$ at every site with vectors $\boldsymbol{S}_i \in \mathbb{R}^d$, which we visualize using arrows below&lt;/p&gt;
&lt;img src="featured.png" alt="Random Ising model configuration with vector spins" width="250px"/&gt;
&lt;p&gt;Let&amp;rsquo;s consider a system of $N$ $d$-dimensional spins and let&amp;rsquo;s label site indices with $i,j,\ldots$ and internal vector-space indices with Greek letters $\alpha,\beta,\ldots$. We let the coupling weight matrix become a tensor $\boldsymbol{J}_{ij} = J_{ij}^{\alpha\beta}$ (matrices coupling every pair of sites) and remove self-couplings by enforcing the couplings&amp;rsquo; block-diagonal to be zero. Additionally, we can symmetrize both the internal dimension and the sites to end up with $N(N-1)/2$ times $d(d+1)/2$ effective free parameters for the couplings. If we also turn the external fields into vectors, we obtain a vector generalization of Eq. \eqref{eq:randomising}:&lt;/p&gt;
\begin{equation}
E = \sum_{i,j} \boldsymbol{S}_{i}^{T} \boldsymbol{J}_{ij} \boldsymbol{S}_{j} + \sum_{i} \boldsymbol{X}_{i} \cdot \boldsymbol{S}_{i}. \label{eq:vectrandomising}
\end{equation}&lt;h2 id="deep-implicit-attention-attention-as-a-collective-response"&gt;Deep implicit attention: attention as a collective response&lt;/h2&gt;
&lt;p&gt;Remember that our goal is to understand attention as the collective response of a statistical-mechanical system. Let&amp;rsquo;s now relate vector models like Eq. \eqref{eq:vectrandomising} to attention models by treating the external magnetic fields $\boldsymbol{X}_{i}$ as input data. Batches of sequences applied to every site act as probes for the system, pushing its behaviour into a certain direction. The system&amp;rsquo;s mean-field average magnetizations $\langle \boldsymbol{S}_{i} \rangle$ are an approximation of the collective response at every site: what is the expected value of this particular vector spin? We interpret solving mean-field equations for $\langle \boldsymbol{S}_{i} \rangle$ in the presence of input injections $\boldsymbol{X}_{i}$ as an attention operation. If the whole system is differentiable, we can tune the couplings $\boldsymbol{J}_{ij}$ in an outer-loop optimization to steer the system&amp;rsquo;s behaviour to better&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; respond to future incoming data.&lt;/p&gt;
&lt;h2 id="slow-and-explicit-solving-the-adaptive-tap-equations"&gt;Slow and explicit: solving the adaptive TAP equations&lt;/h2&gt;
&lt;p&gt;What changes do we have to make to the adaptive TAP mean-field equations to turn them into a vector-based attention module and how can we implement them? Let&amp;rsquo;s explicitly enumerate the objects introduced in
together with their (generalized) tensor shapes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;Iteratively determined fixed-point variables&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Spin means $\langle \boldsymbol{S}_{i} \rangle = \left[ \langle \boldsymbol{S}_{i} \rangle \right]^{\alpha}$ &lt;code&gt;(batch_size, N, d)&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;Cavity variances $\boldsymbol{V}_{i} = V_{i}^{\alpha\beta}$ &lt;code&gt;(N, d, d)&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Other variables calculated during fixed-point iteration&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Cavity means $\boldsymbol{a}_{i} = a_{i}^{\alpha}$ &lt;code&gt;(batch_size, N, d)&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;Spin variances $\langle \boldsymbol{S}_{i}^2 \rangle - \langle \boldsymbol{S}_{i} \rangle^2 = \boldsymbol{\chi}_{ii} = \chi_{ii}^{\alpha\beta}$ &lt;code&gt;(N, d, d)&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;For every site, the scalar spin and cavity variances have turned into $d \times d$ (inverse) covariance matrices on the level of the local dimension. Note that the &amp;ldquo;system properties&amp;rdquo; in the above list have no batch size: their values are identical across all examples and capture the properties of the system irrespective of the input injections $\boldsymbol{X}_i$.&lt;/p&gt;
&lt;p&gt;The vector translation of the single-site partition function looks like&lt;/p&gt;
\begin{equation}
Z_{0}^{(i)} = \int \mathrm{d}^{d} \boldsymbol{S} \ \rho_{i}\left(\boldsymbol{S}\right) \exp \left[ \boldsymbol{S} \cdot \left( \boldsymbol{a}_{i} + \boldsymbol{X}_{i} \right) + \frac{1}{2} \boldsymbol{S}^T \boldsymbol{V}_{i} \boldsymbol{S} \right]
\end{equation}&lt;p&gt;where&lt;/p&gt;
\begin{equation}
\boldsymbol{a}_{i} = \sum_{j} \boldsymbol{J}_{ij} \langle \boldsymbol{S}_{j} \rangle - \boldsymbol{V}_{i}\langle \boldsymbol{S}_{i} \rangle. \label{eq:veccavmeans}
\end{equation}&lt;p&gt;Spin means and variances are then computed from&lt;/p&gt;
\begin{equation}
\langle \boldsymbol{S}_{i} \rangle = \frac{\partial}{\partial\boldsymbol{X}_{i}} \log Z_{0}^{(i)} (\boldsymbol{X}_{i}, \boldsymbol{a}_{i}, \boldsymbol{V}_{i})
\end{equation}\begin{equation}
\langle \boldsymbol{S}_{i}^2 \rangle - \langle \boldsymbol{S}_{i} \rangle^2 = \frac{\partial^2}{\partial\boldsymbol{X}_{i}^2} \log Z_{0}^{(i)} (\boldsymbol{X}_{i}, \boldsymbol{a}_{i}, \boldsymbol{V}_{i})
\end{equation}&lt;p&gt;As a spin prior $\rho_{i}\left(\boldsymbol{S}\right)$, we pick a simple diagonal multivariate Gaussian $\mathcal{N} \left( \boldsymbol{\mu} = \boldsymbol{0}_{d}, \boldsymbol{\Sigma}= \boldsymbol{1}_{d \times d} \right)$ at every site, leading to the explicit equations:&lt;/p&gt;
\begin{equation}
\langle \boldsymbol{S}_{i} \rangle = \left( \boldsymbol{\Sigma}^{-1} - \boldsymbol{V}_{i} \right)^{-1} \left( \boldsymbol{a}_{i} + \boldsymbol{X}_{i} \right)
\end{equation}\begin{equation}
\langle \boldsymbol{S}_{i}^2 \rangle - \langle \boldsymbol{S}_{i} \rangle^2 = \left( \boldsymbol{\Sigma}^{-1} - \boldsymbol{V}_{i} \right)^{-1}
\end{equation}&lt;h3 id="generalizing-the-cavity-variance-calculation"&gt;Generalizing the cavity variance calculation&lt;/h3&gt;
&lt;p&gt;The cavity variance computation can be done by generalizing Eqs. \eqref{eq:chiijlinrespexp}&amp;ndash;\eqref{eq:chiijlinresp} and solving the following system of equations for $\boldsymbol{\chi}_{ij}$,&lt;/p&gt;
\begin{equation}
\left( \delta_{ik} \otimes \boldsymbol{1}_{d} - \boldsymbol{\Sigma}_{i} \boldsymbol{J}_{ik} + \boldsymbol{\Sigma}_{i} \boldsymbol{V}_{i} \delta_{ik} \right)\boldsymbol{\chi}_{kj} = \boldsymbol{\Sigma}_{i} \delta_{ij}
\end{equation}&lt;p&gt;The generalization of the self-consistency condition Eq \eqref{eq:viselfcons} is then obtained by solving $\boldsymbol{\chi}_{ii} \boldsymbol{V}_{i} = \boldsymbol{\chi}_{ii} \boldsymbol{\Lambda}_{i} - \boldsymbol{1}_{N \times d \times d}$ for $\boldsymbol{V}_{i}$, where $ \boldsymbol{\Lambda}_{i} = \boldsymbol{V}_{i} + \boldsymbol{\Sigma}^{-1}$ is computed using the current values of $\boldsymbol{V}_{i}$. The price to pay for this added complexity is a computational cost of $O(N^3d^3)$ and an excruciatingly slow backward pass. The algorithm works, but it ain&amp;rsquo;t pretty.&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;Implementation:&lt;/strong&gt; To avoid &lt;code&gt;torch.solve&lt;/code&gt; crashing on singular matrices during the fixed-point calculation, we found it crucial for stability and learning behaviour to initialize the couplings $J_{ij}^{\alpha\beta} \sim \mathcal{N}(0, \sigma^2)$ with small values $\sigma^2 = 1 / (N*d^2)$ to ensure $|J| \sim \mathcal{O}(1)$. It&amp;rsquo;s also beneficial if the sources satisfy $|\boldsymbol{X}_{i}| \sim \mathcal{O}(1)$ so that terms are balanced in the update step, all together adding up to $\mathcal{O}(1)$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 id="fast-and-neural-parametrizing-the-onsager-self-correction-term"&gt;Fast and neural: parametrizing the Onsager self-correction term&lt;/h2&gt;
&lt;p&gt;Can we somehow approximate the slow and explicit calculation of the cavity variances? Since $\boldsymbol{z}^{*} = \left( \langle \boldsymbol{S}_{i}^{*} \rangle, \boldsymbol{V}_{i}^{*} \right)$ at the fixed point, the Onsager self-correction term in Eq. \eqref{eq:veccavmeans} converges to a constant vector $\boldsymbol{V}_{i}^{*}\langle \boldsymbol{S}_{i}^{*} \rangle$ for every site. We propose to make a bold move by getting rid of the cavity variables altogether and reducing the equations for the fixed-point update step to&lt;/p&gt;
\begin{equation}
\langle \boldsymbol{S}_{i} \rangle = \sum_{j} \boldsymbol{J}_{ij} \langle \boldsymbol{S}_{j} \rangle - f_{\theta} \left( \langle \boldsymbol{S}_{i} \rangle \right) + \boldsymbol{X}_{i}, \label{eq:diaupdate}
\end{equation}&lt;p&gt;where $f_{\theta}$ is a neural network parametrizing the action of the cavity variances on the spin means. Since the parameters $\theta$ stay fixed during the inner-loop fixed-point calculation, we have effectively lifted the optimization of the self-correction term to the outer-loop, which also optimizes the weights $\boldsymbol{J}_{ij}$.&lt;/p&gt;
&lt;p&gt;All of this starts to look an awful lot like a transformer module. Before discussing an explicit comparison in
, let&amp;rsquo;s finish this section with a simple example model.&lt;/p&gt;
&lt;h3 id="simple-example-mnist"&gt;Simple example: MNIST&lt;/h3&gt;
&lt;p&gt;A simple image classification model for MNIST using a convolutional feature extractor and a deep implicit attention layer could look something like&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MNISTNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MNISTNet&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_patch_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Conv2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 26 x 26&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MaxPool2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 12 x 12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Conv2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 10 x 10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MaxPool2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; 4 x 4&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s1"&gt;&amp;#39;b c h w -&amp;gt; b (h w) c&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim_conv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;deq_atn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;DEQFixedPoint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;DEQMeanFieldAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_spins&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;weight_sym_internal&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;weight_sym_sites&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;lin_response&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;anderson&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_fwd_max_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;40&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_fwd_tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_bwd_max_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;40&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;solver_bwd_tol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_patch_embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;cls_tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls_token&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;cls_tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;deq_atn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The ViT-style classification token is interpreted as an additional site in the system, which is probed with a learnable input injection that is shared across examples. The model uses the classification token&amp;rsquo;s output response to do the final classification. The system has to self-organize its behaviour so that the classification token gets all the information it needs.&lt;/p&gt;
&lt;img src="vit_mnist.gif" alt="ViT-style model with deep implicit attention layer on MNIST" width="500px"/&gt;
&lt;p&gt;You can
this small model (26k parameters) on MNIST to find a test set accuracy hovering around 99.1%. The animation above shows a graph reflecting the (directed) connection strengths between spins during training as measured by the Frobenius norms of the matrices $\boldsymbol{J}_{ij}$. Almost all major organization of connections is seen to happen in the first few iterations. One imagines the model getting frustrated at zeros which &lt;em&gt;really&lt;/em&gt; look like nines and just flat-out refusing to remember edge cases out of spite.&lt;/p&gt;
&lt;h1 id="a-mean-field-theory-perspective-on-transformers"&gt;A mean-field theory perspective on transformers&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s conclude this post by applying the mean-field theory perspective on attention to the transformer architecture. Schematically, a vanilla transformer module looks like&lt;/p&gt;
&lt;img src="vanilla_transformer_module.png" alt="Vanilla transformer module" width="200px"/&gt;
&lt;p&gt;which consists of an attention module acting on all vectors in the sequence input followed by a feed-forward layer acting &amp;ldquo;locally&amp;rdquo; across individual vectors in the sequence, mixed with some residual connections and layer normalizations.&lt;/p&gt;
&lt;h2 id="parametrizing-the-couplings-sparse-graph-structure-from-inputs"&gt;Parametrizing the couplings: sparse graph structure from inputs&lt;/h2&gt;
&lt;p&gt;Transformers can be interpreted as fully-connected graph neural networks acting on sets of vectors. Inside an attention module, the row-stochastic attention matrix corresponds to a particular parametrization of the couplings&lt;/p&gt;
\begin{equation}
J_{ij} = \left[\mathrm{softmax}\left( \frac{\boldsymbol{X} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{X}^{T}}{\sqrt{d}} \right)\right]_{ij}. \label{eq:softmaxcouplings}
\end{equation}&lt;p&gt;which swaps storing explicit coupling weights for parameters of linear query-key transformations. By dynamically determining the connectivity of the sites based on the inputs $\boldsymbol{X}$ according to Eq. \eqref{eq:softmaxcouplings}, the coupling weights are no longer completely free parameters. The introduction of queries and keys can be seen as a neural network approach to &amp;ldquo;amortizing&amp;rdquo; the coupling tensor while the softmax temperature promotes sparsity. Multiple attention heads correspond to imposing a block-diagonal structure in the hidden dimensions of the couplings: the dot product gets cut into disjoint pieces, one for each attention head.&lt;/p&gt;
&lt;h2 id="softmax-attention-does-a-single-naive-mean-field-update-step"&gt;Softmax attention does a single, naive mean-field update step&lt;/h2&gt;
&lt;p&gt;Looking at the update step \eqref{eq:diaupdate} and the softmax couplings \eqref{eq:softmaxcouplings}, we observe that the softmax attention module does a single, naive mean-field update step without a self-correction term. Ignoring layer normalizations, the attention update step for every input vector looks like&lt;/p&gt;
\begin{equation}
\boldsymbol{X}'_{i} = \sum_{j} \left[ \mathrm{softmax} \left( \frac{\boldsymbol{X} \boldsymbol{W}_{\boldsymbol{Q}} \boldsymbol{W}_{\boldsymbol{K}}^{T} \boldsymbol{X}^{T}}{\sqrt{d}} \right) \right]_{ij} \left[ \boldsymbol{X} \boldsymbol{W}_{\boldsymbol{V}} \right]_{j} + \boldsymbol{X}_{i}, \nonumber
\label{eq:vanilla-attention}
\end{equation}&lt;p&gt;where, crucially, the residual connection is responsible for adding the source term to the update step. Without a residual connection, the applied magnetic field is effectively turned off and the signal would only be able to propagate via the coupling term.&lt;/p&gt;
&lt;h2 id="feed-forward-layer-corrects-naive-mean-field-update"&gt;Feed-forward layer corrects naive mean-field update&lt;/h2&gt;
&lt;p&gt;Looking at the Onsager self-correction term $f_{\theta} \left( \langle \boldsymbol{S}_{i} \rangle \right)$ in the update step \eqref{eq:diaupdate}, we observe that the full transformer attention module emerges when we substitute $\langle \boldsymbol{S}_{i} \rangle$ for its naive mean-field value, leading to&lt;/p&gt;
\begin{equation}
\mathrm{Attention}(\boldsymbol{X})_{i} = \boldsymbol{X}'_{i} + \mathrm{FeedForward}\left( \boldsymbol{X}'_{i} \right),
\end{equation}&lt;p&gt;with $\boldsymbol{X}'_{i}$ defined above. Again, the residual connection appears to be crucial for the structure of the mean-field theory equations to match the vanilla transformer module&amp;rsquo;s architecture. As previously discussed in
, we hypothesize that feed-forward networks in transformer modules &amp;ldquo;amortize&amp;rdquo; the linear response self-corrections.&lt;/p&gt;
&lt;h2 id="mean-field-theory-framework-for-transformer-architectures"&gt;Mean-field theory framework for transformer architectures&lt;/h2&gt;
&lt;p&gt;Within the general mean-field (or
) structure outlined above, there is considerable freedom in parametrizing the interaction and self-correction terms. Most transformer papers parametrize the self-correction terms with a feed-forward layer, i.e. some variation of an MLP. In
the authors went even further and dropped the softmax parametrization of the interaction term to approximate the full action of summing over couplings with an MLP as well. Related papers like
,
, and
can all be considered as explorations of different parametrizations of the mean-field interaction terms. In the large-scale regime, it seems like the softmax attention module can be swapped for just about any function which mixes tokens as long as the structure of residual connections and self-correction terms is preserved.&lt;/p&gt;
&lt;h2 id="comparison-with-energy-based-perspective"&gt;Comparison with energy-based perspective&lt;/h2&gt;
&lt;p&gt;In a previous post on
, we introduced a picture of attention modules in transformers as stacks of energy functions which are defined dynamically at every layer depending on the outputs of the previous layer (so ultimately on the inputs of the first layer). Looking back, this interpretation feels kind of forced and is also unable to explain the presence of skip connections and fully-connected layers surrounding the attention modules. The mean-field perspective seems more interesting since it (1) relies on just one layer (one energy function) whose fixed-point (an infinite amount of &amp;ldquo;layers&amp;rdquo;) gets calculated, and (2) explains the presence of skip connections (source terms) and fully-connected layers (amortized self-correction terms).&lt;/p&gt;
&lt;h1 id="conclusion-and-outlook"&gt;Conclusion and outlook&lt;/h1&gt;
&lt;p&gt;We have shown how attention can be understood as the mean-field response of Ising-like spin systems being probed by data. By thinking of incoming data as applied magnetic fields and the output of attention modules as spin expectation values, attention can be interpreted as a fixed-point optimization process solving for a compromise between a system&amp;rsquo;s internal dynamics and the data it&amp;rsquo;s being exposed to. Since the whole system is differentiable, we can optimize the interaction weights in an outer loop to nudge the system&amp;rsquo;s behaviour.&lt;/p&gt;
&lt;p&gt;We have also seen how transformers fit into the mean-field theory framework. For scalability, transformers introduce two additional constraints/approximations on top of the mean-field approximation: (1) replacing explicit couplings with parametrized couplings that are dynamically computed from the input via linear transformations (softmax query-key-value attention), and (2) replacing the expensive self-consistent computation of Onsager self-correction terms with a neural network (feed-forward layer).&lt;/p&gt;
&lt;p&gt;Looking ahead, the methods introduced in this post could provide ways to implicitly train mean-field approximations of Boltzmann machines and have them serve as distributed attention modules in larger interconnected systems. To go beyond mean-field approaches, it could be interesting to look at tensor network approaches. Conceptually, the physical interpretation of attention as an interacting many-body system modulating its behaviour by &lt;em&gt;learning to respond to being driven in particular ways&lt;/em&gt; is fun to think about.&lt;/p&gt;
&lt;h1 id="related-work"&gt;Related work&lt;/h1&gt;
&lt;p&gt;A non-exhaustive list of references and inspiration includes:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;On deep equilibrium models:
(2019) by Shaojie Bai, Zico Kolter, Vladlen Koltun and
of the
by Zico Kolter, David Duvenaud, and Matt Johnson&lt;/li&gt;
&lt;li&gt;On the adaptive Thouless-Anderson-Palmer (TAP) mean-field approach in disorder physics:
(2001) by Manfred Opper and Ole Winther&lt;/li&gt;
&lt;li&gt;On variational inference, iterative approximation algorithms, expectation propagation, mean-field methods and belief propagation:
(2014) by Jack Raymond, Andre Manoel, Manfred Opper&lt;/li&gt;
&lt;li&gt;On Boltzmann machines and mean-field theory:
(1998) by H. J. Kappen and
F. B. Rodríguez and
(1998) by Toshiyuki Tanaka&lt;/li&gt;
&lt;li&gt;On approximate message passing (AMP) methods in statistics:
(2021) by Oliver Y. Feng, Ramji Venkataramanan, Cynthia Rush, Richard J. Samworth: the example on page 2 basically describes how transformers implement approximate message passing: an iterative algorithm with a &amp;ldquo;denoising&amp;rdquo; step (attention) followed by a &amp;ldquo;memory term&amp;rdquo; or Onsager correction term (feed-forward layer)&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021deepimplicitattention,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {May},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Whatever &amp;ldquo;better&amp;rdquo; means depends on the system&amp;rsquo;s (meta-)loss function, e.g. predicting corrupted tokens BERT-style or aligning representations to a teacher BYOL/DINO-style.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>