Attention (as discrete-time Markov) Chains

*Denotes equal contribution
1Tel Aviv University 2MPI for Informatics

Abstract

We introduce a new interpretation of the attention matrix as a discrete-time Markov chain. Our interpretation sheds light on common operations involving attention scores such as selection, summation, and averaging in a unified framework. It further extends them by considering indirect attention, propagated through the Markov chain, as opposed to previous studies that only model immediate effects. Our main observation is that tokens corresponding to semantically similar regions form a set of metastable states, where the attention clusters, while noisy attention scores tend to disperse. Metastable states and their prevalence can be easily computed through simple matrix multiplication and eigenanalysis, respectively. Using these lightweight tools, we demonstrate state-of-the-art zero-shot segmentation. Lastly, we define TokenRank—the steady state vector of the Markov chain, which measures global token importance. We demonstrate that using it brings improvements in unconditional image generation. We believe our framework offers a fresh view of how tokens are being attended in modern visual transformers.

Toy Example

intuition

Left: Attention matrix A with sequence length 5. Middle: A DTMC with transition probabilities defined by matrix A, where only strong connections are shown. Right (One-Hot): To evaluate where state-4 attends to, we can iterate using the power method once starting from a one-hot vector (n=0), which results in the row-select operation (n=1). However, this first-order approximation is insufficient since state-0 mostly transitions to state-3 and, therefore, state-4 indirectly attends state-3. This becomes evident as we iterate further (n=2). Right (Uniform): To compute a global token ranking, we can iterate starting from a uniform state (n=0), resulting in a per-column sum operation (n=1). This indicates state-0 as most important because many states have a high probability of transitioning into state-0. However, state-0 maps to state-3 with high probability, and state-3 maps to state-4 with high probability. Therefore, the importance of state-4 should be elevated. When considering the second bounce (n=2), more probability mass is directed into state-3, and with a sufficient number of iterations the steady state ranks state-4 as the most important state globally, which aligns with the intuition above.


Applications

To support the usefullness of our framework, we show improvement in various downstream tasks such as zero-shot segmentation, and unconditional image generation.

zero_shot

TokenRank

TokenRank is the unique steady-state vector of the attention matrix, which can be used for visualizations and global understanding of both incoming and outgoing attention. It can serve as a standard tool for visualizing self-attention.

tokenrank

BibTeX