深度学习模型之CNN(二十)使用Pytorch搭建EfficientNetV2网络

工程目录

1
2
3
4
5
6
7
8
9
10
├── Test11_efficientnetV2
├── model.py(模型文件)
├── my_dataset.py(数据处理文件)
├── train.py(调用模型训练,自动生成class_indices.json,EfficientNet.pth文件)
├── predict.py(调用模型进行预测)
├── utils.py(工具文件,用得上就对了)
├── tulip.jpg(用来根据前期的训练结果来predict图片类型)
└── pre_efficientnetv2-s.pth(用于迁移学习时,提前下载好官方的pre_efficientnetv2-s权重脚本)
└── data_set
└── data数据集

model.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from collections import OrderedDict
from functools import partial
from typing import Callable, Optional

import torch.nn as nn
import torch
from torch import Tensor


def drop_path(x, drop_prob: float = 0., training: bool = False):
# ......

class DropPath(nn.Module):
def __init__(self, drop_prob=None):
# ......

def forward(self, x):
# ......

class ConvBNAct(nn.Module):
def __init__(self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None):
# ......

class SqueezeExcite(nn.Module):
def __init__(self,
input_c: int, # block input channel
expand_c: int, # block expand channel
se_ratio: float = 0.25):
# ......

class MBConv(nn.Module):
def __init__(self,
kernel_size: int,
input_c: int,
out_c: int,
expand_ratio: int,
stride: int,
se_ratio: float,
drop_rate: float,
norm_layer: Callable[..., nn.Module]):
# ......

def forward(self, x: Tensor) -> Tensor:
# ......

class FusedMBConv(nn.Module):
def __init__(self,
kernel_size: int,
input_c: int,
out_c: int,
expand_ratio: int,
stride: int,
se_ratio: float,
drop_rate: float,
norm_layer: Callable[..., nn.Module]):
# ......

def forward(self, x: Tensor) -> Tensor:
# ......

class EfficientNetV2(nn.Module):
def __init__(self,
model_cnf: list,
num_classes: int = 1000,
num_features: int = 1280,
dropout_rate: float = 0.2,
drop_connect_rate: float = 0.2):
# ......

def efficientnetv2_s(num_classes: int = 1000):
# ......


def efficientnetv2_m(num_classes: int = 1000):
# ......

def efficientnetv2_l(num_classes: int = 1000):
# ......

DropPath类

直接调用drop_path函数,原理和上节课一致,如下文所述。

EfficientNet网络中的Dropout与前期所有网络结构的Dropout不全一样,例如原始的Dropout参数丢弃比例为0.2,但EfficientNet中给出Dropout = 0.2的参数表示该网络在0~0.2的丢弃比例下逐渐失活。引用论文为:Deep Networks with Stochastic Depth

下图可以理解为正向传播过程中将输入的特征矩阵经历了一个又一个block,每一个block都可以认为是一个残差结构。例如主分支通过$f$函数进行输出,shortcut直接从输入引到输出,在此过程中,会以一定的概率来对主分支进行丢弃(直接放弃整个主分支,相当于直接将上一层的输出引入到下一层的输入,相当于没有这一层)。即Stochastic Depth(随即深度,指的是网络的depth,因为会随机丢弃任意一层block)。

下图中表示存活概率从1.0至0.5,一个渐变的过程。但在EfficientNetV2中采用drop_prob是0~0.2的丢弃比例(提升训练速度,小幅提升准确率)。

注意:这里的dropout层仅指Fused-MBConv模块以及MBConv模块中的dropout层,不包括最后全连接前的dropout层

正向传播过程-Dropout

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
This function is taken from the rwightman.
It can be seen here:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L140
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

ConvBNAct类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ConvBNAct(nn.Module):
def __init__(self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None):
super(ConvBNAct, self).__init__()

padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.SiLU # alias Swish (torch>=1.7)

self.conv = nn.Conv2d(in_channels=in_planes,
out_channels=out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)

self.bn = norm_layer(out_planes)
self.act = activation_layer()

def forward(self, x):
result = self.conv(x)
result = self.bn(result)
result = self.act(result)

return result

SqueezeExcite类

SE模块如下图所示,由一个全局平均池化,两个全连接层组成。第一个全连接层的节点个数是输入该MBConv特征矩阵channels的1/4,且使用Swish激活函数。第二个全连接层的节点个数等于Depthwise Conv层输出的特征矩阵channels,且使用Sigmoid激活函数。

