Skip to content

Commit

Permalink
* Use Columns instead of Row for grouping (named) columns
Browse files Browse the repository at this point in the history
* Merge `.select` and `.selectRow` back into `.select` supporting both functionalities
  • Loading branch information
prolativ committed May 7, 2023
1 parent 8b34b56 commit 6a3bad8
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 94 deletions.
38 changes: 33 additions & 5 deletions src/main/Column.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
package org.virtuslab.iskra

import org.apache.spark.sql.{ Column => UntypedColumn}
import scala.quoted.*

import org.apache.spark.sql.{Column => UntypedColumn}
import types.DataType

sealed trait NamedColumns[Schema](val underlyingColumns: Seq[UntypedColumn])

object Columns:
transparent inline def apply(inline columns: NamedColumns[?]*): NamedColumns[?] = ${ applyImpl('columns) }

private def applyImpl(columns: Expr[Seq[NamedColumns[?]]])(using Quotes): Expr[NamedColumns[?]] =
import quotes.reflect.*

val columnValuesWithTypes = columns match
case Varargs(colExprs) =>
colExprs.map { arg =>
arg match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)

schemaTpe match
case '[s] =>
'{
val cols = ${ Expr.ofSeq(columnsValues) }.flatten
new NamedColumns[s](cols) {}
}

class Column[+T <: DataType](val untyped: UntypedColumn):

inline def name(using v: ValueOf[Name]): Name = v.value
Expand All @@ -29,11 +58,10 @@ object Column:
inline def &&[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.And[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ||[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Or[T1, T2]): Column[op.Out] = op(col1, col2)

trait NamedColumn[T <: DataType]:
self: Column[T] =>

@annotation.showAsInfix
class :=[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn) extends Column[T](untyped) with NamedColumn[T]
class :=[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn)
extends Column[T](untyped)
with NamedColumns[(L := T) *: EmptyTuple](Seq(untyped))

@annotation.showAsInfix
trait /[+Prefix <: Name, +Suffix <: Name]
Expand Down
1 change: 0 additions & 1 deletion src/main/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ trait DataFrame:
object DataFrame:
export Aliasing.dataFrameAliasingOps
export Select.dataFrameSelectOps
export SelectRow.dataFrameSelectRowOps
export Join.dataFrameJoinOps
export GroupBy.dataFrameGroupByOps
export Where.dataFrameWhereOps
Expand Down
8 changes: 4 additions & 4 deletions src/main/FrameSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ object FrameSchema:
case '[label := column] => true
case _ => false

def schemaTypeFromColumnTypes(colTypes: Seq[Type[?]])(using Quotes): Type[? <: Tuple] =
def schemaTypeFromColumnsTypes(colTypes: Seq[Type[?]])(using Quotes): Type[? <: Tuple] =
colTypes match
case Nil => Type.of[EmptyTuple]
case '[headTpe] :: tail =>
schemaTypeFromColumnTypes(tail) match
case '[TupleSubtype[tailTpe]] => Type.of[headTpe *: tailTpe]
case '[TupleSubtype[headTpes]] :: tail =>
schemaTypeFromColumnsTypes(tail) match
case '[TupleSubtype[tailTpes]] => Type.of[Tuple.Concat[headTpes, tailTpes]]
25 changes: 0 additions & 25 deletions src/main/Row.scala

This file was deleted.

4 changes: 2 additions & 2 deletions src/main/SchemaView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ object StructSchemaView:
def frameAliasViewsByName(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] =
import quotes.reflect.*
allPrefixedColumns(schemaType).groupBy(_._1).map { (frameName, values) =>
val columnTypes = values.map(_._2)
frameName -> refineType(TypeRepr.of[AliasedSchemaView], columnTypes)
val columnsTypes = values.map(_._2)
frameName -> refineType(TypeRepr.of[AliasedSchemaView], columnsTypes)
}.toList

def unambiguousColumns(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] =
Expand Down
18 changes: 9 additions & 9 deletions src/main/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,27 @@ object Select:

given selectOps: {} with
extension [View <: SchemaView](select: Select[View])
transparent inline def apply(inline columns: View ?=> NamedColumn[?]*): StructDataFrame[?] =
${ selectApplyImpl[View]('select, 'columns) }
transparent inline def apply(inline columns: View ?=> NamedColumns[?]*): StructDataFrame[?] =
${ applyImpl[View]('select, 'columns) }

private def selectApplyImpl[View <: SchemaView : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[Seq[View ?=> NamedColumn[?]]]) =
private def applyImpl[View <: SchemaView : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[Seq[View ?=> NamedColumns[?]]]) =
import quotes.reflect.*

val cols = columns match
val columnValuesWithTypes = columns match
case Varargs(colExprs) =>
colExprs.map { arg =>
val reduced = Term.betaReduce('{$arg(using ${ select }.view)}.asTerm).get
reduced.asExpr match
case '{ $value: v } => ('{ ${ value }.asInstanceOf[Column[?]].untyped }, Type.of[v])
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnValues = cols.map(_._1)
val columnTypes = cols.map(_._2)
val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val schemaTpe = FrameSchema.schemaTypeFromColumnTypes(columnTypes)
val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)
schemaTpe match
case '[s] =>
'{
val cols = ${ Expr.ofSeq(columnValues) }
val cols = ${ Expr.ofSeq(columnsValues) }.flatten
StructDataFrame[s](${ select }.underlying.select(cols*))
}
38 changes: 0 additions & 38 deletions src/main/SelectRow.scala

This file was deleted.

4 changes: 2 additions & 2 deletions src/main/StructDataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ object StructDataFrame:
case '[EmptyTuple] => Seq.empty
case '[head *: tail] => allColumns(Type.of[head]) ++ allColumns(Type.of[tail])

private def showColumns(columnTypes: Seq[Type[?]])(using Quotes): String =
val columns = columnTypes.map {
private def showColumns(columnsTypes: Seq[Type[?]])(using Quotes): String =
val columns = columnsTypes.map {
case '[label := dataType] =>
val shortDataType = Type.show[dataType].split("\\.").last
s"${Type.show[label]} := ${shortDataType}"
Expand Down
10 changes: 5 additions & 5 deletions src/main/WithColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ object WithColumns:

given withColumnsApply: {} with
extension [Schema, View <: SchemaView](withColumns: WithColumns[Schema, View])
transparent inline def apply[Columns](columns: View ?=> Columns): StructDataFrame[?] =
${ applyImpl[Schema, View, Columns]('withColumns, 'columns) }
transparent inline def apply[Cols](columns: View ?=> Cols): StructDataFrame[?] =
${ applyImpl[Schema, View, Cols]('withColumns, 'columns) }

def applyImpl[Schema : Type, View <: SchemaView : Type, Columns : Type](
def applyImpl[Schema : Type, View <: SchemaView : Type, Cols : Type](
withColumns: Expr[WithColumns[Schema, View]],
columns: Expr[View ?=> Columns]
columns: Expr[View ?=> Cols]
)(using Quotes): Expr[StructDataFrame[?]] =
import quotes.reflect.*
Type.of[Columns] match
Type.of[Cols] match
case '[name := colType] =>
val label = Expr(Type.valueOfConstant[name].get.toString)
'{
Expand Down
2 changes: 1 addition & 1 deletion src/main/api/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export types.{
}
export UntypedOps.typed
export org.virtuslab.iskra.$
export org.virtuslab.iskra.{Column, DataFrame, ClassDataFrame, StructDataFrame, Row, UntypedColumn, UntypedDataFrame, :=, /}
export org.virtuslab.iskra.{Column, Columns, DataFrame, ClassDataFrame, NamedColumns, StructDataFrame, UntypedColumn, UntypedDataFrame, :=, /}

object functions:
export org.virtuslab.iskra.functions.{lit, when}
Expand Down
4 changes: 2 additions & 2 deletions src/test/example/Workers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ import functions.lit
workers.as("subordinates")
.leftJoin(supervisions).on($.subordinates.id === $.subordinateId)
.leftJoin(workers.as("supervisors")).on($.supervisorId === $.supervisors.id)
.selectRow {
.select {
val salary = (lit(4732) + $.subordinates.yearsInCompany * lit(214)).as("salary")
val supervisor = ($.supervisors.firstName ++ lit(" ") ++ $.supervisors.lastName).as("supervisor")
Row($.subordinates.firstName, $.subordinates.lastName, supervisor, salary)
Columns($.subordinates.firstName, $.subordinates.lastName, supervisor, salary)
}
.where($.salary > lit(5000))
.show()
Expand Down

0 comments on commit 6a3bad8

Please sign in to comment.