forked from apache/mahout
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAHOUT-1976 Canopy Clustering closes apache#314
- Loading branch information
1 parent
9b4eabb
commit c29496c
Showing
15 changed files
with
604 additions
and
2 deletions.
There are no files selected for viewing
26 changes: 26 additions & 0 deletions
26
flink/src/test/scala/org/apache/mahout/flinkbindings/standard/ClusteringSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mahout.flinkbindings.standard | ||
|
||
import org.apache.mahout.flinkbindings.DistributedFlinkSuite | ||
import org.apache.mahout.math.algorithms.ClusteringSuiteBase | ||
import org.scalatest.FunSuite | ||
|
||
class ClusteringSuite extends FunSuite | ||
with DistributedFlinkSuite with ClusteringSuiteBase | ||
|
25 changes: 25 additions & 0 deletions
25
h2o/src/test/scala/org/apache/mahout/math/algorithms/ClusteringSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mahout.math.algorithms | ||
|
||
import org.apache.mahout.h2obindings.test.DistributedH2OSuite | ||
import org.scalatest.FunSuite | ||
|
||
class ClusteringSuite extends FunSuite | ||
with DistributedH2OSuite with ClusteringSuiteBase | ||
|
158 changes: 158 additions & 0 deletions
158
math-scala/src/main/scala/org/apache/mahout/math/algorithms/clustering/Canopy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/** | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.mahout.math.algorithms.clustering | ||
|
||
|
||
|
||
import org.apache.mahout.math.algorithms.common.distance.{DistanceMetric, DistanceMetricSelector} | ||
import org.apache.mahout.math._ | ||
import org.apache.mahout.math.drm._ | ||
import org.apache.mahout.math.drm.RLikeDrmOps._ | ||
import org.apache.mahout.math.function.VectorFunction | ||
import org.apache.mahout.math.scalabindings._ | ||
import org.apache.mahout.math.scalabindings.RLikeOps._ | ||
import org.apache.mahout.math.{Matrix, Vector} | ||
|
||
|
||
class CanopyClusteringModel(canopies: Matrix, dm: Symbol) extends ClusteringModel { | ||
|
||
val canopyCenters = canopies | ||
val distanceMetric = dm | ||
|
||
def cluster[K](input: DrmLike[K]): DrmLike[K] = { | ||
|
||
implicit val ctx = input.context | ||
implicit val ktag = input.keyClassTag | ||
|
||
val bcCanopies = drmBroadcast(canopyCenters) | ||
val bcDM = drmBroadcast(dvec(DistanceMetricSelector.namedMetricLookup(distanceMetric))) | ||
|
||
input.mapBlock(1) { | ||
case (keys, block: Matrix) => { | ||
val outputMatrix = new DenseMatrix(block.nrow, 1) | ||
|
||
val localCanopies: Matrix = bcCanopies.value | ||
for (i <- 0 until block.nrow) { | ||
val distanceMetric = DistanceMetricSelector.select(bcDM.value.get(0)) | ||
|
||
val cluster = (0 until localCanopies.nrow).foldLeft(-1, 9999999999999999.9)((l, r) => { | ||
val dist = distanceMetric.distance(localCanopies(r, ::), block(i, ::)) | ||
if ((dist) < l._2) { | ||
(r, dist) | ||
} | ||
else { | ||
l | ||
} | ||
})._1 | ||
outputMatrix(i, ::) = dvec(cluster) | ||
} | ||
keys -> outputMatrix | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
class CanopyClustering extends ClusteringFitter { | ||
|
||
var t1: Double = _ // loose distance | ||
var t2: Double = _ // tight distance | ||
var t3: Double = _ | ||
var t4: Double = _ | ||
var distanceMeasure: Symbol = _ | ||
|
||
def setStandardHyperparameters(hyperparameters: Map[Symbol, Any] = Map('foo -> None)): Unit = { | ||
t1 = hyperparameters.asInstanceOf[Map[Symbol, Double]].getOrElse('t1, 0.5) | ||
t2 = hyperparameters.asInstanceOf[Map[Symbol, Double]].getOrElse('t2, 0.1) | ||
t3 = hyperparameters.asInstanceOf[Map[Symbol, Double]].getOrElse('t3, t1) | ||
t4 = hyperparameters.asInstanceOf[Map[Symbol, Double]].getOrElse('t4, t2) | ||
|
||
distanceMeasure = hyperparameters.asInstanceOf[Map[Symbol, Symbol]].getOrElse('distanceMeasure, 'Cosine) | ||
|
||
} | ||
|
||
def fit[K](input: DrmLike[K], | ||
hyperparameters: (Symbol, Any)*): CanopyClusteringModel = { | ||
|
||
setStandardHyperparameters(hyperparameters.toMap) | ||
implicit val ctx = input.context | ||
implicit val ktag = input.keyClassTag | ||
|
||
val dmNumber = DistanceMetricSelector.namedMetricLookup(distanceMeasure) | ||
|
||
val distanceBC = drmBroadcast(dvec(t1,t2,t3,t4, dmNumber)) | ||
val canopies = input.allreduceBlock( | ||
{ | ||
|
||
// Assign All Points to Clusters | ||
case (keys, block: Matrix) => { | ||
val t1_local = distanceBC.value.get(0) | ||
val t2_local = distanceBC.value.get(1) | ||
val dm = distanceBC.value.get(4) | ||
CanopyFn.findCenters(block, DistanceMetricSelector.select(dm), t1_local, t2_local) | ||
} | ||
}, { | ||
// Optionally Merge Clusters that are close enough | ||
case (oldM: Matrix, newM: Matrix) => { | ||
val t3_local = distanceBC.value.get(2) | ||
val t4_local = distanceBC.value.get(3) | ||
val dm = distanceBC.value.get(4) | ||
CanopyFn.findCenters(oldM, DistanceMetricSelector.select(dm), t3_local, t4_local) | ||
} | ||
}) | ||
|
||
val model = new CanopyClusteringModel(canopies, distanceMeasure) | ||
model.summary = s"""CanopyClusteringModel\n${canopies.nrow} Clusters\n${distanceMeasure} distance metric used for calculating distances\nCanopy centers stored in model.canopies where row n coresponds to canopy n""" | ||
model | ||
} | ||
|
||
|
||
} | ||
|
||
object CanopyFn extends Serializable { | ||
def findCenters(block: Matrix, distanceMeasure: DistanceMetric, t1: Double, t2: Double): Matrix = { | ||
val block = dense((1.0, 1.2, 1.3, 1.4), (1.1, 1.5, 2.5, 1.0), (6.0, 5.2, -5.2, 5.3), (7.0,6.0, 5.0, 5.0), (10.0, 1.0, 20.0, -10.0)) | ||
var rowAssignedToCanopy = Array.fill(block.nrow) { false } | ||
val clusterBuf = scala.collection.mutable.ListBuffer.empty[org.apache.mahout.math.Vector] | ||
while (rowAssignedToCanopy.contains(false)) { | ||
val rowIndexOfNextUncanopiedVector = rowAssignedToCanopy.indexOf(false) | ||
clusterBuf += block(rowIndexOfNextUncanopiedVector, ::).cloned | ||
block(rowIndexOfNextUncanopiedVector, ::) = svec(Nil, cardinality = block.ncol) | ||
rowAssignedToCanopy(rowIndexOfNextUncanopiedVector) = true | ||
for (i <- 0 until block.nrow) { | ||
if (block(i, ::).getNumNonZeroElements > 0) { // | ||
distanceMeasure.distance(block(i, ::), clusterBuf.last) match { | ||
case d if d < t2 => { | ||
|
||
rowAssignedToCanopy(i) = true | ||
block(i, ::) = svec(Nil, cardinality = block.ncol) | ||
} | ||
case d if d < t1 => { | ||
|
||
rowAssignedToCanopy(i) = true | ||
} | ||
case d => {} | ||
} | ||
} | ||
} | ||
} | ||
dense(clusterBuf) | ||
} | ||
} |
45 changes: 45 additions & 0 deletions
45
math-scala/src/main/scala/org/apache/mahout/math/algorithms/clustering/ClusteringModel.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/** | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.mahout.math.algorithms.clustering | ||
|
||
import org.apache.mahout.math.algorithms.{UnsupervisedFitter, UnsupervisedModel} | ||
import org.apache.mahout.math.drm.DrmLike | ||
|
||
trait ClusteringModel extends UnsupervisedModel { | ||
|
||
def cluster[K](input: DrmLike[K]): DrmLike[K] | ||
|
||
} | ||
|
||
trait ClusteringFitter extends UnsupervisedFitter { | ||
|
||
def fit[K](input: DrmLike[K], | ||
hyperparameters: (Symbol, Any)*): ClusteringModel | ||
|
||
def fitCluster[K](input: DrmLike[K], | ||
hyperparameters: (Symbol, Any)*): DrmLike[K] = { | ||
model = this.fit(input, hyperparameters:_*) | ||
model.cluster(input) | ||
|
||
} | ||
|
||
// used to store the model if `fitTransform` method called | ||
var model: ClusteringModel = _ | ||
} |
48 changes: 48 additions & 0 deletions
48
...la/src/main/scala/org/apache/mahout/math/algorithms/common/distance/DistanceMetrics.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/** | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.mahout.math.algorithms.common.distance | ||
|
||
import org.apache.mahout.math.function.Functions | ||
import org.apache.mahout.math.{CardinalityException, Vector} | ||
|
||
trait DistanceMetric extends Serializable { | ||
def distance(v1: Vector, v2: Vector): Double | ||
} | ||
|
||
|
||
object DistanceMetricSelector extends Serializable{ | ||
|
||
val namedMetricLookup = Map('Chebyshev -> 1.0, 'Cosine -> 2.0) | ||
|
||
def select(dm: Double): DistanceMetric = { | ||
dm match { | ||
case 1.0 => Chebyshev | ||
case 2.0 => Cosine | ||
} | ||
} | ||
} | ||
|
||
object Chebyshev extends DistanceMetric { | ||
def distance(v1: Vector, v2: Vector): Double = { | ||
if (v1.size != v2.size) throw new CardinalityException(v1.size, v2.size) | ||
v1.aggregate(v2, Functions.MAX_ABS, Functions.MINUS) | ||
} | ||
} | ||
|
||
object Cosine extends DistanceMetric { | ||
def distance(v1: Vector, v2: Vector): Double = 1.0 - v1.dot(v2) / (Math.sqrt(v1.getLengthSquared) * Math.sqrt(v2.getLengthSquared)) | ||
} |
48 changes: 48 additions & 0 deletions
48
math-scala/src/test/scala/org/apache/mahout/math/algorithms/ClusteringSuiteBase.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/** | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.mahout.math.algorithms | ||
|
||
import org.apache.mahout.math.algorithms.preprocessing._ | ||
import org.apache.mahout.math.drm.drmParallelize | ||
import org.apache.mahout.math.scalabindings.{dense, sparse, svec} | ||
import org.apache.mahout.math.scalabindings.RLikeOps._ | ||
import org.apache.mahout.test.DistributedMahoutSuite | ||
import org.scalatest.{FunSuite, Matchers} | ||
|
||
import org.apache.mahout.test.DistributedMahoutSuite | ||
|
||
trait ClusteringSuiteBase extends DistributedMahoutSuite with Matchers { | ||
|
||
this: FunSuite => | ||
|
||
test("canopy test") { | ||
val drmA = drmParallelize(dense((1.0, 1.2, 1.3, 1.4), (1.1, 1.5, 2.5, 1.0), (6.0, 5.2, -5.2, 5.3), (7.0,6.0, 5.0, 5.0), (10.0, 1.0, 20.0, -10.0))) | ||
|
||
import org.apache.mahout.math.algorithms.clustering.CanopyClustering | ||
|
||
val model = new CanopyClustering().fit(drmA, 't1 -> 6.5, 't2 -> 5.5, 'distanceMeasure -> 'Chebyshev) | ||
val myAnswer = model.cluster(drmA).collect | ||
|
||
val correctAnswer = dense((0.0), (0.0), (1.0), (0.0), (2.0)) | ||
|
||
val epsilon = 1E-6 | ||
(myAnswer.norm - correctAnswer.norm) should be <= epsilon | ||
} | ||
} |
25 changes: 25 additions & 0 deletions
25
spark/src/test/scala/org/apache/mahout/math/algorithms/ClusteringSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mahout.math.algorithms | ||
|
||
import org.apache.mahout.sparkbindings.test.DistributedSparkSuite | ||
import org.scalatest.FunSuite | ||
|
||
class ClusteringSuite extends FunSuite | ||
with DistributedSparkSuite with ClusteringSuiteBase | ||
|
Oops, something went wrong.