MatMuls are Enough for Linear-Time Dense Attention
Аннотация
Transformers, despite empowering current AI revolution, are bottlenecked by suboptimalhardware utilization and quadratic runtime complexity of softmax attentionw.r.t. input sequence length. Many recent architectures aspire to bring thecomplexity down to sub-quadratic level without compromising modeling quality.However, they are either much slower on all but very long sequences orrely on low-level code tailored to a narrow subset of modern hardware. To simultaneouslyachieve linear complexity, hardware efficiency, and portability, wecompletely eliminate softmax from self-attention; remove, modify, or rearrangeother transformations in the Transformer block; and reduce number of attentionheads. The resulting architecture, DenseAttention Network, is composed entirelyof dense matrix multiplications in the attention which allows for efficient trainingand inference in both quadratic and linear modes. It performs similarly withstandard Transformer in language modeling and surpasses previous TransformerbasedSOTA by 5% on challenging Long Range Arena benchmarks. DenseAttentionmodel written in plain PyTorch is up to 57% faster even on small 512context size, and by orders of magnitude on longer sequences, than Transformeraugmented with low-level FlashAttention kernel.
Похожие публикации
сотрудничества и партнерства