Skip to content

Commit

Permalink
semantic_kitti_metric.py support to print best acc/iou (PaddlePaddle#391
Browse files Browse the repository at this point in the history
)
  • Loading branch information
zhangyk0314 committed Jul 19, 2023
1 parent fe09257 commit e5d0dc6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 4 additions & 3 deletions paddle3d/apis/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
self.eval_dataloader = _dataloader_build_fn(
val_dataset, self.model) if val_dataset else None
self.val_dataset = val_dataset
self.eval_metric_obj = self.val_dataset.metric

self.profiler_options = profiler_options
self.resume = resume
Expand Down Expand Up @@ -471,11 +472,11 @@ def evaluate(self) -> float:
if self.val_dataset is None:
raise RuntimeError('No evaluation dataset specified!')
msg = 'evaluate on validate dataset'
metric_obj = self.val_dataset.metric

for idx, sample in self.logger.enumerate(self.eval_dataloader, msg=msg):
result = validation_step(self.model, sample)
metric_obj.update(predictions=result, ground_truths=sample)
self.eval_metric_obj.update(
predictions=result, ground_truths=sample)

metrics = metric_obj.compute(verbose=True)
metrics = self.eval_metric_obj.compute(verbose=True)
return metrics
10 changes: 10 additions & 0 deletions paddle3d/datasets/semantic_kitti/semantic_kitti_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, num_classes: int, ignore: List[int] = None):
[n for n in range(self.num_classes) if n not in ignore],
dtype="int64")

self.best_acc_avg = 0
self.best_iou_avg = 0

# reset the class counters
self.reset()

Expand Down Expand Up @@ -115,6 +118,10 @@ def getacc(self):
def compute(self, verbose=False) -> dict:
m_accuracy = self.getacc()
m_jaccard, class_jaccard = self.getIoU()
if m_accuracy > self.best_acc_avg:
self.best_acc_avg = m_accuracy
if m_jaccard > self.best_iou_avg:
self.best_iou_avg = m_jaccard

if verbose:
logger.info("Acc avg {:.3f}".format(float(m_accuracy)))
Expand All @@ -129,5 +136,8 @@ def compute(self, verbose=False) -> dict:
SemanticKITTIDataset.LEARNING_MAP_INV[i]],
jacc=float(jacc)))

logger.info("Best Acc avg {:.3f}".format(float(self.best_acc_avg)))
logger.info("Best IoU avg {:.3f}".format(float(self.best_iou_avg)))

return dict(
mean_acc=m_accuracy, mean_iou=m_jaccard, class_iou=class_jaccard)

0 comments on commit e5d0dc6

Please sign in to comment.