This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

MonarchAttention: Zero-Shot Conversion to Fast, Hardware-Aware Structured Attention

Can Yaras Alec S. Xu Pierre Abillama Changwoo Lee Laura Balzano
Abstract

Transformers have achieved state-of-the-art performance across various tasks, but suffer from a notable quadratic complexity in sequence length due to the attention mechanism. In this work, we propose MonarchAttention – a novel approach to sub-quadratic attention approximation via Monarch matrices, an expressive class of structured matrices. Based on the variational form of softmax, we describe an efficient optimization-based algorithm to compute an approximate projection of softmax attention onto the class of Monarch matrices with Θ(NNd)\Theta(N\sqrt{N}d) computational complexity and Θ(Nd)\Theta(Nd) memory/IO complexity. Unlike previous approaches, MonarchAttention is both (1) transferable, yielding minimal performance loss with no additional training, even when replacing every attention layer of the transformer, and (2) hardware-efficient, utilizing the highest-throughput tensor core units on modern GPUs. With optimized kernels, MonarchAttention achieves substantial speed-ups in wall-time over FlashAttention-2: 1.4×1.4\times for shorter sequences (N=256)(N=256), 4.5×4.5\times for medium-length sequences (N=4K)(N=4K), and 8.2×8.2\times for longer sequences (N=16K)(N=16K). We demonstrate the quality of MonarchAttention on diverse tasks and architectures in vision and language problems, showing that it flexibly and accurately approximates softmax attention in a variety of contexts. Our code is available at https://github.com/cjyaras/monarch-attention.

1 Introduction

Over the past decade, transformers (Vaswani et al., 2017) have become the dominant architecture for generating and processing various data modalities, such as text (Brown et al., 2020), images (Dosovitskiy et al., 2021), and speech (Radford et al., 2023). Central to the transformer’s success is attention, the mechanism through which complex interactions within sequential data are captured through weighted combinations of embeddings at every position in the sequence. Famously, the attention mechanism has a quadratic-time complexity Θ(N2d)\Theta(N^{2}d) in the length of the sequence NN, where dd is the head dimension, which is a key bottleneck for both training and inference, particularly in long sequence problems. To address this, numerous works have proposed sub-quadratic substitutes for attention. Yet, such approaches either (1) are not transferable, requiring training from scratch or fine-tuning of existing models, or (2) do not yield speed-ups in practice (except on extremely long sequences) due to a gap between theoretical complexity and practical considerations for modern GPUs, especially compared to highly optimized implementations (Dao et al., 2022b).
 
In this work, we propose MonarchAttention: a novel sub-quadratic attention substitute based on approximating the attention matrix via Monarch matrices (Dao et al., 2022a), a class of expressive structured matrices. At first glance, this is computationally infeasible – for sequence length NN, computing an exact projection onto the set of Monarch matrices has a super-quadratic O(N2N)O(N^{2}\sqrt{N})-time complexity, not to mention that we need to form the entire N×NN\times N attention matrix. Instead, we reframe the computation of the attention matrix as an optimization problem in terms of the variational form of softmax, and exploit low-dimensional structure in the variational objective when constrained to the set of Monarch matrices – this yields a sub-quadratic Θ(NNd)\Theta(N\sqrt{N}d)-time approximation, where dd is the head dimension. This approach is analogous to optimization-based approaches for low-rank approximation of a matrix (Chi et al., 2019), where rather than computing a full SVD and truncating to the desired rank, one can more efficiently minimize a Frobenius norm objective constrained to the set of low-rank matrices. We briefly review prior work on structured matrices, including Monarch matrices, as well as existing approaches to efficient attention.

Refer to caption
Figure 1: Approximation of softmax attention via MonarchAttention. By directly optimizing the softmax variational objective constrained to Monarch matrices, MonarchAttention yields accurate zero-shot approximation to softmax attention compared to other hardware-friendly, efficient attention baselines. Attention maps extracted from RoBERTa on the SQuAD dataset in Section 4.

Structured & Monarch Matrices.

We use the phrase “structured matrices” to mean those that admit sub-quadratic storage and matrix-vector multiplication, such as low-rank or sparse matrices. There are many useful classes of structured matrices, such as those with low displacement rank (Kailath et al., 1979), which includes Toeplitz, Hankel, Vandermonde, Cauchy matrices (Pan, 2001); orthogonal polynomial transforms (Chihara, 2014), which includes discrete Fourier/cosine and Hadamard transforms; butterfly factorizations (Dao et al., 2019), which implement fast matrix-vector multiplication via a recursive divide-and-conquer algorithm similar to that of fast Fourier transforms (FFTs); and Monarch matrices, an expressive family of structured matrices (generalizing butterfly matrices and thereby many fast transforms) that overcome unfavorable memory access patterns typical to FFT-like algorithms by implementing matrix products via batched dense matrix multiplications (also called matmuls) on fast tensor cores found in modern GPUs.

Sub-Quadratic Attention.

Nearly all approaches to sub-quadratic attention approximate the attention matrix by a structured matrix, specifically low-rank and/or sparse.

  • \bullet

    Low-Rank. Motivated by Johnson-Lindenstrauss embeddings, Wang et al. (2020) propose sketching the key and value matrices along the sequence dimension via learnable projections. Katharopoulos et al. (2020) introduce linear attention, where the exponential kernel is approximated via inner products of queries and keys lifted via some feature map. Several follow-up works proposed various feature maps, such as the exponential linear unit (ELU) (Katharopoulos et al., 2020), random positive features (Choromanski et al., 2021), rectified linear unit (ReLU) with cosine reweighting (Qin et al., 2022), and learnable single-layer multi-layer perceptrons (MLPs) (Zhang et al., 2024). Xiong et al. (2021) use the Nyström method for computing low-rank approximations by sampling rows and columns.

  • \bullet

    Sparse. Child et al. (2019) introduce sparsity by applying fixed, structured sparse masks on the attention matrix. In particular, Chen et al. (2022) propose a particular block butterfly matrix for the sparse mask, which is more hardware-friendly at the cost of reduced expressiveness. Those that do not enforce a structure on the sparsity pattern include Kitaev et al. (2020); Daras et al. (2020) where they utilize locality-sensitive hashing (LSH) on shared query/key vectors to only compute attention within clusters of similar tokens.

  • \bullet

    Low-Rank + Sparse. Inspired by robust PCA, Chen et al. (2021) decompose the attention matrix into a sum of two matrices: an unstructured sparse component using LSH and a low-rank component that is constructed via linear attention. Han et al. (2024) propose to subsample columns of the non-normalized attention matrix based on row norms of the value matrix, while estimating the softmax normalization factors from a few large elements via LSH.

We note that there are significant drawbacks to the approaches described above. Pure low-rank methods are often fast and hardware-friendly, but are not typically suitable as drop-in replacements for attention in pre-trained transformers due to the prevalence of “strongly diagonal”, high-rank attention matrices where attention weights are concentrated locally in a sequence. Making up for this with a fixed sparsity pattern does not allow for data-dependent support of the attention matrix, necessary for zero-shot conversion. Finally, sparsity/LSH-based approaches that do not have a fixed sparsity pattern, improve on accuracy over low-rank approximations but suffer from significant overhead due to GPU incompatibility. MonarchAttention, on the other hand, achieves the best of both worlds: it is fast and hardware-friendly due to utilization of tensor cores for batched matmuls, while computing highly accurate approximations to the extent that it can directly replace softmax attention with no additional training.
 
We conclude this section by discussing closely related works. Dao et al. (2022b) propose FlashAttention, an IO-aware streaming algorithm for computing exact softmax attention. We show in Section 3 that each step of MonarchAttention can be written as a FlashAttention-like computation, allowing for similar IO savings to FlashAttention – in fact, we demonstrate that MonarchAttention achieves a strictly better worst-case IO complexity compared to FlashAttention. We also note that Dao et al. (2022b) propose to further accelerate FlashAttention using block butterfly attention masks, so MonarchAttention can be viewed as a generalization of block-sparse FlashAttention to more general Monarch matrices. Finally, MonarchAttention is closely related to Monarch Mixer (Fu et al., 2023), a mixer-type architecture (Tolstikhin et al., 2021) that utilizes Monarch instead of dense matrices for token and channel mixing. MonarchAttention also uses Monarch matrices for mixing tokens – however, it is based on the attention operation which is data-dependent, unlike Monarch Mixer.

Organization.

In Section 2, we discuss preliminaries on (softmax) attention, and Monarch matrices. In Section 3, we describe the MonarchAttention algorithm and implementation. In Section 4, we evaluate MonarchAttention in a variety of settings for zero-shot conversion to sub-quadratic attention and benchmark its implementation. In Section 5, we discuss limitations and future directions.

2 Preliminaries

Notation.

We use [N][N] to denote the index set {1,2,,N}\{1,2,\dots,N\}. We use ΔN\Delta^{N} to denote the (N1)(N-1) dimensional unit simplex, given by ΔN={𝒂N:𝒂0,𝟏N,𝒂=1}\Delta^{N}=\{\bm{a}\in\mathbb{R}^{N}:\bm{a}\succeq 0,\langle\bm{1}_{N},\bm{a}\rangle=1\}. We denote the m×mm\times m identity matrix by 𝑰m\bm{I}_{m}. We use the notation 𝑨ijk\bm{A}_{ijk} to denote an element of a 3-way tensor, and 𝑨i,:,k\bm{A}_{i,:,k} to denote a slice. We use δkl\delta_{kl} to denote the Kronecker delta that is 11 if k=lk=l and otherwise 0.

Softmax.

The softmax function NΔN\mathbb{R}^{N}\rightarrow\Delta^{N} maps NN real numbers to the (N1)(N-1)-dimensional unit simplex, and is defined as

[softmax(𝒛)]i:=exp(𝒛i)jexp(𝒛j),i[N].[\operatorname{softmax}(\bm{z})]_{i}:=\frac{\exp(\bm{z}_{i})}{\sum_{j}\exp(\bm{z}_{j})},\;\forall i\in[N]. (1)

An alternative definition (Blondel et al., 2019) is given by the following variational form:

