Skip to content

Commit

Permalink
Bug fix in YOLOLayer.forward
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Jun 1, 2018
1 parent a606b62 commit b7cb7c4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 78 deletions.
8 changes: 4 additions & 4 deletions config/yolov3.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[net]
# Testing
batch=1
subdivisions=1
# Training
#batch=16
#batch=1
#subdivisions=1
# Training
batch=16
subdivisions=1
width=416
height=416
channels=3
Expand Down
86 changes: 38 additions & 48 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def __init__(self, anchors, num_classes, image_dim):
self.class_scale = 1
self.seen = 0

self.mse_loss = nn.MSELoss(size_average=False)
self.ce_loss = nn.CrossEntropyLoss(size_average=False)
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCELoss()
self.bce_logits_loss = nn.BCEWithLogitsLoss()

def forward(self, x, targets=None):
bs = x.size(0)
Expand All @@ -101,57 +102,48 @@ def forward(self, x, targets=None):
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor

if x.is_cuda:
self.mse_loss = self.mse_loss.cuda()
self.ce_loss = self.ce_loss.cuda()

prediction = x.view(bs, self.bbox_attrs * self.num_anchors, g_dim * g_dim)
prediction = prediction.transpose(1, 2).contiguous()
prediction = prediction.view(bs, g_dim * g_dim * self.num_anchors, self.bbox_attrs)
prediction = x.view(bs, self.num_anchors, self.bbox_attrs, g_dim, g_dim).permute(0, 1, 3, 4, 2).contiguous()

# Get outputs
x = torch.sigmoid(prediction[:, :, 0]) # Center x
y = torch.sigmoid(prediction[:, :, 1]) # Center y
w = prediction[:, :, 2] # Width
h = prediction[:, :, 3] # Height
conf = torch.sigmoid(prediction[:, :, 4]) # Conf
pred_cls = torch.sigmoid(prediction[:, :, 5:]) # Cls pred.

# Get x and y offsets for each grid
grid = np.arange(g_dim)
a, b = np.meshgrid(grid, grid)
x_offset = FloatTensor(a).view(-1, 1)
y_offset = FloatTensor(b).view(-1, 1)
x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1, self.num_anchors).view(-1, 2).unsqueeze(0)

# Scale anchors
x = torch.sigmoid(prediction[..., 0]) # Center x
y = torch.sigmoid(prediction[..., 1]) # Center y
w = prediction[..., 2] # Width
h = prediction[..., 3] # Height
conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.

# Calculate offsets for each grid
grid_x = torch.linspace(0, g_dim-1, g_dim).repeat(g_dim,1).repeat(bs*self.num_anchors, 1, 1).view(x.shape).type(FloatTensor)
grid_y = torch.linspace(0, g_dim-1, g_dim).repeat(g_dim,1).t().repeat(bs*self.num_anchors, 1, 1).view(y.shape).type(FloatTensor)
scaled_anchors = [(a_w / stride, a_h / stride) for a_w, a_h in self.anchors]
anchors = FloatTensor(scaled_anchors).repeat(g_dim * g_dim, 1).unsqueeze(0)
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, g_dim*g_dim).view(w.shape)
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, g_dim*g_dim).view(h.shape)

# Add offset and scale with anchors
pred_boxes = FloatTensor(prediction[:, :, :4].shape)
pred_boxes[:, :, 0] = x.data + x_y_offset[:, :, 0]
pred_boxes[:, :, 1] = y.data + x_y_offset[:, :, 1]
pred_boxes[:, :, 2] = torch.exp(w.data) * anchors[:, :, 0]
pred_boxes[:, :, 3] = torch.exp(h.data) * anchors[:, :, 1]
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h

self.seen += prediction.size(0)

# Training
if targets is not None:

nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls = build_targets(pred_boxes[:, :, :4].cpu().data,
if x.is_cuda:
self.mse_loss = self.mse_loss.cuda()
self.bce_loss = self.bce_loss.cuda()

nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls = build_targets(pred_boxes.cpu().data,
targets.cpu().data,
scaled_anchors,
self.num_anchors,
self.num_classes,
g_dim,
g_dim,
self.noobject_scale,
self.object_scale,
self.ignore_thres,
self.image_dim,
self.seen)
self.ignore_thres)


nProposals = int((conf > 0.25).sum().item())
Expand All @@ -161,18 +153,16 @@ def forward(self, x, targets=None):
tw = Variable(tw.type(FloatTensor), requires_grad=False)
th = Variable(th.type(FloatTensor), requires_grad=False)
tconf = Variable(tconf.type(FloatTensor), requires_grad=False)
tcls = Variable(tcls[cls_mask == 1].type(LongTensor), requires_grad=False)
tcls = Variable(tcls[cls_mask == 1].type(FloatTensor), requires_grad=False)
coord_mask = Variable(coord_mask.type(FloatTensor), requires_grad=False)
conf_mask = Variable(conf_mask.type(FloatTensor), requires_grad=False)

pred_cls = pred_cls[cls_mask.view(bs, -1) == 1]

loss_x = self.coord_scale * self.mse_loss(x.view_as(tx)*coord_mask, tx*coord_mask) / 2
loss_y = self.coord_scale * self.mse_loss(y.view_as(ty)*coord_mask, ty*coord_mask) / 2
loss_w = self.coord_scale * self.mse_loss(w.view_as(tw)*coord_mask, tw*coord_mask) / 2
loss_h = self.coord_scale * self.mse_loss(h.view_as(th)*coord_mask, th*coord_mask) / 2
loss_conf = self.mse_loss(conf.view_as(tconf)*conf_mask, tconf*conf_mask)
loss_cls = self.class_scale * self.ce_loss(pred_cls, tcls)
loss_x = self.coord_scale * self.mse_loss(x[coord_mask == 1], tx[coord_mask == 1]) / 2
loss_y = self.coord_scale * self.mse_loss(y[coord_mask == 1], ty[coord_mask == 1]) / 2
loss_w = self.coord_scale * self.mse_loss(w[coord_mask == 1], tw[coord_mask == 1]) / 2
loss_h = self.coord_scale * self.mse_loss(h[coord_mask == 1], th[coord_mask == 1]) / 2
loss_conf = self.bce_loss(conf_mask[conf_mask == 1], tconf[conf_mask == 1])
loss_cls = self.class_scale * self.bce_loss(pred_cls[cls_mask == 1], tcls)
loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

print('%d: nGT %d, recall %d, AP %.2f%% proposals %d, loss: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, 100*float(nCorrect/nGT), nProposals, loss_x.item(), loss_y.item(), loss_w.item(), loss_h.item(), loss_conf.item(), loss_cls.item(), loss.item()))
Expand All @@ -181,7 +171,7 @@ def forward(self, x, targets=None):

else:
# If not in training phase return predictions
output = torch.cat((pred_boxes * stride, conf.unsqueeze(-1), pred_cls), -1)
output = torch.cat((pred_boxes.view(bs, -1, 4) * stride, conf.view(bs, -1, 1), pred_cls.view(bs, -1, self.num_classes)), -1)
return output


Expand Down Expand Up @@ -300,4 +290,4 @@ def save_weights(self, path, cutoff=-1):
# Load conv weights
conv_layer.weight.data.cpu().numpy().tofile(fp)

fp.close()
fp.close()
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

# Initiate model
model = Darknet(opt.model_config_path)
model.load_weights(opt.weights_path)
#model.load_weights(opt.weights_path)
#model.apply(weights_init_normal)

if cuda:
Expand Down
54 changes: 29 additions & 25 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,21 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):

return output

