Skip to content

Commit

Permalink
Add validation SQL to Spark Timeline pipeline
Browse files Browse the repository at this point in the history
Summary:
1. Add an option to validate Spark data pipeline via custom SQL.
2. Remove links to FB internal documentation from public scaladoc.

Reviewed By: czxttkl

Differential Revision: D19873339

fbshipit-source-id: e0ce5058bb0f6bb597ea8c1873b58ae2468e94d8
  • Loading branch information
Tian Tian Metzgar authored and facebook-github-bot committed Mar 19, 2020
1 parent efe1f83 commit 1a09f10
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 12 deletions.
31 changes: 31 additions & 0 deletions preprocessing/src/main/scala/com/facebook/spark/rl/Helper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ package com.facebook.spark.rl
import org.slf4j.LoggerFactory
import org.apache.spark.sql._

sealed abstract class TimelineExecutionFailure
sealed abstract class ValidationFailure extends TimelineExecutionFailure
case class UnexpectedOutput(message: String, result: DataFrame) extends ValidationFailure
case class ExceptionDuringValidation(e: Throwable) extends ValidationFailure

object Helper {

private val log = LoggerFactory.getLogger(this.getClass.getName)
Expand Down Expand Up @@ -74,4 +79,30 @@ object Helper {
case e: Throwable => log.error(e.toString())
}

def validateTimeline(sqlContext: SQLContext, validationSql: String): Option[ValidationFailure] =
try {
log.info(s"Executing validation query: ${validationSql}")
val validationDf = sqlContext.sql(validationSql)

val maybeError = if (validationDf.count() == 0) {
Some("query did not return any results.")
} else if (validationDf.count() > 1) {
Some("query returned more than one row.")
} else {
val everyColumnIsTrue = validationDf.first().toSeq.forall { cell =>
true == cell
}
if (everyColumnIsTrue) {
None
} else {
Some("query returned one or more non-TRUE results.")
}
}

return maybeError.map { msg =>
UnexpectedOutput(msg, validationDf)
}
} catch {
case e: Throwable => Some(ExceptionDuringValidation(e))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ case class MultiStepTimelineConfiguration(
* mdp_id, state_features, action, reward, next_state_features, next_action,
* sequence_number, sequence_number_ordinal, time_diff, possible_next_actions.
* Shuffles the results.
* Reference:
* https://our.intern.facebook.com/intern/wiki/Reinforcement-learning/
*
* Args:
* input_table: string, input table name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ object Preprocessor {
inputDf.createOrReplaceTempView(timelineConfig.inputTableName)

Timeline.run(sparkSession.sqlContext, timelineConfig)

val query = if (timelineConfig.actionDiscrete) {
Query.getDiscreteQuery(queryConfig)
} else {
Expand Down
33 changes: 24 additions & 9 deletions preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ case class TimelineConfiguration(
percentileFunction: String = "percentile_approx",
rewardColumns: List[String] = Constants.DEFAULT_REWARD_COLUMNS,
extraFeatureColumns: List[String] = Constants.DEFAULT_EXTRA_FEATURE_COLUMNS,
timeWindowLimit: Option[Long] = None
timeWindowLimit: Option[Long] = None,
validationSql: Option[String] = None
)

/**
Expand All @@ -34,8 +35,6 @@ case class TimelineConfiguration(
* mdp_id, state_features, action, reward, next_state_features, next_action,
* sequence_number, sequence_number_ordinal, time_diff, possible_next_actions.
* Shuffles the results.
* Reference:
* https://our.intern.facebook.com/intern/wiki/Reinforcement-learning/
*
* Args:
* input_table: string, input table name
Expand Down Expand Up @@ -118,11 +117,18 @@ case class TimelineConfiguration(
* possible_next_actions ( ARRAY<STRING> OR ARRAY<MAP<BIGINT,DOUBLE>> )
* A list of actions that were possible at the next step.
*
* config.validationSql (Option[String], default None).
* A SQL query to validate against a Timeline Pipeline output table where
* result should have only one row and that row contains only true booleans
* Ex: select if((select count(*) from {config.outputTableName} where mdp_id<0) == 0, TRUE, FALSE)
*/
object Timeline {

private val log = LoggerFactory.getLogger(this.getClass.getName)
def run(sqlContext: SQLContext, config: TimelineConfiguration): Unit = {
def run(
sqlContext: SQLContext,
config: TimelineConfiguration
): Unit = {
var filterTerminal = "HAVING next_state_features IS NOT NULL";
if (config.addTerminalStateRow) {
filterTerminal = "";
Expand Down Expand Up @@ -324,14 +330,23 @@ object Timeline {
.withColumn(next_col_name, coalesce(df(next_col_name), empty_placeholder))
}

val finalTableName = "finalTable"
df.createOrReplaceTempView(finalTableName)
val stagingTable = "stagingTable_" + config.outputTableName
if (sqlContext.tableNames.contains(stagingTable)) {
log.warn("RL ValidationSql staging table name collision occurred, name: " + stagingTable)
}
df.createOrReplaceTempView(stagingTable)

val maybeError = config.validationSql.flatMap { query =>
Helper.validateTimeline(sqlContext, query.replace("{config.outputTableName}", stagingTable))
}

assert(maybeError.isEmpty, "validationSql validation failure: " + maybeError)

val insertCommand = s"""
val insertCommandOutput = s"""
INSERT OVERWRITE TABLE ${config.outputTableName} PARTITION(ds='${config.endDs}')
SELECT * FROM ${finalTableName}
SELECT * FROM ${stagingTable}
""".stripMargin
sqlContext.sql(insertCommand)
sqlContext.sql(insertCommandOutput)
}

def mdpLengthThreshold(sqlContext: SQLContext, config: TimelineConfiguration): Option[Double] =
Expand Down
146 changes: 145 additions & 1 deletion preprocessing/src/test/scala/com/facebook/spark/rl/TimelineTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,151 @@ class TimelineTest extends PipelineTester {
assert(firstRow.getAs[Map[String, Double]](14) == Map("Widgets" -> 10.0))
}

test("two-state-continuous-mdp-validation-success") {
val action_discrete: Boolean = false
val sqlCtx = sqlContext
import sqlCtx.implicits._
val sparkContext = sqlCtx.sparkContext

// Setup configuration
val outputTableName = "some_rl_timeline_4"
val config = TimelineConfiguration(
startDs = "2018-01-01",
endDs = "2018-01-01",
addTerminalStateRow = false,
actionDiscrete = action_discrete,
inputTableName = "some_rl_input_4",
outputTableName = outputTableName,
evalTableName = null,
numOutputShards = 1,
validationSql =
Some("select if((select count(*) from {config.outputTableName}) == 1, TRUE, FALSE)")
)

// destroy previous schema
Helper.destroyTrainingTable(
sqlContext,
s"${config.outputTableName}"
)

// Create fake input data
val rl_input = sparkContext
.parallelize(
List(
(
"2018-01-01",
"mdp1",
1,
1.0,
Map(1001L -> 0.3, 1002L -> 0.5),
0.8,
Map(1L -> 1.0),
List(Map(1001L -> 0.3, 1002L -> 0.5), Map(1001L -> 0.6, 1002L -> 0.2)),
Map("Widgets" -> 10.0)
), // First state
(
"2018-01-01",
"mdp1",
11,
0.2,
Map.empty[Long, Double],
0.7,
Map(2L -> 1.0),
List(),
Map("Widgets" -> 20.0)
) // Second state
)
)
.toDF(
"ds",
"mdp_id",
"sequence_number",
"reward",
"action",
"action_probability",
"state_features",
"possible_actions",
"metrics"
)
rl_input.createOrReplaceTempView(config.inputTableName)

// Run the pipeline, assert made in Timeline.scala
Timeline.run(sqlContext, config)
}

test("two-state-continuous-mdp-validation-failure") {
val action_discrete: Boolean = false
val sqlCtx = sqlContext
import sqlCtx.implicits._
val sparkContext = sqlCtx.sparkContext

val outputTableName = "some_rl_timeline_4"
// Setup configuration
val config = TimelineConfiguration(
startDs = "2018-01-01",
endDs = "2018-01-01",
addTerminalStateRow = false,
actionDiscrete = action_discrete,
inputTableName = "some_rl_input_4",
outputTableName = outputTableName,
evalTableName = null,
numOutputShards = 1,
validationSql =
Some(s"select if((select count(*) from {config.outputTableName}) == 0, TRUE, FALSE)")
)

// destroy previous schema
Helper.destroyTrainingTable(
sqlContext,
s"${config.outputTableName}"
)

// Create fake input data
val rl_input = sparkContext
.parallelize(
List(
(
"2018-01-01",
"mdp1",
1,
1.0,
Map(1001L -> 0.3, 1002L -> 0.5),
0.8,
Map(1L -> 1.0),
List(Map(1001L -> 0.3, 1002L -> 0.5), Map(1001L -> 0.6, 1002L -> 0.2)),
Map("Widgets" -> 10.0)
), // First state
(
"2018-01-01",
"mdp1",
11,
0.2,
Map.empty[Long, Double],
0.7,
Map(2L -> 1.0),
List(),
Map("Widgets" -> 20.0)
) // Second state
)
)
.toDF(
"ds",
"mdp_id",
"sequence_number",
"reward",
"action",
"action_probability",
"state_features",
"possible_actions",
"metrics"
)
rl_input.createOrReplaceTempView(config.inputTableName)

intercept[AssertionError] {
Timeline.run(sqlContext, config)
}
}

test("two-state-continuous-mdp-sparse-action") {
val action_discrete: Boolean = false
val extraFeatureColumns: List[String] =
Expand Down Expand Up @@ -1357,5 +1502,4 @@ class TimelineTest extends PipelineTester {
assert(thirdRow.getAs[Map[Long, Double]](10) == Map())
assert(thirdRow.getAs[Map[Long, List[Map[Long, Double]]]](11) == Map())
}

}

0 comments on commit 1a09f10

Please sign in to comment.