softmax(𝒛):=argmax𝒂ΔN𝒂,𝒛+H(𝒂),\operatorname{softmax}(\bm{z}):=\operatorname*{arg\,max}_{\bm{a}\in\Delta^{N}}\;\langle\bm{a},\bm{z}\rangle+H(\bm{a}), (2)

where H(𝒂)=i𝒂ilog𝒂iH(\bm{a})=-\sum_{i}\bm{a}_{i}\log\bm{a}_{i} is Shannon entropy. See Appendix A for equivalence of (1) and (2).

Attention.

Given query, key, value matrices 𝑸,𝑲,𝑽N×d\bm{Q},\bm{K},\bm{V}\in\mathbb{R}^{N\times d}, where NN is the sequence length and dd is the head dimension, a single head of standard softmax attention111Typically, the 𝑸𝑲\bm{Q}\bm{K}^{\top} matrix is scaled by a factor of d1/2d^{-1/2}, but this can be absorbed into 𝑸\bm{Q}. computes

𝑶=softmax(𝑸𝑲)𝑽,\bm{O}=\operatorname{softmax}\left(\bm{Q}\bm{K}^{\top}\right)\bm{V}, (3)

where the softmax function is applied across rows. The computational complexity of attention is Θ(N2d)\Theta(N^{2}d) for each forward pass, because the matrices 𝑸,𝑲,𝑽\bm{Q},\bm{K},\bm{V} are data-dependent.

Monarch Matrices.

Given N=m×bN=m\times b for integers m,bm,b, we define a block rank-one matrix 𝑩N×N\bm{B}\in\mathbb{R}^{N\times N} as

𝑩=[𝑩11𝑩1m𝑩b1𝑩bm],where𝑩jk=𝑳jk𝑹kjm×b\bm{B}=\begin{bmatrix}\bm{B}_{11}&\dots&\bm{B}_{1m}\\ \vdots&\ddots&\vdots\\ \bm{B}_{b1}&\dots&\bm{B}_{bm}\end{bmatrix},\quad\mbox{where}\quad\bm{B}_{jk}=\bm{L}_{jk}\bm{R}_{kj}^{\top}\in\mathbb{R}^{m\times b}

for some 𝑳jkm,𝑹kjb\bm{L}_{jk}\in\mathbb{R}^{m},\bm{R}_{kj}\in\mathbb{R}^{b} for j[b]j\in[b] and k[m]k\in[m]. It follows that

𝑩=[𝑳1𝑳2𝑳b]𝑷[𝑹1𝑹2𝑹m],\bm{B}=\begin{bmatrix}\bm{L}_{1}^{\top}&&&\\ &\bm{L}_{2}^{\top}&&\\ &&\ddots&\\ &&&\bm{L}_{b}^{\top}\\ \end{bmatrix}\bm{P}\begin{bmatrix}\bm{R}_{1}&&&\\ &\bm{R}_{2}&&\\ &&\ddots&\\ &&&\bm{R}_{m}\\ \end{bmatrix},

where 𝑳jm×m\bm{L}_{j}\in\mathbb{R}^{m\times m} for j[b]j\in[b] and 𝑹kb×b\bm{R}_{k}\in\mathbb{R}^{b\times b} for k[m]k\in[m], and 𝑷N×N\bm{P}\in\mathbb{R}^{N\times N} is a “transpose”222𝑷𝒙\bm{P}\bm{x} corresponds to row-major reshaping 𝒙N\bm{x}\in\mathbb{R}^{N} to m×b\mathbb{R}^{m\times b}, transposing to b×m\mathbb{R}^{b\times m}, then row-major flattening back to N\mathbb{R}^{N}. See Appendix B for an example. permutation matrix whose (i+1)(i+1)th row is given by 𝒆σ(i)+1\bm{e}_{\sigma(i)+1} where

σ(i)=b(imodm)+im,i{0,,N1}.\sigma(i)=b\cdot(i\bmod m)+\left\lfloor\frac{i}{m}\right\rfloor,\quad i\in\{0,\dots,N-1\}.

Given the above, a Monarch matrix 𝑴N×N\bm{M}\in\mathbb{R}^{N\times N} is given by 𝑴=𝑷𝑩\bm{M}=\bm{P}^{\top}\bm{B} – in other words, it is a row-permuted block rank-one matrix. When m=b=Nm=b=\sqrt{N}, storing such a matrix requires only Θ(NN)\Theta(N\sqrt{N}) space, while matrix multiplication (matmul) with a matrix 𝑽N×d\bm{V}\in\mathbb{R}^{N\times d} can be computed efficiently in Θ(NNd)\Theta(N\sqrt{N}d) operations (as opposed to Θ(N2d)\Theta(N^{2}d) for dense matrices) with batched matmuls and transposes:

𝑴𝑽=𝑶,where𝑶b(l1)+j,v=k𝑳jkl𝒀jkv,𝒀jkv=i𝑹kji𝑽b(k1)+i,v\bm{M}\bm{V}=\bm{O},\quad\mbox{where}\quad\bm{O}_{b\cdot(l-1)+j,v}=\sum_{k}\bm{L}_{jkl}\bm{Y}_{jkv},\quad\bm{Y}_{jkv}=\sum_{i}\bm{R}_{kji}\bm{V}_{b\cdot(k-1)+i,v} (4)

for i,j[b],k,l[m]i,j\in[b],k,l\in[m]. A useful characterization of 𝑴\bm{M} is in block form:

𝑴=[𝑴11𝑴1m𝑴m1𝑴mm],\bm{M}=\begin{bmatrix}\bm{M}_{11}&\dots&\bm{M}_{1m}\\ \vdots&\ddots&\vdots\\ \bm{M}_{m1}&\dots&\bm{M}_{mm}\end{bmatrix}, (5)

where 𝑴lkb×b,[𝑴lk]ji=𝑳jkl𝑹kji,i,j[b],k,l[m]\bm{M}_{lk}\in\mathbb{R}^{b\times b},\;[\bm{M}_{lk}]_{ji}=\bm{L}_{jkl}\bm{R}_{kji},\;\forall i,j\in[b],k,l\in[m].

3 MonarchAttention

The main goal of MonarchAttention is to find a Monarch matrix 𝑴N×N\bm{M}\in\mathbb{R}^{N\times N} in o(N2d)o(N^{2}d) time such that 𝑴softmax(𝑸𝑲)\bm{M}\approx\operatorname{softmax}(\bm{Q}\bm{K}^{\top}). Then, we can approximately compute the output 𝑶=𝑴𝑽\bm{O}=\bm{M}\bm{V} using efficient matmul. We can do this by viewing the softmax operation as an optimization problem via its variational form (2), whose objective can be efficiently maximized with exact alternating steps when constrained to Monarch matrices. As shown in Figure 1, this yields highly accurate approximations to the softmax attention matrix.

Softmax Objective.

First, from (2) we can write

σ(𝑸𝑲)=argmax𝑨ΔN×Nf(𝑨;𝑸,𝑲):=𝑨,𝑸𝑲+H(𝑨),\sigma(\bm{Q}\bm{K}^{\top})=\operatorname*{arg\,max}_{\bm{A}\in\Delta^{N\times N}}\;f(\bm{A};\bm{Q},\bm{K}):=\langle\bm{A},\bm{Q}\bm{K}^{\top}\rangle+H(\bm{A}), (6)

where ΔN×N\Delta^{N\times N} denotes a matrix whose rows lie on ΔN\Delta^{N}, and H(𝑨)=i,j𝑨ijlog𝑨ijH(\bm{A})=-\sum_{i,j}\bm{A}_{ij}\log\bm{A}_{ij}. For a dense matrix 𝑨\bm{A}, computing f(𝑨;𝑸,𝑲)f(\bm{A};\bm{Q},\bm{K}) requires Θ(N2d)\Theta(N^{2}d) operations, which is the same as computing σ(𝑸𝑲)\sigma(\bm{Q}\bm{K}^{\top}) directly. However, we are interested in the case where 𝑨\bm{A} is a Monarch matrix 𝑴=𝑷𝑩\bm{M}=\bm{P}^{\top}\bm{B}:

f(𝑷𝑩;𝑸,𝑲)=𝑷𝑩,𝑸𝑲+H(𝑷𝑩)=𝑩,𝑸~𝑲+H(𝑩)=j,kf(𝑩jk;𝑸~j,𝑲k),\displaystyle f(\bm{P}^{\top}\bm{B};\bm{Q},\bm{K})=\langle\bm{P}^{\top}\bm{B},\bm{Q}\bm{K}^{\top}\rangle+H(\bm{P}^{\top}\bm{B})=\langle\bm{B},\widetilde{\bm{Q}}\bm{K}^{\top}\rangle+H(\bm{B})=\sum_{j,k}f(\bm{B}_{jk};\widetilde{\bm{Q}}_{j},\bm{K}_{k}),

where 𝑸~=𝑷𝑸\widetilde{\bm{Q}}=\bm{P}\bm{Q}, and 𝑸~jm×d,𝑲kb×d\widetilde{\bm{Q}}_{j}\in\mathbb{R}^{m\times d},\bm{K}_{k}\in\mathbb{R}^{b\times d} are the jjth and kkth block of rows of 𝑸~,𝑲\widetilde{\bm{Q}},\bm{K} respectively. Then, for each j[b],k[m]j\in[b],k\in[m] we evaluate ff on the rank-one matrix 𝑩jk=𝑳jk𝑹kj\bm{B}_{jk}=\bm{L}_{jk}\bm{R}_{kj}^{\top}:

