深度学习模型之CNN(二十二)使用pytorch搭建Vision Transformer(vit)模型

工程目录

1
2
3
4
5
6
7
8
9
10
├── vision transformer
├── vit_model.py(模型文件)
├── my_dataset.py(数据处理文件)
├── train.py(调用模型训练,自动生成class_indices.json,vision_transformer.pth文件)
├── predict.py(调用模型进行预测)
├── utils.py(工具文件,用得上就对了)
├── tulip.jpg(用来根据前期的训练结果来predict图片类型)
└── vit_base_patch16_224_in21k.pth(迁移学习,提前下载好vit_base_patch16_224_in21k.pth权重脚本)
└── data_set
└── data数据集

vit_model.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


def drop_path(x, drop_prob: float = 0., training: bool = False):
# ......

class DropPath(nn.Module):
def __init__(self, drop_prob=None):
# ......

def forward(self, x):
# ......

class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
# ......

def forward(self, x):
# ......

class Attention(nn.Module):
def __init__(self,
dim, # 输入token的dim
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.,
proj_drop_ratio=0.):
# ......

def forward(self, x):
# ......

class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
# ......

def forward(self, x):
# ......

class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
# ......

def forward(self, x):
# ......

class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
# ......

def forward_features(self, x):
# ......

def forward(self, x):
# ......

def _init_vit_weights(m):
# ......

def vit_base_patch16_224(num_classes: int = 1000):
# ......

def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
# ......

def vit_base_patch32_224(num_classes: int = 1000):
# ......

def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
# ......

def vit_large_patch16_224(num_classes: int = 1000):
# ......

def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
# ......

def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
# ......

def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
# ......

DropPath类

DropPath类在正向传播过程中直接调用drop_path方法,就是一个Stochastic Depth,在EfficientNet中有详细讲解,这里照搬过来。

正向传播过程中将输入的特征矩阵经历了一个又一个block,每一个block都可以认为是一个残差结构。例如主分支通过$f$函数进行输出,shortcut直接从输入引到输出,在此过程中,会以一定的概率来对主分支进行丢弃(直接放弃整个主分支,相当于直接将上一层的输出引入到下一层的输入,相当于没有这一层)。即Stochastic Depth(随机深度,指的是网络的depth,因为会随机丢弃任意一层block)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

PatchEmbed类

对应Patch Embedding层(经过一个卷积核大小为16x16,stride = 16的卷积层,再进行一个Flatten展平处理,使得输入的RGB彩色图像shape从[224,224,3]->[14,14,768]->[196,768]

Patch-Embedding层

初始化函数

  • grid_size:img_size除以patch_size,即224 // 16 = 14,因此grid_size为14x14,对应卷积层输出的特征矩阵宽高;
  • num_patches:计算patches的数目,即14x14 = 196;
  • proj:定义卷积层;
  • norm:norm_layer默认为None,如果有传入norm_layer,则会初始化norm_layer,如果没有传入,则nn.Identity()表示不做处理。

正向传播函数

将特征矩阵传入正向传播函数,首先对输入x进行判断,如果输入特征矩阵的宽、高与预先设定的值不一样的话,程序会报错处理。

注意:这里所讲的ViT模型,并不与之前所讲的CNN模型那样,可以更改输入图片的大小的。在ViT模型中,输入图片大小必须是固定的(因为之后的全连接层并没有再特殊处理,所以要求图片大小固定)

接下来将传入的数据走到卷积层得到的Tensor是[B,C,H,W](Batch,Channel,Height,Width),之后进行flatten展平处理(flatten(2):将位置在2及以后的信息进行展平,该处意为将H和W进行展平,即[B,C,H,W]->[B,C,HW]),再通过transpose多个位置替换函数(transpose(1,2)的1和2位置交换,该处意为C和HW交换,即[B,C,HW]->[B,HW,C]) 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
# 14x14
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
# 14x14 = 196
self.num_patches = self.grid_size[0] * self.grid_size[1]

# 相当于将一整张图片分成14个大小为16x16的patch
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
# 因为norm_layer=None,所以不做任何处理
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
# batch,3,224,224
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

# flatten: [B, C, H, W] -> [B, C, HW] == [B,768,14,14]->[B,768,196]
# transpose: [B, C, HW] -> [B, HW, C] == [B,768,196]->[B,196,768]
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x

Attention类

用来实现Transformer当中的Muti-Head Transformer模块。

初始化函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Attention(nn.Module):
def __init__(self,
dim, # 输入token的dim
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.,
proj_drop_ratio=0.):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
  • dim:token的dimension;
  • qkv_bias:生成qkv时是否使用偏置;
  • head_dim:针对每一个head的dimension,就是传入的dim // num_heads = 768//8=96;

通过$W^q$,$W^k$,$W^v$来生成$q$,$k$,$v$,接着根据head的数目,将$q$,$k$,$v$均分成多少份(例如下图有2个head,则将$q$,$k$,$v$均分为2部分。针对每一个部分也就是每一个head所采用的$qkv$的dimension = 最开始的dimension除以$qkv$的个数),即head_dim = dim // num_heads,得到了每一个head的$qkv$所对应的dimension。

dim即自己设定的总编码的大小(这里为768),除以heads就是表示用几个att分别得到的子编码来组成这个768

Muti-Head_Transformer中的head

  • qk_scale:如果传入了qk_scale,则将self.scale = qk_scale,否则self.scale = head_dim的开平方分之一,即$self.scale = 1/\sqrt head_dim$;

Attention公式

  • self.qkv:$qkv$生成是直接通过一个全连接层实现的。

注意:在其他人源码中有些是需要通过三个全连接层来分别得到$qkv$,但是在此处,直接使用一个全连接层直接得到$qkv$。实际上没有区别,因此此处的全连接层的节点个数是dim*3,和使用3个节点个数为dim的全连接层的效果是一样的(分为3个可能是为了并行化效果更好)。

  • attn_drop:定义一个Dropout层,失活性为传入的attn_drop_ratio;
  • proj:再定义一个全连接层,输入输出节点个数都是dim;

在Muti-Head Transformer中,会将每一个得到的head进行concat拼接,之后用一个$W^o$进行映射

Muti-Head_Transformer中concat拼接

  • proj_drop:再定义一个Dropout层,失活性为传入的proj_drop_ratio。

正向传播函数

传入的x实际上为[batch_size, num_patches + 1, total_embed_dim]batch_size指训练时这一批数据传入的图片的数目;num_patches指传入图片高宽除以Patches高宽的得数(这里即$(224 // 16)^2 = (14*14)^2 = 196$),+1是指经过Patch Embedding层之后会拼接上[class]token,即196+1 = 197;total_embed_dim指768。

qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

qkv():先通过qkv全连接层生成qkv(qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]);

reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head];

permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head](permute函数可批量调整顺序,这里这么改是为了方便后续做运算),第一个维度3,表示包含三个张量,分别对应Queries(Q)、Keys(K)、Values(V)

