DiT:Scalable Diffusion Models with Transformers

2212.09748v2

摘要

我们探索了一种基于transformer体系结构的新型diffusion模型。我们训练图像的潜在扩散模型,用一个操作潜在补丁的变压器取代常用的U-Net主干。我们通过用浮点运算次数(Gflops)衡量的前向传播复杂度来分析Diffusion Transformers(DiTs)的可扩展性。我们发现,DiTs具有较高的Gflops——通过增加变压器的深度/宽度或增加输入令牌的数量——始终具有较低的FID。除了具有良好的可扩展性特性外,我们最大的DiT-XL/2模型在条件分类ImageNet512×512和256×256基准上优于所有先前的扩散模型,在后者上实现了最先进的的FID 2.27。

导言

机器学习正在经历一场由transformers驱动的复兴。在过去的五年里,自然语言处理[8,42]、视觉[10]和其他几个领域的神经结构在很大程度上被变形者[60]所包含。许多类型的图像级生成模型仍然坚持这一趋势,尽管transformers在自回归模型[3,6,43,47]中被广泛使用,但它们在其他生成建模框架中的采用较少。例如,扩散模型一直处于图像级生成模型[9,46]最新进展的前沿;然而,它们都采用卷积U-Net架构作为主干的骨干选择。

Ho等人[19]的开创性工作首先引入了扩散模型的U-Net主干。最初在像素级自回归模型和条件GANs [23]中获得成功后,U-Net继承了PixelCNN++ [52,58],只做了一些变化。该模型是卷积的,主要由ResNet [15]块组成。与标准的U-Net [49]相比,额外的空间自注意块是变压器的基本组件,以较低的分辨率分散。达里瓦尔和Nichol [9]取消了UNet的几种架构选择,例如使用自适应归一化层[40]来为卷积层注入条件信息和信道计数。然而,Ho等人的UNet的高级设计基本上保持完整。

通过这项工作,我们的目的是揭开扩散模型中架构选择的重要性,并为未来的生成建模研究提供经验基线。我们表明,U-Net感应偏差对扩散模型的性能不是至关重要的,它们可以很容易地替换为标准设计,如transformers。因此,扩散模型可以很好地从最近的架构统一趋势中获益——例如,通过继承来自其他领域的最佳实践和训练配方,以及保留诸如可扩展性、鲁棒性和效率等有利的特性。一个标准化的架构也将为跨领域的研究开辟新的可能性。

本文主要研究了一类新的基于transformers的扩散模型。我们称它们为Diffusion Transformers,或简称为DiTs。DiTs坚持视觉变压器(ViTs)[10]的最佳实践,该[10]已被证明比传统的卷积网络更有效地扩展到视觉识别领域。

更具体地说,我们研究了transformers在网络复杂度与样本质量方面的尺度行为。我们表明,通过在潜在扩散模型(LDMs)[48]框架下构建和建立倾斜设计空间的基准测试,其中扩散模型在VAE的潜在空间内进行训练,我们可以成功地用变压器取代U-Net主干。我们进一步证明,DiTs对于扩散模型是可扩展的架构:网络复杂度(通过Gflops度量)和样本质量(通过FID度量)之间有很强的相关性。通过简单地扩展DiT和训练一个具有高容量主干(118.6 Gflops)的LDM,我们能够在类条件256×256 ImageNet生成基准上获得2.27 FID的最新结果。

Diffusion Transformers

Preliminaries

  • Diffusion formulation

  • Classifier-free guidance.

  • Latent diffusion models

Diffusion Transformer Design Space

我们介绍了Diffusion Transformers(DiTs),一种新的扩散模型体系结构。我们的目标是尽可能忠实于标准的变压器架构,以保持其缩放特性。由于我们的重点是训练图像的ddpm(特别是图像的空间表示),DiT基于视觉转换器(ViT)架构,它对[10]补丁序列进行操作。DiT保留了vit的许多最佳实践。图3显示了完整的DiT体系结构的概述。在本节中,我们将描述DiT的正向通过,以及DiT类的设计空间的组件。