SE模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class SqueezeExcite(nn.Module):
def __init__(self,
input_c: int, # block input channel
expand_c: int, # block expand channel
se_ratio: float = 0.25):
super(SqueezeExcite, self).__init__()
squeeze_c = int(input_c * se_ratio)
self.conv_reduce = nn.Conv2d(expand_c, squeeze_c, 1)
self.act1 = nn.SiLU() # alias Swish
self.conv_expand = nn.Conv2d(squeeze_c, expand_c, 1)
self.act2 = nn.Sigmoid()

def forward(self, x: Tensor) -> Tensor:
scale = x.mean((2, 3), keepdim=True)
scale = self.conv_reduce(scale)
scale = self.act1(scale)
scale = self.conv_expand(scale)
scale = self.act2(scale)
return scale * x

MBConv类

MBConv模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class MBConv(nn.Module):
def __init__(self,
kernel_size: int,
input_c: int,
out_c: int,
expand_ratio: int,
stride: int,
se_ratio: float,
drop_rate: float,
norm_layer: Callable[..., nn.Module]):
super(MBConv, self).__init__()

if stride not in [1, 2]:
raise ValueError("illegal stride value.")

self.has_shortcut = (stride == 1 and input_c == out_c)

activation_layer = nn.SiLU # alias Swish
expanded_c = input_c * expand_ratio

# 在EfficientNetV2中,MBConv中不存在expansion=1的情况所以conv_pw肯定存在
assert expand_ratio != 1
# Point-wise expansion
self.expand_conv = ConvBNAct(input_c,
expanded_c,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer)

# Depth-wise convolution
self.dwconv = ConvBNAct(expanded_c,
expanded_c,
kernel_size=kernel_size,
stride=stride,
groups=expanded_c,
norm_layer=norm_layer,
activation_layer=activation_layer)

self.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else nn.Identity()

# Point-wise linear projection
self.project_conv = ConvBNAct(expanded_c,
out_planes=out_c,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Identity) # 注意这里没有激活函数,所有传入Identity

self.out_channels = out_c

# 只有在使用shortcut连接时才使用dropout层
self.drop_rate = drop_rate
if self.has_shortcut and drop_rate > 0:
self.dropout = DropPath(drop_rate)

def forward(self, x: Tensor) -> Tensor:
result = self.expand_conv(x)
result = self.dwconv(result)
result = self.se(result)
result = self.project_conv(result)

if self.has_shortcut:
if self.drop_rate > 0:
result = self.dropout(result)
result += x

return result

FusedMBConv类

Fused-MBConv模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class FusedMBConv(nn.Module):
def __init__(self,
kernel_size: int,
input_c: int,
out_c: int,
expand_ratio: int,
stride: int,
se_ratio: float,
drop_rate: float,
norm_layer: Callable[..., nn.Module]):
super(FusedMBConv, self).__init__()

assert stride in [1, 2]
assert se_ratio == 0

self.has_shortcut = stride == 1 and input_c == out_c
self.drop_rate = drop_rate

self.has_expansion = expand_ratio != 1

activation_layer = nn.SiLU # alias Swish
expanded_c = input_c * expand_ratio

# 只有当expand ratio不等于1时才有expand conv
if self.has_expansion:
# Expansion convolution
self.expand_conv = ConvBNAct(input_c,
expanded_c,
kernel_size=kernel_size,
stride=stride,
norm_layer=norm_layer,
activation_layer=activation_layer)

self.project_conv = ConvBNAct(expanded_c,
out_c,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Identity) # 注意没有激活函数
else:
# 当只有project_conv时的情况
self.project_conv = ConvBNAct(input_c,
out_c,
kernel_size=kernel_size,
stride=stride,
norm_layer=norm_layer,
activation_layer=activation_layer) # 注意有激活函数

self.out_channels = out_c

# 只有在使用shortcut连接时才使用dropout层
self.drop_rate = drop_rate
if self.has_shortcut and drop_rate > 0:
self.dropout = DropPath(drop_rate)

def forward(self, x: Tensor) -> Tensor:
if self.has_expansion:
result = self.expand_conv(x)
result = self.project_conv(result)
else:
result = self.project_conv(x)

if self.has_shortcut:
if self.drop_rate > 0:
result = self.dropout(result)

result += x

return result

EfficientNetV2类

EfficientNetV2-S模型框架

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class EfficientNetV2(nn.Module):
def __init__(self,
model_cnf: list,
num_classes: int = 1000,
num_features: int = 1280,
dropout_rate: float = 0.2,
drop_connect_rate: float = 0.2):
super(EfficientNetV2, self).__init__()

