Skip to content

Commit

Permalink
Corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
HuguesTHOMAS committed Apr 29, 2020
1 parent 342abc4 commit 99f0c8b
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions utils/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug

return

def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=True):
"""
Test method for slam segmentation models
"""
Expand Down Expand Up @@ -502,12 +502,13 @@ def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=
if test_loader.dataset.set == 'validation':

# Insert false columns for ignored labels
frame_probs_uint8_bis = frame_probs_uint8.copy()
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
if label_value in test_loader.dataset.ignored_labels:
frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1)
frame_probs_uint8_bis = np.insert(frame_probs_uint8_bis, l_ind, 0, axis=1)

# Predicted labels
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8,
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8_bis,
axis=1)].astype(np.int32)

# Save some of the frame pots
Expand All @@ -528,6 +529,15 @@ def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=
[frame_points[:, :3], frame_labels, frame_preds],
['x', 'y', 'z', 'gt', 'pre'])

# Also Save lbl probabilities
probpath = join(test_path, folder, filename[:-4] + '_probs.ply')
lbl_names = [test_loader.dataset.label_to_names[l]
for l in test_loader.dataset.label_values
if l not in test_loader.dataset.ignored_labels]
write_ply(probpath,
[frame_points[:, :3], frame_probs_uint8],
['x', 'y', 'z'] + lbl_names)

# keep frame preds in memory
all_f_preds[s_ind][f_ind] = frame_preds
all_f_labels[s_ind][f_ind] = frame_labels
Expand Down Expand Up @@ -575,8 +585,8 @@ def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=
last_display = t[-1]
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%'
min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials)))
pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item()
current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num
pot_num = torch.sum(test_loader.dataset.potentials > min_pot + 0.5).type(torch.int32).item()
current_num = pot_num + (i + 1 - config.validation_size) * config.val_batch_num
print(message.format(test_epoch, i,
100 * i / config.validation_size,
1000 * (mean_dt[0]),
Expand Down

0 comments on commit 99f0c8b

Please sign in to comment.