def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, img_dim, seen):
def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, grid_dim, ignore_thres):
nB = target.size(0)
nA = num_anchors
nC = num_classes
dim = grid_dim
anchor_step = len(anchors)/num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx = torch.zeros(nB, nA, nH, nW)
ty = torch.zeros(nB, nA, nH, nW)
tw = torch.zeros(nB, nA, nH, nW)
th = torch.zeros(nB, nA, nH, nW)
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)

pred_boxes = pred_boxes.view(nB, nH, nW, nA, -1)
conf_mask = torch.ones(nB, nA, dim, dim)
coord_mask = torch.zeros(nB, nA, dim, dim)
cls_mask = torch.zeros(nB, nA, dim, dim)
tx = torch.zeros(nB, nA, dim, dim)
ty = torch.zeros(nB, nA, dim, dim)
tw = torch.zeros(nB, nA, dim, dim)
th = torch.zeros(nB, nA, dim, dim)
tconf = torch.zeros(nB, nA, dim, dim)
tcls = torch.zeros(nB, nA, dim, dim, num_classes)

for b in range(nB):
# Get sample predictions
Expand All @@ -136,14 +135,14 @@ def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW,
if target[b, t, 1] == 0:
break
# Convert to position relative to box
gx = target[b, t, 1] * nW
gy = target[b, t, 2] * nH
gw = target[b, t, 3] * nW
gh = target[b, t, 4] * nH
gx = target[b, t, 1] * dim
gy = target[b, t, 2] * dim
gw = target[b, t, 3] * dim
gh = target[b, t, 4] * dim
cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0)
cur_ious = torch.max(cur_ious, bbox_iou(cur_pred_boxes.data, cur_gt_boxes.data, x1y1x2y2=False))
# Objects with highest confidence than threshold are set to zero
conf_mask[b][cur_ious.view_as(conf_mask[b])>sil_thresh] = 0
conf_mask[b][cur_ious.view_as(conf_mask[b]) > ignore_thres] = 0

nGT = 0
nCorrect = 0
Expand All @@ -153,10 +152,10 @@ def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW,
continue
nGT = nGT + 1
# Convert to position relative to box
gx = target[b, t, 1] * nW
gy = target[b, t, 2] * nH
gw = target[b, t, 3] * nW
gh = target[b, t, 4] * nH
gx = target[b, t, 1] * dim
gy = target[b, t, 2] * dim
gw = target[b, t, 3] * dim
gh = target[b, t, 4] * dim
# Get grid box indices
gi = int(gx)
gj = int(gy)
Expand All @@ -171,23 +170,28 @@ def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW,
best_iou = anch_ious[best_n]
# Get the ground truth box and corresponding best prediction
gt_box = torch.FloatTensor(np.array([gx, gy, gw, gh])).unsqueeze(0)
pred_box = pred_boxes[b, gj, gi, best_n].unsqueeze(0)
pred_box = pred_boxes[b, best_n, gj, gi].unsqueeze(0)

# Masks
coord_mask[b][best_n][gj][gi] = 1
cls_mask[b][best_n][gj][gi] = 1
conf_mask[b][best_n][gj][gi] = object_scale
conf_mask[b][best_n][gj][gi] = 1
# Coordinates
tx[b][best_n][gj][gi] = gx - gi
ty[b][best_n][gj][gi] = gy - gj
# Width and height
tw[b][best_n][gj][gi] = math.log(gw/anchors[best_n][0] + 1e-8)
th[b][best_n][gj][gi] = math.log(gh/anchors[best_n][1] + 1e-8)
# Calculate iou between ground truth and best matching prediction
iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) # best_iou # gw * img_dim / a_w * g_dim
iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False)
tconf[b][best_n][gj][gi] = iou
tcls[b][best_n][gj][gi] = target[b, t, 0]
tcls[b][best_n][gj][gi] = to_categorical(int(target[b, t, 0]), num_classes)

if iou > 0.5:
nCorrect = nCorrect + 1

return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls

def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """
return torch.from_numpy(np.eye(num_classes, dtype='uint8')[y])

0 comments on commit b7cb7c4

Please sign in to comment.