f(𝑩jk;𝑸~j,𝑲k)\displaystyle f(\bm{B}_{jk};\widetilde{\bm{Q}}_{j},\bm{K}_{k}) =𝑳jk𝑹kj,𝑸~j𝑲kl,i𝑳jkl𝑹kjilog(𝑳jkl𝑹kji)\displaystyle=\langle\bm{L}_{jk}\bm{R}_{kj}^{\top},\widetilde{\bm{Q}}_{j}\bm{K}_{k}^{\top}\rangle-\sum_{l,i}\bm{L}_{jkl}\bm{R}_{kji}\log(\bm{L}_{jkl}\bm{R}_{kji})
=𝑸~j𝑳jk,𝑲k𝑹kjl,i𝑳jkl𝑹kjilog𝑳jkll,i𝑳jkl𝑹kjilog𝑹kji\displaystyle=\langle\widetilde{\bm{Q}}_{j}^{\top}\bm{L}_{jk},\bm{K}_{k}^{\top}\bm{R}_{kj}\rangle-\sum_{l,i}\bm{L}_{jkl}\bm{R}_{kji}\log\bm{L}_{jkl}-\sum_{l,i}\bm{L}_{jkl}\bm{R}_{kji}\log\bm{R}_{kji}
=𝑸~j𝑳jk,𝑲k𝑹kj(𝟏𝑹kj)H(𝑳jk)(𝟏𝑳jk)H(𝑹kj).\displaystyle=\langle\widetilde{\bm{Q}}_{j}^{\top}\bm{L}_{jk},\bm{K}_{k}^{\top}\bm{R}_{kj}\rangle-\left(\bm{1}^{\top}\bm{R}_{kj}\right)\cdot H(\bm{L}_{jk})-\left(\bm{1}^{\top}\bm{L}_{jk}\right)\cdot H(\bm{R}_{kj}).

Thus, for each j[b],k[m]j\in[b],k\in[m] we only need Θ((m+b)d)\Theta((m+b)d) operations to compute f(𝑩jk;𝑸~j,𝑲k)f(\bm{B}_{jk};\widetilde{\bm{Q}}_{j},\bm{K}_{k}) due to 𝑸~j𝑳jk\widetilde{\bm{Q}}_{j}^{\top}\bm{L}_{jk} and 𝑲k𝑹kj\bm{K}_{k}^{\top}\bm{R}_{kj}. We emphasize that the rank-one structure implies separability of the entropy term, meaning we can compute the entropy on 𝑳jk\bm{L}_{jk} and 𝑹kj\bm{R}_{kj} individually and avoid the need to materialize 𝑩jk\bm{B}_{jk}, which would incur Θ(mb)\Theta(mb) cost as opposed to Θ(m+b)\Theta(m+b). Since there are mbm\cdot b many 𝑩jk\bm{B}_{jk} matrices, we have in total Θ((m2b+b2m)d)\Theta((m^{2}b+b^{2}m)d) operations to compute f(𝑴;𝑸,𝑲)f(\bm{M};\bm{Q},\bm{K}), which for m=b=Nm=b=\sqrt{N} is Θ(NNd)\Theta(N\sqrt{N}d), improving on the dense computation by a factor of N\sqrt{N}.

Refer to caption
Figure 2: Zero-shot conversion of attention layers for image classification and question answering. We vary hyperparameters for various baselines to evaluate model quality vs compute tradeoff. Left. Top-5 accuracy vs. total attention FLOPs across all layers for ViT on ImageNet. Right. F1 score vs total attention FLOPs across all layers for RoBERTa on SQuAD.

Alternating Maximization with Constraints.

We will now explain the alternating maximization approach for optimizing ff. When 𝑳\bm{L} is fixed, the objective is concave in 𝑹\bm{R}, and vice-versa – therefore, we can derive closed form expressions via KKT conditions for 𝑳\bm{L} and 𝑹\bm{R} that maximize ff with one of 𝑳\bm{L} or 𝑹\bm{R} fixed, which will constitute a single update step. Evaluating (and therefore differentiating) ff w.r.t. 𝑳\bm{L} and 𝑹\bm{R} can be done in Θ(NNd)\Theta(N\sqrt{N}d) time, which will be the same complexity as one of these steps. For TT steps, this will require Θ(TNNd)\Theta(TN\sqrt{N}d) computation; provided that T=o(N)T=o(\sqrt{N}), this will still be sub-quadratic. However, the constraint 𝑴ΔN×N\bm{M}\in\Delta^{N\times N} presents a challenge in its current form, since this requires materializing 𝑴\bm{M} to check that each entry is non-negative. Instead, we use the fact that

𝑳j,:,lΔm,𝑹kjΔb,j[b],k,l[m]𝑴ΔN×N,\bm{L}_{j,:,l}\in\Delta^{m},\;\bm{R}_{kj}\in\Delta^{b},\;\forall j\in[b],\forall k,l\in[m]\implies\bm{M}\in\Delta^{N\times N},

i.e., slices of 𝑳,𝑹\bm{L},\bm{R} individually lying on the unit simplex is sufficient to enforce the constraint on 𝑴\bm{M}. This is easily seen from (5) – obviously if 𝑳jkl,𝑹kji0\bm{L}_{jkl},\bm{R}_{kji}\geq 0, then [𝑴lk]ji0[\bm{M}_{lk}]_{ji}\geq 0. Moreover, this also enforces the sum-to-one constraint, as rows of 𝑴\bm{M} sum as

k,i[𝑴lk]ji=(k𝑳jkl)(i𝑹kji)=1.\sum_{k,i}[\bm{M}_{lk}]_{ji}=\left(\sum_{k}\bm{L}_{jkl}\right)\left(\sum_{i}\bm{R}_{kji}\right)=1.

We now present the updates for 𝑳,𝑹\bm{L},\bm{R}. Initializing 𝑳jkl(0)=δkl\bm{L}_{jkl}^{(0)}=\delta_{kl} as block identity, we have

𝑹(t)\displaystyle\bm{R}^{(t)} =softmaxi(𝒁R(t)),𝒁R,kji(t)=𝜷R,kji(t)/𝒄R,kj(t),\displaystyle=\operatorname{softmax}_{i}(\bm{Z}_{R}^{(t)}),\quad\bm{Z}_{R,kji}^{(t)}=\bm{\beta}_{R,kji}^{(t)}/\bm{c}_{R,kj}^{(t)}, (7)
𝑳(t)\displaystyle\bm{L}^{(t)} =softmaxk(𝒁L(t)),𝒁L,jkl(t)=𝜷L,jkl(t)𝒄L,jk(t),\displaystyle=\operatorname{softmax}_{k}(\bm{Z}_{L}^{(t)}),\quad\bm{Z}_{L,jkl}^{(t)}=\bm{\beta}_{L,jkl}^{(t)}-\bm{c}_{L,jk}^{(t)}, (8)

for t[T]t\in[T], where softmaxk,softmaxi\operatorname{softmax}_{k},\operatorname{softmax}_{i} are applied along kk and ii index dimensions respectively, and

𝜷R,kji(t)\displaystyle\bm{\beta}_{R,kji}^{(t)} =v𝜶R,kjv(t)𝑲¯kiv,𝜶R,kjv(t)=l𝑳jkl(t1)𝑸¯jlv,𝒄R,kj(t)=l𝑳jkl(t1),\displaystyle=\sum_{v}\bm{\alpha}_{R,kjv}^{(t)}\overline{\bm{K}}_{kiv},\quad\bm{\alpha}_{R,kjv}^{(t)}=\sum_{l}\bm{L}_{jkl}^{(t-1)}\overline{\bm{Q}}_{jlv},\quad\bm{c}_{R,kj}^{(t)}=\sum_{l}\bm{L}_{jkl}^{(t-1)}, (9)
𝜷L,jkl(t)\displaystyle\bm{\beta}_{L,jkl}^{(t)} =v𝜶L,jkv(t)𝑸¯jlv,𝜶L,jkv(t)=i𝑹kji(t)𝑲¯kiv,𝒄L,jk(t)=i𝑹kji(t)log𝑹kji(t),\displaystyle=\sum_{v}\bm{\alpha}_{L,jkv}^{(t)}\overline{\bm{Q}}_{jlv},\quad\bm{\alpha}_{L,jkv}^{(t)}=\sum_{i}\bm{R}_{kji}^{(t)}\overline{\bm{K}}_{kiv},\quad\bm{c}_{L,jk}^{(t)}=\sum_{i}\bm{R}_{kji}^{(t)}\log\bm{R}_{kji}^{(t)}, (10)

where 𝑸¯jl,𝑲¯kid\overline{\bm{Q}}_{jl},\overline{\bm{K}}_{ki}\in\mathbb{R}^{d} are the (b(l1)+j)(b\cdot(l-1)+j)th and (b(k1)+i)(b\cdot(k-1)+i)th row of 𝑸\bm{Q} and 𝑲\bm{K} respectively. The full derivation is provided in Section C.1. After TT steps, we obtain the final Monarch approximation 𝑴(T)σ(𝑸𝑲)\bm{M}^{(T)}\approx\sigma(\bm{Q}\bm{K}^{\top}) with factors 𝑳(T)\bm{L}^{(T)} and 𝑹(T)\bm{R}^{(T)}, from which we output 𝑴(T)𝑽\bm{M}^{(T)}\bm{V} using (4). A naïve implementation of the full algorithm is provided in Section C.2. We discuss in Section C.3 how padding can be incorporated into MonarchAttention for when NN is not divisible by bb.

Refer to caption
Figure 3: Zero-shot conversion of attention layers for long sequence summarization. We vary the sequence length of the text to be summarized to evaluate model quality vs compute tradeoff. We report recall-based ROUGE-1 and ROUGE-L scores vs. total attention FLOPs across all layers for BART on BookSum-chapters.

Implementation.

To minimize data movement and memory usage on GPU, we do not materialize 𝑳\bm{L} or 𝑹\bm{R} in high-bandwidth memory (HBM). In addition to 𝑸,𝑲,𝑽,𝑶\bm{Q},\bm{K},\bm{V},\bm{O}, we only need to maintain states333The 𝜶\bm{\alpha} and 𝒄\bm{c} variables can share the same memory location as those corresponding to 𝑹\bm{R} can be derived from 𝑳\bm{L} (and vice-versa). 𝜶R(t),𝜶L(t),𝒄R(t),𝒄L(t)\bm{\alpha}_{R}^{(t)},\bm{\alpha}_{L}^{(t)},\bm{c}_{R}^{(t)},\bm{c}_{L}^{(t)} from (9) and (10), resulting in Θ(Nd)\Theta(Nd) additional memory. All other intermediate values are only materialized in on-chip SRAM, fusing all operations between the 𝜶\bm{\alpha} (and similarly 𝒄\bm{c}) variables. For instance, from the above update equations, the computation of 𝜶L(t)\bm{\alpha}_{L}^{(t)} from 𝜶R(t)\bm{\alpha}_{R}^{(t)} is given by

