深度学习模型之CNN(二十六)MobileViT网络讲解及通过Pytorch搭建

MobileViT是CNN和Transformer的混合架构模型,原论文:MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer

网络架构学习

前言

当前纯Transformer模型存在的问题:

  • 参数多,算力要求高(比如ViT-L Patch16模型,仅权重模型就有1G多);
  • 缺少空间归纳偏置;
  • 迁移到其他任务比较繁琐(相对于CNN);

为什么会繁琐?

主要由于位置偏置导致的,比如在Vision Transformer当中采用的是绝对位置偏置,那么绝对位置偏置的序列长度是和输入token的序列长度保持一致的。也就是说在训练模型的时候,在指定了输入图像尺寸之后,绝对位置偏置所对应的序列长度其实就固定了,如果后期要更改输入图片的尺寸的话,会发现通过图片生成的token序列长度和绝对位置偏置的序列长度是不一致的。这样就没法进行一个相加以及后续的处理了。

针对这个问题,现有的问题最简单的就是去进行一个差值。也就是说将绝对位置编码给差值到与输入token数据序列相同的一个长度,那么差值之后呢,又会引入另外一个新的问题。就是说一般我们将差值之后的模型拿来直接用的话,会发现可能会出现掉点的情况。但是对于CNN模型,比如在224x224的图片尺寸上进行训练,然后在384x384的尺寸上进行验证,一般是会出现一个长点的情况,比如在ImageNet上可能会涨一个点左右。但是对于Transformer的模型,如果简单通过差值的方式在一个相对更高的分辨率上进行验证,会发现可能会掉点。

所以说一般对Transformer的绝对位置偏置进行差值之后,还要进行一个微调。但如果每次修改了图片尺寸之后都要重新对绝对位置偏置进行一个差值和微调,就会太麻烦了一点。

有人会说,可以采用像Swin Transformer当中所采用的相对位置偏执。的确如此,在Swin Transformer当中的相对位置偏执,对输入图片尺寸并不敏感,只对设置的window的大小有关。但如果训练的模型的输入图片尺寸和迁移到其他任务的图片尺寸相差比较大的话,其实一般还是会对window的尺寸进行一个调整的。比如说先在ImageNet上进行一个预训练,那么训练的时候可能输入的图片大小为224x224,假设要迁移到目标检测任务中,那么此时输入的图像分辨率可能是1280x1280,那么很明显,从224到1280,图像尺寸发生非常大的变化。如果此时不去调整window的尺寸大小的话,那么效果依旧会受到影响。所以一般针对这个情况,还是会去将window的尺寸给设置的更大一点。一旦window的尺寸发生变化,那么相对位置编码的序列长度也会发生变化,那么还是遇见更改提到的问题。

因此当前所采用的这些位置编码其实有很多值得优化的地方,比如在Swin TransformerV2的论文当中,其实就针对Swin Transformerv1当中所采用的相对位置编码进行了优化。

  • 模型训练困难(相对于CNN)

根据现有的一些经验,训练一个Transformer往往需要更多训练数据和迭代更多的epoch,需要更大的L2正则,需要更多的数据增强,并且对数据增强是比较敏感的。

针对以上提出的几点问题,现有一个很好的解决办法就是可以将CNN架构和Transformer架构进行一个混合使用。因为CNN架构本身就带有空间归纳偏置,如果使用它之后就不需要单独去加上位置偏置或者位置编码。并且加入CNN之后是能够加速网络的收敛,使整个网络的训练过程更加稳定。

对比MobileViT以及当年比较主流的一些相对轻量的ViT模型

在上图中,Augmentation指数据增强的两种方式,一个是比较基础的basic,另一个是更加先进的advance。basic就代表采用的使像ResNet那样的一个比较简单的数据增强,也就是随机裁剪加一个水平方向的随机翻转。但对advance所包含的数据增强方式就非常的多。

根据上图(b)表可以看出,MobileViT尽管采用的Augmentation中的basic,但是Top-1还是能达到74.8和78.4,也说明MobileViT对数据增强没有那么敏感,而且学习能力也是比较强的。

