From 5a0860220f672d3cb15e28487bcc5bfdb9ba7101 Mon Sep 17 00:00:00 2001 From: "v.toandm2" Date: Sat, 7 Dec 2019 01:04:23 +0700 Subject: [PATCH] Change learning rate, freeze bn, add demo, fix topk --- datasets/augmentation.py | 5 ++++ demo.py | 63 ++++++++++++++++------------------------ models/efficientdet.py | 28 ++++++++++++++++-- train.py | 13 +++++++-- 4 files changed, 65 insertions(+), 44 deletions(-) diff --git a/datasets/augmentation.py b/datasets/augmentation.py index 80cdf81..fa3da16 100644 --- a/datasets/augmentation.py +++ b/datasets/augmentation.py @@ -25,9 +25,14 @@ def get_augumentation(phase, width=512, height=512, min_area=0., min_visibility= list_transforms.extend([ albu.CenterCrop(p=0.2, height=height, width=width) ]) + if(phase == 'show'): + return albu.Compose(list_transforms) + list_transforms.extend([ ToTensor() ]) + if(phase=='test'): + return albu.Compose(list_transforms) return albu.Compose(list_transforms, bbox_params=albu.BboxParams(format='pascal_voc', min_area=min_area, min_visibility=min_visibility, label_fields=['category_id'])) diff --git a/demo.py b/demo.py index 906f858..c877eaf 100644 --- a/demo.py +++ b/demo.py @@ -5,30 +5,7 @@ from torchvision import transforms import numpy as np import skimage -class Resizer(object): - """Convert ndarrays in sample to Tensors.""" - - def __call__(self, image, side=512): - rows, cols, cns = image.shape - - scale = float(side)/float(max(rows, cols)) - # resize the image with the computed scale - image = skimage.transform.resize(image, (int(round(rows*scale)), int(round((cols*scale))))) - rows, cols, cns = image.shape - - pad_w = side-rows - pad_h = side-cols - new_image = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32) - new_image[:rows, :cols, :] = image.astype(np.float32) - return torch.from_numpy(new_image) -class Normalizer(object): - - def __init__(self): - self.mean = np.array([[[0.485, 0.456, 0.406]]]) - self.std = np.array([[[0.229, 0.224, 0.225]]]) - - def __call__(self, image): - return (image.astype(np.float32)-self.mean)/self.std +from datasets import get_augumentation class Detect(object): """ @@ -38,10 +15,8 @@ def __init__(self, weights, num_class=21): super(Detect, self).__init__() self.weights = weights self.device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') - self.transform = transforms.Compose([ - Normalizer(), - Resizer() - ]) + self.transform = get_augumentation(phase='test') + self.show_transform = get_augumentation(phase='show') self.model = EfficientDet(num_classes=num_class, is_training=False) self.model = self.model.to(self.device) if(self.weights is not None): @@ -53,17 +28,29 @@ def __init__(self, weights, num_class=21): def process(self, file_name): img = cv2.imread(file_name) - cv2.imwrite('kaka.png', img) - img = self.transform(img) + + show_aug = self.show_transform(image = img) + show_image = show_aug['image'] + augmentation = self.transform(image = img) + img = augmentation['image'] img = img.to(self.device) - img = img.unsqueeze(0).permute(0, 3, 1, 2) - scores, classification, transformed_anchors = self.model(img) - print('scores: ', scores) - scores = scores.detach().cpu().numpy() - idxs = np.where(scores>0.1) - return idxs + img = img.unsqueeze(0) + + with torch.no_grad(): + scores, classification, transformed_anchors = self.model(img) + for i in range(transformed_anchors.size(0)): + bbox = transformed_anchors[i, :] + x1 = int(bbox[0]) + y1 = int(bbox[1]) + x2 = int(bbox[2]) + y2 = int(bbox[3]) + print(x1, x2, y1, y2) + color = (255, 0, 0) + thickness = 2 + cv2.rectangle(show_image, (x1, y1), (x2, y2), color, thickness) + cv2.imwrite('output.png', show_image) if __name__=='__main__': - detect = Detect(weights = './weights/checkpoint_87.pth') - output = detect.process('/root/data/VOCdevkit/VOC2007/JPEGImages/001234.jpg') + detect = Detect(weights = './weights/checkpoint_100.pth') + output = detect.process('/root/data/VOCdevkit/VOC2007/JPEGImages/003476.jpg') print('output: ', output) diff --git a/models/efficientdet.py b/models/efficientdet.py index 09c34d6..fc01b7d 100644 --- a/models/efficientdet.py +++ b/models/efficientdet.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn - +import math from models.efficientnet import EfficientNet from models.bifpn_v2 import BIFPN from models.module import RegressionModel, ClassificationModel, Anchors, ClipBoxes, BBoxTransform @@ -24,7 +24,22 @@ def __init__(self, self.anchors = Anchors() self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() - + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + prior = 0.01 + + self.classificationModel.output.weight.data.fill_(0) + self.classificationModel.output.bias.data.fill_(-math.log((1.0-prior)/prior)) + self.regressionModel.output.weight.data.fill_(0) + self.regressionModel.output.bias.data.fill_(0) + self.freeze_bn() + def forward(self, inputs): features = self.efficientnet(inputs) features = self.BIFPN(features[2:]) @@ -37,7 +52,9 @@ def forward(self, inputs): transformed_anchors = self.regressBoxes(anchors, regression) transformed_anchors = self.clipBoxes(transformed_anchors, inputs) scores = torch.max(classification, dim=2, keepdim=True)[0] - scores_over_thresh = (scores>0.05)[0, :, 0] + scores_over_thresh = torch.topk(scores, k=3, dim=1)[1][0, :, 0] + + # scores_over_thresh = (scores>0.05)[0, :, 0] if scores_over_thresh.sum() == 0: print('No boxes to NMS') # no boxes to NMS, just return @@ -48,3 +65,8 @@ def forward(self, inputs): anchors_nms_idx = nms(transformed_anchors[0, :, :], scores[0, :, 0], iou_threshold = 0.5) nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1) return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]] + def freeze_bn(self): + '''Freeze BatchNorm layers.''' + for layer in self.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.eval() \ No newline at end of file diff --git a/train.py b/train.py index a36f06f..9698432 100644 --- a/train.py +++ b/train.py @@ -35,11 +35,11 @@ help='Checkpoint state_dict file to resume training from') parser.add_argument('--start_iter', default=0, type=int, help='Resume training at this iter') -parser.add_argument('--num_workers', default=4, type=int, +parser.add_argument('--num_workers', default=12, type=int, help='Number of workers used in dataloading') parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model') -parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, +parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, help='Momentum value for optim') @@ -79,8 +79,11 @@ def train(): model = model.cuda() + model = torch.nn.DataParallel(model, device_ids=[0, 1]) - optimizer = optim.AdamW(model.parameters(), lr=args.lr) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True) + criterion = FocalLoss() model.train() iteration = 0 @@ -98,6 +101,7 @@ def train(): loss = classification_loss + regression_loss if bool(loss == 0): + print('loss equal zero(0)') continue optimizer.zero_grad() loss.backward() @@ -108,7 +112,10 @@ def train(): if(iteration%100==0): print('Epoch/Iteration: {}/{}, classification: {}, regression: {}, totol_loss: {}'.format(epoch, iteration, classification_loss.item(), regression_loss.item(), np.mean(total_loss))) iteration+=1 + scheduler.step(np.mean(total_loss)) torch.save(model.state_dict(), './weights/checkpoint_{}.pth'.format(epoch)) + model.eval() + torch.save(model.state_dict(), './weights/final_weight.pth') if __name__ == '__main__': train()