𝜶L,jkv(t)=softmaxi(v𝜶R,kjv(t)𝑲¯kiv𝒄R,kj(t))𝑲¯kiv,\bm{\alpha}_{L,jkv}^{(t)}=\operatorname{softmax}_{i}\left(\frac{\sum_{v}\bm{\alpha}_{R,kjv}^{(t)}\overline{\bm{K}}_{kiv}}{\bm{c}_{R,kj}^{(t)}}\right)\overline{\bm{K}}_{kiv},

which can be seen as a batched attention computation, meaning we can implement a FlashAttention-like kernel to reduce IO between HBM and on-chip SRAM. However, several aspects of this computation make it particularly IO-efficient. Besides the fact that 𝑲¯\overline{\bm{K}} acts as both the 𝑲\bm{K} and 𝑽\bm{V} matrices in (3), the effective sequence length is N\sqrt{N}. This eliminates the need for tiling along the sequence length, except for very long sequences having Θ(Nd)>S\Theta(\sqrt{N}d)>S, where SS is the size of on-chip SRAM. This means that we have an optimal IO complexity of Θ(Nd)\Theta(Nd) for a single call, as opposed to the worst-case O(N2d2/S)O(N^{2}d^{2}/S) complexity of FlashAttention. The computation of 𝜶R(t)\bm{\alpha}_{R}^{(t)} from 𝜶L(t)\bm{\alpha}_{L}^{(t)}, as well as the Monarch matmul, can be written in a similar fashion. Based on this, MonarchAttention not only achieves significant speed-up over FlashAttention for longer sequences, but also for shorter ones. Python-like code for MonarchAttention is given in Figure 4.

def al_cl_kernel(aR, cR, Kb): # Computes aL, cL from aR, cR
R = softmax(bmm(aR, Kb.transpose(1, 2)) / cR[:, :, None], dim=2)
cL = sum(R * log(R), dim=2).transpose(0, 1)
aL = bmm(R, Kb).transpose(0, 1)
return aL, cL
\pardef ar_cr_kernel(aL, cL, Qb): # Computes aR, cR from aL, cL
L = softmax(bmm(aL, Qb.transpose(1, 2)) - cL[:, :, None], dim=1)
cR = sum(L, dim=2).transpose(0, 1)
aR = bmm(L, Qb).transpose(0, 1)
return aR, cR
\pardef al_y_cl_kernel(aR, cR, Kb, Vb): # Fuse al_cl_kernel + Monarch matmul 1st step
R = softmax(bmm(aR, Kb.transpose(1, 2)) / cR[:, :, None], dim=2)
cL = sum(R * log(R), dim=2).transpose(0, 1)
aL = bmm(R, Kb).transpose(0, 1)
y = bmm(R, Vb).transpose(0, 1)
return aL, y, cL
\pardef z_kernel(aL, y, cL, Qb): # Monarch matmul 2nd step
L = softmax(bmm(Qb, aL.transpose(1, 2)) - cL[:, None, :], dim=2)
z = bmm(L, y).transpose(0, 1)
return z
\pardef monarch_attention(Q, K, V, T): # Q, K, V: (N, d), T: number of steps
Qb = Q.reshape(m, b, d).transpose(0, 1)
Kb = K.reshape(m, b, d)
Vb = V.reshape(m, b, d)
aR = Q.reshape(m, b, d)
cR = ones(m, b)
\parfor t in range(T-1):
aL, cL = al_cl_kernel(aR, cR, Kb)
aR, cR = ar_cr_kernel(aL, cL, Qb)
\paraL, y, cL = al_y_cl_kernel(aR, cR, Kb, Vb)
z = z_kernel(aL, y, cL, Qb)
o = z.reshape(N, d)
return o
Figure 4: Python-like code for MonarchAttention. Each kernel materializes all intermediate arrays in SRAM to reduce data movement.

4 Experiments

Refer to caption
(a) Softmax
Refer to caption
(b) MonarchAttention
Refer to caption
(c) Nyströmformer
Figure 5: Visual quality of generated images for zero-shot conversion of attention layers. Example images generated by with softmax (left), MonarchAttention (middle), and Nyströmformer (right). Only the first half of the attention layers of DiT are replaced.

In this section, we evaluate the zero-shot performance (no additional training) of MonarchAttention for converting pre-trained/fine-tuned transformer attention layers to sub-quadratic attention in four different model/task settings. We compare with previous low-rank attention methods (Katharopoulos et al., 2020; Choromanski et al., 2021; Xiong et al., 2021; Qin et al., 2022); see Section D.1 for more details on the baselines. We specifically exclude low-rank methods with learnable components (Wang et al., 2020; Zhang et al., 2024), since we are focused on the zero-shot setting, as well as sparsity/LSH-based approaches (Kitaev et al., 2020; Daras et al., 2020; Chen et al., 2021; Han et al., 2024), since these do not admit efficient implementations on current GPUs. We also benchmark our fast implementation of MonarchAttention, comparing with FlashAttention-2 (Dao, 2023).

Image Classification with Vision Transformer.

We convert all 12 attention layers, each having 12 heads with sequence length N=197N=197 and head dimension d=64d=64, of the 87M parameter ViT-B (Dosovitskiy et al., 2021) that has been pre-trained on ImageNet-21K (Deng et al., 2009) and fine-tuned on ImageNet-1K (Russakovsky et al., 2015) for image classification. To evaluate the performance at different FLOP counts, we vary the number of steps TT for MonarchAttention, and vary the rank for Performer and Nyströmformer; see Section D.2 for more details on the set-up. The results are shown in the left panel of Figure 2. MonarchAttention achieves significant improvement over other baselines – compared to the original softmax attention, MonarchAttention loses only 5% accuracy to reduce attention FLOPs by 80%, or matches the performance to reduce attention FLOPs by 50%.

Question Answering with Encoder-Only Transformer.

We convert the initial 4 and final 4 layers of the 12 attention layers, each having 12 heads with sequence length N=384N=384 and head dimension d=64d=64, of the 125M parameter RoBERTa-B (Liu et al., 2019) that has been pre-trained on a large English corpus and fine-tuned on SQuAD1.1 (Rajpurkar et al., 2016) for question answering. As before, to evaluate the performance at different FLOP counts, we vary the block size bb for MonarchAttention, and vary the rank for Performer and Nyströmformer; see Section D.3 for more details on the set-up. The results are shown in the right panel of Figure 2. Once again, MonarchAttention achieves significant improvement over other baselines – compared to the original softmax attention, MonarchAttention loses only 10 points in F1 score to reduce attention FLOPs by 60%, or matches the performance to reduce attention FLOPs by 35%.

Summarization with Encoder-Decoder Transformer.

We convert all 6 attention layers, each having 12 heads with head dimension d=64d=64, in the encoder of the 139M parameter BART-B (Lewis et al., 2020) that has been pre-trained on a large English corpus and fine-tuned on BookSum-chapters (Kryściński et al., 2022) for summarization. We only convert the encoder model and leave the decoder intact. To evaluate the benefits of sub-quadratic attention for processing longer sequences, we truncate the text to be summarized to various sequence lengths NN for each method; see Section D.4 for more details on the set-up. The results are shown in Figure 3. We see that MonarchAttention achieves a strictly better ROUGE score (Lin, 2004) vs. FLOPs tradeoff than even softmax attention, due to accurate and efficient processing of longer sequences. In particular, the N=8192N=8192 MonarchAttention model improves on the N=2048N=2048 softmax attention model by 0.750.75 on ROUGE-1 and 0.50.5 on ROUGE-L with slightly fewer FLOPs, while the N=8192N=8192 Nyströmformer model with similar FLOPs does strictly worse than softmax.

Table 1: Quantitative results for zero-shot conversion of attention layers for image generation. We report FID and sFID (using the original softmax attention model as reference) of DiT when replacing all or half of the attention layers.
Layers Replaced Method Total Attention FLOPs (10910^{9}) FID (\downarrow) sFID (\downarrow)
Softmax 8.46
All Nyströmformer 3.30 5.97 13.47
MonarchAttention 3.44 2.82 5.09
First Half Nyströmformer 5.88 8.17 19.01
MonarchAttention 5.95 0.39 0.66
Second Half Nyströmformer 5.88 6.76 13.58
MonarchAttention 5.95 1.98 3.36

Image Generation with Diffusion Transformer.

We convert a subset of the 28 attention layers, each having 16 heads with sequence length N=256N=256 and head dimension d=72d=72, of the 675M parameter DiT-XL (Peebles and Xie, 2023) that has been trained on ImageNet (Deng et al., 2009). We consider replacing either all layers, the first 1414 layers, or the last 1414 layers; see Section D.5 for more details on the set-up. Examples of generated images with each method for replacing the first 14 layers are shown in Figure 5, where MonarchAttention produces clear images resembling those of softmax attention, while the Nyströmformer ones are extremely noisy. We also quantitatively evaluate the (s)FID scores (Heusel et al., 2017) of MonarchAttention compared with Nyströmformer, using images generated with the original softmax attention model as reference – the results are reported in Table 1. MonarchAttention noticeably outperforms Nyströmformer with similar FLOP count. In particular, using MonarchAttention in the first half of the DiT layers results in extremely small FID and sFID from the softmax attention model’s images, while reducing FLOPs by nearly 30%30\%.

Benchmarking MonarchAttention.

