Skip to content

Commit

Permalink
a stub for training script
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Trusov committed Sep 15, 2017
1 parent aafc66e commit e15d73a
Showing 1 changed file with 88 additions and 206 deletions.
294 changes: 88 additions & 206 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,207 +1,89 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
padding=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class ResNet(nn.Module):
def __init__(self, block, layers):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)

layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x_3 = self.layer3(x)
x = self.layer4(x_3)

return x, x_3


class PSPModule(nn.Module):
def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
super().__init__()
self.stages = []
self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
self.relu = nn.ReLU()

@staticmethod
def _make_stage(features, size):
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
return nn.Sequential(prior, conv)

def forward(self, feats):
h, w = feats.size(2), feats.size(3)
priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats]
bottle = self.bottleneck(torch.cat(priors, 1))
return self.relu(bottle)


class PSPUpsample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.PReLU()
)

def forward(self, x):
h, w = 2 * x.size(2), 2 * x.size(3)
p = F.upsample(input=x, size=(h, w), mode='bilinear')
return self.conv(p)


class PSPNetGradual(nn.Module):
def __init__(self, n_classes=18, sizes=(1, 2, 3, 6)):
super().__init__()
self.feats = ResNet(Bottleneck, [3, 4, 23, 3])
self.psp = PSPModule(2048, 1024, sizes)
self.drop_1 = nn.Dropout2d(p=0.3)

self.up_1 = PSPUpsample(1024, 256)
self.up_2 = PSPUpsample(256, 64)
self.up_3 = PSPUpsample(64, 64)

self.drop_2 = nn.Dropout2d(p=0.15)
self.final = nn.Sequential(
nn.Conv2d(64, n_classes, kernel_size=1),
nn.LogSoftmax()
)

def forward(self, x):
f = self.feats(x)
p = self.psp(f)
p = self.drop_1(p)

p = self.up_1(p)
p = self.drop_2(p)

p = self.up_2(p)
p = self.drop_2(p)

p = self.up_3(p)
p = self.drop_2(p)

return self.final(p)


class PSPNet(nn.Module):
def __init__(self, n_classes=18, sizes=(1, 2, 3, 6)):
super().__init__()
self.feats = ResNet(Bottleneck, [3, 4, 23, 3])
self.psp = PSPModule(2048, 1024, sizes)
self.drop_1 = nn.Dropout2d(p=0.3)

self.up_1 = PSPUpsample(1024, 256)
self.up_2 = PSPUpsample(256, 64)
self.up_3 = PSPUpsample(64, 64)

self.drop_2 = nn.Dropout2d(p=0.15)
self.final = nn.Sequential(
nn.Conv2d(64, n_classes, kernel_size=1),
nn.LogSoftmax()
)

self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Linear(256, n_classes)
)

def forward(self, x):
f, class_f = self.feats(x) # class_f has 1024 channels and is 8x downsampled
p = self.psp(f)
p = self.drop_1(p)

p = self.up_1(p)
p = self.drop_2(p)

p = self.up_2(p)
p = self.drop_2(p)

p = self.up_3(p)
p = self.drop_2(p)

auxiliary = F.adaptive_max_pool2d(input=class_f, output_size=(1, 1)).view(-1, 1024)

return self.final(p), self.classifier(auxiliary)
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader

from pspnet import PSPNet

import logging
import click
import os
import numpy as np

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)


def weights_log(class_freq):
weights = torch.log1p(1 / class_freq)
return weights / torch.sum(weights)


def lr_poly(base_lr, epoch, max_epoch, power):
return max(0.00001, base_lr * np.power(1. - epoch / max_epoch, power))


def build_network(snapshot):
epoch = 0
net = PSPNet()
net = nn.DataParallel(net)
if snapshot is not None:
_, epoch = os.path.basename(snapshot).split('_')
epoch = int(epoch)
net.load_state_dict(torch.load(snapshot))
logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot))
net = net.cuda()
return net, epoch


@click.command()
@click.option('--data-path', type=str, help='Path to dataset with directories imgs/ maps/')
@click.option('--models-path', type=str, help='Path for storing model snapshots')
@click.option('--snapshot', type=str, default=None, help='Path to pretrained weights')
@click.option('--crop_x', type=int, default=200)
@click.option('--crop_y', type=int, default=300)
@click.option('--batch-size', type=int, default=1)
@click.option('--alpha', type=float, default=5.0, help='Coefficient for classification loss term')
@click.option('--epochs', type=int, default=20, help='Number of training epochs to run')
@click.option('--gpu', type=str, default='0')
@click.option('--start-lr', type=float, default=0.01)
@click.option('--lr-power', type=float, default=0.9)
def train(data_path, models_path, snapshot, crop_x, crop_y, batch_size, alpha, epochs, start_lr, lr_power, gpu):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
net, starting_epoch = build_network(snapshot)
steps = 0

for epoch in range(starting_epoch, starting_epoch + epochs):

# You have to load all this stuff by yourself
# class_weights is simply a 1d normalized Tensor
# n_images is used to calculate the "poly" LR

loader, class_weights, n_images = None, None, None

n_images *= epochs
seg_criterion = nn.NLLLoss2d(weight=class_weights.cuda())
cls_criterion = nn.BCEWithLogitsLoss(weight=class_weights)
epoch_losses = []
for x, y, y_cls in loader:
steps += batch_size
lr = lr_poly(start_lr, steps, n_images, lr_power)
optimizer = optim.Adam(net.parameters(), lr=lr)
optimizer.zero_grad()
x = Variable(x).cuda()
y = Variable(y).cuda()
y_cls = Variable(y_cls).cuda()
out, out_cls = net(x)
seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(out_cls, y_cls)
loss = seg_loss + alpha * cls_loss
logging.info(
'Step {4}/{5} : Seg loss = {0:0.5f}, Cls loss = {1:0.5f}, Total = {2:0.5f}, LR = {3:0.5f}'.format(
seg_loss.data[0], cls_loss.data[0], loss.data[0], lr, steps, n_images))
loss.backward()
optimizer.step()
logging.info('Epoch = {0}, Loss = {1:0.5f}'.format(epoch, np.mean(epoch_losses)))
torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch)])))


if __name__ == '__main__':
train()

0 comments on commit e15d73a

Please sign in to comment.