-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Roman Trusov
committed
Sep 15, 2017
1 parent
aafc66e
commit e15d73a
Showing
1 changed file
with
88 additions
and
206 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,207 +1,89 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import math | ||
|
||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): | ||
super(Bottleneck, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, | ||
padding=dilation, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * 4) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet(nn.Module): | ||
def __init__(self, block, layers): | ||
self.inplanes = 64 | ||
super(ResNet, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | ||
bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(block, 64, layers[0]) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [block(self.inplanes, planes, stride, downsample)] | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes, dilation=dilation)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x_3 = self.layer3(x) | ||
x = self.layer4(x_3) | ||
|
||
return x, x_3 | ||
|
||
|
||
class PSPModule(nn.Module): | ||
def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): | ||
super().__init__() | ||
self.stages = [] | ||
self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) | ||
self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) | ||
self.relu = nn.ReLU() | ||
|
||
@staticmethod | ||
def _make_stage(features, size): | ||
prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) | ||
conv = nn.Conv2d(features, features, kernel_size=1, bias=False) | ||
return nn.Sequential(prior, conv) | ||
|
||
def forward(self, feats): | ||
h, w = feats.size(2), feats.size(3) | ||
priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats] | ||
bottle = self.bottleneck(torch.cat(priors, 1)) | ||
return self.relu(bottle) | ||
|
||
|
||
class PSPUpsample(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super().__init__() | ||
self.conv = nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, 3, padding=1), | ||
nn.BatchNorm2d(out_channels), | ||
nn.PReLU() | ||
) | ||
|
||
def forward(self, x): | ||
h, w = 2 * x.size(2), 2 * x.size(3) | ||
p = F.upsample(input=x, size=(h, w), mode='bilinear') | ||
return self.conv(p) | ||
|
||
|
||
class PSPNetGradual(nn.Module): | ||
def __init__(self, n_classes=18, sizes=(1, 2, 3, 6)): | ||
super().__init__() | ||
self.feats = ResNet(Bottleneck, [3, 4, 23, 3]) | ||
self.psp = PSPModule(2048, 1024, sizes) | ||
self.drop_1 = nn.Dropout2d(p=0.3) | ||
|
||
self.up_1 = PSPUpsample(1024, 256) | ||
self.up_2 = PSPUpsample(256, 64) | ||
self.up_3 = PSPUpsample(64, 64) | ||
|
||
self.drop_2 = nn.Dropout2d(p=0.15) | ||
self.final = nn.Sequential( | ||
nn.Conv2d(64, n_classes, kernel_size=1), | ||
nn.LogSoftmax() | ||
) | ||
|
||
def forward(self, x): | ||
f = self.feats(x) | ||
p = self.psp(f) | ||
p = self.drop_1(p) | ||
|
||
p = self.up_1(p) | ||
p = self.drop_2(p) | ||
|
||
p = self.up_2(p) | ||
p = self.drop_2(p) | ||
|
||
p = self.up_3(p) | ||
p = self.drop_2(p) | ||
|
||
return self.final(p) | ||
|
||
|
||
class PSPNet(nn.Module): | ||
def __init__(self, n_classes=18, sizes=(1, 2, 3, 6)): | ||
super().__init__() | ||
self.feats = ResNet(Bottleneck, [3, 4, 23, 3]) | ||
self.psp = PSPModule(2048, 1024, sizes) | ||
self.drop_1 = nn.Dropout2d(p=0.3) | ||
|
||
self.up_1 = PSPUpsample(1024, 256) | ||
self.up_2 = PSPUpsample(256, 64) | ||
self.up_3 = PSPUpsample(64, 64) | ||
|
||
self.drop_2 = nn.Dropout2d(p=0.15) | ||
self.final = nn.Sequential( | ||
nn.Conv2d(64, n_classes, kernel_size=1), | ||
nn.LogSoftmax() | ||
) | ||
|
||
self.classifier = nn.Sequential( | ||
nn.Linear(1024, 256), | ||
nn.ReLU(), | ||
nn.Linear(256, n_classes) | ||
) | ||
|
||
def forward(self, x): | ||
f, class_f = self.feats(x) # class_f has 1024 channels and is 8x downsampled | ||
p = self.psp(f) | ||
p = self.drop_1(p) | ||
|
||
p = self.up_1(p) | ||
p = self.drop_2(p) | ||
|
||
p = self.up_2(p) | ||
p = self.drop_2(p) | ||
|
||
p = self.up_3(p) | ||
p = self.drop_2(p) | ||
|
||
auxiliary = F.adaptive_max_pool2d(input=class_f, output_size=(1, 1)).view(-1, 1024) | ||
|
||
return self.final(p), self.classifier(auxiliary) | ||
from torch import nn | ||
from torch import optim | ||
from torch.autograd import Variable | ||
from torch.utils.data import DataLoader | ||
|
||
from pspnet import PSPNet | ||
|
||
import logging | ||
import click | ||
import os | ||
import numpy as np | ||
|
||
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) | ||
|
||
|
||
def weights_log(class_freq): | ||
weights = torch.log1p(1 / class_freq) | ||
return weights / torch.sum(weights) | ||
|
||
|
||
def lr_poly(base_lr, epoch, max_epoch, power): | ||
return max(0.00001, base_lr * np.power(1. - epoch / max_epoch, power)) | ||
|
||
|
||
def build_network(snapshot): | ||
epoch = 0 | ||
net = PSPNet() | ||
net = nn.DataParallel(net) | ||
if snapshot is not None: | ||
_, epoch = os.path.basename(snapshot).split('_') | ||
epoch = int(epoch) | ||
net.load_state_dict(torch.load(snapshot)) | ||
logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot)) | ||
net = net.cuda() | ||
return net, epoch | ||
|
||
|
||
@click.command() | ||
@click.option('--data-path', type=str, help='Path to dataset with directories imgs/ maps/') | ||
@click.option('--models-path', type=str, help='Path for storing model snapshots') | ||
@click.option('--snapshot', type=str, default=None, help='Path to pretrained weights') | ||
@click.option('--crop_x', type=int, default=200) | ||
@click.option('--crop_y', type=int, default=300) | ||
@click.option('--batch-size', type=int, default=1) | ||
@click.option('--alpha', type=float, default=5.0, help='Coefficient for classification loss term') | ||
@click.option('--epochs', type=int, default=20, help='Number of training epochs to run') | ||
@click.option('--gpu', type=str, default='0') | ||
@click.option('--start-lr', type=float, default=0.01) | ||
@click.option('--lr-power', type=float, default=0.9) | ||
def train(data_path, models_path, snapshot, crop_x, crop_y, batch_size, alpha, epochs, start_lr, lr_power, gpu): | ||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu | ||
net, starting_epoch = build_network(snapshot) | ||
steps = 0 | ||
|
||
for epoch in range(starting_epoch, starting_epoch + epochs): | ||
|
||
# You have to load all this stuff by yourself | ||
# class_weights is simply a 1d normalized Tensor | ||
# n_images is used to calculate the "poly" LR | ||
|
||
loader, class_weights, n_images = None, None, None | ||
|
||
n_images *= epochs | ||
seg_criterion = nn.NLLLoss2d(weight=class_weights.cuda()) | ||
cls_criterion = nn.BCEWithLogitsLoss(weight=class_weights) | ||
epoch_losses = [] | ||
for x, y, y_cls in loader: | ||
steps += batch_size | ||
lr = lr_poly(start_lr, steps, n_images, lr_power) | ||
optimizer = optim.Adam(net.parameters(), lr=lr) | ||
optimizer.zero_grad() | ||
x = Variable(x).cuda() | ||
y = Variable(y).cuda() | ||
y_cls = Variable(y_cls).cuda() | ||
out, out_cls = net(x) | ||
seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(out_cls, y_cls) | ||
loss = seg_loss + alpha * cls_loss | ||
logging.info( | ||
'Step {4}/{5} : Seg loss = {0:0.5f}, Cls loss = {1:0.5f}, Total = {2:0.5f}, LR = {3:0.5f}'.format( | ||
seg_loss.data[0], cls_loss.data[0], loss.data[0], lr, steps, n_images)) | ||
loss.backward() | ||
optimizer.step() | ||
logging.info('Epoch = {0}, Loss = {1:0.5f}'.format(epoch, np.mean(epoch_losses))) | ||
torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch)]))) | ||
|
||
|
||
if __name__ == '__main__': | ||
train() |