Finally, we validate that the computational/IO complexity reduction achieved by MonarchAttention translates into actual speed-ups on the NVIDIA A40, a modern GPU. We implement the pseudo-code described in Section 3 as four separate Triton kernels and compare it against the fastest available implementations of FlashAttention-2 – either the Triton implementation or PyTorch’s scaled_dot_product_attention, which calls the CUDA backend for FlashAttention-2. Using a fixed batch size EE, number of heads HH, and head dimension dd, we sweep the input sequence length NN and compare the run-times of FlashAttention-2 and MonarchAttention (with b=Nb=\sqrt{N} and T=1T=1) in Figure 6 (left). As sequence length increases, MonarchAttention consistently outperforms FlashAttention-2, notably achieving up to 8.2×8.2\times speed-up with N=16384N=16384. To highlight gains for shorter sequences, we implement MonarchAttention as a single fully-fused Triton kernel, with a single thread block computing a single head. For fixed sequence length N=256N=256, number of heads HH, and head dimension dd, we sweep the batch size EE and compare the run-time of the fully-fused MonarchAttention kernel against FlashAttention-2 in Figure 6 (right). With smaller batch sizes, we have low utilization of hardware, since we compute a single head with a single thread block. However, as we increase the batch size, MonarchAttention achieves up to 1.4×1.4\times speed-up over FlashAttention-2.

Refer to caption
Figure 6: Run-times of MonarchAttention and FlashAttention-2 across various sequence lengths. Normalized runtime (1=slowest,0=fastest1=\mbox{slowest},0=\mbox{fastest}) of MonarchAttention and FlashAttention-2 on NVIDIA A40 GPU. Left: sweep of sequence length NN with E=1E=1, H=12H=12, and d=64d=64. Right: sweep of batch size EE with N=256N=256, H=12H=12, and d=64d=64.

5 Conclusion

To conclude, we discuss several limitations and future directions for this work. First, the implementation of MonarchAttention can be improved, particularly with more recent GPUs having expanded capabilities such as distributed shared memory found on Hopper architectures. Next, while we have presented MonarchAttention as a direct replacement for softmax attention with no additional training, in theory it can also accelerate training from scratch or achieve better results by converting existing models with some fine-tuning. Moreover, MonarchAttention currently does not support causal masking, which could accelerate training of language models trained with next-token prediction. Finally, we believe that viewing fundamental operations such as softmax through their variational form is a powerful idea that can be generalized, allowing for more generic structured approximations beyond Monarch approximations to softmax attention.

Acknowledgement

LB and CY were supported in part by NSF CAREER award CCF-1845076 and an Intel Early Career award. LB was also supported by the University of Michigan Crosby award. CY and AX were supported by NSF CCF 312842. PA and CL were supported in part by COGNISENSE, one of seven centers in JUMP 2.0, a Semiconductor Research Corporation (SRC) program sponsored by DARPA. We thank Samet Oymak (University of Michigan) for discussion and use of computational resources provided by an Amazon Research Award on Foundation Model Development.

References

  • Blondel et al. (2019) Mathieu Blondel, Andre Martins, and Vlad Niculae. Learning classifiers with fenchel-young losses: Generalized entropies, margins, and algorithms. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 606–615. PMLR, 2019.
  • Brown et al. (2020) Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Chen et al. (2021) Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher Ré. Scatterbrain: Unifying sparse and low-rank attention. Advances in Neural Information Processing Systems, 34:17413–17426, 2021.
  • Chen et al. (2022) Beidi Chen, Tri Dao, Kaizhao Liang, Jiaming Yang, Zhao Song, Atri Rudra, and Christopher Re. Pixelated butterfly: Simple and efficient sparse training for neural network models. International Conference on Learning Representations, 2022.
  • Chi et al. (2019) Yuejie Chi, Yue M Lu, and Yuxin Chen. Nonconvex optimization meets low-rank matrix factorization: An overview. IEEE Transactions on Signal Processing, 67(20):5239–5269, 2019.
  • Chihara (2014) Theodore S Chihara. An Introduction to Orthogonal Polynomials. Courier Corporation, 2014.
  • Child et al. (2019) Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
  • Choromanski et al. (2021) Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. International Conference on Learning Representations, 2021.
  • Dao (2023) Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
  • Dao et al. (2019) Tri Dao, Albert Gu, Matthew Eichhorn, Atri Rudra, and Christopher Ré. Learning fast algorithms for linear transforms using butterfly factorizations. In International conference on machine learning, pages 1517–1527. PMLR, 2019.
  • Dao et al. (2022a) Tri Dao, Beidi Chen, Nimit S Sohoni, Arjun Desai, Michael Poli, Jessica Grogan, Alexander Liu, Aniruddh Rao, Atri Rudra, and Christopher Ré. Monarch: Expressive structured matrices for efficient and accurate training. In International Conference on Machine Learning, pages 4690–4721. PMLR, 2022a.
  • Dao et al. (2022b) Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in neural information processing systems, 35:16344–16359, 2022b.
  • Daras et al. (2020) Giannis Daras, Nikita Kitaev, Augustus Odena, and Alexandros G Dimakis. Smyrf-efficient attention using asymmetric clustering. Advances in Neural Information Processing Systems, 33:6476–6489, 2020.
  • Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
  • Dosovitskiy et al. (2021) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. International Conference on Learning Representations, 2021.
  • Fu et al. (2023) Dan Fu, Simran Arora, Jessica Grogan, Isys Johnson, Evan Sabri Eyuboglu, Armin Thomas, Benjamin Spector, Michael Poli, Atri Rudra, and Christopher Ré. Monarch mixer: A simple sub-quadratic gemm-based architecture. Advances in Neural Information Processing Systems, 36:77546–77603, 2023.
  • Han et al. (2024) Insu Han, R Jayaram, A Karbasi, V Mirrokno, D Woodruff, and A Zandieh. Hyperattention: Long-context attention in near-linear time. International Conference on Learning Representations, 2024.
  • Heusel et al. (2017) Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in neural information processing systems, 30, 2017.
  • Kailath et al. (1979) Thomas Kailath, Sun-Yuan Kung, and Martin Morf. Displacement ranks of matrices and linear equations. Journal of Mathematical Analysis and Applications, 68(2):395–407, 1979.
  • Katharopoulos et al. (2020) Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pages 5156–5165. PMLR, 2020.
  • Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kitaev et al. (2020) Nikita Kitaev, Lukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. International Conference on Learning Representations, 2020.
  • Kryściński et al. (2022) Wojciech Kryściński, Nazneen Rajani, Divyansh Agarwal, Caiming Xiong, and Dragomir Radev. Booksum: A collection of datasets for long-form narrative summarization. In Findings of the Association for Computational Linguistics: EMNLP 2022, pages 6536–6558, 2022.
  • Lewis et al. (2020) Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Veselin Stoyanov, and Luke Zettlemoyer. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 7871–7880, 2020.
  • Lhoest et al. (2021) Quentin Lhoest, Albert Villanova del Moral, Yacine Jernite, Abhishek Thakur, Patrick von Platen, Suraj Patil, Julien Chaumond, Mariama Drame, Julien Plu, Lewis Tunstall, et al. Datasets: A community library for natural language processing. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pages 175–184, 2021.
  • Lin (2004) Chin-Yew Lin. Rouge: A package for automatic evaluation of summaries. In Text summarization branches out, pages 74–81, 2004.
  • Liu et al. (2019) Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019.
  • Pan (2001) Victor Pan. Structured matrices and polynomials: unified superfast algorithms. Springer Science & Business Media, 2001.
  • Peebles and Xie (2023) William Peebles and Saining Xie. Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF international conference on computer vision, pages 4195–4205, 2023.
  • Qin et al. (2022) Zhen Qin, Weixuan Sun, Hui Deng, Dongxu Li, Yunshen Wei, Baohong Lv, Junjie Yan, Lingpeng Kong, and Yiran Zhong. cosformer: Rethinking softmax in attention. International Conference on Learning Representations, 2022.
  • Radford et al. (2023) Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, and Ilya Sutskever. Robust speech recognition via large-scale weak supervision. In International conference on machine learning, pages 28492–28518. PMLR, 2023.
  • Rajpurkar et al. (2016) Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. Squad: 100,000+ questions for machine comprehension of text. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pages 2383–2392, 2016.
  • Russakovsky et al. (2015) Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV), 115(3):211–252, 2015. doi: 10.1007/s11263-015-0816-y.
  • Tolstikhin et al. (2021) Ilya O Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, Daniel Keysers, Jakob Uszkoreit, et al. Mlp-mixer: An all-mlp architecture for vision. Advances in neural information processing systems, 34:24261–24272, 2021.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wang et al. (2020) Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
  • Wolf et al. (2019) Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Rémi Louf, Morgan Funtowicz, et al. Huggingface’s transformers: State-of-the-art natural language processing. arXiv preprint arXiv:1910.03771, 2019.
  • Xiong et al. (2021) Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, and Vikas Singh. Nyströmformer: A nyström-based algorithm for approximating self-attention. In Proceedings of the AAAI conference on artificial intelligence, volume 35, pages 14138–14148, 2021.
  • Zhang et al. (2024) Michael Zhang, Kush Bhatia, Hermann Kumbong, and Christopher Re. The hedgehog & the porcupine: Expressive linear attentions with softmax mimicry. International Conference on Learning Representations, 2024.

Appendix A Equivalence of Softmax Definitions

Consider the optimization problem in (2):

max𝒂f(𝒂):=i𝒂i𝒛ii𝒂ilog𝒂is.t.𝒂i0,i𝒂i=1.\max_{\bm{a}}\quad f(\bm{a}):=\sum_{i}\bm{a}_{i}\bm{z}_{i}-\sum_{i}\bm{a}_{i}\log\bm{a}_{i}\quad\mathrm{s.t.}\quad\bm{a}_{i}\geq 0,\quad\sum_{i}\bm{a}_{i}=1.

From the KKT stationarity condition, we have

𝒂i(f(𝒂)+λ(1i𝒂i)+i𝝁i𝒂i)=0𝒛i(1+log𝒂i)λ+𝝁i=0,\frac{\partial}{\partial\bm{a}_{i}}\left(f(\bm{a})+\lambda\left(1-\sum_{i}\bm{a}_{i}\right)+\sum_{i}\bm{\mu}_{i}\bm{a}_{i}\right)=0\implies\bm{z}_{i}-(1+\log\bm{a}_{i})-\lambda+\bm{\mu}_{i}=0,

where λ,𝝁N\lambda\in\mathbb{R},\bm{\mu}\in\mathbb{R}^{N} are dual variables. From complementary slackness 𝒂i𝝁i=0\bm{a}_{i}\bm{\mu}_{i}=0 and the fact that log𝒂i\log\bm{a}_{i} is not defined for 𝒂i=0\bm{a}_{i}=0, we must have μi=0\mu_{i}=0, which gives

