斯坦福大学CS博士新作Attention提速24倍,BERT
机器之心报道
编辑:陈萍
FlashAttention是一种具有IO感知,且兼具快速、内存高效的新型注意力算法。
一种快速、内存高效的注意力算法来了,被命名为FlashAttention。通过减少GPU内存读取写入,FlashAttention的运行速度比PyTorch标准注意力快24倍,所需内存减少520倍。
这项研究由斯坦福大学、纽约州立大学布法罗分校的研究者共同完成。共同一作是两位斯坦福计算机博士生TriDao和DanFu。
下面我们介绍一下论文具体内容。
FlashAttention
Transformer已然成为自然语言处理和图像分类等应用中最广泛使用的架构。随着研究的不断前进,Transformer尺寸变得越来越大、层数也越来越深,但是给Transformer配备更长的上下文仍然很困难,因为Transformer核心自注意力模块的时间复杂度以及内存复杂度在序列长度上是二次方的。
有研究者提出一些近似注意力的方法,旨在减少注意力计算和内存需求。这些方法包括稀疏近似、低秩近似以及它们的组合。从序列长度来看,尽管这些方法可以将计算降低到线性或接近线性,但它们并没有显示出针对标准注意力的wallclock加速,因而没有被广泛使用。这其中一个主要原因是这些研究专注于减少FLOP(这可能与wallclock速度无关)并且倾向于忽略来自内存访问(IO)的开销。
在本文中,该研究认为应该让注意力算法具有IO感知即考虑显存级间的读写。现代GPU计算速度超过了内存速度,transformer中的大多数操作都被内存访问所阻塞。IO感知算法对于类似的内存绑定操作至关重要,这种重要性体现在当读写数据占据很大运行时例如数据库连接、图像处理、数值线性代数等。然而,用于深度学习的常见Python接口,如PyTorch和Tensorflow,不允许对内存访问进行细粒度控制。
论文地址:https:arxiv。orgpdf2205。14135。pdf
GitHub地址:https:github。comHazyResearchflashattention
该研究提出了一种新的注意力算法FlashAttention,它可以使用更少的内存访问来计算精确的注意力。FlashAttention旨在避免从HBM(HighBandwidthMemory)中读取和写入注意力矩阵。这需要做到:(i)在不访问整个输入的情况下计算softmaxreduction;(ii)在后向传播中不能存储中间注意力矩阵。
该研究采用两种成熟的技术来应对这些挑战:
(i)该研究重组注意力计算,将输入分成块,并在输入块上进行多次传递,从而逐步执行softmaxreduction(也称为tiling);
(ii)该研究存储前向传递的softmax归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。
该研究在CUDA中实现FlashAttention,以达到对内存访问的细粒度控制,并将所有注意力操作融合到一个GPU内核中。即使由于重新计算导致FLOPs增加,但其运行速度更快(在GPT2上高达7。6倍,图1右图)并且使用更少的内存(序列长度线性),主要是因为大大减少了HBM访问量。
该研究分析了FlashAttention的IO复杂度,证明它需要(221)HBM访问,其中是head维度,是SRAM的大小,而标准的注意力需要(2)HBM访问。对于和的典型值,与标准注意力相比,FlashAttention需要的HBM访问次数要少很多(最多减少9倍,如图2所示)。此外,该研究还提供了一个下限,表明没有精确的注意力算法可以渐近地提高所有SRAM大小的HBM访问次数。
该研究还表明,FlashAttention可以作为一种原语(primitive),通过克服内存访问开销问题来实现近似注意力算法。作为概念证明,该研究实现了块稀疏FlashAttention,这是一种稀疏注意力算法,比FlashAttention快24倍,可扩展到64k的序列长度。该研究证明了块稀疏FlashAttention比FlashAttention具有更好的IO复杂度。
值得一提的是,该研究还开源了FlashAttention。
实验结果
BERT:FlashAttention得到了最快的单节点BERT训练速度。该研究在Wikipedia上用FlashAttention训练了一个BERTlarge模型。表1将FlashAttention训练时间与NvidiaMLPerf1。1进行了比较,结果表明FlashAttention的训练速度提高了15。
GPT2:表2显示,与HuggingFace相比,FlashAttention端到端加速可达3倍,与MegatronLM相比,加速可达1。7倍
LongrangeArena:该研究在longrangearena(LRA)基准上进行了实验,他们测量了准确率、吞吐量、训练时间。每个任务有不同的序列长度,从1024到4096不等。此外,实验遵循Tay和Xiong等人的实验设置。表3显示,与标准注意力相比,FlashAttention的速度提高了2。4倍。块稀疏FlashAttention比所有近似注意力方法都要快。
具有长上下文的语言模型:FlashAttention的运行时间和内存效率允许我们将GPT2的上下文长度增加4倍,同时仍然比MegatronLM的运行更快。从表4可以看出,上下文长度为4K的FlashAttentionGPT2仍然比上下文长度为1K的Megatron的GPT2快30,同时perplexity提高了0。7。
表5表明,在MIMIC上序列长度为16K的性能比长度为512的高出4。3个点,而在ECtHR上,序列长度为8K的比长度512高出8。5个点。
表6展示了Transformer模型可以解决PathX、Path256问题。该研究在Path64上预训练transformer,然后通过空间插值位置嵌入迁移到PathX。FlashAttention在PathX上达到61。4的准确率。此外,块稀疏FlashAttention使得Transformers将序列扩展到64K,在Path256实现63。1的准确率。
图3(左)报告了以毫秒为单位的FlashAttention和块稀疏FlashAttention前向后向传播的运行时间与基准比较,图3(右)显示了与各种精确、近似和稀疏注意基线相比,FlashAttention和块稀疏FlashAttention的内存占用情况。