q, k, v = qkv[0], qkv[1], qkv[2]:通过切片的方式拿到qkv的数据

q, k, v此时的shape为[batch_size, num_heads, num_patches + 1, embed_dim_per_head]

  • transpose函数:可将其中两个位置相互交换;

例如:transpose(-2, -1):将最后两个维度进行调换。transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]

  • permute函数:批量调整顺序

attn = (q @ k.transpose(-2, -1)) * self.scale

@:矩阵乘法的符号

k的shape为[batch_size, num_heads, num_patches + 1, embed_dim_per_head],经过transpose函数将最后两个维度调换之后,为[batch_size, num_heads, embed_dim_per_head, num_patches+ 1 ],即实现矩阵的转置

经过q与k的转置矩阵相乘之后(实际为最后两个维度相乘),即[batch_size, num_heads, num_patches + 1, embed_dim_per_head]*[batch_size, num_heads, embed_dim_per_head, num_patches+ 1 ] = [batch_size, num_heads, num_patches + 1, num_patches+ 1 ](因为矩阵中,axb的矩阵乘以bxa的矩阵,得出来的值为axa

之后再乘上scale,再经过softmax处理,最终呈现的原理就是下图所示

Attention公式

attn.softmax(dim=-1):其中的dim = -1指的是在矩阵的每一行进行softmax处理,如果是dim = -2,则是在每一列进行softmax处理。

之后再根据每个v的权重经过Dropout层。

为了实现上图的公式,还需要根据softmax之后,针对每一个V的权重来进行加权求和的操作。即(attn @ v).transpose(1, 2).reshape(B, N, C)中的attn @ v。这里经过加权求和之后,shape变为[batch_size, num_heads, num_patches + 1, embed_dim_per_head]。最后经过reshape操作,就是将最后两个维度拼接在一起。

self.proj(x):有时候需要通过一个$W^o$去映射,因此这里通过proj全连接层得到结果;

proj_drop:再通过一个Dropout层得到最终输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim]
# [224,197,768]
B, N, C = x.shape

# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head] = [batch,8,197,96]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

# transpose(转置): -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: 矩阵相乘multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

# attn @ v这一步结束公式,之后的操作是还原Tensor通道排列顺序
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

Mlp类

对应为上一篇文中所讲的Encoder Block中的MLP Block,也就是下图所示的结构:

