<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Dynamical Systems | mcbal</title><link>https://mcbal.github.io/tags/dynamical-systems/</link><atom:link href="https://mcbal.github.io/tags/dynamical-systems/index.xml" rel="self" type="application/rss+xml"/><description>Dynamical Systems</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-gb</language><lastBuildDate>Wed, 17 Mar 2021 22:36:17 +0100</lastBuildDate><image><url>https://mcbal.github.io/media/icon.svg</url><title>Dynamical Systems</title><link>https://mcbal.github.io/tags/dynamical-systems/</link></image><item><title>Attention as Energy Minimization: Visualizing Energy Landscapes</title><link>https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/</link><pubDate>Wed, 17 Mar 2021 22:36:17 +0100</pubDate><guid>https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/</guid><description>&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;&lt;strong&gt;📓 Colab notebook available
. Comments welcome.&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Recent work &lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; has shown that the softmax-attention update step in transformer models can be intepreted as a one-step gradient update or &amp;ldquo;inference&amp;rdquo; step of a judiciously chosen energy function. An overview of these ideas can be found in previous blog posts:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;The goal of this educational blog post is to explicitly show how vanilla softmax attention is related to energy minimization approaches and how the former can be substituted for the latter. For pedagogical purposes, we will focus purely on the attention operation. However, for transformer models to perform well in practice, it is necessary to wrap attention in residual connections and point-wise feedforward processing layers, see e.g.
.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Summary:&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;We provide a pedagogical energy-based attention module that stays as close as possible to vanilla softmax attention for ease of comparison.&lt;/li&gt;
&lt;li&gt;We walk through the correspondence between modern Hopfield networks and vanilla softmax attention by gradually adding complexity.&lt;/li&gt;
&lt;li&gt;We present visualizations of energy landscapes and trajectories associated to attention update steps for two-dimensional toy patterns.&lt;/li&gt;
&lt;/ol&gt;
&lt;h1 id="prelude-pattern-terminology"&gt;Prelude: pattern terminology&lt;/h1&gt;
&lt;p&gt;Transformer literature almost exclusively talks about queries, keys, and values. For self-attention, these are all obtained from different linear transformations acting on the same set of &lt;em&gt;input patterns&lt;/em&gt;. For cross-attention, only the queries derive from the &lt;em&gt;input patterns&lt;/em&gt;; the keys and values are obtained from a different set of &lt;em&gt;context patterns&lt;/em&gt;: think of a decoder architecture attending to encoded translations or the
model attending to multimodal input.&lt;/p&gt;
&lt;p&gt;Hopfield networks literature starts from the idea of trying to implement an associative memory system for storing and retrieving patterns. Patterns stored in memory are called &lt;em&gt;stored patterns&lt;/em&gt;. A &lt;em&gt;state pattern&lt;/em&gt; is an input prompt for the associative memory system: what patterns stored in memory are closest to this particular prompt?&lt;/p&gt;
&lt;p&gt;Depending on the context (heh), we can refer to input patterns as state patterns or queries and to context patterns as stored patterns or memory or keys.&lt;/p&gt;
&lt;h1 id="attention-modules"&gt;Attention modules&lt;/h1&gt;
&lt;h2 id="explicit-vanilla-softmax-attention"&gt;Explicit vanilla softmax attention&lt;/h2&gt;
&lt;p&gt;To compare the behavior of explicit attention modules to that of energy-based attention modules, we need to first of all define a vanilla softmax attention module. The annotated implementation below features a &lt;code&gt;bare_attn&lt;/code&gt; toggle in the forward pass for ease of comparison with the &amp;ldquo;bare&amp;rdquo; modern continuous Hopfield energy function we will discuss later on. The flag essentially disables all linear mappings so input and context patterns are processed &amp;ldquo;raw&amp;rdquo;.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;VanillaSoftmaxAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Vanilla softmax attention.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Adapted from https://github.com/lucidrains/perceiver-pytorch (commit 37e2eb6).
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Inner dimension is expressed in terms of head count and dimensionality&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# and thus decoupled from query_dim/context_dim (heads always &amp;#34;fit&amp;#34;).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Linear transformations (queries, keys, values, head-mixing).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# To facilitate comparison with modern Hopfield networks, setting `bare_attn`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# to `True` disables all linear mappings, assures there&amp;#39;s only a single head and&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# reduces the module to a barebone attention which takes in &amp;#34;raw&amp;#34; queries or state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# patterns and attends to a &amp;#34;raw&amp;#34; context/memory of stored patterns.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;only a single head when bare attention&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;query_dim/context_dim must match&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Adaptive scale.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Take context either from elsewhere of from self (attention vs. self-attention).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Map x to queries and context to keys and values.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Split up latent dimension into subspaces for heads to act on.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Head dimension becomes part of batch dimension (=&amp;gt; parallel processing of heads).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b n (h d) -&amp;gt; (b h) n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Scaled dot product of all queries against all keys (sum over `inner_dim`).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;b i d, b j d -&amp;gt; b i j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Optional masking.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;finfo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b j -&amp;gt; (b h) () j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Softmax operation across &amp;#34;keys&amp;#34; sequence dimension.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Contract attention matrix with values.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;b i j, b j d -&amp;gt; b i d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Move head dimension out of batch again.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;(b h) n d -&amp;gt; b n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Mix all the heads&amp;#39; outputs; stir well and serve immediately.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="implicit-energy-based-attention"&gt;Implicit energy-based attention&lt;/h2&gt;
&lt;p&gt;Next, we define our energy-based attention module. Its forward pass will make use of the simple gradient descent function defined below to do energy minimization and update queries accordingly.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;minimize_energy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Minimize energy function with respect to queries.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Keeps track of energies and trajectories for logging and plotting.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;defaultdict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queries&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;grad_queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;autograd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energies&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;grad_outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energies&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;create_graph&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# enables double backprop for optimization&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;step_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;grad_queries&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queries&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;energies&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;energies&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queries&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The &lt;code&gt;EnergyBasedAttention&lt;/code&gt; module below has been structured to look as similar as possible to the the &lt;code&gt;VanillaSoftmaxAttention&lt;/code&gt; module defined above. The main difference is the appearance of an energy function and the energy minimization call in the forward pass where the softmax attention used to be. Other differences include the absence of a linear map to &amp;ldquo;values&amp;rdquo; and masking being pushed into the energy function.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;EnergyBasedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Linear transformations (queries, keys, output).&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inner_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;energy_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Bare checks.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;only a single head when bare attention&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;query_dim/context_dim must match&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b n (h d) -&amp;gt; (b h) n d&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;b j -&amp;gt; (b h) () j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Minimize energy with respect to queries.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;minimize_energy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;return_trajs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rearrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;(b h) n d -&amp;gt; b n (h d)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_out&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h1 id="from-modern-hopfield-networks-to-multi-head-attention"&gt;From modern Hopfield networks to multi-head attention&lt;/h1&gt;
&lt;p&gt;Let&amp;rsquo;s start with the simplest possible case: bare attention. We disable all linear mappings to queries/keys/values/output to make sure input and context patterns are processed &amp;ldquo;raw&amp;rdquo; and restrict ourselves to a single attention head. We numerically verify that a &amp;ldquo;bare&amp;rdquo; explicit attention module indeed returns the same result as doing a single, big step of energy minimization with respect to input state patterns. Put differently and more to the point, we merely show that automatic differentiation works.&lt;/p&gt;
&lt;h2 id="energy-function"&gt;Energy function&lt;/h2&gt;
&lt;p&gt;Consider the energy function of a modern continuous Hopfield network for a set of state patterns $\boldsymbol{\Xi}$ and stored patterns $\boldsymbol{X}$:&lt;/p&gt;
\begin{equation}
E(\boldsymbol{\Xi}; \boldsymbol{X}) = \frac{1}{2} \boldsymbol{\Xi}^T \boldsymbol{\Xi} -\mathrm{logsumexp} \left( \boldsymbol{X}^T \boldsymbol{\Xi} \right),\label{eq:energy}
\end{equation}&lt;p&gt;Think of this model as the scoring function of an associative memory system. For now, we&amp;rsquo;d like to keep the stored patterns fixed as memory slots and wiggle around the state patterns. We can translate this energy function into the following (batched) function:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;kinetic&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;b i d, b i d -&amp;gt; b i&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scaled_dot_product&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;b i d, b j d -&amp;gt; b i j&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stored_patterns&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;finfo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scaled_dot_product&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scaled_dot_product&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_neg_value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;potential&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logsumexp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scaled_dot_product&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;kinetic&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;potential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="verifying-the-update-rule"&gt;Verifying the update rule&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s sample some state patterns and stored patterns and enable gradient tracking for the state patterns since we want to take derivatives with respect to these parameters later on.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="cross-attention"&gt;Cross-attention&lt;/h3&gt;
&lt;p&gt;First up is cross-attention. We feed state patterns as input and stored patterns as context into a vanilla softmax attention module.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;softmax_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VanillaSoftmaxAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_softmax_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now we do the same for an energy-based attention module and tell it to take a single, big gradient update step.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;EnergyBasedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now let&amp;rsquo;s compare the outputs of the two methods:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output_bare_softmax_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_bare_energy_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;True
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;Both tensors are approximately equal: bare softmax attention corresponds to taking a single gradient step of &lt;code&gt;step_size=1.0&lt;/code&gt; with respect to the state patterns using the energy function of modern Hopfield networks as a loss. For more details on this correspondence, we refer to
.&lt;/p&gt;
&lt;h3 id="self-attention"&gt;Self-attention&lt;/h3&gt;
&lt;p&gt;Let&amp;rsquo;s do the same check for self-attention, which boils down to only inputting state patterns. Internally, the modules will consider the state patterns as stored patterns and effectively make the patterns pay attention to themselves.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_softmax_self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_bare_energy_self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_bare_softmax_self_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_bare_energy_self_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Norm between input state patterns and energy-minimized patterns: &amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;output_bare_energy_self_attn&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;True
Norm between input state patterns and energy-minimized patterns: 5.553587470785715e-06
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;The pattern update step looks almost like an an identity operation, which is to be expected for &amp;ldquo;bare&amp;rdquo; self-attention&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;. Without any linear transformations to map state patterns to queries and keys, every state pattern starts off already close to a local minimum since it coincides with itself as a stored pattern. The query starts off close to the key since the query-key mappings are identities. We will visualize this behavior in
for two-dimensional patterns.&lt;/p&gt;
&lt;h2 id="adding-queries-keys-and-values"&gt;Adding queries, keys, and values&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s now move closer to proper vanilla softmax attention by enabling linear transformations which map state patterns to queries and stored patterns to keys (and values). These parameters are able to move patterns around on the energy landscape before (queries, keys) and after (values) paying attention.&lt;/p&gt;
&lt;p&gt;We recycle the previously instantiated patterns and modules and compare outputs again, making sure the parameters are equal and omitting the &lt;code&gt;bare_attn&lt;/code&gt; flag:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_softmax_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;strict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output_softmax_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_energy_attn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;False
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;Why don&amp;rsquo;t the outputs match? We have to make sure we compare apples to apples and be mindful of the fact that the energy minimization step only knows about keys. Indeed, as shown previously in
, the one-step energy minimization, expressed in terms of queries and keys, effectively implements&lt;/p&gt;
\begin{equation}
\boldsymbol{Q}^{\text{new}} = \text{softmax}\left( \frac{1}{\sqrt{d_k}} \boldsymbol{Q} \boldsymbol{K}^T \right) \boldsymbol{K}
\end{equation}&lt;p&gt;instead of the vanilla softmax attention step&lt;/p&gt;
\begin{equation}
\boldsymbol{Q}^{\text{new}} = \text{softmax}\left( \frac{1}{\sqrt{d_k}} \boldsymbol{Q} \boldsymbol{K}^T \right) \boldsymbol{V}
\end{equation}&lt;p&gt;We can approximately undo this mapping to make a forced comparison for fixed parameters:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output_energy_attn_transformed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;softmax_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_energy_attn&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pinverse&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energy_attn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;weight&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output_softmax_attn&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;output_energy_attn_transformed&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;tensor(0.0005, grad_fn=&amp;lt;CopyBackwards&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;Yet since all these parameters would be optimized in a real-world scenario, we should only care about whether the representational power of the modules is similar. To make the single-head energy-based attention module more expressive, we can always add an output layer, parametrized by weights $W_{O}$, to the module. As long as the composition of linear transformations $W_{K}W_{O}$ doesn&amp;rsquo;t collapse and its rank does not fall below that of the softmax attention&amp;rsquo;s $W_{V}$, things should be okay.&lt;/p&gt;
&lt;h2 id="adding-masking-and-multiple-attention-heads"&gt;Adding masking and multiple attention heads&lt;/h2&gt;
&lt;p&gt;Finally, let us tie up some loose ends and complete the correspondence between vanilla softmax attention and energy-based minimization.&lt;/p&gt;
&lt;h3 id="masking"&gt;Masking&lt;/h3&gt;
&lt;p&gt;Since masking boils down to putting restrictions on what patterns in the inputs are allowed to talk to each other, it can just as well be done at the level of the energy function. By filling the tensor inside the &lt;code&gt;logsumexp&lt;/code&gt; operator in &lt;code&gt;hopfield_energy&lt;/code&gt; with $-\infty$ values at to-be-masked-out positions, we get the same effect as the masking operation in the forward pass of &lt;code&gt;VanillaSoftmaxAttention&lt;/code&gt;. Boolean masks can be passed to the &lt;code&gt;EnergyBasedAttention&lt;/code&gt;&amp;rsquo;s forward function and propagate to the energy function.&lt;/p&gt;
&lt;h3 id="multi-head-attention"&gt;Multi-head attention&lt;/h3&gt;
&lt;p&gt;Up to now, we have only considered a single attention head. Essentially, multiple attention heads subdivide the latent space into equal parts and process these subproblems in parallel. The head dimension becomes part of the batch dimension. This translates to having parallel energy minimizations going on for different heads, each acting on their own subspace. Since our &lt;code&gt;hopfield_energy&lt;/code&gt; function is already batched, we can use the same machinery of the previous sections, as shown below.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mha_energy_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;EnergyBasedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mha_energy_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;pre&gt;&lt;code&gt;tensor([[[-0.0514, -0.0353, 0.0243, ..., -0.0335, -0.0060, 0.0243],
[-0.1004, -0.0136, -0.0297, ..., 0.0079, 0.0083, 0.0336],
[-0.0507, -0.0369, -0.0219, ..., -0.0022, -0.0246, -0.0223],
...,
[-0.0388, -0.0217, -0.0470, ..., -0.0067, 0.0020, -0.0139],
[-0.0283, -0.0699, -0.0205, ..., -0.0261, -0.0667, 0.0052],
[-0.0262, -0.0360, -0.0139, ..., -0.0011, -0.0199, -0.0004]]],
grad_fn=&amp;lt;AddBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;It is hard to compare with the exact output of the equivalent &lt;code&gt;VanillaSoftmaxAttention&lt;/code&gt; module for fixed module parameters. For multi-head attention, the updated queries coming out of the separate energy minimization steps will have summed over each heads&amp;rsquo; keys instead of its values. For a single attention head we could undo the keys&amp;rsquo; transformation by acting with the inverse of the keys&amp;rsquo; weights. For multiple attention heads, that is no longer possible.&lt;/p&gt;
&lt;p&gt;Again, since all these parameters would be optimized in a real-world scenario, we should only care about whether the representational power of the modules is similar. One approach would be to add parameters inside the energy function that take care of mapping to &amp;ldquo;values&amp;rdquo; on the level of the heads.&lt;/p&gt;
&lt;h1 id="attention-in-flatland-visualizing-energy-landscapes"&gt;Attention in flatland: visualizing energy landscapes&lt;/h1&gt;
&lt;p&gt;We now leave the world of high-dimensional latent spaces behind us and focus on the toy model scenario of just two latent space dimensions. We only consider a single attention head because having just two heads, each with dimension one, is just silly. For every two-dimensional token pattern vector, a third dimension will be provided by the value of the scalar energy function at that point.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s sample some tiny toy patterns to play around with.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h4 id="bare-cross-attention"&gt;Bare cross-attention&lt;/h4&gt;
&lt;p&gt;Let&amp;rsquo;s plot our tiny toy patterns taking a big gradient step!&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-bare-cross-attention"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_16e739ac33b2bfa8.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_89bbd1b398e61db4.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_c884d6bda828a1f1.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_54_0_hu_16e739ac33b2bfa8.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Bare cross-attention
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;In the figure above, the blue open circles correspond to the stored patterns (memory, context, keys, &amp;hellip;), the red circles denote the initial state patterns (inputs, queries, probes, &amp;hellip;) and the red crosses the updated queries obtained after &lt;code&gt;n_steps&lt;/code&gt; of energy minimization. The red arrows denote the trajectory in the energy landscape.&lt;/p&gt;
&lt;p&gt;We will now illustrate some example scenarios.&lt;/p&gt;
&lt;h4 id="small-steps-go-nowhere"&gt;Small steps go nowhere&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-small-steps"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_913d624b8622617a.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_ac2668d10a8d96ce.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_20adc42ede72878c.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_56_0_hu_913d624b8622617a.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Small steps
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="lots-of-big-steps-converge-near-global-minimum-or-repeated-softmax-iterations-make-all-token-representations-identical"&gt;Lots of (big) steps converge near (global) minimum or repeated softmax iterations make all token representations identical&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-big-steps"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_d991410d3d62e314.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_4174b72e3e8a46a4.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_68a97e48bda407ba.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_57_0_hu_d991410d3d62e314.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Big steps
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="decreasing-the-scale-increasing-the-temperature-makes-the-landscape-smoother-and-encourages-convergence-to-same-global-minimum"&gt;Decreasing the scale (increasing the temperature) makes the landscape smoother and encourages convergence to same (global) minimum&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-decrease-scale"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_ed4cc3ea5763a50b.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_b9ea06f031681b75.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_65e05b2705206111.webp 574w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_58_0_hu_ed4cc3ea5763a50b.webp"
width="574"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Decrease scale
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="increasing-the-scale-lowering-the-temperature-creates-disconnected-valleys-in-the-energy-landscape-inhabited-by-stored-patterns-which-act-as-attractors-for-any-query-that-happens-to-be-in-its-basin-of-attraction"&gt;Increasing the scale (lowering the temperature) creates &amp;ldquo;disconnected&amp;rdquo; valleys in the energy landscape inhabited by stored patterns which act as attractors for any query that happens to be in its basin of attraction&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-increase-scale"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_11c5d579723c22a8.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_b9fdc7b06e7f4bd4.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_36776070bfc9a5b8.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_59_0_hu_11c5d579723c22a8.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Increase scale
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="adding-linear-query-key-value-transformations"&gt;Adding linear query-key-value transformations&lt;/h4&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# As commented on before, the value transformation is applied&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# after the update step so that effectively the product&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# W_K x W_V is applied to the updated state patterns.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;to_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;to_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;to_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;to_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;to_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_stored_patterns&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;values_post_processing_func&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;to_v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_grid_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-adding-query-key-value-mappings"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_7f82d31be0acfa17.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_1c2fe38a2f7233eb.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_39d9fe4f4110b9e2.webp 587w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_60_0_hu_7f82d31be0acfa17.webp"
width="587"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Adding query-key-value mappings
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;The yellow arrows point from the final, energy-minimized, query updates to the &amp;ldquo;value-transformed&amp;rdquo; output queries, which are denoted with yellow crosses. Running this cell again in the colab notebook will give different landscapes and trajectories every time since the queries and keys depend on the random linear layers. The differences are more pronounced when increasing the scale (lowering the temperature).&lt;/p&gt;
&lt;p&gt;Since the value transformation is done after the energy minimization, it can and does undo some of the influence of the keys&amp;rsquo; attractors, e.g. sending updated queries to &amp;ldquo;uphill&amp;rdquo; regions in the energy landscape defined at that that layer. This suggests that the value transformation should not be seen as part of the core attention mechanism but that its role is rather to learn during training how to best hop to different regions in preparation for whatever the next layer needs.&lt;/p&gt;
&lt;h4 id="bare-self-attention-on-the-importance-of-scale-and-why-multiple-heads"&gt;Bare self-attention: on the importance of scale and why multiple heads&lt;/h4&gt;
&lt;p&gt;Since all of the flatland examples so far have been for cross-attention, let&amp;rsquo;s also visualize a self-attention update below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;simulate_and_plot_patterns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hopfield_energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;copy_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;toy_state_patterns&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;step_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plot_title&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Energy landscape for two-dimensional toy patterns&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-bare-self-attention-visualization"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_78f0dd4fca34764a.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_28e0072d72ffd56f.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_df1d0c6c073c1e65.webp 577w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_62_0_hu_78f0dd4fca34764a.webp"
width="577"
height="496"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Bare self-attention visualization
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Wait, what? Why did the updated state patterns move from their initialization? Didn&amp;rsquo;t we see before that the norm between inputs and outputs hardly changed at all for bare self-attention?&lt;/p&gt;
&lt;p&gt;To look into this, let&amp;rsquo;s plot the norm between inputs and outputs in function of the latent dimension, while scaling the scale or inverse temperature relative to the transformer default $\beta = 1/\sqrt{\mathrm{d_k}}$. We sample toy patterns repeatedly for every dimension/scale combination to get an idea of the statistical behavior.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;dims&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;int32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;norms&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bare_attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;VanillaSoftmaxAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim_head&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta_scale&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bare_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bare_attn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta_scale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norms&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# Suppresses a warning.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;norms&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ma&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;norms&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;norms&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# Plot data.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;fig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gca&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;meshgrid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;contourplot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contourf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dims&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta_scales&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norms&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;colors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LogNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vmin&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vmax&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;levels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;d_k&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;scale / sqrt(d_k)&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;colorbar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;contourplot&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;format&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;%.e&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ticks&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;ticker&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LogLocator&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;base&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;transformer_default_scale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure id="figure-bare-self-attention-experiment"&gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="alt text"
srcset="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_c6264e1df494a014.webp 320w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_d419bf703bcb545.webp 480w, https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_bef84266f747834.webp 588w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/vanilla_softmax_attention_final_refactor_68_0_hu_c6264e1df494a014.webp"
width="588"
height="485"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;figcaption&gt;
Bare self-attention experiment
&lt;/figcaption&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;In this contour plot, we plot the norm differences between inputs and outputs of a bare self-attention step for a sweep across latent dimensions and inverse temperature scale factors. The horizontal red line corresponds to the scale factor used by default in most transformer implementations. Some comments:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;For a fixed latent dimension, we see that increasing the scale factor corresponds to smaller norm differences, i.e. more pronounced valleys where it&amp;rsquo;s much harder to get out of, especially if you start at the bottom and there is no query-key-value mapping taking you elsewhere.&lt;/li&gt;
&lt;li&gt;The vertical red line corresponds to the earlier bare self-attention result using a latent dimension of 512. The intersection point indeed corresponds a norm difference of the order we saw previously. The value for a latent dimension of 2 (left border of plot) suggests that patterns do move around quite a bit, confirming our visualization above.&lt;/li&gt;
&lt;li&gt;Setting the scale for bare multi-head attention proportionally to the (smaller) head dimension instead of the full latent dimension corresponds to moving leftwards along the horizontal red line. The norm difference increases so that, for bare multi-head self-attention, patterns in multiple small heads tend to bounce around more than they would in a single big head. This might be one of the reasons why multiple heads help with training transformers: since the effective temperature is lower in the smaller latent spaces, the topography of the lower-dimensional energy landscapes is more pronounced and individual heads can go explore a bit to find their niche valley.&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;Using the tools presented in this blog post, we have shown that it is possible to swap the explicit attention module in a transformer for an implicit energy minimization method. What happens when we start playing around with different energy functions? Can we make patterns interact? Can we make the energy minimization step more efficient by treating it as a fixed-point problem? It remains to be seen whether all of this is a useful thing to do.&lt;/p&gt;
&lt;h1 id="references--footnotes"&gt;References &amp;amp; footnotes&lt;/h1&gt;
&lt;p&gt;If you happen to find this work useful, please consider citing it as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{bal2021visualizingattention,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = {Attention as Energy Minimization: Visualizing Energy Landscapes},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = {Bal, Matthias},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = {2021},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; month = {March},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = {https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/},
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;em&gt;Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Lukas Gruber, Markus Holzleitner, Milena Pavlović, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;em&gt;Dmitry Krotov and John Hopfield,
(2020)&lt;/em&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;strong&gt;Caveat:&lt;/strong&gt; For the special case of bare energy-based self-attention, state patterns actually appear quadratically in the argument of the &lt;code&gt;logsumexp&lt;/code&gt; part of the energy function. Taking the derivative using &lt;code&gt;minimize_energy(..)&lt;/code&gt; however assumes the context is a different node in the computational graph, which, in this case, where we &lt;em&gt;should&lt;/em&gt; be taking the derivative of &lt;code&gt;energy(x, x)&lt;/code&gt; instead of &lt;code&gt;energy(x, context)&lt;/code&gt;, yields a gradient that misses a factor of 2. But ensuring the gradient is &amp;ldquo;correct&amp;rdquo; for this special case would of course screw up the cancellation of the state pattern with itself for &lt;code&gt;step_size=1.0&lt;/code&gt; and &lt;code&gt;num_steps=1&lt;/code&gt; so that the updated query would no longer match the output of bare vanilla softmax attention. Proper treatment of doing multiple steps of bare energy-based self-attention should also include manually setting the context to the updated queries (since the queries themselves change every update step). Luckily no one would seriously consider using bare energy-based self-attention.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>