-
Notifications
You must be signed in to change notification settings - Fork 352
/
frcnn.py
110 lines (104 loc) · 4.66 KB
/
frcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch.nn as nn
from nets.classifier import Resnet50RoIHead, VGG16RoIHead
from nets.resnet50 import resnet50
from nets.rpn import RegionProposalNetwork
from nets.vgg16 import decom_vgg16
class FasterRCNN(nn.Module):
def __init__(self, num_classes,
mode = "training",
feat_stride = 16,
anchor_scales = [8, 16, 32],
ratios = [0.5, 1, 2],
backbone = 'vgg',
pretrained = False):
super(FasterRCNN, self).__init__()
self.feat_stride = feat_stride
#---------------------------------#
# 一共存在两个主干
# vgg和resnet50
#---------------------------------#
if backbone == 'vgg':
self.extractor, classifier = decom_vgg16(pretrained)
#---------------------------------#
# 构建建议框网络
#---------------------------------#
self.rpn = RegionProposalNetwork(
512, 512,
ratios = ratios,
anchor_scales = anchor_scales,
feat_stride = self.feat_stride,
mode = mode
)
#---------------------------------#
# 构建分类器网络
#---------------------------------#
self.head = VGG16RoIHead(
n_class = num_classes + 1,
roi_size = 7,
spatial_scale = 1,
classifier = classifier
)
elif backbone == 'resnet50':
self.extractor, classifier = resnet50(pretrained)
#---------------------------------#
# 构建classifier网络
#---------------------------------#
self.rpn = RegionProposalNetwork(
1024, 512,
ratios = ratios,
anchor_scales = anchor_scales,
feat_stride = self.feat_stride,
mode = mode
)
#---------------------------------#
# 构建classifier网络
#---------------------------------#
self.head = Resnet50RoIHead(
n_class = num_classes + 1,
roi_size = 14,
spatial_scale = 1,
classifier = classifier
)
def forward(self, x, scale=1., mode="forward"):
if mode == "forward":
#---------------------------------#
# 计算输入图片的大小
#---------------------------------#
img_size = x.shape[2:]
#---------------------------------#
# 利用主干网络提取特征
#---------------------------------#
base_feature = self.extractor.forward(x)
#---------------------------------#
# 获得建议框
#---------------------------------#
_, _, rois, roi_indices, _ = self.rpn.forward(base_feature, img_size, scale)
#---------------------------------------#
# 获得classifier的分类结果和回归结果
#---------------------------------------#
roi_cls_locs, roi_scores = self.head.forward(base_feature, rois, roi_indices, img_size)
return roi_cls_locs, roi_scores, rois, roi_indices
elif mode == "extractor":
#---------------------------------#
# 利用主干网络提取特征
#---------------------------------#
base_feature = self.extractor.forward(x)
return base_feature
elif mode == "rpn":
base_feature, img_size = x
#---------------------------------#
# 获得建议框
#---------------------------------#
rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn.forward(base_feature, img_size, scale)
return rpn_locs, rpn_scores, rois, roi_indices, anchor
elif mode == "head":
base_feature, rois, roi_indices, img_size = x
#---------------------------------------#
# 获得classifier的分类结果和回归结果
#---------------------------------------#
roi_cls_locs, roi_scores = self.head.forward(base_feature, rois, roi_indices, img_size)
return roi_cls_locs, roi_scores
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()