log𝒂i=𝒛iλ1𝒂i=exp(𝒛i)/exp(λ+1).\log\bm{a}_{i}=\bm{z}_{i}-\lambda-1\implies\bm{a}_{i}=\exp(\bm{z}_{i})/\exp(\lambda+1).

Finally, from the constraint i𝒂i=1\sum_{i}\bm{a}_{i}=1, we must have exp(λ+1)=jexp(𝒛j)\exp(\lambda+1)=\sum_{j}\exp(\bm{z}_{j}), which gives the form of softmax in (1).

Appendix B Monarch Background

We provide an example of the transpose permutation 𝑷\bm{P} in Section 2. Recall that applying 𝑷N×N\bm{P}\in\mathbb{R}^{N\times N} to a vector 𝒙N\bm{x}\in\mathbb{R}^{N} corresponds to row-major reshaping 𝒙\bm{x} to m×b\mathbb{R}^{m\times b}, transposing to b×m\mathbb{R}^{b\times m}, then row-major flattening back to N\mathbb{R}^{N}. This is equivalent to applying a permutation matrix whose (i+1)(i+1)th row is given by 𝒆σ(i)+1\bm{e}_{\sigma(i)+1} where

σ(i)=b(imodm)+im,i{0,,N1}.\sigma(i)=b\cdot(i\bmod m)+\left\lfloor\frac{i}{m}\right\rfloor,\quad i\in\{0,\dots,N-1\}.

As an illustrative example, let N=6N=6, b=3b=3, and m=2m=2. The action of 𝑷\bm{P} is given by the following steps:

[123456]reshape 2×3[123456]transpose[142536]flatten[142536].\begin{bmatrix}1\\ 2\\ 3\\ 4\\ 5\\ 6\end{bmatrix}\overset{\mathrm{reshape}\;2\times 3}{\parbox{56.9055pt}{ \leavevmode\hbox to57.31pt{\vbox to0.4pt{\pgfpicture\makeatletter\hbox{\thinspace\lower-0.2pt\hbox to0.0pt{\pgfsys@beginscope\pgfsys@invoke{ }\definecolor{pgfstrokecolor}{rgb}{0,0,0}\pgfsys@color@rgb@stroke{0}{0}{0}\pgfsys@invoke{ }\pgfsys@color@rgb@fill{0}{0}{0}\pgfsys@invoke{ }\pgfsys@setlinewidth{0.4pt}\pgfsys@invoke{ }\nullfont\hbox to0.0pt{\pgfsys@beginscope\pgfsys@invoke{ }{{}{{}}{} {{}{}}{}{}{{ {\pgfsys@beginscope\pgfsys@setlinewidth{0.32pt}\pgfsys@setdash{}{0.0pt}\pgfsys@roundcap\pgfsys@roundjoin{} {}{}{} {}{}{} \pgfsys@moveto{-1.19998pt}{1.59998pt}\pgfsys@curveto{-1.09998pt}{0.99998pt}{0.0pt}{0.09999pt}{0.29999pt}{0.0pt}\pgfsys@curveto{0.0pt}{-0.09999pt}{-1.09998pt}{-0.99998pt}{-1.19998pt}{-1.59998pt}\pgfsys@stroke\pgfsys@endscope}} }{}{}{{}}\pgfsys@moveto{0.0pt}{0.0pt}\pgfsys@lineto{56.44553pt}{0.0pt}\pgfsys@stroke\pgfsys@invoke{ }{{}{{}}{}{}{{}}{{{}}{{{}}{\pgfsys@beginscope\pgfsys@invoke{ }\pgfsys@transformcm{1.0}{0.0}{0.0}{1.0}{56.44553pt}{0.0pt}\pgfsys@invoke{ }\pgfsys@invoke{ }\pgfsys@invoke{ }\pgfsys@endscope}}{{}}}} } \pgfsys@invoke{ }\pgfsys@endscope{}{}{}\hss}\pgfsys@discardpath\pgfsys@invoke{ }\pgfsys@endscope\hss}}\endpgfpicture}}}}\begin{bmatrix}1&2&3\\ 4&5&6\end{bmatrix}\overset{\mathrm{transpose}}{\parbox{42.67912pt}{ \leavevmode\hbox to43.08pt{\vbox to0.4pt{\pgfpicture\makeatletter\hbox{\thinspace\lower-0.2pt\hbox to0.0pt{\pgfsys@beginscope\pgfsys@invoke{ }\definecolor{pgfstrokecolor}{rgb}{0,0,0}\pgfsys@color@rgb@stroke{0}{0}{0}\pgfsys@invoke{ }\pgfsys@color@rgb@fill{0}{0}{0}\pgfsys@invoke{ }\pgfsys@setlinewidth{0.4pt}\pgfsys@invoke{ }\nullfont\hbox to0.0pt{\pgfsys@beginscope\pgfsys@invoke{ }{{}{{}}{} {{}{}}{}{}{}{}{}{{}}\pgfsys@moveto{0.0pt}{0.0pt}\pgfsys@lineto{42.21913pt}{0.0pt}\pgfsys@stroke\pgfsys@invoke{ }{{}{{}}{}{}{{}}{{{}}{{{}}{\pgfsys@beginscope\pgfsys@invoke{ }\pgfsys@transformcm{1.0}{0.0}{0.0}{1.0}{42.21913pt}{0.0pt}\pgfsys@invoke{ }\pgfsys@invoke{ }\pgfsys@invoke{ }\pgfsys@endscope}}{{}}}} } \pgfsys@invoke{ }\pgfsys@endscope{}{}{}\hss}\pgfsys@discardpath\pgfsys@invoke{ }\pgfsys@endscope\hss}}\endpgfpicture}}}}\begin{bmatrix}1&4\\ 2&5\\ 3&6\end{bmatrix}\overset{\mathrm{flatten}}{\parbox{36.98866pt}{ \leavevmode\hbox to37.39pt{\vbox to0.4pt{\pgfpicture\makeatletter\hbox{\thinspace\lower-0.2pt\hbox to0.0pt{\pgfsys@beginscope\pgfsys@invoke{ }\definecolor{pgfstrokecolor}{rgb}{0,0,0}\pgfsys@color@rgb@stroke{0}{0}{0}\pgfsys@invoke{ }\pgfsys@color@rgb@fill{0}{0}{0}\pgfsys@invoke{ }\pgfsys@setlinewidth{0.4pt}\pgfsys@invoke{ }\nullfont\hbox to0.0pt{\pgfsys@beginscope\pgfsys@invoke{ }{{}{{}}{} {{}{}}{}{}{}{}{}{{}}\pgfsys@moveto{0.0pt}{0.0pt}\pgfsys@lineto{36.5286pt}{0.0pt}\pgfsys@stroke\pgfsys@invoke{ }{{}{{}}{}{}{{}}{{{}}{{{}}{\pgfsys@beginscope\pgfsys@invoke{ }\pgfsys@transformcm{1.0}{0.0}{0.0}{1.0}{36.5286pt}{0.0pt}\pgfsys@invoke{ }\pgfsys@invoke{ }\pgfsys@invoke{ }\pgfsys@endscope}}{{}}}} } \pgfsys@invoke{ }\pgfsys@endscope{}{}{}\hss}\pgfsys@discardpath\pgfsys@invoke{ }\pgfsys@endscope\hss}}\endpgfpicture}}}}\begin{bmatrix}1\\ 4\\ 2\\ 5\\ 3\\ 6\end{bmatrix}.

In matrix form, we have

𝑷=[100000000100010000000010001000000001].\bm{P}=\begin{bmatrix}1&0&0&0&0&0\\ 0&0&0&1&0&0\\ 0&1&0&0&0&0\\ 0&0&0&0&1&0\\ 0&0&1&0&0&0\\ 0&0&0&0&0&1\end{bmatrix}.

Appendix C Details for MonarchAttention

C.1 Updates

Derivatives.

We evaluate ff with 𝑨=𝑴\bm{A}=\bm{M} as a Monarch matrix. Using (5), we have

f(𝑴;𝑸,𝑲)\displaystyle f(\bm{M};\bm{Q},\bm{K}) =𝑳jkl𝑹kji𝑸¯jlv𝑲¯kiv𝑳jkl𝑹kjilog(𝑳jkl𝑹kji)\displaystyle=\sum\bm{L}_{jkl}\bm{R}_{kji}\overline{\bm{Q}}_{jlv}\overline{\bm{K}}_{kiv}-\sum\bm{L}_{jkl}\bm{R}_{kji}\log(\bm{L}_{jkl}\bm{R}_{kji})
=𝑳jkl𝑹kji𝑸¯jlv𝑲¯kiv𝑳jkl𝑹kjilog𝑹kji𝑹kji𝑳jkllog𝑳jkl.\displaystyle=\sum\bm{L}_{jkl}\bm{R}_{kji}\overline{\bm{Q}}_{jlv}\overline{\bm{K}}_{kiv}-\sum\bm{L}_{jkl}\bm{R}_{kji}\log\bm{R}_{kji}-\sum\bm{R}_{kji}\bm{L}_{jkl}\log\bm{L}_{jkl}.

The derivatives of ff w.r.t. each factor are given by

f(𝑴;𝑸,𝑲)𝑳jkl\displaystyle\frac{\partial f(\bm{M};\bm{Q},\bm{K})}{\partial\bm{L}_{jkl}} =𝜷L,jkl𝒄L,jk(1+log𝑳jkl)𝜸L,jk,\displaystyle=\bm{\beta}_{L,jkl}-\bm{c}_{L,jk}-(1+\log\bm{L}_{jkl})\bm{\gamma}_{L,jk}, (11)
f(𝑴;𝑸,𝑲)𝑹kji\displaystyle\frac{\partial f(\bm{M};\bm{Q},\bm{K})}{\partial\bm{R}_{kji}} =𝜷R,kji𝜸R,kj(1+log𝑹kji)𝒄R,kj,\displaystyle=\bm{\beta}_{R,kji}-\bm{\gamma}_{R,kj}-(1+\log\bm{R}_{kji})\bm{c}_{R,kj}, (12)

