OpenAI研究 使用稀疏变换器进行生成建模
我们开发了 Sparse Transformer,这是一种深度神经网络,它在预测序列中接下来会发生什么方面创造了新记录——无论是文本、图像还是声音。它使用注意力机制的算法改进, 从比以前长 30 倍的序列中提取模式。
AI 研究中的一个现有挑战是对复杂数据(如图像、视频或声音)中的远程、微妙的相互依赖性进行建模。稀疏变压器包含一个 奥( ñ否
深度关注
在 Transformers 中,每个输出元素都连接到每个输入元素,并且它们之间的权重是根据情况动态计算的,这个过程称为 注意力。虽然人们认为这允许 Transformers 比具有固定连接模式的模型更灵活,但实际上它需要创建一个 否×否 每层和注意力头的注意力矩阵,当应用于具有许多元素的数据类型(如图像或原始音频)时,它会消耗大量内存。
数据类型 | 存储 | 重新计算 |
1024 个文本标记(几个段落) | 1.0GB | 16MB |
32x32x3 像素(CIFAR-10 图像) | 9.6GB | 151MB |
64x64x3 像素(Imagenet 64 图像) | 154GB | 2.4GB |
24,000 个样本(约 2 秒的 12 kHz 音频) | 590GB | 9.2GB |
当矩阵存储在内存中或在反向传递期间重新计算时,注意深度 Transformer(64 层和 4 头)的内存使用情况。作为参考,用于深度学习的标准 GPU 通常具有 12-32 GB 的内存。
减少这种情况的一种方法是通过在反向传播期间从检查点重新计算注意力矩阵 ,这是深度学习中一种行之有效的技术,用于以更多计算为代价减少内存使用。当对 Transformers 中的注意力矩阵完成时,这意味着最大的内存成本变得与层数无关,让我们训练的网络深度比以前大得多。在实践中,我们发现深度高达 128 层的 Transformer 在基准任务(如 CIFAR-10)上的表现优于较浅的网络。
为了以更高的深度训练这些模型,我们对 transformer 中的操作顺序进行了一些调整,并修改了初始化方案。完整的细节可以在我们的论文中看到。
注意力稀疏
然而,即使计算单个注意力矩阵对于非常大的输入也变得不切实际。我们改为使用稀疏注意力模式,其中每个输出位置仅计算输入位置子集的权重。当子集相对于整个输入集较小时(例如, 否
为了评估该方法的可行性,我们首先在图像上可视化深度变形金刚学习到的注意力模式,发现许多表现出可解释和结构化的稀疏模式。下面的每个图像都显示了给定的注意力头关注哪些输入像素(以白色突出显示)以预测图像中的下一个值。当输入部分集中在小子集上并显示出高度规律性时,该层可以进行稀疏化。此处显示了 CIFAR-10 图像上 128 层模型的示例:

第 19 层

第 20 层

第 6 层

第 36 层
虽然许多层显示出稀疏结构,但一些层清楚地显示了延伸到整个图像的动态注意力。为了保持我们的网络学习这种模式的能力,我们实现了注意力矩阵的二维分解,其中网络可以通过稀疏注意力的两个步骤关注所有位置。






第一个版本, strided attention,大致相当于每个位置注意它的行和它的列,类似于上面网络学习的注意力模式。(请注意,列注意力可以等效地表述为关注转置矩阵的行)。第二个版本, 固定 注意力,关注固定列和最新列元素之后的元素,我们发现这种模式在数据不适合二维结构(如文本)时很有用。有关更多详细信息,我们建议读者参阅我们的论文。
实验结果
Sparse Transformers 为 CIFAR-10、Enwik8 和 Imagenet 64 的密度估计设置了新的最先进分数。
CIFAR10 | 每暗淡的位数 |
PixelCNN++(Salimans 等人,2017 年) | 2.92 |
Image Transformer(Parmar 等人,2018 年) | 2.90 |
PixelSNAIL(Chen 等人,2017 年) | 2.85 |
稀疏变压器 59M (256W, 128L, 2H) | 2.80 |
恩维克8 | 每字节位数 |
Deeper Self-Attention(Al-Rfou 等人,2018 年) | 1.06 |
Transformer-XL 88M(Dai 等人,2018 年) | 1.03 |
Transformer-XL 277M(Dai 等人,2018 年) | 0.99 |
稀疏变压器95M(512W,30L,8H) | 0.99 |
ImageNet 64x64 | 每暗淡的位数 |
门控 PixelCNN(van den Oord 等人,2016 年) | 3.57 |
并行多尺度(Reed 等人,2017 年) | 3.7 |
SPN 150M(Menick 和 Kalchbrenner,2018 年) | 3.52 |
稀疏变压器152M(512W、48L、16H) | 3.44 |
各种基准数据集上的密度建模性能,以每字节位数(或暗淡)为单位。M 表示网络中使用的数百万个参数,W 表示网络的宽度,L 表示层数,H 表示头数。
我们还发现,稀疏注意力的损失比全注意力低,而且速度明显更快(请参阅我们的论文进行比较)。这可能表明我们的稀疏模式存在有用的归纳偏差,或者存在密集关注的潜在优化问题。
生成图像
使用稀疏注意力的 Transformer 似乎具有全局结构的概念,可以通过查看图像补全来对其进行定性评估。在这里,我们可视化训练的模型 64×64图片网:



我们还生成了完全无条件的样本,其未调整的 softmax 温度为 1.0。这些模型使用最大似然目标进行训练,众所周知,最大似然目标涵盖数据的所有模式(包括可能不存在的模式),而不是增加一小部分数据的保真度。从这些具有未调整温度的模型中采样,让我们看到模型认为世界上存在的图像的完整分布。因此,一些样本可能看起来很奇怪。

模型样品

真实数据
生成原始音频波形
稀疏变换器也可以通过简单地改变位置嵌入来适应生成原始音频而不是图像。随着深度学习扩展到新的数据类型,我们相信使用此类网络轻松指定归纳偏差将是一个有用的工具。
该模型在原始古典音乐片段上进行训练,并使用稀疏注意力生成长度为 65,000 的序列。这相当于大约 5 秒的原始音频,我们在下面的每个剪辑中将几个样本连接在一起。