Skip to content

Latest commit

 

History

History
120 lines (89 loc) · 6.49 KB

README.md

File metadata and controls

120 lines (89 loc) · 6.49 KB

Deep Implicit Attention

Experimental implementation of deep implicit attention in PyTorch.

Summary: Using deep equilibrium models to implicitly solve a set of self-consistent mean-field equations of a random Ising model implements attention as a collective response 🤗 and provides insight into the transformer architecture, connecting it to mean-field theory, message-passing algorithms, and Boltzmann machines.

Blog post: Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms

Mean-field theory framework for transformer architectures

Transformer architectures can be understood as particular approximations of a parametrized mean-field description of a vector Ising model being probed by incoming data x_i:

z_i = sum_j J_ij z_j - f(z_i) + x_i

where f is a neural network acting on every vector z_i and the z_i are solved for iteratively.

DEQMLPMixerAttention

A deep equilibrium version of MLP-Mixer transformer attention (https://arxiv.org/abs/2105.02723, https://arxiv.org/abs/2105.01601):

z_i = g({z_j}) - f(z_i) + x_i

where g is an MLP acting across the sequence dimension instead of the feature dimension (so across patches). The network f parametrizes the self-correction term and acts across the feature dimension (so individually on every sequence).

Compared to a vanilla softmax attention transformer module (see below), the sum over couplings has been "amortized" and parametrized by an MLP. The fixed-point variables z_i are also fed straight into the feed-forward self-correction term. One could feed the naive mean-field update g({z_j}) + x_i instead to fully mimic the residual connection in the explicit MLP-Mixer architecture.

DEQVanillaSoftmaxAttention

A deep equilibrium version of vanilla softmax transformer attention (https://arxiv.org/abs/1706.03762):

z_i = sum_j J_ij z_j - f(z_i) + x_i

where

J_ij = [softmax(X W_Q W_K^T X^T / sqrt(dim))]_ij

Transformer attention takes the couplings J_ij to depend on x_i parametrically and considers the fixed-point equation above as a single-step update equation. Compared to the explicit vanilla softmax attention transformer module, there's no values and the fixed-point variables z_i are fed straight into the feed-forward self-correction term.

DEQMeanFieldAttention

Fast and neural deep implicit attention as introduced in https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/.

Schematically, the fixed-point mean-field equations including the Onsager self-correction term look like:

z_i = sum_j J_ij z_j - f(z_i) + x_i

where f is a neural network parametrizing the self-correction term for every site and x_i denote the input injection or magnetic fields applied at site i. Mean-field results are obtained by dropping the self- correction term. This model generalizes the current generation of transformers in the sense that its couplings are free parameters independent of the incoming data x_i.

DEQAdaTAPMeanFieldAttention

Slow and explicit deep implicit attention as introduced in https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/ (served as grounding and inspiration for fast and neural one above)

Ising-like vector model with multivariate Gaussian prior over spins. Generalization of the application of the adaptive TAP mean-field approach from a system of binary/scalar spins to vector spins. Schematically, the fixed-point mean-field equations including the Onsager term look like:

S_i ~ sum_j J_ij S_j - V_i S_i + x_i

where the V_i are self-corrections obtained self-consistently and x_i denote the input injection or magnetic fields applied at site i. The linear response correction step involves solving a system of equations, leading to a complexity ~ O(N^3*d^3). Mean-field results are obtained by setting V_i = 0.

Given the couplings between spins and a prior distribution for the single- spin partition function, the adaptive TAP framework provides a closed-form solution in terms of sets of equations that should be solved for a fixed point. The algorithm is related to expectation propagation (see Section 4.3 in https://arxiv.org/abs/1409.6179) and boils down to matching the first and second moments assuming a Gaussian cavity distribution.

Setup

Install package in editable mode:

$ pip install -e .

Run tests with:

$ python -m unittest

References

Selection of literature

On variational inference, iterative approximation algorithms, expectation propagation, mean-field methods and belief propagation:

On the adaptive Thouless-Anderson-Palmer (TAP) mean-field approach in disorder physics:

On Boltzmann machines and mean-field theory:

On deep equilibrium models:

On approximate message passing (AMP) methods in statistics:

  • A unifying tutorial on Approximate Message Passing (2021) by Oliver Y. Feng, Ramji Venkataramanan, Cynthia Rush, Richard J. Samworth: the example on page 2 basically describes how transformers implement approximate message passing: an iterative algorithm with a "denoising" step (attention) followed by a "memory term" or Onsager correction term (feed-forward layer)

Code inspiration