where

𝜷L,jkl\displaystyle\bm{\beta}_{L,jkl} =v𝑸¯jlvi(𝑹kji𝑲¯kiv),𝒄L,jk=i𝑹kjilog𝑹kji,𝜸L,jk=i𝑹kji,\displaystyle=\sum_{v}\overline{\bm{Q}}_{jlv}\sum_{i}(\bm{R}_{kji}\overline{\bm{K}}_{kiv}),\quad\bm{c}_{L,jk}=\sum_{i}\bm{R}_{kji}\log\bm{R}_{kji},\quad\bm{\gamma}_{L,jk}=\sum_{i}\bm{R}_{kji},
𝜷R,kji\displaystyle\bm{\beta}_{R,kji} =v𝑲¯kivl(𝑳jkl𝑸¯jlv),𝜸R,kj=l𝑳jkllog(𝑳jkl),𝜸R,kj=l𝑳jkl.\displaystyle=\sum_{v}\overline{\bm{K}}_{kiv}\sum_{l}(\bm{L}_{jkl}\overline{\bm{Q}}_{jlv}),\quad\bm{\gamma}_{R,kj}=\sum_{l}\bm{L}_{jkl}\log(\bm{L}_{jkl}),\quad\bm{\gamma}_{R,kj}=\sum_{l}\bm{L}_{jkl}.

We derive updates for each factor based on maximizing ff with the other factor fixed.

𝑳\bm{L} update.

First, we fix 𝑹Δm×b×b\bm{R}\in\Delta^{m\times b\times b} and consider

max𝑳f(𝑴;𝑸,𝑲)s.t.𝑳jkl0,k𝑳jkl=1.\max_{\bm{L}}\quad f(\bm{M};\bm{Q},\bm{K})\quad\mathrm{s.t.}\quad\bm{L}_{jkl}\geq 0,\quad\sum_{k}\bm{L}_{jkl}=1.

From the KKT stationarity condition, we have

𝑳jkl(f(𝑴;𝑸,𝑲)+𝝀L,jl(1k𝑳jkl)+𝝁L,jkl𝑳jkl)=0\displaystyle\frac{\partial}{\partial\bm{L}_{jkl}}\left(f(\bm{M};\bm{Q},\bm{K})+\sum\bm{\lambda}_{L,jl}\left(1-\sum_{k}\bm{L}_{jkl}\right)+\sum\bm{\mu}_{L,jkl}\bm{L}_{jkl}\right)=0

where 𝝀Lb×m,𝝁Lb×m×m\bm{\lambda}_{L}\in\mathbb{R}^{b\times m},\bm{\mu}_{L}\in\mathbb{R}^{b\times m\times m} are dual variables. Along with (11), we have

𝜷L,jkl𝒄L,jk(1+log𝑳jkl)𝜸L,jk𝝀L,jl+𝝁L,jkl=0.\displaystyle\bm{\beta}_{L,jkl}-\bm{c}_{L,jk}-(1+\log\bm{L}_{jkl})\bm{\gamma}_{L,jk}-\bm{\lambda}_{L,jl}+\bm{\mu}_{L,jkl}=0.

Now, from complementary slackness 𝝁L,jkl𝑳jkl=0\bm{\mu}_{L,jkl}\bm{L}_{jkl}=0 and the fact that log𝑳jkl\log\bm{L}_{jkl} is not defined for 𝑳jkl=0\bm{L}_{jkl}=0, we must have 𝝁L,jkl=0\bm{\mu}_{L,jkl}=0. Moreover, since 𝑹Δm×b×b\bm{R}\in\Delta^{m\times b\times b}, we have 𝜸L,jk=1\bm{\gamma}_{L,jk}=1. Altogether, we have

log𝑳jkl=𝜷L,jkl𝒄L,jk𝝀L,jl1𝑳jkl=exp(𝜷L,jkl𝒄L,jk)exp(𝝀L,jl+1).\log\bm{L}_{jkl}=\bm{\beta}_{L,jkl}-\bm{c}_{L,jk}-\bm{\lambda}_{L,jl}-1\implies\bm{L}_{jkl}=\frac{\exp(\bm{\beta}_{L,jkl}-\bm{c}_{L,jk})}{\exp(\bm{\lambda}_{L,jl}+1)}.

Finally, from the constraint k𝑳jkl=1\sum_{k}\bm{L}_{jkl}=1, we must have exp(𝝀L,jl+1)=kexp(𝜷L,jkl𝒄L,jk)\exp(\bm{\lambda}_{L,jl}+1)=\sum_{k}\exp(\bm{\beta}_{L,jkl}-\bm{c}_{L,jk}), which gives the final closed form update:

𝑳=softmaxk(𝒁L),𝒁L,jkl=𝜷L,jkl𝒄L,jk,\bm{L}=\operatorname{softmax}_{k}(\bm{Z}_{L}),\quad\bm{Z}_{L,jkl}=\bm{\beta}_{L,jkl}-\bm{c}_{L,jk},

where softmaxk\operatorname{softmax}_{k} is applied along the kk index dimension.

𝑹\bm{R} update.

Similarly, we fix 𝑳Δb×m×m\bm{L}\in\Delta^{b\times m\times m} and consider

max𝑹f(𝑴;𝑸,𝑲)s.t.𝑹kji0,i𝑹kji=1.\max_{\bm{R}}\quad f(\bm{M};\bm{Q},\bm{K})\quad\mathrm{s.t.}\quad\bm{R}_{kji}\geq 0,\quad\sum_{i}\bm{R}_{kji}=1.

From the KKT stationarity condition, we have

𝑹kji(f(𝑴;𝑸,𝑲)+𝝀R,kj(1i𝑹kji)+𝝁R,kji𝑹kji)=0\displaystyle\frac{\partial}{\partial\bm{R}_{kji}}\left(f(\bm{M};\bm{Q},\bm{K})+\sum\bm{\lambda}_{R,kj}\left(1-\sum_{i}\bm{R}_{kji}\right)+\sum\bm{\mu}_{R,kji}\bm{R}_{kji}\right)=0

where 𝝀Rm×b,𝝁Rm×b×b\bm{\lambda}_{R}\in\mathbb{R}^{m\times b},\bm{\mu}_{R}\in\mathbb{R}^{m\times b\times b} are dual variables. Along with (12), we have

𝜷R,kji𝜸R,kj(1+log𝑹kji)𝒄R,kj𝝀R,kj+𝝁R,kji=0.\displaystyle\bm{\beta}_{R,kji}-\bm{\gamma}_{R,kj}-(1+\log\bm{R}_{kji})\bm{c}_{R,kj}-\bm{\lambda}_{R,kj}+\bm{\mu}_{R,kji}=0.

As before, from complementary slackness 𝝁R,kji𝑹kji=0\bm{\mu}_{R,kji}\bm{R}_{kji}=0 and the fact that log𝑹kji\log\bm{R}_{kji} is not defined for 𝑹kji=0\bm{R}_{kji}=0, we must have 𝝁R,kji=0\bm{\mu}_{R,kji}=0. Altogether, we have

log𝑹kji=𝜷R,kji𝜸R,kj𝝀R,kj𝒄R,kj1𝑹kji=exp(𝜷R,kji/𝒄R,kj)exp(𝜸R,kj/𝒄R,kj+𝝀kjR/𝒄R,kj+1).\log\bm{R}_{kji}=\frac{\bm{\beta}_{R,kji}-\bm{\gamma}_{R,kj}-\bm{\lambda}_{R,kj}}{\bm{c}_{R,kj}}-1\implies\bm{R}_{kji}=\frac{\exp(\bm{\beta}_{R,kji}/\bm{c}_{R,kj})}{\exp(\bm{\gamma}_{R,kj}/\bm{c}_{R,kj}+\bm{\lambda}_{kj}^{R}/\bm{c}_{R,kj}+1)}.

Finally, from the constraint i𝑹kji=1\sum_{i}\bm{R}_{kji}=1, we must have exp(𝜸R,kj/𝒄R,kj+𝝀kjR/𝒄R,kj+1)=iexp(𝜶¯kjiR/𝜸kjR)\exp(\bm{\gamma}_{R,kj}/\bm{c}_{R,kj}+\bm{\lambda}_{kj}^{R}/\bm{c}_{R,kj}+1)=\sum_{i}\exp(\overline{\bm{\alpha}}_{kji}^{R}/\bm{\gamma}_{kj}^{R}), which gives the final closed form update:

𝑹=softmaxi(𝒁R),𝒁R,kji=𝜷R,kji/𝒄R,kj,\bm{R}=\operatorname{softmax}_{i}(\bm{Z}_{R}),\quad\bm{Z}_{R,kji}=\bm{\beta}_{R,kji}/\bm{c}_{R,kj},

where softmaxi\operatorname{softmax}_{i} is applied along the ii index dimension.

C.2 Naïve Algorithm

We provide pseudo-code for the naive version of MonarchAttention in Figure 7. This is a direct implementation of alternating maximization for finding the Monarch factors, ignoring concerns about memory and IO. We alternate between updating 𝑳\bm{L} and 𝑹\bm{R} as described above for TT iterations. Specifically, we highlight the choices to (1) initialize 𝑳\bm{L} to be block identity, and (2) update 𝑹\bm{R} before 𝑳\bm{L} in each iteration.

# Q: array of size (N, d)
# K: array of size (N, d)
# V: array of size (N, d)
# T: number of steps
\pardef monarch_attention(Q, K, V, T):
L = stack(b * [eye(m)])
Qb = einshape(”(lj)v->jlv”, Q, j=b)
Kb = einshape(”(ki)v->kiv”, K, i=b)
\par# Alternating maximization for L, R
for t in range(T):
# R update
aR = einsum(”jkl,jlv->kjv”, L, Qb)
bR = einsum(”kjv,kiv->kji”, aR, Kb)
cR = einsum(”jkl->kj”, L)
R = softmax(bR / cR[:, :, None], axis=2)
\par# L update
aL = einsum(”kji,kiv->jkv, R, Kb)
bL = einsum(”jkv,jlv->jkl”, aL, Qb)
cL = einsum(”kji->jk”, R * log(R))
L = softmax(bL - cL[:, :, None], axis=1)
\par# Monarch multiply
Vb = einshape(”(ki)v->kiv”, V, i=b)
Y = einsum(”kji,kiv->jkv”, R, Vb)
Z = einsum(”jkl,jkv->ljv, L, Y)
O = einshape(”ljv->(lj)v”, Z)
\parreturn O
Figure 7: Naïve implemention of MonarchAttention

