Skip to content

Commit

Permalink
MAHOUT-1913: Clean Up of VCL bindings closes apache#290
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmusselman committed Feb 27, 2017
1 parent 7883ebc commit 9ed9b4a
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.mahout.math.backend

import org.apache.mahout.logging._
import org.apache.mahout.math.backend.jvm.JvmBackend
import org.apache.mahout.math.scalabindings.{MMul, _}
import org.apache.mahout.math.scalabindings.{MMBinaryFunc, MMul, _}

import scala.collection._
import scala.reflect.{ClassTag, classTag}
Expand All @@ -28,7 +28,7 @@ final object RootSolverFactory extends SolverFactory {

import org.apache.mahout.math.backend.incore._

implicit val logger = getLog(RootSolverFactory.getClass)
private implicit val logger = getLog(RootSolverFactory.getClass)

private val solverTagsToScan =
classTag[MMulSolver] ::
Expand All @@ -43,42 +43,39 @@ final object RootSolverFactory extends SolverFactory {

}

////////////////////////////////////////////////////////////

// TODO: MAHOUT-1909: lazy initialze the map. Query backends. Build resolution rules.
// TODO: MAHOUT-1909: Cache Modular Backend solvers after probing
// That is, lazily initialize the map, query backends, and build resolution rules.
override protected[backend] val solverMap = new mutable.HashMap[ClassTag[_], Any]()
validateMap()

validateMap()

// default is JVM
// Default solver is JVM
var clazz: MMBinaryFunc = MMul

// eventually match on implicit Classtag . for now. just take as is.
// this is a bit hacky, Shoud not be doing onlytry/catch here..
// TODO: Match on implicit Classtag

def getOperator[C: ClassTag]: MMBinaryFunc = {

try {
// TODO: fix logging properties so that we're not mimicing as we are here.
println("[INFO] Creating org.apache.mahout.viennacl.opencl.GPUMMul solver")
logger.info("Creating org.apache.mahout.viennacl.opencl.GPUMMul solver")
clazz = Class.forName("org.apache.mahout.viennacl.opencl.GPUMMul$").getField("MODULE$").get(null).asInstanceOf[MMBinaryFunc]
println("[INFO] Successfully created org.apache.mahout.viennacl.opencl.GPUMMul solver")
logger.info("Successfully created org.apache.mahout.viennacl.opencl.GPUMMul solver")

} catch {
case x: Exception =>
println("[WARN] Unable to create class GPUMMul: attempting OpenMP version")
// println(x.getMessage)
logger.warn("Unable to create class GPUMMul: attempting OpenMP version")
try {
// attempt to instantiate the OpenMP version, assuming we’ve
// Attempt to instantiate the OpenMP version, assuming we’ve
// created a separate OpenMP-only module (none exist yet)
println("[INFO] Creating org.apache.mahout.viennacl.openmp.OMPMMul solver")
logger.info("Creating org.apache.mahout.viennacl.openmp.OMPMMul solver")
clazz = Class.forName("org.apache.mahout.viennacl.openmp.OMPMMul$").getField("MODULE$").get(null).asInstanceOf[MMBinaryFunc]
println("[INFO] Successfully created org.apache.mahout.viennacl.openmp.OMPMMul solver")
logger.info("Successfully created org.apache.mahout.viennacl.openmp.OMPMMul solver")

} catch {
case xx: Exception =>
println(xx.getMessage)
// fall back to JVM Dont need to Dynamicly assign MMul is in the same package.
println("[INFO] Unable to create class OMPMMul: falling back to java version")
logger.error(xx.getMessage)
// Fall back to JVM; don't need to dynamically assign since MMul is in the same package.
logger.info("Unable to create class OMPMMul: falling back to java version")
clazz = MMul
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class SimilarityAnalysisSuite extends FunSuite with MahoutSuite with Distributed
(0.0, 0.0, 0.6795961471815897, 0.0, 4.498681156950466))


test("Cross-occurrence [A'A], [B'A] boolbean data using LLR") {
test("Cross-occurrence [A'A], [B'A] boolean data using LLR") {
val a = dense(
(1, 1, 0, 0, 0),
(0, 0, 1, 1, 0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import scala.collection.JavaConversions._

object OMPMMul extends MMBinaryFunc {

private final implicit val log = getLog(OMPMMul.getClass)
private implicit val log = getLog(OMPMMul.getClass)

override def apply(a: Matrix, b: Matrix, r: Option[Matrix]): Matrix = {

Expand Down Expand Up @@ -209,7 +209,7 @@ object OMPMMul extends MMBinaryFunc {

@inline
private def jvmRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("jvmRWRW")
log.info("Using jvmRWRW method")
// A bit hackish: currently, this relies a bit on the fact that like produces RW(?)
val bclone = b.like(b.ncol, b.nrow).t
for (brow b) bclone(brow.index(), ::) := brow
Expand All @@ -221,12 +221,12 @@ object OMPMMul extends MMBinaryFunc {
}

private def jvmCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("jvmCWCW")
log.info("Using jvmCWCW method")
jvmRWRW(b.t, a.t, r.map(_.t)).t
}

private def jvmCWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("jvmCWRW")
log.info("Using jvmCWRW method")
// This is a primary contender with Outer Prod sum algo.
// Here, we force-reorient both matrices and run RWCW.
// A bit hackish: currently, this relies a bit on the fact that clone always produces RW(?)
Expand All @@ -240,22 +240,24 @@ object OMPMMul extends MMBinaryFunc {

// left is Sparse right is any
private def ompSparseRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("ompSparseRWRW")
log.info("Using ompSparseRWRW method")
val mxR = r.getOrElse(b.like(a.nrow, b.ncol))

// make sure that the matrix is not empty. VCL {{compressed_matrix}}s must
// hav nnz > 0
// this method is horribly inefficent. however there is a difference between
// getNumNonDefaultElements() and getNumNonZeroElements() which we do not always
// have access to created MAHOUT-1882 for this
/* Make sure that the matrix is not empty. VCL {{compressed_matrix}}s must
have nnz > 0
N.B. This method is horribly inefficent. However there is a difference between
getNumNonDefaultElements() and getNumNonZeroElements() which we do not always
have access to. We created MAHOUT-1882 for this.
*/

val hasElementsA = a.zSum() > 0.0
val hasElementsB = b.zSum() > 0.0

// A has a sparse matrix structure of unknown size. We do not want to
// simply convert it to a Dense Matrix which may result in an OOM error.
// If it is empty use JVM MMul, since we can not convert it to a VCL CSR Matrix.
if (!hasElementsA) {
println("Matrix a has zero elements can not convert to CSR")
log.warn("Matrix a has zero elements can not convert to CSR")
return MMul(a, b, r)
}

Expand All @@ -268,7 +270,7 @@ object OMPMMul extends MMBinaryFunc {
val oclC = new DenseRowMatrix(prod(oclA, oclB))
val mxC = fromVclDenseRM(oclC)
ms = System.currentTimeMillis() - ms
debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")
log.debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")

oclA.close()
oclB.close()
Expand All @@ -278,7 +280,7 @@ object OMPMMul extends MMBinaryFunc {
} else {
// Fall back to JVM based MMul if either matrix is sparse and empty
if (!hasElementsA || !hasElementsB) {
println("Matrix a or b has zero elements can not convert to CSR")
log.warn("Matrix a or b has zero elements can not convert to CSR")
return MMul(a, b, r)
}

Expand All @@ -289,7 +291,7 @@ object OMPMMul extends MMBinaryFunc {
val oclC = new CompressedMatrix(prod(oclA, oclB))
val mxC = fromVclCompressedMatrix(oclC)
ms = System.currentTimeMillis() - ms
debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")
log.debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")

oclA.close()
oclB.close()
Expand All @@ -302,15 +304,15 @@ object OMPMMul extends MMBinaryFunc {

//sparse %*% dense
private def ompSparseRowRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("ompSparseRowRWRW")
log.info("Using ompSparseRowRWRW method")
val hasElementsA = a.zSum() > 0

// A has a sparse matrix structure of unknown size. We do not want to
// simply convert it to a Dense Matrix which may result in an OOM error.
// If it is empty fall back to JVM MMul, since we can not convert it
// to a VCL CSR Matrix.
if (!hasElementsA) {
println("Matrix a has zero elements can not convert to CSR")
log.warn("Matrix a has zero elements can not convert to CSR")
return MMul(a, b, r)
}

Expand All @@ -321,7 +323,7 @@ object OMPMMul extends MMBinaryFunc {
val oclC = new DenseRowMatrix(prod(oclA, oclB))
val mxC = fromVclDenseRM(oclC)
ms = System.currentTimeMillis() - ms
debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")
log.debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")

oclA.close()
oclB.close()
Expand All @@ -339,7 +341,6 @@ object OMPMMul extends MMBinaryFunc {
private def jvmSparseRowRWCW(a: Matrix, b: Matrix, r: Option[Matrix]) =
ompSparseRowRWRW(a, b cloned, r)


private def jvmSparseRowCWRW(a: Matrix, b: Matrix, r: Option[Matrix]) =
ompSparseRowRWRW(a cloned, b, r)

Expand All @@ -356,7 +357,7 @@ object OMPMMul extends MMBinaryFunc {
ompSparseRWRW(a cloned, b cloned, r)

private def jvmDiagRW(diagm:Matrix, b:Matrix, r:Option[Matrix] = None):Matrix = {
println("jvmDiagRW")
log.info("Using jvmDiagRW method")
val mxR = r.getOrElse(b.like(diagm.nrow, b.ncol))

for (del diagm.diagv.nonZeroes())
Expand All @@ -366,7 +367,7 @@ object OMPMMul extends MMBinaryFunc {
}

private def jvmDiagCW(diagm: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("jvmDiagCW")
log.info("Using jvmDiagCW method")
val mxR = r.getOrElse(b.like(diagm.nrow, b.ncol))
for (bcol b.t) mxR(::, bcol.index()) := bcol * diagm.diagv
mxR
Expand All @@ -378,20 +379,19 @@ object OMPMMul extends MMBinaryFunc {
private def jvmRWDiag(a: Matrix, diagm: Matrix, r: Option[Matrix] = None) =
jvmDiagCW(diagm, a.t, r.map {_.t}).t


/** Dense column-wise AA' */
private def jvmDCWAAt(a:Matrix, b:Matrix, r:Option[Matrix] = None) = {
// a.t must be equiv. to b. Cloning must rewrite to row-wise.
ompDRWAAt(a.cloned,null,r)
}

/** Dense Row-wise AA' */
// we probably will not want to use this for the actual release unless A is cached already
// We probably will not want to use this for the actual release unless A is cached already
// but adding for testing purposes.
private def ompDRWAAt(a:Matrix, b:Matrix, r:Option[Matrix] = None) = {
// a.t must be equiv to b.
println("executing on OMP")
debug("AAt computation detected; passing off to OMP")
log.info("Executing on OMP")
log.debug("AAt computation detected; passing off to OMP")

// Check dimensions if result is supplied.
require(r.forall(mxR mxR.nrow == a.nrow && mxR.ncol == a.nrow))
Expand All @@ -406,7 +406,7 @@ object OMPMMul extends MMBinaryFunc {

val mxC = fromVclDenseRM(oclC)
ms = System.currentTimeMillis() - ms
debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")
log.debug(s"ViennaCL/OpenMP multiplication time: $ms ms.")

oclA.close()
//oclApr.close()
Expand All @@ -418,9 +418,9 @@ object OMPMMul extends MMBinaryFunc {
}

private def jvmOuterProdSum(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
println("jvmOuterProdSum")
// This may be already laid out for outer product computation, which may be faster than reorienting
// both matrices? need to check.
log.info("Using jvmOuterProdSum method")
// Need to check whether this is already laid out for outer product computation, which may be faster than
// reorienting both matrices.
val (m, n) = (a.nrow, b.ncol)

// Prefer col-wise result iff a is dense and b is sparse. In all other cases default to row-wise.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class ViennaCLSuiteOMP extends FunSuite with Matchers {
info(s"ViennaCL/cpu/OpenMP Sparse multiplication time: $ms ms.")

val ompMxC = fromVclCompressedMatrix(ompC)
(mxC - ompMxC).norm / mxC.nrow / mxC.ncol should be < 1e-16
(mxC - ompMxC).norm / mxC.nrow / mxC.ncol should be < 1e-10

ompA.close()
ompB.close()
Expand Down Expand Up @@ -192,7 +192,7 @@ class ViennaCLSuiteOMP extends FunSuite with Matchers {

ms = System.currentTimeMillis() - ms
info(s"ViennaCL/cpu/OpenMP dense matrix %*% dense vector multiplication time: $ms ms.")
(ompDvecC.toColMatrix - mDvecC.toColMatrix).norm / s should be < 1e-16
(ompDvecC.toColMatrix - mDvecC.toColMatrix).norm / s should be < 1e-10

ompMxA.close()
ompVecB.close()
Expand Down
Loading

0 comments on commit 9ed9b4a

Please sign in to comment.