Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into kitti
Browse files Browse the repository at this point in the history
  • Loading branch information
thangvubk committed Aug 20, 2022
2 parents 49ed66f + 481ec4c commit 4ae2386
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions softgroup/model/softgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance
@force_fp32(apply_to=('cls_scores', 'mask_scores', 'iou_scores'))
def instance_loss(self, cls_scores, mask_scores, iou_scores, proposals_idx, proposals_offset,
instance_labels, instance_pointnum, instance_cls, instance_batch_idxs):
if proposals_idx.size(0) == 0 or (instance_cls != self.ignore_label).sum() == 0:
cls_loss = cls_scores.sum() * 0
mask_loss = mask_scores.sum() * 0
iou_score_loss = iou_scores.sum() * 0
return dict(cls_loss=cls_loss, mask_loss=mask_loss, iou_score_loss=iou_score_loss)

losses = {}
proposals_idx = proposals_idx[:, 1].cuda()
proposals_offset = proposals_offset.cuda()
Expand Down Expand Up @@ -370,8 +376,12 @@ def forward_grouping(self,
if proposals_idx.size(0) > 0:
proposals_idx_list.append(proposals_idx)
proposals_offset_list.append(proposals_offset)
proposals_idx = torch.cat(proposals_idx_list, dim=0)
proposals_offset = torch.cat(proposals_offset_list)
if len(proposals_idx_list) > 0:
proposals_idx = torch.cat(proposals_idx_list, dim=0)
proposals_offset = torch.cat(proposals_offset_list)
else:
proposals_idx = torch.zeros((0, 2), dtype=torch.int32)
proposals_offset = torch.zeros((0, ), dtype=torch.int32)
return proposals_idx, proposals_offset

def forward_instance(self, inst_feats, inst_map):
Expand All @@ -393,6 +403,9 @@ def forward_instance(self, inst_feats, inst_map):
@force_fp32(apply_to=('semantic_scores', 'cls_scores', 'iou_scores', 'mask_scores'))
def get_instances(self, scan_id, proposals_idx, semantic_scores, cls_scores, iou_scores,
mask_scores):
if proposals_idx.size(0) == 0:
return []

num_instances = cls_scores.size(0)
num_points = semantic_scores.size(0)
cls_scores = cls_scores.softmax(1)
Expand Down Expand Up @@ -467,6 +480,17 @@ def clusters_voxelization(self,
scale,
spatial_shape,
rand_quantize=False):
if clusters_idx.size(0) == 0:
# create dummpy tensors
coords = torch.tensor(
[[0, 0, 0, 0], [0, spatial_shape - 1, spatial_shape - 1, spatial_shape - 1]],
dtype=torch.int,
device='cuda')
feats = feats[0:2]
voxelization_feats = spconv.SparseConvTensor(feats, coords, [spatial_shape] * 3, 1)
inp_map = feats.new_zeros((1, ), dtype=torch.long)
return voxelization_feats, inp_map

batch_idx = clusters_idx[:, 0].cuda().long()
c_idxs = clusters_idx[:, 1].cuda()
feats = feats[c_idxs.long()]
Expand Down

0 comments on commit 4ae2386

Please sign in to comment.