C.3 Padding and Masking

In practice, the sequence length NN may not be divisible by the desired block size bb. In such cases, we round the number of blocks mm to m=N/bm^{\prime}=\lceil N/b\rceil, and set the new sequence length N=mbN^{\prime}=m^{\prime}b, post-padding 𝑸,𝑲,𝑽\bm{Q},\bm{K},\bm{V} to have NN^{\prime} rows. However, we need to take special care that the final NNN^{\prime}-N columns of the padded Monarch attention matrix 𝑴ΔN×N\bm{M}\in\Delta^{N^{\prime}\times N^{\prime}} are zero, since these correspond to padded rows of 𝑽\bm{V}. This is also an issue when batched sequences of different lengths are padded to a maximum length to avoid dynamic resizing.
 
From (5), it is clear that to set all columns of 𝑴\bm{M} beyond the NNth column to zero, it is sufficient to set 𝑹kji\bm{R}_{kji} to zero whenever b(k1)+i>Nb(k-1)+i>N. Thus, we simply form the mask 𝝎m×b×b\bm{\omega}\in\mathbb{R}^{m^{\prime}\times b\times b} given by

𝝎kji={0b(k1)+iNotherwise,\bm{\omega}_{kji}=\begin{cases}0&b(k-1)+i\leq N\\ -\infty&\mbox{otherwise},\end{cases}

which we then add to 𝒁R\bm{Z}_{R} before softmax in (7). We can also pre-pad the sequence, which would be change the above condition to b(k1)+iNNb(k-1)+i\geq N^{\prime}-N.

Appendix D Experimental Details

D.1 Baselines

We describe the baselines used in Section 4.

  • linear-attention (Katharopoulos et al., 2020) approximates exp(𝒒𝒌)ϕ(𝒒)ϕ(𝒌)\exp(\bm{q}^{\top}\bm{k})\approx\phi(\bm{q})^{\top}\phi(\bm{k}) where ϕ:dr\phi:\mathbb{R}^{d}\rightarrow\mathbb{R}^{r} is a kernel feature map, resulting in a rank rr approximation to softmax attention:

    softmax(𝑸𝑲)ϕ(𝑸)ϕ(𝑲)ϕ(𝑸)ϕ(𝑲)𝟏N𝟏N.\operatorname{softmax}(\bm{Q}\bm{K}^{\top})\approx\frac{\phi(\bm{Q})\phi(\bm{K})^{\top}}{\phi(\bm{Q})\phi(\bm{K})^{\top}\bm{1}_{N}\bm{1}_{N}^{\top}}.

    Katharopoulos et al. (2020) propose the map ϕ(𝒙)=1+elu(𝒙)\phi(\bm{x})=1+\operatorname{elu}(\bm{x}) with r=dr=d where elu\operatorname{elu} is the exponential linear unit applied element-wise.

  • performer (Choromanski et al., 2021) is a linear attention method using the fact that

    exp(𝒒𝒌)=𝔼𝝎𝒩(𝟎,𝑰d)[exp(𝝎𝒒𝒒22)exp(𝝎𝒌𝒌22)]\exp(\bm{q}^{\top}\bm{k})=\mathbb{E}_{\bm{\omega}\sim\mathcal{N}(\bm{0},\bm{I}_{d})}\left[\exp\left(\bm{\omega}^{\top}\bm{q}-\frac{\|\bm{q}\|^{2}}{2}\right)\exp\left(\bm{\omega}^{\top}\bm{k}-\frac{\|\bm{k}\|^{2}}{2}\right)\right]

    to construct a random kernel feature map

    ϕ(𝒙)=1rexp(𝒙22)[exp(𝝎1𝒙)exp(𝝎r𝒙)],\phi(\bm{x})=\frac{1}{\sqrt{r}}\exp\left(-\frac{\|\bm{x}\|^{2}}{2}\right)\begin{bmatrix}\exp(\bm{\omega}_{1}^{\top}\bm{x})&\dots&\exp(\bm{\omega}_{r}^{\top}\bm{x})\end{bmatrix}^{\top},

    where 𝝎1,,𝝎riid𝒩(𝟎,𝑰d)\bm{\omega}_{1},\dots,\bm{\omega}_{r}\overset{iid}{\sim}\mathcal{N}(\bm{0},\bm{I}_{d}).

  • cosformer (Qin et al., 2022) is a linear attention method utilizing position-dependent kernel feature maps of the form

    ϕi(𝒙)=[sin(πi2N)relu(𝒙i)cos(πi2N)relu(𝒙i)],i[N],\displaystyle\phi_{i}(\bm{x})=\begin{bmatrix}\sin\left(\frac{\pi i}{2N}\right)\operatorname{relu}(\bm{x}_{i})&\cos\left(\frac{\pi i}{2N}\right)\operatorname{relu}(\bm{x}_{i})\end{bmatrix},\;\forall i\in[N],

    which produces a rank r=2dr=2d approximation.

  • nystromformer (Xiong et al., 2021) computes landmark 𝑸~,𝑲~r×d\tilde{\bm{Q}},\tilde{\bm{K}}\in\mathbb{R}^{r\times d} from 𝑸,𝑲\bm{Q},\bm{K} by averaging N/rN/r consecutive spans of rows, which are used to approximate softmax attention via the quadrature method:

    𝑭~=softmax(𝑸𝑲~),𝑩~=softmax(𝑸~𝑲),𝑨~=softmax(𝑸~𝑲~)\displaystyle\tilde{\bm{F}}=\operatorname{softmax}(\bm{Q}\tilde{\bm{K}}^{\top}),\;\tilde{\bm{B}}=\operatorname{softmax}(\tilde{\bm{Q}}\bm{K}^{\top}),\;\tilde{\bm{A}}=\operatorname{softmax}(\tilde{\bm{Q}}\tilde{\bm{K}}^{\top})
    softmax(𝑸𝑲)𝑭~𝑨~+𝑩~,\displaystyle\operatorname{softmax}(\bm{Q}\bm{K}^{\top})\approx\tilde{\bm{F}}\tilde{\bm{A}}^{+}\tilde{\bm{B}},

    where 𝑨~+\tilde{\bm{A}}^{+} denotes the pseudoinverse of 𝑨~\tilde{\bm{A}}, producing a rank rr approximation.

D.2 Image Classification with ViT

The ViT-B model fine-tuned on ImageNet-21K is retrieved from the Hugging Face transformers library (Wolf et al., 2019) as google/vit-base-patch16-224. The ImageNet-1K evaluation dataset is retrieved from the Hugging Face datasets library (Lhoest et al., 2021) as imagenet-1k using the validation split. We vary the following hyperparameters:

  • monarch-attention: b=14b=14 and T{1,2,3}T\in\{1,2,3\}

  • performer: r{16,32,48,64,80,96}r\in\{16,32,48,64,80,96\}

  • nystromformer: r{16,24,32,40}r\in\{16,24,32,40\}

D.3 Question Answering with RoBERTa

The RoBERTa-B model fine-tuned on SQuAD1.1 is retrieved from the Hugging Face transformers library as csarron/roberta-base-squad-v1. The SQuAD1.1 evaluation dataset is retrieved from the Hugging Face datasets library as squad using the validation split. For evaluation, we truncate and pad to sequence length of 384. We vary the following hyperparameters:

  • monarch-attention: T=1T=1 and b{24,48,96,128}b\in\{24,48,96,128\}

  • performer: r{32,64,96,128,160,192}r\in\{32,64,96,128,160,192\}

  • nystromformer: r{16,32,48,64}r\in\{16,32,48,64\}

D.4 Summarization with BART

The pre-trained BART-B model is retrieved from the Hugging Face transformers library as facebook/bart-base. The BookSum-chapters training/evaluation dataset is retrieved from the Hugging Face datasets library as kmfoda/booksum using the train and validation splits respectively. BART employs learned positional embeddings up to 1024 sequence length, and since we are interested in long-sequence summarization up to 8192 tokens, we linearly interpolate the encoder positional embeddings up to 8192 tokens, before fine-tuning on BookSum-chapters – we leave the decoder positional embeddings intact. We fine-tune for 5 epochs with batch size of 32 and learning rate of 10410^{-4} using the Adam optimizer (Kingma and Ba, 2014) without weight decay, with the input and summary sequences truncated and padded to 8192 and 512 tokens respectively. For evaluation, we truncate the input sequence to the corresponding sequence length in Figure 3. The hyperparameters for each method across sequence lengths are shown in Table 2.

Table 2: Hyperparameters used for BART summarization.
Sequence Length Method (bb, TT) rr Total Attention FLOPs (10910^{9})
1024 Softmax 9.66
Nyströmformer 64 1.93
MonarchAttention (32, 3) 1.96
2048 Softmax 38.7
Nyströmformer 80 4.41
MonarchAttention (32, 2) 3.93
4096 Softmax 155.
Nyströmformer 112 10.6
MonarchAttention (64, 2) 10.9
8192 Softmax 619.
Nyströmformer 160 35.0
MonarchAttention (64, 2) 31.4

D.5 Image Generation with DiT

The pre-trained model DiT-XL is retrieved from the Hugging Face transformers library as facebook/DiT-XL-2-256. Following Peebles and Xie (2023), we generate images using 32 sampling steps, a 2×22\times 2 patch size, and a classifier-free guidance scale of 1.5. We use the following hyperparameters:

  • monarch-attention: b=16b=16 and T=3T=3

  • nystromformer: r=32r=32

To create the images in Figure 5, we used a random seed of 0 input the same 3636 random Gaussian samples into all three models. To obtain the results in Table 1, we again used a random seed of 0, and generated 50K50K images from each type of model, again using the same 50K50K random samples across all models.