转载

OpenAI研究 使用稀疏变换器进行生成建模

我们开发了 Sparse Transformer,这是一种深度神经网络,它在预测序列中接下来会发生什么方面创造了新记录——无论是文本、图像还是声音。它使用注意力机制的算法改进,  从比以前长 30 倍的序列中提取模式。

AI 研究中的一个现有挑战是对复杂数据(如图像、视频或声音)中的远程、微妙的相互依赖性进行建模。稀疏变压器包含一个 ()奥( ñ否的)的重新制定 (2个)奥( ñ2个) Transformer  self-attention 机制,连同其他几项改进,将其直接应用于这些丰富的数据类型。以前,用于这些数据的模型是专门为一个领域设计的,或者很难扩展到长度超过几千个元素的序列。相比之下,我们的模型可以使用数百层对具有数万个元素的序列进行建模,从而在多个领域实现最先进的性能。在 OpenAI,我们正在使用它来帮助我们构建具有更强理解世界能力的人工智能系统。

深度关注

在 Transformers 中,每个输出元素都连接到每个输入元素,并且它们之间的权重是根据情况动态计算的,这个过程称为 注意力。虽然人们认为这允许 Transformers 比具有固定连接模式的模型更灵活,但实际上它需要创建一个 ×否×否  每层和注意力头的注意力矩阵,当应用于具有许多元素的数据类型(如图像或原始音频)时,它会消耗大量内存。

数据类型存储重新计算
1024 个文本标记(几个段落)1.0GB16MB
32x32x3 像素(CIFAR-10 图像)9.6GB151MB
64x64x3 像素(Imagenet 64 图像)154GB2.4GB
24,000 个样本(约 2 秒的 12 kHz 音频)590GB9.2GB

当矩阵存储在内存中或在反向传递期间重新计算时,注意深度 Transformer(64 层和 4 头)的内存使用情况。作为参考,用于深度学习的标准 GPU 通常具有 12-32 GB 的内存。

减少这种情况的一种方法是通过在反向传播期间从检查点重新计算注意力矩阵  ,这是深度学习中一种行之有效的技术,用于以更多计算为代价减少内存使用。当对 Transformers 中的注意力矩阵完成时,这意味着最大的内存成本变得与层数无关,让我们训练的网络深度比以前大得多。在实践中,我们发现深度高达 128 层的 Transformer 在基准任务(如 CIFAR-10)上的表现优于较浅的网络。

为了以更高的深度训练这些模型,我们对 transformer 中的操作顺序进行了一些调整,并修改了初始化方案。完整的细节可以在我们的论文中看到。

注意力稀疏

然而,即使计算单个注意力矩阵对于非常大的输入也变得不切实际。我们改为使用稀疏注意力模式,其中每个输出位置仅计算输入位置子集的权重。当子集相对于整个输入集较小时(例如, 的元素代替  否元素),即使对于非常长的序列,由此产生的注意力计算也变得易于处理,算法复杂度为 ()奥( ñ否的) 代替 (2个)奥( ñ2个).

为了评估该方法的可行性,我们首先在图像上可视化深度变形金刚学习到的注意力模式,发现许多表现出可解释和结构化的稀疏模式。下面的每个图像都显示了给定的注意力头关注哪些输入像素(以白色突出显示)以预测图像中的下一个值。当输入部分集中在小子集上并显示出高度规律性时,该层可以进行稀疏化。此处显示了 CIFAR-10 图像上 128 层模型的示例:

第 19 层

第 19 层

第 20 层

第 20 层

为 128 层 CIFAR-10 网络的多个层学习的注意模式(白色突出显示)。这些层学会了在两个维度上分离注意力。第 19 层汇总每一行的信息,第 20 层按列聚合这些汇总,从而实现全注意力操作的有效分解。
第 6 层

第 6 层

第 36 层

第 36 层

一些层学会了访问位置记忆,通常关注相似的位置,而不管输入数据或时间步长(第 6 层)。其他层学习了高度依赖数据的访问模式(第 36 层)。

虽然许多层显示出稀疏结构,但一些层清楚地显示了延伸到整个图像的动态注意力。为了保持我们的网络学习这种模式的能力,我们实现了注意力矩阵的二维分解,其中网络可以通过稀疏注意力的两个步骤关注所有位置。