for cnf in model_cnf:
assert len(cnf) == 8
# eps=1e-3为论文设置,不能更改
norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)

stem_filter_num = model_cnf[0][4]

self.stem = ConvBNAct(3,
stem_filter_num,
kernel_size=3,
stride=2,
norm_layer=norm_layer) # 激活函数默认是SiLU

total_blocks = sum([i[0] for i in model_cnf])
block_id = 0
blocks = []
for cnf in model_cnf:
repeats = cnf[0]
op = FusedMBConv if cnf[-2] == 0 else MBConv
for i in range(repeats):
blocks.append(op(kernel_size=cnf[1],
input_c=cnf[4] if i == 0 else cnf[5],
out_c=cnf[5],
expand_ratio=cnf[3],
stride=cnf[2] if i == 0 else 1,
se_ratio=cnf[-1],
drop_rate=drop_connect_rate * block_id / total_blocks,
norm_layer=norm_layer))
block_id += 1
self.blocks = nn.Sequential(*blocks)

head_input_c = model_cnf[-1][-3]
head = OrderedDict()

head.update({"project_conv": ConvBNAct(head_input_c,
num_features,
kernel_size=1,
norm_layer=norm_layer)}) # 激活函数默认是SiLU

head.update({"avgpool": nn.AdaptiveAvgPool2d(1)})
head.update({"flatten": nn.Flatten()})

if dropout_rate > 0:
head.update({"dropout": nn.Dropout(p=dropout_rate, inplace=True)})
head.update({"classifier": nn.Linear(num_features, num_classes)})

self.head = nn.Sequential(head)

# initial weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

def forward(self, x: Tensor) -> Tensor:
# Stage0对应的3x3卷积
x = self.stem(x)
x = self.blocks(x)
x = self.head(x)

return x

实例化efficientnetv2类

  • r代表当前Stage中Operator重复堆叠的次数;
  • k代表kernel_size;
  • s代表步距stride;
  • e代表expansion ratio;
  • i代表input channels;
  • o代表output channels;
  • c代表conv_type,1代表Fused-MBConv,0代表MBConv(默认为MBConv);
  • se代表使用SE模块,以及se_ratio
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#################### EfficientNet V2 configs ####################
v2_s_block = [ # about base * (width1.4, depth1.8)
'r2_k3_s1_e1_i24_o24_c1',
'r4_k3_s2_e4_i24_o48_c1',
'r4_k3_s2_e4_i48_o64_c1',
'r6_k3_s2_e4_i64_o128_se0.25',
'r9_k3_s1_e6_i128_o160_se0.25',
'r15_k3_s2_e6_i160_o256_se0.25',
]
v2_m_block = [ # about base * (width1.6, depth2.2)
'r3_k3_s1_e1_i24_o24_c1',
'r5_k3_s2_e4_i24_o48_c1',
'r5_k3_s2_e4_i48_o80_c1',
'r7_k3_s2_e4_i80_o160_se0.25',
'r14_k3_s1_e6_i160_o176_se0.25',
'r18_k3_s2_e6_i176_o304_se0.25',
'r5_k3_s1_e6_i304_o512_se0.25',
]
v2_l_block = [ # about base * (width2.0, depth3.1)
'r4_k3_s1_e1_i32_o32_c1',
'r7_k3_s2_e4_i32_o64_c1',
'r7_k3_s2_e4_i64_o96_c1',
'r10_k3_s2_e4_i96_o192_se0.25',
'r19_k3_s1_e6_i192_o224_se0.25',
'r25_k3_s2_e6_i224_o384_se0.25',
'r7_k3_s1_e6_i384_o640_se0.25',
]

按照配置文件实例化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def efficientnetv2_s(num_classes: int = 1000):
"""
EfficientNetV2
https://arxiv.org/abs/2104.00298
"""
# train_size: 300, eval_size: 384

# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
model_config = [[2, 3, 1, 1, 24, 24, 0, 0],
[4, 3, 2, 4, 24, 48, 0, 0],
[4, 3, 2, 4, 48, 64, 0, 0],
[6, 3, 2, 4, 64, 128, 1, 0.25],
[9, 3, 1, 6, 128, 160, 1, 0.25],
[15, 3, 2, 6, 160, 256, 1, 0.25]]

model = EfficientNetV2(model_cnf=model_config,
num_classes=num_classes,
dropout_rate=0.2)
return model
def efficientnetv2_m(num_classes: int = 1000):
"""
EfficientNetV2
https://arxiv.org/abs/2104.00298
"""
# train_size: 384, eval_size: 480

# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
model_config = [[3, 3, 1, 1, 24, 24, 0, 0],
[5, 3, 2, 4, 24, 48, 0, 0],
[5, 3, 2, 4, 48, 80, 0, 0],
[7, 3, 2, 4, 80, 160, 1, 0.25],
[14, 3, 1, 6, 160, 176, 1, 0.25],
[18, 3, 2, 6, 176, 304, 1, 0.25],
[5, 3, 1, 6, 304, 512, 1, 0.25]]

model = EfficientNetV2(model_cnf=model_config,
num_classes=num_classes,
dropout_rate=0.3)
return model


def efficientnetv2_l(num_classes: int = 1000):
"""
EfficientNetV2
https://arxiv.org/abs/2104.00298
"""
# train_size: 384, eval_size: 480

# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
model_config = [[4, 3, 1, 1, 32, 32, 0, 0],
[7, 3, 2, 4, 32, 64, 0, 0],
[7, 3, 2, 4, 64, 96, 0, 0],
[10, 3, 2, 4, 96, 192, 1, 0.25],
[19, 3, 1, 6, 192, 224, 1, 0.25],
[25, 3, 2, 6, 224, 384, 1, 0.25],
[7, 3, 1, 6, 384, 640, 1, 0.25]]

model = EfficientNetV2(model_cnf=model_config,
num_classes=num_classes,
dropout_rate=0.4)
return model

train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")

print(args)
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

img_size = {"s": [300, 384], # train_size, val_size
"m": [384, 480],
"l": [384, 480]}
num_model = "s"

data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])

# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])

batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)

# 如果存在预训练权重则载入
model = create_model(num_classes=args.num_classes).to(device)
if args.weights != "":
if os.path.exists(args.weights):
weights_dict = torch.load(args.weights, map_location=device)
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
print(model.load_state_dict(load_weights_dict, strict=False))
else:
raise FileNotFoundError("not found weights file: {}".format(args.weights))

# 是否冻结权重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除head外,其他权重全部冻结
if "head" not in name:
para.requires_grad_(False)
else:
print("training {}".format(name))

pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

for epoch in range(args.epochs):
# train
train_loss, train_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)

scheduler.step()

# validate
val_loss, val_acc = evaluate(model=model,
data_loader=val_loader,
device=device,
epoch=epoch)

tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
tb_writer.add_scalar(tags[0], train_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_loss, epoch)
tb_writer.add_scalar(tags[3], val_acc, epoch)
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lrf', type=float, default=0.01)

# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="D:/python_test/deep-learning-for-image-processing/data_set/flower_data/flower_photos")

# download model weights
# 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ 密码: 5gu1
parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-s.pth',
help='initial weights path')
parser.add_argument('--freeze-layers', type=bool, default=True)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

opt = parser.parse_args()

main(opt)

训练结果

训练结果

predict.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import efficientnetv2_s as create_model


def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

img_size = {"s": [300, 384], # train_size, val_size
"m": [384, 480],
"l": [384, 480]}
num_model = "s"

data_transform = transforms.Compose(
[transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# load image
img_path = "../tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
class_indict = json.load(f)

# create model
model = create_model(num_classes=5).to(device)
# load model weights
model_weight_path = "./weights/model-29.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()

print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()


if __name__ == '__main__':
main()

预测结果

预测结果

utils.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import sys
import json
import pickle
import random

import torch
from tqdm import tqdm

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 排序,保证各平台顺序一致
images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))

for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)

print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
assert len(train_images_path) > 0, "number of training images must greater than 0."
assert len(val_images_path) > 0, "number of validation images must greater than 0."

plot_image = False
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()

return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)

json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)

for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()


def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list


def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
accu_loss = torch.zeros(1).to(device) # 累计损失
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
optimizer.zero_grad()

sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]

pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()

loss = loss_function(pred, labels.to(device))
loss.backward()
accu_loss += loss.detach()

data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)

if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)

optimizer.step()
optimizer.zero_grad()

return accu_loss.item() / (step + 1), accu_num.item() / sample_num


@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
loss_function = torch.nn.CrossEntropyLoss()

model.eval()

accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
accu_loss = torch.zeros(1).to(device) # 累计损失

sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]

pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()

loss = loss_function(pred, labels.to(device))
accu_loss += loss

data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)

return accu_loss.item() / (step + 1), accu_num.item() / sample_num

my_dataset.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
"""自定义数据集"""

def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform

def __len__(self):
return len(self.images_path)

def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]

if self.transform is not None:
img = self.transform(img)

return img, label

@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))

images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels