Skip to content

Commit

Permalink
add filter method for PredictionClust (#16)
Browse files Browse the repository at this point in the history
* add filter

* fix unit test

* update required packages

Co-authored-by: Michel Lang <[email protected]>
  • Loading branch information
be-marc and mllg committed Feb 14, 2022
1 parent 2e78d94 commit de94e10
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ URL: https://mlr3cluster.mlr-org.com,
BugReports: https://github.com/mlr-org/mlr3cluster/issues
Depends:
R (>= 3.1.0),
mlr3 (>= 0.10.0)
mlr3 (>= 0.13.0)
Imports:
backports (>= 1.1.10),
checkmate,
clue,
clusterCrit,
data.table,
mlr3misc (>= 0.4.0),
mlr3misc (>= 0.9.4),
paradox,
R6,
stats
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(as_task_clust,data.frame)
S3method(as_task_clust,formula)
S3method(c,PredictionDataClust)
S3method(check_prediction_data,PredictionDataClust)
S3method(filter_prediction_data,PredictionDataClust)
S3method(is_missing_prediction_data,PredictionDataClust)
export(LearnerClust)
export(LearnerClustAP)
Expand Down
16 changes: 16 additions & 0 deletions R/PredictionDataClust.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,19 @@ c.PredictionDataClust = function(..., keep_duplicates = TRUE) {

set_class(result, "PredictionDataClust")
}

#' @export
filter_prediction_data.PredictionDataClust = function(pdata, row_ids) {
keep = pdata$row_ids %in% row_ids
pdata$row_ids = pdata$row_ids[keep]

if (!is.null(pdata$partition)) {
pdata$partition = pdata$partition[keep]
}

if (!is.null(pdata$prob)) {
pdata$prob = pdata$prob[keep,, drop = FALSE]
}

pdata
}
12 changes: 12 additions & 0 deletions tests/testthat/test_PredictionClust.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@ test_that("Internally constructed Prediction", {
expect_prediction(p)
expect_prediction_clust(p)
})

test_that("filter works", {
task = tsk("usarrests")
lrn = mlr_learners$get("clust.featureless")
lrn$param_set$values = list(num_clusters = 1L)
p = lrn$train(task)$predict(task)
pdata = p$data

pdata = filter_prediction_data(pdata, row_ids = 1:3)
expect_set_equal(pdata$row_ids, 1:3)
expect_integer(pdata$partition, len = 3)
})

0 comments on commit de94e10

Please sign in to comment.