对比MobileViT以及当年比较主流的一些相对轻量和重量的CNN模型

根据上图MobileViT与比较轻量和重的模型对比,能够看出来CNN和Transformer所结合的MobileViT模型确实效果是非常不错的。

对比MobileViT与CNN网络的训练速度

模型结构解析

Vision Transformer结构简介

这是论文当中作者所给的标准的Vision Transformer视觉模型结构,和之前讲过的Vision Transformer有一点点的不一样。最主要这里并没有Vision Transformer里面所提到的class token。其实class token就是参考BERT网络,但是对于视觉任务而言,其实class token并不是必须的,所以下图所展示的是一个更加标准的针对视觉的一个Vision Transformer架构。

Standard visual transformer(ViT)

首先可以看到我们是针对输入的图片划分为一个一个patch,然后将每个patch的数据进行展平,展平之后再通过一个线性映射得到针对每一个patch所对应的token(其实每一个token对应的也就是一个向量而已),那么将这些token放在一起就得到一个token序列(在网络实际搭建过程当中,其实关于这一步也就是展平加线性映射这一块是可以直接通过一个卷积操作实现的),然后再加上一个位置编码或者说位置偏置(可以采用绝对位置偏置或者相对位置偏置),接着再通过L x Transformer Block(其实可以在Transformer Block和全连接层之间加一个全局池化层),再通过一个全连接层就得到输出。

MobileViT介绍

整体架构

MobileViT visual transformer

MV2

相当于MobileNet v2当中提出的Inverted Residual Block。有些MV2会有向下的箭头,这代表这个模块是需要对特征图进行一个下采样的。

MobileViT block

首先输入一个$H✖W✖C$的特征图,先做一个局部的表征或者说做一个局部的建模(Local representations其实就是通过一个卷积核大小为$n✖n$的卷积层实现的。在代码当中就是一个3x3的卷积层,然后再通过一个1x1的卷积层去调整通道数)。

调整完之后,进行一个全局表征或者说全局的建模(global representations其实就是通过一个Unfold,再通过L个Transformer Block,然后再通过Fold折叠回特征图)。

接着再通过一个1x1的卷积层去调整通道数,将通道数又还原回了C,也就是和输入的特征图的通道数保持一致。接着通过一个shortcut将更改得到的特征图和输入特征图进行concat拼接,拼接玩完之后通道数为2C,再通过一个$n✖n$的卷积层进行一个特征融合(在源码中,这里的n对应的是3x3)。

这就是整一个MobileViT block的结构,核心其实还是有关全局表征这部分

MV2和Mobile ViT Block

全局表征中的Transformer

下图(中)为方便忽略channel,对于输入transformer block或者说transformer encoder,一般将特征图直接展平成一个序列,然后再输入到transformer block当中。

在做Self-Attention的时候,图中的每一个像素或者说每一个token是需要和所有的token进行一个Self-Attention的。但是在MobileViT当中,并不是这么去做的

首先会将输入特征图划分成一个个patch,在下图(中)中以2x2大小的patch为例。

划分完之后在实际做Attention时,其实是将每一个patch当中对应相同位置的token去做self Attention,也就是说,下图(中)这些颜色相同的token才会去所self Attention,那么通过这么个方式,就能减少Attention的计算量

对于原始的self Attention这段计算过程(也就是说每一个token都要和所有的token去进行一个Attention),假设计算某一个token与其他所有token进行Attention的计算量,记为$WHC$,因为要和每一个token都去进行self Attention;

但是在MobileViT当中,只是让颜色相同的这些token去做self Attention,以下图2x2的patch为例,对于每一个token做self Attention的时候,实际计算量为$\frac{HWC}{4}$,因为这里的patch大小为2x2,所以计算量缩减为原来的$\frac{1}{4}$。

其实这样做只能减少在做self Attention时的计算量。对于transformer block或者说transformer encoder的其他部分的计算量是没有任何变化的。因为像下图(左)中这些像Norm以及MLP其实是针对token去做处理的。

为什么可以这么做呢?

因为在对图像进行处理中,是存在非常多的冗余数据,特别是对于图像分辨率较高的一个情况。对于相对底层的特征图也就是说当H和W比较大的时候,相邻像素之间的一个信息差异其实是比较小的。如果在做self Attention的时候,每一个token都要去看一遍的话,还是挺浪费算力的。

但并不是说去看相邻的像素或者token没有意义,只是说在分辨率较高的特征图上,收益可能很低,那么增加了这些计算成本远大于ACC上的收益。而且在做全局表征之前,也就是Local representations,已经提前做了一个局部表征,后面做全局表征的时候其实就没必要那么细了。

全局表征中的Transformer block

全局表征中的Unfold和Fold

在MovileViT中,只是将这些颜色相同的token去做Attention,颜色不同的token是不做信息交互的,所以在论文当中,这里的Unfold就是将颜色相同的这些token给拼成一个序列。比如将patch设置为2x2的话,通过Unfold可以得到4个序列。

之后将每个序列输入到Transformer Block当中进行全局建模。这里的每一个序列在输入Transformer Block时,是可以进行并行计算的,所以速度还是非常快的。

最后再通过Fold方法将这些特征折叠回原特征图的一个形式。

所以全局表征中的Unfold和Fold就是对特征图进行一个拆分和重新折叠的过程

全局表征中的Unfold和Fold

Patch Size对性能的影响

作者有做两组对比实验,分别对应的Patch Size时8,4,2和2,2,2。这三个数字分别对应的是针对下采样的8倍,16倍以及32倍的 特征图。并且如下图所示,分别在分类、目标检测和分割任务上进行了对比。横坐标对应的时推理时间,希望越小越好,纵坐标对应的时在各项任务上的一个指标,一般都是越大越好。所以越靠近坐标的左上方代表模型的综合性能越好。

Patch Size对性能的影响(两组实验)

模型详细配置

一共有三类模型配置:

  • MobileViT-S(small);
  • MobileViT-XS(extra small);
  • MobileViT-XXS(extra extra small)

模型配置-MobileViT-XXS

  • out_channels:每一个layer输出的一个特征图的通道数;
  • mv2_exp:在Inverted Residual模块当中的expansion ratio;
  • transformer_channels:输入transformer block的一个token的向量长度或者输入特征图的通道数;
  • ffn_dim:transformer block MLP中间层的一个节点个数;
  • patch_hpatch_w:patch size的大小;
  • num_heads:transformer block当中的Muti-Head Self-Attention的header的个数。

Pytorch搭建

工程目录

1
2
3
4
5
6
7
8
9
10
11
12
├── Test15_MobileViT
├── model.py(模型文件)  
├── my_dataset.py(数据处理文件)  
├── train.py(调用模型训练,自动生成class_indices.json,mobilevit.pth文件)
├── predict.py(调用模型进行预测)
├── model_config.py(三种模型的详细参数)
├── transformer.py
├── utils.py(工具文件,用得上就对了)
├── tulip.jpg(用来根据前期的训练结果来predict图片类型)
└── mobilevit_xxs.pt(迁移学习,提前下载好mobilevit_xxs.pt权重脚本)
└── data_set
└── data数据集