image.png
图3.扩散变压器(DiT)结构。左:我们训练条件潜在DiT模型。输入被分解为patches,并由多个DiT块进行处理。右:DiT 块的细节。我们实验了标准变压器块的变体,它们通过adaptive layer norm, cross-attention and extra input tokens。Adaptive layer norm的效果最好。

Patchify

DiT的输入是一个空间表示z(对于256×256×3图像,z的形状为32×32×4)。DiT的第一层是“patchify”,它通过线性嵌入每个补丁,将空间输入转换为一个T标记序列,每个维度为d。在模式化之后,我们将标准的ViT基于频率的位置嵌入(正弦-余弦版本)应用到所有输入标记。由补丁化创建的令牌T的数量由补丁化大小的超参数p决定。如图4所示,p减半将使T增加四倍,因此至少总变压器gflop增加四倍。尽管它对Gflops有显著影响,但请注意,更改p对下游参数计数没有有意义的影响。

我们将 p=2,4,8添加到DiT设计空间。

![image.png](Scalable+Diffusion+Models+with+Transformers/image 1.png)

图4。DiT的输入规格。给定patch大小$p×p$,形状$I×I×C$的空间表示(来自VAE的潜在噪声)被“patchified”成一个长度为$T =(I/p)^2$的序列,隐藏维度为d。较小的补丁大小p会导致更长的序列长度,从而产生更多的Gflops。

DiT block design

在patchify之后,输入标记由一系列变压器块进行处理。除了噪声图像输入外,扩散模型有时还会处理额外的条件信息,如噪声时间步长t、类标签c、自然语言等。我们探索了以不同处理条件输入的变压器块的四种变体。这些设计对标准的ViT块设计进行了小但重要的修改。所有方块的设计如图所示:

  • In-context conditioning。我们简单地将t和c的向量嵌入作为输入序列中的两个附加标记附加,处理它们与图像标记没有什么区别。这类似于ViTs中的cls令牌,它允许我们使用无需修改的标准ViT块。在最后一个块之后,我们从序列中删除条件反射令牌。这种方法在模型中引入了可忽略不计的新Gflops。

  • Cross-attention block。我们将t和c的嵌入连接到一个长度为2的序列中,与图像标记序列分开。变压器块经过修改,在多头自注意块之后包括一个额外的多头交叉注意层,类似于Vaswani等人[60]的原始设计,也类似于LDM用于类标签的调节。交叉关注为该模型增加了最多的Gflops,大约有15%的开销。

  • Adaptive layer norm (adaLN) block。随着自适应归一化层[40]在GANs [2,28]和UNet骨干[9]扩散模型中的广泛应用,我们探索用自适应层范数(adaLN)替换变压器块中的标准层范数层。我们不是直接学习维数尺度和位移参数γ和β,而是从t和c的嵌入向量的和中回归它们。在我们探索的三个块设计中,adaLN添加了最小的Gflops,因此是计算效率最高的。它也是唯一一种被限制为对所有标记应用相同功能的调节机制。

  • adaLN-Zero block。先前对ResNets的研究发现,将每个残差块初始化为身份函数是有益的。例如,Goyal等人发现,零初始化最终批范数尺度因子γ加速了监督学习设置[13]中的大规模训练。扩散U-Net模型使用类似的初始化策略,在任何剩余连接之前,对每个块中的最终卷积层进行零初始化。我们探索了对同样的adaLN DiT块的修改。除了回归γ和β之外,我们还回归了维度尺度参数α之前应用的。

Model size

![image.png](Scalable+Diffusion+Models+with+Transformers/image 2.png)

Transformer decoder

在最后的DiT块之后,我们需要将我们的图像标记序列解码为输出噪声预测和输出对角线协方差预测。这两个输出的形状都等于原始的空间输入。我们使用一个标准的线性解码器来做到这一点;我们应用最后一层范数(如果使用adaLN则自适应),并将每个标记线性解码为$p×p×2C$张量,其中C是空间输入的通道数。最后,我们将解码后的标记重新排列成原始的空间布局,得到预测的噪声和协方差。

Last updated