MLP Block结构

  • hidden_features:对应的是第一个全连接层的节点个数,通常为in_features节点个数的4倍;
  • out_features:和in_features节点个数一致
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
# Block类中没有输入out_features,所以默认out_features=None
# 所以这条语句out_features = in_features
out_features = out_features or in_features
# hidden_features = hidden_features = 3072
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

Block类

对应为上一篇文中所讲的Encoder Block,结构如下图所示:

Encoder Block结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
# 在VisionTransoformer类中传入值为从0到drop_path_ratio的12次等差数列
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
# mlp_hidden_dim = 768*4=3072
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

def forward(self, x):
# 两个残差结构
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x

VisionTransformer类

VisionTransformer模型结构

初始化函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_c (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_ratio (float): dropout rate
attn_drop_ratio (float): attention dropout rate
drop_path_ratio (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
"""
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU

# embed_layer=PatchEmbed
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
# num_patches = 196
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
# num_patches + self.num_tokens = 196+1=197
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)

# dpr是一个列表,按次序递增的drop_path_ratio等差数列。
# x.item()就是指for x in torch.linspace(0, drop_path_ratio, depth)中的x,只不过item()更精确
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
# 进入Encoder Block(depth = 12,所以重复12次)
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)

# Representation layer
# representation_size=None, distilled=False
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()

# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

# Weight init
nn.init.trunc_normal_(self.pos_embed, std=0.02)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)

nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(_init_vit_weights)
  • depth:在Transformer Encoder中重复Encoder Block多少次,这里为12次;
  • representation_size:对应MLP Head中的Pre-Logits当中全连接层的节点个数,如果为None,则不会构建MLP Head中的Pre-Logits,即在MLP Head中只有一个全连接层;
  • embed_layer:指Patch Embedding层。

norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6):默认为norm_layer ,使用partial函数将nn.LayerNorm传入的eps默认参数改为1e-6

nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)

LayerNorm也是归一化的一种方法,与BatchNorm不同的是它是对每单个batch进行的归一化,而batchnorm是对所有batch一起进行归一化的

1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)):通过nn.Parameter构建了一个可训练的参数,直接使用一个零矩阵进行初始化,shape大小为[1,1,embed_dim](batch维度,[class]token中1x768的1,768);

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)):结构图中,Position Embedding的shape是和拼接之后的shape一样,都是[197,768]。同样根据nn.Parameter构建了一个可训练的参数,直接使用一个零矩阵进行初始化,第一个1是batch维度(可以不用管),num_patches + self.num_tokens=14x14+1 = 196+1=197,embed_dim即传入值;

self.pos_drop = nn.Dropout(p=drop_ratio):此处的Dropout层指的是Transformer Encoder之前的Droupout层;

dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]:根据传入的drop_path_ratio构建一个等差序列,范围是从0到drop_path_ratio,这个序列当中总共由depth个元素。也就是说,在Transformer Encoder当中,每一个Encoder Block所采用的drop_path方法所使用的drop_path_ratio是递增的。此刻默认为0;

1
2
3
4
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])

构建Block,即Transformer Encoder堆叠12次。通过循环depth达到重复blocks的效果,每循环一次,都会在列表当中添加一个Block,也就是Encoder Block()其中传入参数都是不变的,除非drop_path_ratio=dpr[i],是递增的。再通过nn.Sequential将列表中的所有模块打包成一个整体,赋值给self.blocks;

representation_size:对应MLP Head中的Pre-Logits当中全连接层的节点个数,如果为None,则不会构建MLP Head中的Pre-Logits。

1
2
3
4
>self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))

通过nn.Sequential方法再加上OrderedDict有序字典来构造pre_logits,即一个全连接层+Tanh激活函数

self.head:最后的全连接层;

正向传播函数

forward_features函数

  • 首先将输入x传递给patch_embed,即对应着Patch Embedding结构;
  • 接下来将cls_token进行expand处理,在之前构建的cls_token的shape为[1,1,768],那么这条语句会根据传入的batch_size的个数去expand这的cls_token,也就是说将cls_token在batch维度复制batch_size份,即shape变成了[B,1,768];
  • self.dist_token在这默认为None,即对cls_token与x在1维度进行拼接,即196的维度,[B,197,768];
  • 将拼接之后得出的数据加上self.pos_embed,再通过一个pos.drop,即dropout层;
  • self.pre_logits(x[:, 0]):通过该条语句获取输出,将x参数的第二个维度(除开第一个维度的batch)上的索引为0的数据

forward函数

直接forward_features函数处理,最后的x = self.head(x),即结构图中最后一个全连接层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # [B, 196, 768]
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]

def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x

实例化模型

  • LayersTransformer Encoder中重复堆叠Encoder Block的次数
  • Hidden Size通过Embedding层后每个token的dim (向量的长度);
  • MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍)
  • Heads代表Transformer中Multi-Head Attention的heads数
Model Patch Size Layers Hidden Size D MLP size Heads Params
ViT-Base 16x16 12 768 3072 12 86M
ViT-Large 16x16 24 1024 4096 16 307M
ViT-Huge 14x14 32 1280 5120 16 632M
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def vit_base_patch16_224(num_classes: int = 1000):
"""
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
"""
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes)
return model


def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
"""
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model


def vit_base_patch32_224(num_classes: int = 1000):
"""
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
"""
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes)
return model


def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
"""
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model


def vit_large_patch16_224(num_classes: int = 1000):
"""
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
"""
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=None,
num_classes=num_classes)
return model


def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
"""
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes)
return model