model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
original code from apple:
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
"""

from typing import Optional, Tuple, Union, Dict
import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

from transformer import TransformerEncoder
from model_config import get_config


def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
# ......


class ConvLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = 1,
groups: Optional[int] = 1,
bias: Optional[bool] = False,
use_norm: Optional[bool] = True,
use_act: Optional[bool] = True,
) -> None:
# ......

def forward(self, x: Tensor) -> Tensor:
# ......


class InvertedResidual(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: Union[int, float],
skip_connection: Optional[bool] = True,
) -> None:
# ......

def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
# ......


class MobileViTBlock(nn.Module):
def __init__(
self,
in_channels: int,
transformer_dim: int,
ffn_dim: int,
n_transformer_blocks: int = 2,
head_dim: int = 32,
attn_dropout: float = 0.0,
dropout: float = 0.0,
ffn_dropout: float = 0.0,
patch_h: int = 8,
patch_w: int = 8,
conv_ksize: Optional[int] = 3,
*args,
**kwargs
) -> None:
# ......

def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
# ......

def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
# ......

def forward(self, x: Tensor) -> Tensor:
# ......


class MobileViT(nn.Module):
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
# ......

def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
# ......

@staticmethod
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
# ......

@staticmethod
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
# ......

@staticmethod
def init_parameters(m):
# ......

def forward(self, x: Tensor) -> Tensor:
# ......


def mobile_vit_xx_small(num_classes: int = 1000):
# ......


def mobile_vit_x_small(num_classes: int = 1000):
# ......


def mobile_vit_small(num_classes: int = 1000):
# ......

ConvLayer类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class ConvLayer(nn.Module):
"""
Applies a 2D convolution over an input
Args:
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
groups (Optional[int]): Number of groups in convolution. Default: 1
bias (Optional[bool]): Use bias. Default: ``False``
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
Default: ``True``
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
.. note::
For depth-wise convolution, `groups=C_{in}=C_{out}`.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = 1,
groups: Optional[int] = 1,
bias: Optional[bool] = False,
use_norm: Optional[bool] = True,
use_act: Optional[bool] = True,
) -> None:
super().__init__()

if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)

if isinstance(stride, int):
stride = (stride, stride)

assert isinstance(kernel_size, Tuple)
assert isinstance(stride, Tuple)

padding = (
int((kernel_size[0] - 1) / 2),
int((kernel_size[1] - 1) / 2),
)

block = nn.Sequential()

conv_layer = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
padding=padding,
bias=bias
)

block.add_module(name="conv", module=conv_layer)

if use_norm:
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
block.add_module(name="norm", module=norm_layer)

if use_act:
act_layer = nn.SiLU()
block.add_module(name="act", module=act_layer)

# 返回的Sequential的类
self.block = block

def forward(self, x: Tensor) -> Tensor:
return self.block(x)

MV2(InvertedResidual类)

MV2和Mobile ViT Block

skip_connection:是否使用shortcut

hidden_dim:通过第一个1x1卷积层之后将特征图的通道数调整为多少

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class InvertedResidual(nn.Module):
"""
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
Args:
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
stride (int): Use convolutions with a stride. Default: 1
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
skip_connection (Optional[bool]): Use skip-connection. Default: True
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
.. note::
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
"""

def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: Union[int, float],
skip_connection: Optional[bool] = True,
) -> None:
assert stride in [1, 2]
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)

super().__init__()

block = nn.Sequential()
if expand_ratio != 1:
block.add_module(
name="exp_1x1",
module=ConvLayer(
in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=1
),
)

block.add_module(
name="conv_3x3",
module=ConvLayer(
in_channels=hidden_dim,
out_channels=hidden_dim,
stride=stride,
kernel_size=3,
groups=hidden_dim
),
)

block.add_module(
# dw卷积
name="red_1x1",
module=ConvLayer(
in_channels=hidden_dim,
out_channels=out_channels,
kernel_size=1,
use_act=False,
use_norm=True,
),
)

self.block = block
self.in_channels = in_channels
self.out_channels = out_channels
self.exp = expand_ratio
self.stride = stride
self.use_res_connect = (
self.stride == 1 and in_channels == out_channels and skip_connection
)

def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
if self.use_res_connect:
return x + self.block(x)
else:
return self.block(x)

MobileViTBlock

MobileViT Block

transformer_dim:输入到Transformer Encoder Block中每个token所对应的序列长度;

ffn_dim:Transformer Encoder Block中MLP结构的第一个全连接层的节点个数;

n_transformer_blocks:global representations当中重复堆叠Transformer Encoder Block的次数;

head_dim:在做Muti-Head Self-Attention时每个header所对应的dimension;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class MobileViTBlock(nn.Module):
"""
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
Args:
opts: command line arguments
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
transformer_dim (int): Input dimension to the transformer unit
ffn_dim (int): Dimension of the FFN block
n_transformer_blocks (int): Number of transformer blocks. Default: 2
head_dim (int): Head dimension in the multi-head attention. Default: 32
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
dropout (float): Dropout rate. Default: 0.0
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
patch_h (int): Patch height for unfolding operation. Default: 8
patch_w (int): Patch width for unfolding operation. Default: 8
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
"""

def __init__(
self,
in_channels: int,
transformer_dim: int,
ffn_dim: int,
n_transformer_blocks: int = 2,
head_dim: int = 32,
attn_dropout: float = 0.0,
dropout: float = 0.0,
ffn_dropout: float = 0.0,
patch_h: int = 8,
patch_w: int = 8,
conv_ksize: Optional[int] = 3,
*args,
**kwargs
) -> None:
super().__init__()

conv_3x3_in = ConvLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
stride=1
)
conv_1x1_in = ConvLayer(
in_channels=in_channels,
out_channels=transformer_dim,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False
)

conv_1x1_out = ConvLayer(
in_channels=transformer_dim,
out_channels=in_channels,
kernel_size=1,
stride=1
)
conv_3x3_out = ConvLayer(
in_channels=2 * in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
stride=1
)

self.local_rep = nn.Sequential()
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)

assert transformer_dim % head_dim == 0
num_heads = transformer_dim // head_dim

global_rep = [
TransformerEncoder(
embed_dim=transformer_dim,
ffn_latent_dim=ffn_dim,
num_heads=num_heads,
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout
)
for _ in range(n_transformer_blocks)
]
global_rep.append(nn.LayerNorm(transformer_dim))
self.global_rep = nn.Sequential(*global_rep)

self.conv_proj = conv_1x1_out
self.fusion = conv_3x3_out

self.patch_h = patch_h
self.patch_w = patch_w
self.patch_area = self.patch_w * self.patch_h

self.cnn_in_dim = in_channels
self.cnn_out_dim = transformer_dim
self.n_heads = num_heads
self.ffn_dim = ffn_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.ffn_dropout = ffn_dropout
self.n_blocks = n_transformer_blocks
self.conv_ksize = conv_ksize

unfolding函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
patch_w, patch_h = self.patch_w, self.patch_h
patch_area = patch_w * patch_h
batch_size, in_channels, orig_h, orig_w = x.shape

# 向上取整
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)

# 通过差值的形式,将特征图给差值到刚刚计算得到的new_h和new_w,以保证特征图能够被patch完整划分的
interpolate = False
if new_w != orig_w or new_h != orig_h:
# Note: Padding can be done, but then it needs to be handled in attention function.
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
interpolate = True

# number of patches along width and height
num_patch_w = new_w // patch_w # n_w
num_patch_h = new_h // patch_h # n_h
num_patches = num_patch_h * num_patch_w # N

# 将相同颜色的token给抽离出来
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
x = x.transpose(1, 2)
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
# [B, C, N, P] -> [B, P, N, C]
x = x.transpose(1, 3)
# [B, P, N, C] -> [BP, N, C]
x = x.reshape(batch_size * patch_area, num_patches, -1)

info_dict = {
"orig_size": (orig_h, orig_w),
"batch_size": batch_size,
"interpolate": interpolate,
"total_patches": num_patches,
"num_patches_w": num_patch_w,
"num_patches_h": num_patch_h,
}

return x, info_dict

folding函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
n_dim = x.dim()
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
x.shape
)
# [BP, N, C] --> [B, P, N, C]
x = x.contiguous().view(
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
)

batch_size, pixels, num_patches, channels = x.size()
num_patch_h = info_dict["num_patches_h"]
num_patch_w = info_dict["num_patches_w"]

# [B, P, N, C] -> [B, C, N, P]
x = x.transpose(1, 3)
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
x = x.transpose(1, 2)
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
if info_dict["interpolate"]:
x = F.interpolate(
x,
size=info_dict["orig_size"],
mode="bilinear",
align_corners=False,
)
return x

正向传播函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def forward(self, x: Tensor) -> Tensor:
res = x

fm = self.local_rep(x)

# convert feature map to patches
patches, info_dict = self.unfolding(fm)

# learn global representations
for transformer_layer in self.global_rep:
patches = transformer_layer(patches)

# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
fm = self.folding(x=patches, info_dict=info_dict)

fm = self.conv_proj(fm)

fm = self.fusion(torch.cat((res, fm), dim=1))
return fm

MobileViT类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class MobileViT(nn.Module):
"""
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
"""
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
super().__init__()

image_channels = 3
out_channels = 16

self.conv_1 = ConvLayer(
in_channels=image_channels,
out_channels=out_channels,
kernel_size=3,
stride=2
)

self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])

exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
self.conv_1x1_exp = ConvLayer(
in_channels=out_channels,
out_channels=exp_channels,
kernel_size=1
)

self.classifier = nn.Sequential()
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
self.classifier.add_module(name="flatten", module=nn.Flatten())
if 0.0 < model_cfg["cls_dropout"] < 1.0:
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))

# weight init
self.apply(self.init_parameters)

_make_layer函数

1
2
3
4
5
6
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
block_type = cfg.get("block_type", "mobilevit")
if block_type.lower() == "mobilevit":
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
else:
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
_make_mobilenet_layer函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@staticmethod
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
output_channels = cfg.get("out_channels")
num_blocks = cfg.get("num_blocks", 2)
expand_ratio = cfg.get("expand_ratio", 4)
block = []

for i in range(num_blocks):
stride = cfg.get("stride", 1) if i == 0 else 1

layer = InvertedResidual(
in_channels=input_channel,
out_channels=output_channels,
stride=stride,
expand_ratio=expand_ratio
)
block.append(layer)
input_channel = output_channels

return nn.Sequential(*block), input_channel
_make_mit_layer函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@staticmethod
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
stride = cfg.get("stride", 1)
block = []

if stride == 2:
layer = InvertedResidual(
in_channels=input_channel,
out_channels=cfg.get("out_channels"),
stride=stride,
expand_ratio=cfg.get("mv_expand_ratio", 4)
)

block.append(layer)
input_channel = cfg.get("out_channels")

transformer_dim = cfg["transformer_channels"]
ffn_dim = cfg.get("ffn_dim")
num_heads = cfg.get("num_heads", 4)
head_dim = transformer_dim // num_heads

if transformer_dim % head_dim != 0:
raise ValueError("Transformer input dimension should be divisible by head dimension. "
"Got {} and {}.".format(transformer_dim, head_dim))

block.append(MobileViTBlock(
in_channels=input_channel,
transformer_dim=transformer_dim,
ffn_dim=ffn_dim,
n_transformer_blocks=cfg.get("transformer_blocks", 1),
patch_h=cfg.get("patch_h", 2),
patch_w=cfg.get("patch_w", 2),
dropout=cfg.get("dropout", 0.1),
ffn_dropout=cfg.get("ffn_dropout", 0.0),
attn_dropout=cfg.get("attn_dropout", 0.1),
head_dim=head_dim,
conv_ksize=3
))

return nn.Sequential(*block), input_channel

init_parameters函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@staticmethod
def init_parameters(m):
if isinstance(m, nn.Conv2d):
if m.weight is not None:
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
if m.weight is not None:
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.Linear,)):
if m.weight is not None:
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
else:
pass

forward函数

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x: Tensor) -> Tensor:
x = self.conv_1(x)
x = self.layer_1(x)
x = self.layer_2(x)

x = self.layer_3(x)
x = self.layer_4(x)
x = self.layer_5(x)
x = self.conv_1x1_exp(x)
x = self.classifier(x)
return x

transformer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor


class MultiHeadAttention(nn.Module):
"""
This layer applies a multi-head self- or cross-attention as described in
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
Args:
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
num_heads (int): Number of heads in multi-head attention
attn_dropout (float): Attention dropout. Default: 0.0
bias (bool): Use bias or not. Default: ``True``
Shape:
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
and :math:`C_{in}` is input embedding dim
- Output: same shape as the input
"""

def __init__(
self,
embed_dim: int,
num_heads: int,
attn_dropout: float = 0.0,
bias: bool = True,
*args,
**kwargs
) -> None:
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
self.__class__.__name__, embed_dim, num_heads
)
)

self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)

self.attn_dropout = nn.Dropout(p=attn_dropout)
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)

self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.softmax = nn.Softmax(dim=-1)
self.num_heads = num_heads
self.embed_dim = embed_dim

def forward(self, x_q: Tensor) -> Tensor:
# [N, P, C]
b_sz, n_patches, in_channels = x_q.shape

# self-attention
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)

# [N, P, 3, h, c] -> [N, h, 3, P, C]
qkv = qkv.transpose(1, 3).contiguous()

# [N, h, 3, P, C] -> [N, h, P, C] x 3
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

query = query * self.scaling

# [N h, P, c] -> [N, h, c, P]
key = key.transpose(-1, -2)

# QK^T
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
attn = torch.matmul(query, key)
attn = self.softmax(attn)
attn = self.attn_dropout(attn)

# weighted sum
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
out = torch.matmul(attn, value)

# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
out = self.out_proj(out)

return out


class TransformerEncoder(nn.Module):
"""
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
Args:
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
ffn_latent_dim (int): Inner dimension of the FFN
num_heads (int) : Number of heads in multi-head attention. Default: 8
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
dropout (float): Dropout rate. Default: 0.0
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
Shape:
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
and :math:`C_{in}` is input embedding dim
- Output: same shape as the input
"""

def __init__(
self,
embed_dim: int,
ffn_latent_dim: int,
num_heads: Optional[int] = 8,
attn_dropout: Optional[float] = 0.0,
dropout: Optional[float] = 0.0,
ffn_dropout: Optional[float] = 0.0,
*args,
**kwargs
) -> None:
super().__init__()

attn_unit = MultiHeadAttention(
embed_dim,
num_heads,
attn_dropout=attn_dropout,
bias=True
)

self.pre_norm_mha = nn.Sequential(
nn.LayerNorm(embed_dim),
attn_unit,
nn.Dropout(p=dropout)
)

self.pre_norm_ffn = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
nn.SiLU(),
nn.Dropout(p=ffn_dropout),
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
nn.Dropout(p=dropout)
)
self.embed_dim = embed_dim
self.ffn_dim = ffn_latent_dim
self.ffn_dropout = ffn_dropout
self.std_dropout = dropout

def forward(self, x: Tensor) -> Tensor:
# multi-head attention
res = x
x = self.pre_norm_mha(x)
x = x + res

# feed forward network
x = x + self.pre_norm_ffn(x)
return x

unfold_test

up将把token按照相同颜色抽离出来的那部分代码自己重新写了,会更加容易理解。(这一块看图理解了,但代码是怎么据图片那样将颜色相同的token拼接成一个向量的,还没搞明白)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import time
import torch

batch_size = 8
in_channels = 32
patch_h = 2
patch_w = 2
num_patch_h = 16
num_patch_w = 16
num_patches = num_patch_h * num_patch_w
patch_area = patch_h * patch_w


def official(x: torch.Tensor):
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
x = x.transpose(1, 2)
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
# [B, C, N, P] -> [B, P, N, C]
x = x.transpose(1, 3)
# [B, P, N, C] -> [BP, N, C]
x = x.reshape(batch_size * patch_area, num_patches, -1)

return x


def my_self(x: torch.Tensor):
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
x = x.transpose(3, 4)
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
# [B, C, N, P] -> [B, P, N, C]
x = x.transpose(1, 3)
# [B, P, N, C] -> [BP, N, C]
x = x.reshape(batch_size * patch_area, num_patches, -1)

return x


if __name__ == '__main__':
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
print(torch.equal(official(t), my_self(t)))

t1 = time.time()
for _ in range(1000):
official(t)
print(f"official time: {time.time() - t1}")