Normal transformer visualization in an attention matrix
Normal transformer visualization in an attention matrix
普通变压器
Strided attention transformer visualization in an attention matrix
Strided attention transformer visualization in an attention matrix
大步注意
Fixed attention transformer visualization in an attention matrix
Variants Fixed Step2 1
专注

第一个版本,  strided  attention,大致相当于每个位置注意它的行和它的列,类似于上面网络学习的注意力模式。(请注意,列注意力可以等效地表述为关注转置矩阵的行)。第二个版本, 固定 注意力,关注固定列和最新列元素之后的元素,我们发现这种模式在数据不适合二维结构(如文本)时很有用。有关更多详细信息,我们建议读者参阅我们的论文。

实验结果

Sparse Transformers 为 CIFAR-10、Enwik8 和 Imagenet 64 的密度估计设置了新的最先进分数。

CIFAR10每暗淡的位数
PixelCNN++(Salimans 等人,2017 年)2.92
Image Transformer(Parmar 等人,2018 年)2.90
PixelSNAIL(Chen 等人,2017 年)2.85
稀疏变压器 59M (256W, 128L, 2H)2.80
恩维克8每字节位数
Deeper Self-Attention(Al-Rfou 等人,2018 年)1.06
Transformer-XL 88M(Dai 等人,2018 年)1.03
Transformer-XL 277M(Dai 等人,2018 年)0.99
稀疏变压器95M(512W,30L,8H)0.99
ImageNet 64x64每暗淡的位数
门控 PixelCNN(van den Oord 等人,2016 年)3.57
并行多尺度(Reed 等人,2017 年)3.7
SPN 150M(Menick 和 Kalchbrenner,2018 年)3.52
稀疏变压器152M(512W、48L、16H)3.44

各种基准数据集上的密度建模性能,以每字节位数(或暗淡)为单位。M 表示网络中使用的数百万个参数,W 表示网络的宽度,L 表示层数,H 表示头数。

我们还发现,稀疏注意力的损失比全注意力低,而且速度明显更快(请参阅我们的论文进行比较)。这可能表明我们的稀疏模式存在有用的归纳偏差,或者存在密集关注的潜在优化问题。

生成图像

使用稀疏注意力的 Transformer 似乎具有全局结构的概念,可以通过查看图像补全来对其进行定性评估。在这里,我们可视化训练的模型 64×6464×64图片网:

Half-images used to train a sparse attention transformer to complete images
Grid of images completed using a sparse attention transformer
Row of photographs used as samples to train a sparse attention transformer to complete images

我们还生成了完全无条件的样本,其未调整的 softmax 温度为 1.0。这些模型使用最大似然目标进行训练,众所周知,最大似然目标涵盖数据的所有模式(包括可能不存在的模式),而不是增加一小部分数据的保真度。从这些具有未调整温度的模型中采样,让我们看到模型认为世界上存在的图像的完整分布。因此,一些样本可能看起来很奇怪。

Sample images used to train a sparse attention transformer model

模型样品

Imagenet True Data

真实数据

生成原始音频波形

稀疏变换器也可以通过简单地改变位置嵌入来适应生成原始音频而不是图像。随着深度学习扩展到新的数据类型,我们相信使用此类网络轻松指定归纳偏差将是一个有用的工具。

该模型在原始古典音乐片段上进行训练,并使用稀疏注意力生成长度为 65,000 的序列。这相当于大约 5 秒的原始音频,我们在下面的每个剪辑中将几个样本连接在一起。

代码发布

通常,实现稀疏注意力会涉及在块中对查询和关键矩阵进行切片,因此为了简化实验,我们实现了一组块 稀疏内核 ,可以在 GPU 上高效地执行这些操作。我们开源这些内核并在 此存储库中提供示例稀疏注意力函数。

未来的工作和限制

  • 我们介绍的稀疏注意力模式只是长序列高效建模方向的初步步骤。我们认为探索稀疏性的不同模式和组合是有用的,并且学习稀疏模式是下一代神经网络架构的一个特别有前途的研究途径。
  • 即使有了我们上面描述的改进,自回归序列生成对于非常高分辨率的图像或视频似乎仍然不切实际。然而,我们引入的优化注意力操作可能是有用的原语,可以与其他方法结合起来对高维数据进行建模,例如多尺度方法。