def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
"""
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes)
return model


def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: converted weights not currently available, too large for github release hosting.
"""
model = VisionTransformer(img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
representation_size=1280 if has_logits else None,
num_classes=num_classes)
return model

train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import math
import argparse

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms


from my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")

if os.path.exists("./weights") is False:
os.makedirs("./weights")

tb_writer = SummaryWriter()

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])

# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])

batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)

model = create_model(num_classes=args.num_classes, has_logits=False).to(device)

if args.weights != "":
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
weights_dict = torch.load(args.weights, map_location=device)
# 删除不需要的权重
del_keys = ['head.weight', 'head.bias'] if model.has_logits \
else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
for k in del_keys:
del weights_dict[k]
print(model.load_state_dict(weights_dict, strict=False))

if args.freeze_layers:
for name, para in model.named_parameters():
# 除head, pre_logits外,其他权重全部冻结
if "head" not in name and "pre_logits" not in name:
para.requires_grad_(False)
else:
print("training {}".format(name))

pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

for epoch in range(args.epochs):
# train
train_loss, train_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)

scheduler.step()

# validate
val_loss, val_acc = evaluate(model=model,
data_loader=val_loader,
device=device,
epoch=epoch)

tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
tb_writer.add_scalar(tags[0], train_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_loss, epoch)
tb_writer.add_scalar(tags[3], val_acc, epoch)
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lrf', type=float, default=0.01)

# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="D:/python_test/deep-learning-for-image-processing/data_set/flower_data/flower_photos")
parser.add_argument('--model-name', default='', help='create model name')

# 预训练权重路径,如果不想载入就设置为空字符
parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth',
help='initial weights path')
# 是否冻结权重
parser.add_argument('--freeze-layers', type=bool, default=True)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

opt = parser.parse_args()

main(opt)

训练结果

训练结果

predict.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from vit_model import vit_base_patch16_224_in21k as create_model


def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# load image
img_path = "tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
class_indict = json.load(f)

# create model
model = create_model(num_classes=5, has_logits=False).to(device)
# load model weights
model_weight_path = "./weights/model-9.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()

print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()


if __name__ == '__main__':
main()

预测结果

预测结果

flops.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from fvcore.nn import FlopCountAnalysis

from vit_model import Attention


def main():
# Self-Attention
a1 = Attention(dim=512, num_heads=1)
a1.proj = torch.nn.Identity() # remove Wo

# Multi-Head Attention
a2 = Attention(dim=512, num_heads=8)

# [batch_size, num_tokens, total_embed_dim]
t = (torch.rand(32, 1024, 512),)

flops1 = FlopCountAnalysis(a1, t)
print("Self-Attention FLOPs:", flops1.total())

flops2 = FlopCountAnalysis(a2, t)
print("Multi-Head Attention FLOPs:", flops2.total())


if __name__ == '__main__':
main()

utils.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import sys
import json
import pickle
import random

import torch
from tqdm import tqdm

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 排序,保证各平台顺序一致
images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))

for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)

print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
assert len(train_images_path) > 0, "number of training images must greater than 0."
assert len(val_images_path) > 0, "number of validation images must greater than 0."

plot_image = False
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()

return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)

json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)

for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()


def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list


def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
accu_loss = torch.zeros(1).to(device) # 累计损失
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
optimizer.zero_grad()

sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]

pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()

loss = loss_function(pred, labels.to(device))
loss.backward()
accu_loss += loss.detach()

data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)

if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)

optimizer.step()
optimizer.zero_grad()

return accu_loss.item() / (step + 1), accu_num.item() / sample_num


@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
loss_function = torch.nn.CrossEntropyLoss()

model.eval()

accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
accu_loss = torch.zeros(1).to(device) # 累计损失

sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]

pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()

loss = loss_function(pred, labels.to(device))
accu_loss += loss

data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)

return accu_loss.item() / (step + 1), accu_num.item() / sample_num

my_dataset.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
"""自定义数据集"""

def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform

def __len__(self):
return len(self.images_path)

def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]

if self.transform is not None:
img = self.transform(img)

return img, label

@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))

images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels