深度学习模型之CNN(二十三)Swin Transformer网络结构详解
前言
Swin Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper
的荣誉称号。Swin Transformer网络是Transformer模型在视觉领域的又一次碰撞。该论文一经发表就已在多项视觉任务中霸榜(下图State of the Art
表示第一)。原论文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
网络整体框架
下图(左)是Swin Transformer,下图(右)是之前讲的Vision Transformer。通过对比至少可以看出两点不同:
- Swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps),比如特征图尺寸中有对图像下采样4倍的,8倍的以及16倍的,这样的backbone有助于在此基础上构建目标检测、实例分割等任务。而在之前的Vision Transformer中是一开始就直接下采样16倍,后面的特征图也是维持这个下采样率不变。
- 在Swin Transformer中使用一个个窗口的形式将图片分隔开,窗口和窗口之间没有重叠,即Windows Multi-Head Self-Attention(W-MSA)的概念,比如在下图的4倍下采样和8倍下采样中,将特征图划分成了多个不相交的区域(Window),并且Multi-Head Self-Attention只在每个窗口(Window)内进行。相对于Vision Transformer中直接对整个(Global)特征图进行分割,即Multi-Head Self-Attention,这样做的目的是能够减少计算量的,尤其是在浅层特征图很大的时候。这样做虽然减少了计算量但也会隔绝不同窗口之间的信息传递,所以在论文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过此方法能够让信息在相邻的窗口中进行传递。
原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。通过图(a)可以看出整个框架的基本流程如下:
- 首先将图片输入到Patch Partition模块中进行分块,即每4x4相邻的像素为一个Patch,然后在channel方向展平(flatten)。假设输入的是RGB三通道图片,那么每个patch就有4x4=16个像素,然后每个像素有R、G、B三个值所以展平后是16x3=48,所以通过Patch Partition后图像shape由
[H, W, 3]
变成了[H/4, W/4, 48]
。然后在通过Linear Embedding层对每个像素的channel数据做线性变换,由48变成C,即图像shape再由[H/4, W/4, 48]
变成了[H/4, W/4, C]
。其实在源码中Patch Partition和Linear Embedding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。
-
然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样,然后都是重复堆叠Swin Transformer Block注意这里的Block其实有两种结构,如图(b)中所示,这两种结构的不同之处仅在于一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以会发现堆叠Swin Transformer Block的次数都是偶数(因为成对使用下图b)。
-
最后对于分类网络,在Stage4之后还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。
接下来,在分别对Patch Merging、W-MSA、SW-MSA以及使用到的相对位置偏执(relative position bias)进行详解。关于Swin Transformer Block中的MLP结构和Vision Transformer中的结构是一样的,所以这里也不在赘述。
Patch Merging详解
即上图的Stage2、3、4。在每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素
给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
W-MSA详解
- 目的:减少计算量;
- 缺点:窗口之间无法进行信息交互;
引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token,patch)在Self-Attention计算过程中需要和所有的像素去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。
两者的计算量具体差多少呢?原论文中有给出下面两个公式,这里忽略了Softmax的计算复杂度。
- h代表feature map的高度
- w代表feature map的宽度
- C代表feature map的深度
- M代表每个窗口(Windows)的大小
h = w = 112,M= 7,C = 128,节省40124743680 FLOPS
Self-Attention公式
MSA模块计算量
对于feature map中的每个像素(或称作token,patch),都要通过$W_q$,$W_k$,$W_v$生成对应的query(q),key(k)以及value(v)。这里假设q, k, v的向量长度与feature map的深度C保持一致。那么对应所有像素生成Q的过程如下式:
- $A^{hw*C}$:将所有像素(token)拼接在一起得到的矩阵(一共有hw个像素,每个像素的深度为C);
- $W_q^{C*C}$:生成query的变换矩阵;
- $Q^{hwC}$:所有像素通过$W_q^{CC}$得到的query拼接后的矩阵
矩阵乘法当中,$A^{ab}·B^{bc}$的FLOPs为a x b x c
根据矩阵运算的计算量公式可以得到生成Q的计算量为hw x C x C,,生成K和V同理都是$hwC^2$,那么总共是$3hwC^2$。接下来$Q$和$K$相乘,对应计算量为$(hw)^2C$:
接下来忽略除以$\sqrt{d}$以及softmax的计算量,假设得到$Λ^{hw * hw}$(通过softmax之后得到的输出),最后还要乘以V,对应的计算量为$(hw)^2C$:
那么对应单头的Self-Attention模块,总共需要$3hwC^2+(hw)^2C+(hw)^2C=3hwC^2+2(hw)^2C$。而在实际使用过程中,使用的是多头的Multi-head Self-Attention模块,多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵$W^o$的计算量$(hw)^2C$。
所以总共加起来是:$4hwC^2+2(hw)^2C$
W-MSA模块计算量
对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到$\frac{h}{M}✖\frac{w}{M}$个窗口,然后对每个窗口内使用多头注意力模块。刚刚计算高为h,宽为w,深度为C的feature map的计算量为$4hwC^2+2(hw)^2C$,这里每个窗口的高为M宽为M,带入公式得:
又因为有$\frac{h}{M}✖\frac{w}{M}$个窗口,则::
故使用W-MSA模块的计算量为:
假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:$4(hw)^2C-2M^2hwC = 2✖112^4✖128-2✖7^2✖112^2✖128 = 40124743680$
SW-MSA详解
采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。
如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L层),那么根据之前介绍的W-MSA和SW-MSA是成对使用的,那么第L+1层使用的就是SW-MSA(右侧图)。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了$\lfloor\frac{M}{2}\rfloor$个像素,up的视频有动图,十分清晰明了)。看下偏移后的窗口(右侧图),比如对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。
根据上图,可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,这样做感觉又变麻烦了。为了解决这个麻烦,作者又提出而了Efficient batch computation for shifted configuration
,一种更加高效的计算方法。下面是原论文给的示意图。
为描述上图Efficient batch computation for shifted configuration的过程,up画了下面几幅图来更清晰的讲解。
下图左侧是刚刚通过偏移窗口后得到的新窗口,右侧是为了方便大家理解,对每个窗口加上了一个标识。然后0对应的窗口标记为区域A,3和6对应的窗口标记为区域B,1和2对应的窗口标记为区域C。
然后先将区域A和C移到最下方。
接着,再将区域A和B移至最右侧。
移动完后,4是一个单独的窗口;将3和5合并成一个窗口;7和1合并成一个窗口;8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口了,所以能够保证计算量是一样的。这里肯定有人会想,把不同的区域合并在一起(比如5和3)进行MSA,这信息不就乱窜了吗?是的,为了防止这个问题,在实际计算中使用的是masked MSA即带蒙板mask的MSA
,这样就能够通过设置蒙板来隔绝不同区域的信息了。关于mask如何使用,可以看下图,下图是以上面的区域5和区域3为例。
对于该窗口内的每一个像素(或称token,patch)在进行MSA计算时,都要先生成对应的query(q),key(k),value(v)。假设对于上图的像素0而言,得到$q^0$后要与每一个像素的k进行匹配(match),假设$\alpha_{0,0}$代表$q^0$与像素0对应的$k^0$进行匹配的结果,那么同理可以得到$\alpha_{0,0}$至$\alpha_{0,15}$。按照普通的MSA计算,接下来就是SoftMax操作了。
但对于这里的masked MSA
,像素0是属于区域5的,我们只想让它和区域5内的像素进行匹配(只想计算区域5内部的MSA计算,不希望引入区域3的信息)。那么我们可以将像素0与区域3中的所有像素匹配结果都减去100(例如$\alpha_{0,2}$,$\alpha_{0,3}$,$\alpha_{0,6}$,$\alpha_{0,7}$,由于$\alpha$的值都很小,一般都是零点几的数,将其中一些数减去100后在通过SoftMax得到对应的权重都等于0了。所以对于像素0而言实际上还是只和区域5内的像素进行了MSA。对于其他像素也是同理。
注意,在计算完后还要把数据给挪回到原来的位置上(例如上上图的A,B,C区域)。
举例实操SW-MSA
偏移量为M/2向下取整,下图中M = 3,所以偏移量为1,因此需要改动下图(左)的第一行和第一列,改动完之后是下图(右),将第一行补到最后一行,第一列补到最后一列。
对于上图(右)黑色的分割线对应的是上一层的分割线(即还没有使用偏移量的分割线),此时在挪动了feature map上使用3x3的window(窗口)来进行分割,即如下图所示。
对于4个橙色的window区域,可以直接进行MSA操作,因为每个window内部数据本身是连续的,且通过window分割之后会发现每个window都能和融合上一层4个window的信息。
对于紫色的window区域,因为并不是连续的,所以需要用到具有masked的MSA
Relative Position Bias详解
关于相对位置偏执,使用了相对位置偏执后给够带来明显的提升。根据原论文中的表4可以看出,在Imagenet数据集上如果不使用任何位置偏执,top-1为80.1,但使用了相对位置偏执(rel. pos.)后top-1为83.3,提升还是很明显的。(w/o:without)
那这个相对位置偏执是加在哪的呢,根据论文中提供的公式可知是在Q和K进行匹配并除以$\sqrt d $后加上了相对位置偏执B。
如下图,假设输入的feature map高宽都为2,那么首先可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是( 0 , 0 ) 。接下来再看看相对位置索引,首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引
。例如黄色像素的绝对位置索引是( 0 , 1 ),则它相对蓝色像素的相对位置索引为(0,0)-(0,1)=(0,-1)
。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵
。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平
,并拼接在一起可以得到下面的4x4矩阵 。
请注意,这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为( 0 , − 1 ) 。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为( 0 , − 1 ) 。可以发现这两者的相对位置索引都是( 0 , − 1 ) ,所以他们使用的相对位置偏执参数都是一样的。
但在源码中作者为了方便把二维索引给转成了一维索引。首先在原始的相对位置索引上加上M-1(M为窗口的大小,在本示例中M=2),加上之后索引中就不会有负数了。
接着将所有的行标都乘上2M-1。
最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现仅仅将相对位置信息左简单相加而导致0 + ( − 1 ) = ( − 1 ) + 0的问题了。
这是一种从向量化索引角度出发的计算方式,不理解的用二维方式来理解:
- 二维的相对偏移不过就是上下、左右两类,这两类的偏移距离范围都是**[-M+1,M-1]**;
- 那么我们可以构建一个(2M-1)x(2M-1)大小的矩阵,其中的值表示position bias;
- 以上述构建的矩阵为table,针对不同相对位置,直接以(上下偏移距离,左右偏移距离)为索引去这个table里取值,就是我们要的relative position bias。
当把上述3步都理解之后,再用向量化的方式理解,就清晰为什么作者这么做了。
之前计算的是相对位置索引,并不是相对位置偏执参数。真正使用到的可训练参数$\widehat B$是保存在relative position bias table表里的,这个表的长度是等于(2M-1)x(2M-1)的。那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table
表得到的,如下图所示。
(2M-1)x(2M-1)
:行和列的范围都是[-M+1,M-1],以M = 2为例,范围都是[-1,1],可以取到的值为-1,0,1三种,所以进行行和列的排列组合之后一共会有3x3=9中可能,即对应(2M-1)x(2M-1),也就是下图中relative position bias table的个数。
注意:上图中得到的relative position bias才是最终要带入下图公式的B,也就是说最终训练过程中用到的参数实际上是relative position bias table对应的参数。而relative position index只要窗口大小是固定的,那么它本身也会是固定的。
模型详细配置参数
下图(表7)是原论文中给出的关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:
- concat 4x4,96-d,LN
:Patch Partition+Linear Embedding,相当于Stage2、3、4中的Patch Partition,都是对输入特征矩阵进行下采样以及调整特征矩阵的channel,再通过一个Linear Norm进行输出,
4x4即将高和宽下采样4倍,
96`即通过Linear Embedding之后特征矩阵channel变为96; win. sz. 7x7
:使用的窗口(Windows)的大小;dim
:feature map的channel深度(或者说token的向量长度),例如dim 96
即通过swin transformer block之后输出的特征矩阵的channel是96;head
:多头注意力模块中head的个数。
注意:Swin Transformer Block在堆叠过程当中是两两一组,即W-MSA和SW-MSA,所以堆叠Block都是偶数倍。