t1 = time.time()
for _ in range(1000):
my_self(t)
print(f"self time: {time.time() - t1}")

model_config

模型配置-MobileViT-XXS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def get_config(mode: str = "xxs") -> dict:
if mode == "xx_small":
mv2_exp_mult = 2
config = {
"layer1": {
"out_channels": 16,
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": 24,
"expand_ratio": mv2_exp_mult,
"num_blocks": 3,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": 48,
"transformer_channels": 64,
"ffn_dim": 128,
"transformer_blocks": 2,
"patch_h": 2, # 8,
"patch_w": 2, # 8,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": 64,
"transformer_channels": 80,
"ffn_dim": 160,
"transformer_blocks": 4,
"patch_h": 2, # 4,
"patch_w": 2, # 4,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": 80,
"transformer_channels": 96,
"ffn_dim": 192,
"transformer_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
"cls_dropout": 0.1
}
elif mode == "x_small":
mv2_exp_mult = 4
config = {
"layer1": {
"out_channels": 32,
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": 48,
"expand_ratio": mv2_exp_mult,
"num_blocks": 3,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": 64,
"transformer_channels": 96,
"ffn_dim": 192,
"transformer_blocks": 2,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": 80,
"transformer_channels": 120,
"ffn_dim": 240,
"transformer_blocks": 4,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": 96,
"transformer_channels": 144,
"ffn_dim": 288,
"transformer_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
"cls_dropout": 0.1
}
elif mode == "small":
mv2_exp_mult = 4
config = {
"layer1": {
"out_channels": 32,
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": 64,
"expand_ratio": mv2_exp_mult,
"num_blocks": 3,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": 96,
"transformer_channels": 144,
"ffn_dim": 288,
"transformer_blocks": 2,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": 128,
"transformer_channels": 192,
"ffn_dim": 384,
"transformer_blocks": 4,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": 160,
"transformer_channels": 240,
"ffn_dim": 480,
"transformer_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"num_heads": 4,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
"cls_dropout": 0.1
}
else:
raise NotImplementedError

for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})

return config

train

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from my_dataset import MyDataSet
from model import mobile_vit_xx_small as create_model
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")

if os.path.exists("./weights") is False:
os.makedirs("./weights")

tb_writer = SummaryWriter()

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

img_size = 224
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])

# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])

batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)

model = create_model(num_classes=args.num_classes).to(device)

if args.weights != "":
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
weights_dict = torch.load(args.weights, map_location=device)
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
# 删除有关分类类别的权重
for k in list(weights_dict.keys()):
if "classifier" in k:
del weights_dict[k]
print(model.load_state_dict(weights_dict, strict=False))

if args.freeze_layers:
for name, para in model.named_parameters():
# 除head外,其他权重全部冻结
if "classifier" not in name:
para.requires_grad_(False)
else:
print("training {}".format(name))

pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)

best_acc = 0.
for epoch in range(args.epochs):
# train
train_loss, train_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)

# validate
val_loss, val_acc = evaluate(model=model,
data_loader=val_loader,
device=device,
epoch=epoch)

tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
tb_writer.add_scalar(tags[0], train_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_loss, epoch)
tb_writer.add_scalar(tags[3], val_acc, epoch)
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "./weights/best_model.pth")

torch.save(model.state_dict(), "./weights/latest_model.pth")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--lr', type=float, default=0.0002)

# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="D:/python_test/deep-learning-for-image-processing/data_set/flower_data/flower_photos")

# 预训练权重路径,如果不想载入就设置为空字符
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
help='initial weights path')
# 是否冻结权重
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

opt = parser.parse_args()

main(opt)

训练结果

训练结果

predict

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import mobile_vit_xx_small as create_model


def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

img_size = 224
data_transform = transforms.Compose(
[transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# load image
img_path = "tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
class_indict = json.load(f)

# create model
model = create_model(num_classes=5).to(device)
# load model weights
model_weight_path = "./weights/best_model.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()

print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()


if __name__ == '__main__':
main()

预测结果

预测结果