Skip to content

Commit

Permalink
[SPARK-32457][ML] logParam thresholds in DT/GBT/FM/LR/MLP
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
logParam `thresholds` in DT/GBT/FM/LR/MLP

### Why are the changes needed?
param `thresholds` is logged in NB/RF, but not in other ProbabilisticClassifier

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing testsuites

Closes apache#29257 from zhengruifeng/instr.logParams_add_thresholds.

Authored-by: zhengruifeng <[email protected]>
Signed-off-by: Huaxin Gao <[email protected]>
  • Loading branch information
zhengruifeng authored and huaxingao committed Jul 27, 2020
1 parent c114066 commit f7542d3
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain,
maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed)
maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds)

val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ class FMClassifier @Since("3.0.0") (
def setSeed(value: Long): this.type = set(seed, value)

override protected def train(
dataset: Dataset[_]
): FMClassificationModel = instrumented { instr =>

dataset: Dataset[_]): FMClassificationModel = instrumented { instr =>
val numClasses = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
Expand All @@ -190,7 +188,7 @@ class FMClassifier @Since("3.0.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam,
miniBatchFraction, initStd, maxIter, stepSize, tol, solver)
miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds)
instr.logNumClasses(numClasses)

val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class GBTClassifier @Since("1.4.0") (
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, leafCol,
impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain,
minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds,
checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol)
checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol, thresholds)
instr.logNumClasses(numClasses)

val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ class LogisticRegression @Since("1.2.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol,
fitIntercept, blockSize)
probabilityCol, regParam, elasticNetParam, standardization, threshold, thresholds, maxIter,
tol, fitIntercept, blockSize)

val instances = extractInstances(dataset)
.setName("training instances")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, layers, maxIter,
tol, blockSize, solver, stepSize, seed)
tol, blockSize, solver, stepSize, seed, thresholds)

val myLayers = $(layers)
val labels = myLayers.last
Expand Down

0 comments on commit f7542d3

Please sign in to comment.