Skip to content

Commit

Permalink
Change learning rate, freeze bn, add demo, fix topk
Browse files Browse the repository at this point in the history
  • Loading branch information
v.toandm2 committed Dec 6, 2019
1 parent a19c913 commit 5a08602
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 44 deletions.
5 changes: 5 additions & 0 deletions datasets/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ def get_augumentation(phase, width=512, height=512, min_area=0., min_visibility=
list_transforms.extend([
albu.CenterCrop(p=0.2, height=height, width=width)
])
if(phase == 'show'):
return albu.Compose(list_transforms)

list_transforms.extend([
ToTensor()
])
if(phase=='test'):
return albu.Compose(list_transforms)
return albu.Compose(list_transforms, bbox_params=albu.BboxParams(format='pascal_voc', min_area=min_area,
min_visibility=min_visibility, label_fields=['category_id']))

Expand Down
63 changes: 25 additions & 38 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,7 @@
from torchvision import transforms
import numpy as np
import skimage
class Resizer(object):
"""Convert ndarrays in sample to Tensors."""

def __call__(self, image, side=512):
rows, cols, cns = image.shape

scale = float(side)/float(max(rows, cols))
# resize the image with the computed scale
image = skimage.transform.resize(image, (int(round(rows*scale)), int(round((cols*scale)))))
rows, cols, cns = image.shape

pad_w = side-rows
pad_h = side-cols
new_image = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32)
new_image[:rows, :cols, :] = image.astype(np.float32)
return torch.from_numpy(new_image)
class Normalizer(object):

def __init__(self):
self.mean = np.array([[[0.485, 0.456, 0.406]]])
self.std = np.array([[[0.229, 0.224, 0.225]]])

def __call__(self, image):
return (image.astype(np.float32)-self.mean)/self.std
from datasets import get_augumentation

class Detect(object):
"""
Expand All @@ -38,10 +15,8 @@ def __init__(self, weights, num_class=21):
super(Detect, self).__init__()
self.weights = weights
self.device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
self.transform = transforms.Compose([
Normalizer(),
Resizer()
])
self.transform = get_augumentation(phase='test')
self.show_transform = get_augumentation(phase='show')
self.model = EfficientDet(num_classes=num_class, is_training=False)
self.model = self.model.to(self.device)
if(self.weights is not None):
Expand All @@ -53,17 +28,29 @@ def __init__(self, weights, num_class=21):

def process(self, file_name):
img = cv2.imread(file_name)
cv2.imwrite('kaka.png', img)
img = self.transform(img)

show_aug = self.show_transform(image = img)
show_image = show_aug['image']
augmentation = self.transform(image = img)
img = augmentation['image']
img = img.to(self.device)
img = img.unsqueeze(0).permute(0, 3, 1, 2)
scores, classification, transformed_anchors = self.model(img)
print('scores: ', scores)
scores = scores.detach().cpu().numpy()
idxs = np.where(scores>0.1)
return idxs
img = img.unsqueeze(0)

with torch.no_grad():
scores, classification, transformed_anchors = self.model(img)
for i in range(transformed_anchors.size(0)):
bbox = transformed_anchors[i, :]
x1 = int(bbox[0])
y1 = int(bbox[1])
x2 = int(bbox[2])
y2 = int(bbox[3])
print(x1, x2, y1, y2)
color = (255, 0, 0)
thickness = 2
cv2.rectangle(show_image, (x1, y1), (x2, y2), color, thickness)
cv2.imwrite('output.png', show_image)

if __name__=='__main__':
detect = Detect(weights = './weights/checkpoint_87.pth')
output = detect.process('/root/data/VOCdevkit/VOC2007/JPEGImages/001234.jpg')
detect = Detect(weights = './weights/checkpoint_100.pth')
output = detect.process('/root/data/VOCdevkit/VOC2007/JPEGImages/003476.jpg')
print('output: ', output)
28 changes: 25 additions & 3 deletions models/efficientdet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn

import math
from models.efficientnet import EfficientNet
from models.bifpn_v2 import BIFPN
from models.module import RegressionModel, ClassificationModel, Anchors, ClipBoxes, BBoxTransform
Expand All @@ -24,7 +24,22 @@ def __init__(self,
self.anchors = Anchors()
self.regressBoxes = BBoxTransform()
self.clipBoxes = ClipBoxes()


for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
prior = 0.01

self.classificationModel.output.weight.data.fill_(0)
self.classificationModel.output.bias.data.fill_(-math.log((1.0-prior)/prior))
self.regressionModel.output.weight.data.fill_(0)
self.regressionModel.output.bias.data.fill_(0)
self.freeze_bn()

def forward(self, inputs):
features = self.efficientnet(inputs)
features = self.BIFPN(features[2:])
Expand All @@ -37,7 +52,9 @@ def forward(self, inputs):
transformed_anchors = self.regressBoxes(anchors, regression)
transformed_anchors = self.clipBoxes(transformed_anchors, inputs)
scores = torch.max(classification, dim=2, keepdim=True)[0]
scores_over_thresh = (scores>0.05)[0, :, 0]
scores_over_thresh = torch.topk(scores, k=3, dim=1)[1][0, :, 0]

# scores_over_thresh = (scores>0.05)[0, :, 0]
if scores_over_thresh.sum() == 0:
print('No boxes to NMS')
# no boxes to NMS, just return
Expand All @@ -48,3 +65,8 @@ def forward(self, inputs):
anchors_nms_idx = nms(transformed_anchors[0, :, :], scores[0, :, 0], iou_threshold = 0.5)
nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]
def freeze_bn(self):
'''Freeze BatchNorm layers.'''
for layer in self.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.eval()
13 changes: 10 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
help='Checkpoint state_dict file to resume training from')
parser.add_argument('--start_iter', default=0, type=int,
help='Resume training at this iter')
parser.add_argument('--num_workers', default=4, type=int,
parser.add_argument('--num_workers', default=12, type=int,
help='Number of workers used in dataloading')
parser.add_argument('--cuda', default=True, type=bool,
help='Use CUDA to train model')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
help='Momentum value for optim')
Expand Down Expand Up @@ -79,8 +79,11 @@ def train():


model = model.cuda()
model = torch.nn.DataParallel(model, device_ids=[0, 1])

optimizer = optim.AdamW(model.parameters(), lr=args.lr)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

criterion = FocalLoss()
model.train()
iteration = 0
Expand All @@ -98,6 +101,7 @@ def train():

loss = classification_loss + regression_loss
if bool(loss == 0):
print('loss equal zero(0)')
continue
optimizer.zero_grad()
loss.backward()
Expand All @@ -108,7 +112,10 @@ def train():
if(iteration%100==0):
print('Epoch/Iteration: {}/{}, classification: {}, regression: {}, totol_loss: {}'.format(epoch, iteration, classification_loss.item(), regression_loss.item(), np.mean(total_loss)))
iteration+=1
scheduler.step(np.mean(total_loss))
torch.save(model.state_dict(), './weights/checkpoint_{}.pth'.format(epoch))
model.eval()
torch.save(model.state_dict(), './weights/final_weight.pth')

if __name__ == '__main__':
train()

0 comments on commit 5a08602

Please sign in to comment.