Skip to content

Commit

Permalink
Basic support for nested types
Browse files Browse the repository at this point in the history
  • Loading branch information
prolativ committed May 27, 2022
1 parent 984f3c3 commit 2517a1d
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 104 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ThisBuild / scalaVersion := "3.1.2"
// ThisBuild / scalaVersion := "3.2.0-RC0-bin-SNAPSHOT" // experimental code completions
// ThisBuild / scalaVersion := "3.1.2"
ThisBuild / scalaVersion := "3.2.0-RC0-bin-SNAPSHOT" // experimental code completions

val sparkVersion = "3.2.0"

Expand Down
30 changes: 18 additions & 12 deletions src/main/scala/org/virtuslab/example/HelloSpark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package org.virtuslab.example
import scala3encoders.given

import org.apache.spark.sql.SparkSession
import org.virtuslab.typedframes.types.StructType.SCons
import org.virtuslab.typedframes.types.StringType

// case class JustInt(int: Int) // int - forbidden field name in a case class? - check at compiletime
case class JustInt(i: Int)
Expand Down Expand Up @@ -41,17 +43,18 @@ object HellSpark {

val untypedInts = Seq(1, 2, 3, 4).toDF("int")
untypedInts.show()
val typedInts = untypedInts.typed[JustInt]

import org.virtuslab.typedframes.given
val typedInts = untypedInts.typed[JustInt]

val ints = Seq(1, 2, 3, 4).toTypedDF.withColumn["i"]
val ints = Seq(1, 2, 3, 4).toTypedDF("i")
ints.show()

val strings = Seq("abc", "def").toTypedDF.withColumn("abc")
val strings = Seq("abc", "def").toTypedDF("ab")
strings.show()

strings.select($.abc).show()
strings.select($.ab, $.ab.named("abcde")).select($.abcde).show()

import types.{DataType, StructType}

val foos = Seq(
Foo("aaaa", 1, 10),
Expand All @@ -65,22 +68,25 @@ object HellSpark {
foos.select(($.b + $.b)).show()
foos.select($.b, $.b).show()

val afterSelect = foos.select($.a, ($.b + $.b).named["bb"])
val afterSelect = foos.select(($.b + $.b).named("i"), $.a.named("str"))

afterSelect.show()


afterSelect.select($.bb.named["bbb"]).show()
println(afterSelect.collect[Baz1]())

// // afterSelect.select($.bc.named["bbb"]).show() // <- This won't compile


val persons = Seq(
Person(1, Name("William", "Shakespeare"))
).toTypedDF

persons.select($.name).show()

// TODOs:

// val persons = Seq(
// Person(1, Name("William", "Shakespeare"))
// ).toTypedDF
// persons.select($.name.first).show()

//persons.select($.name.first).show()
spark.stop()
}
}
9 changes: 5 additions & 4 deletions src/main/scala/org/virtuslab/typedframes/Column.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package org.virtuslab.typedframes

import org.apache.spark.sql.{ Column => UntypedColumn}
import types.DataType
import Internals.Name

object TypedColumnOpaqueScope:
opaque type TypedColumn[N <: Name, T] = UntypedColumn // Should be covariant for T?
opaque type TypedColumn[N <: Name, T <: DataType] = UntypedColumn // Should be covariant for T?

def NamedTypedColumn[N <: Name, T](underlying: UntypedColumn): TypedColumn[N, T] = underlying
def UnnamedTypedColumn[T](underlying: UntypedColumn): TypedColumn[Nothing, T] = underlying
extension [N <: Name, T](inline tc: TypedColumn[N, T])
def NamedTypedColumn[N <: Name, T <: DataType](underlying: UntypedColumn): TypedColumn[N, T] = underlying
def UnnamedTypedColumn[T <: DataType](underlying: UntypedColumn): TypedColumn[Nothing, T] = underlying
extension [N <: Name, T <: DataType](inline tc: TypedColumn[N, T])
inline def underlying: UntypedColumn = tc
inline def named[N1 <: Name](using v: ValueOf[N1]): TypedColumn[N1, T] =
NamedTypedColumn[N1, T](underlying.as(v.value))
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/org/virtuslab/typedframes/ColumnOps.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.virtuslab.typedframes

import types.*
import Internals.Name

extension [N1 <: Name](col1: TypedColumn[N1, Int])
inline def +[N2 <: Name](col2: TypedColumn[N2, Int]) = UnnamedTypedColumn[Int](col1.underlying + col2.underlying)
extension [N1 <: Name](col1: TypedColumn[N1, IntegerType])
inline def +[N2 <: Name](col2: TypedColumn[N2, IntegerType]) = UnnamedTypedColumn[IntegerType](col1.underlying + col2.underlying)

// More operations can be added easily
11 changes: 6 additions & 5 deletions src/main/scala/org/virtuslab/typedframes/DataFrame.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package org.virtuslab.typedframes

import org.apache.spark.sql.{ DataFrame => UntypedDataFrame, Encoder, SparkSession /* cdscsdc */ }
import org.apache.spark.sql.{ DataFrame => UntypedDataFrame, Encoder, SparkSession }
import types.{DataType, StructType}

object TypedDataFrameOpaqueScope:
opaque type TypedDataFrame[+S <: FrameSchema] = UntypedDataFrame
opaque type TypedDataFrame[+S <: StructType] = UntypedDataFrame
extension (inline df: UntypedDataFrame)
inline def typed[A](using schema: FrameSchema.Provider[A]): TypedDataFrame[schema.Schema] = df // TODO: Check schema at runtime? Check if names of columns match?
inline def withSchema[S <: FrameSchema]: TypedDataFrame[S] = df // TODO? make it private[typedframes]
inline def typed[A](using encoder: DataType.StructEncoder[A]): TypedDataFrame[encoder.Encoded] = df // TODO: Check schema at runtime? Check if names of columns match?
inline def withSchema[S <: StructType]: TypedDataFrame[S] = df // TODO? make it private[typedframes]

extension [S <: FrameSchema](tdf: TypedDataFrame[S])
extension [S <: StructType](tdf: TypedDataFrame[S])
inline def untyped: UntypedDataFrame = tdf

export TypedDataFrameOpaqueScope.*
40 changes: 12 additions & 28 deletions src/main/scala/org/virtuslab/typedframes/DataFrameBuilders.scala
Original file line number Diff line number Diff line change
@@ -1,45 +1,29 @@
package org.virtuslab.typedframes

import scala.quoted._
import org.apache.spark.sql.{ DataFrame => UntypedDataFrame, Encoder, SparkSession /* cdscsdc */ }
import org.apache.spark.sql
import org.apache.spark.sql.{ DataFrame => UntypedDataFrame, SparkSession }
import types.{ DataType, StructType }
import Internals.Name
import TypedDataFrameOpaqueScope.*

object TypedDataFrameBuilders:
//TODO: More inlining?
class UnnamedColumnBuilder[A <: Int | String](spark: SparkSession, seq: Seq[A]):
transparent inline def withColumn[N <: Name]: TypedDataFrame[FrameSchema] = ${toTypedDFWithNameImpl[N, A]('seq, 'spark)}
transparent inline def withColumn[N <: Name](columnName: N): TypedDataFrame[FrameSchema] = withColumn[N]
given primitiveTypeBuilderOps: {} with
extension [A <: Int | String](inline seq: Seq[A])(using typeEncoder: DataType.Encoder[A], spark: SparkSession) // TODO: Add more primitive types
transparent inline def toTypedDF[N <: Name](name: N): TypedDataFrame[StructType] = ${toTypedDFWithNameImpl[N, A, typeEncoder.Encoded]('seq, 'spark)}

given foo1: {} with
extension [A <: Int | String](inline seq: Seq[A])(using spark: SparkSession) // TODO: Add more primitive types
// TODO decide on/unify naming
// transparent inline def toTypedDF[N <: Name]: TypedDataFrame[FrameSchema] = ${toTypedDFWithNameImpl[N, A]('seq, 'spark)}
inline def toTypedDF: UnnamedColumnBuilder[A] = new UnnamedColumnBuilder[A](spark, seq)

// given foo2: {} with
// extension [A <: Int | String](inline seq: Seq[A])(using spark: SparkSession) // TODO: Add more primitive types
// // transparent inline def toTypedDF[N <: Name](columnName: N): TypedDataFrame[FrameSchema] = seq.toTypedDF[N]
// transparent inline def toTypedDFNamed[N <: Name](columnName: N): TypedDataFrame[FrameSchema] = seq.toTypedDF[N]

private def toTypedDFWithNameImpl[N <: Name : Type, A : Type](using Quotes)(seq: Expr[Seq[A]], spark: Expr[SparkSession]): Expr[TypedDataFrame[FrameSchema/* TableSchema */]] =
private def toTypedDFWithNameImpl[N <: Name : Type, A : Type, E <: DataType : Type](using Quotes)(seq: Expr[Seq[A]], spark: Expr[SparkSession]): Expr[TypedDataFrame[StructType/* TableSchema */]] =
'{
val s = $spark
given Encoder[A] = ${ Expr.summon[Encoder[A]].get }
given sql.Encoder[A] = ${ Expr.summon[sql.Encoder[A]].get }
import s.implicits.*
localSeqToDatasetHolder($seq).toDF(valueOf[N]).withSchema[FrameSchema.WithSingleColumn[N, A]]
localSeqToDatasetHolder($seq).toDF(valueOf[N]).withSchema[StructType.WithSingleColumn[N, E]]
}

given foo3: {} with
extension [A](inline seq: Seq[A])(using schema: FrameSchema.Provider[A])(using encoder: Encoder[A], spark: SparkSession)
inline def toTypedDF: TypedDataFrame[schema.Schema] =
given structTypeBuilderOps: {} with
extension [A](inline seq: Seq[A])(using typeEncoder: DataType.StructEncoder[A], runtimeEncoder: sql.Encoder[A], spark: SparkSession)
inline def toTypedDF: TypedDataFrame[typeEncoder.Encoded] =
import spark.implicits.*
seq.toDF(/* Should we explicitly pass columns here? */).typed

// given foo4: {} with
// extension [A](inline seq: Seq[A])(using schema: FrameSchema.Provider[A])(using encoder: Encoder[A], spark: SparkSession)
// inline def crash: TypedDataFrame[schema.Schema] =
// import spark.implicits.*
// seq.toDF(/* Should we explicitly pass columns here? */).typed

export TypedDataFrameBuilders.given
9 changes: 5 additions & 4 deletions src/main/scala/org/virtuslab/typedframes/DataFrameOps.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package org.virtuslab.typedframes

import org.apache.spark.sql.Encoder
import org.apache.spark.sql
import types.{DataType, StructType}

extension [S <: FrameSchema](inline tdf: TypedDataFrame[S])
extension [S <: StructType](inline tdf: TypedDataFrame[S])
inline def show(): Unit = tdf.untyped.show()

// TODO: check schema conformance instead of equality
inline def collect[T]()(using e: Encoder[T], fsf: FrameSchema.Provider[T], ev: fsf.Schema =:= S): List[T] =
tdf.untyped.as[T].collect.toList
inline def collect[A]()(using typeEncoder: DataType.StructEncoder[A], runtimeEncoder: sql.Encoder[A], eq: typeEncoder.Encoded =:= S): List[A] =
tdf.untyped.as[A].collect.toList
75 changes: 42 additions & 33 deletions src/main/scala/org/virtuslab/typedframes/FrameSchema.scala
Original file line number Diff line number Diff line change
@@ -1,46 +1,55 @@
package org.virtuslab.typedframes
// package org.virtuslab.typedframes

import scala.quoted._
import scala.deriving.Mirror
import scala.compiletime.erasedValue
import Internals.{Name, NameLike}
// import scala.quoted._
// import scala.deriving.Mirror
// import scala.compiletime.erasedValue
// import types.DataType
// import Internals.{Name, Name.Subtype}

sealed trait FrameSchema:
//private def elems
override def toString =
val elems = FrameSchema.elems(this).map((label, typeName) => s"${label}: ${typeName}").mkString(", ")
s"FrameSchema {${elems}}"
// sealed trait FrameSchema:
// //private def elems
// override def toString =
// val elems = FrameSchema.elems(this).map((label, typeName) => s"${label}: ${typeName}").mkString(", ")
// s"FrameSchema {${elems}}"


object FrameSchema:
object SNil extends FrameSchema//:
type SNil = SNil.type
final case class SCons[N <: Name, H, T <: FrameSchema](headLabel: N, headTypeName: String, tail: T) extends FrameSchema//:
private def elems(schema: FrameSchema): List[(String, String)] = schema match
case SNil => Nil
case SCons(headLabel, headTypeName, tail) => (headLabel, headTypeName) :: elems(tail)
// object FrameSchema:
// object SNil extends FrameSchema//:
// type SNil = SNil.type
// final case class SCons[N <: Name, H <: DataType, T <: FrameSchema](headLabel: N, headTypeName: String, tail: T) extends FrameSchema//:
// private def elems(schema: FrameSchema): List[(String, String)] = schema match
// case SNil => Nil
// case SCons(headLabel, headTypeName, tail) => (headLabel, headTypeName) :: elems(tail)

type FromLabelsAndTypes[Ls <: Tuple, Ts <: Tuple] <: FrameSchema = Ls match
case NameLike[elemLabel] *: elemLabels => Ts match
case elemType *: elemTypes =>
SCons[elemLabel, elemType, FromLabelsAndTypes[elemLabels, elemTypes]]
case EmptyTuple => SNil
// type FromLabelsAndTypes[Ls <: Tuple, Ts <: Tuple] <: FrameSchema = Ls match
// case Name.Subtype[elemLabel] *: elemLabels => Ts match
// case elemType *: elemTypes =>
// SCons[elemLabel, elemType, FromLabelsAndTypes[elemLabels, elemTypes]]
// case EmptyTuple => SNil

type WithSingleColumn[N <: Name, ColType] = FromLabelsAndTypes[N *: EmptyTuple, ColType *: EmptyTuple]
// type WithSingleColumn[N <: Name, ColType <: DataType] = FromLabelsAndTypes[N *: EmptyTuple, ColType *: EmptyTuple]

trait Provider[A]:
type Schema <: FrameSchema
// trait Provider[A]:
// type Schema <: FrameSchema

object Provider:
transparent inline given fromMirror[A](using m: Mirror.ProductOf[A]): Provider[A] = new Provider[A]:
type Schema = FromLabelsAndTypes[m.MirroredElemLabels, m.MirroredElemTypes]
// object Provider:
// transparent inline given fromMirror[A](using m: Mirror.ProductOf[A]): Provider[A] = new Provider[A]:
// type Schema = FromLabelsAndTypes[m.MirroredElemLabels, m.MirroredElemTypes]

inline def of[A](using p: Provider[A]): p.Schema = instance[p.Schema]
// inline def of[A](using p: Provider[A]): p.Schema = instance[p.Schema]

inline def instance[Schema <: FrameSchema]: Schema = inline erasedValue[Schema] match
case _: SNil.type => SNil.asInstanceOf[Schema]
case _: SCons[headLabel, headType, tail] =>
(new SCons[headLabel, headType, tail](valueOf[headLabel], getTypeName[headType], instance[tail])).asInstanceOf[Schema]
// inline def instance[Schema <: FrameSchema]: Schema = inline erasedValue[Schema] match
// case _: SNil.type => SNil.asInstanceOf[Schema]
// case _: SCons[headLabel, headType, tail] =>
// (new SCons[headLabel, headType, tail](valueOf[headLabel], getTypeName[headType], instance[tail])).asInstanceOf[Schema]


/////////

//type FrameSchema = types.StructType


/////////////////////


// TODO: Conformance should be recursive
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/org/virtuslab/typedframes/Internals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ package org.virtuslab.typedframes

private object Internals:
type Name = String & Singleton
type NameLike[T <: Name] = T
object Name:
type Subtype[T <: Name] = T
15 changes: 8 additions & 7 deletions src/main/scala/org/virtuslab/typedframes/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package org.virtuslab.typedframes

import scala.quoted.*
import org.apache.spark.sql.{ Column => UntypedColumn}
import types.{ DataType, StructType }
import Internals.Name

trait SelectCtx extends SparkOpCtx:
type CtxOut <: SelectionView

extension [S <: FrameSchema](inline tdf: TypedDataFrame[S])(using svp: SelectionView.Provider[S])
extension [S <: StructType](inline tdf: TypedDataFrame[S])(using svp: SelectionView.Provider[S])
inline def select[T](f: SelectCtx { type CtxOut = svp.View } ?=> T)(using mc: MergeColumns[T]): TypedDataFrame[mc.MergedSchema] =
val ctx: SelectCtx { type CtxOut = svp.View } = new SelectCtx {
type CtxOut = svp.View
Expand All @@ -18,7 +19,7 @@ extension [S <: FrameSchema](inline tdf: TypedDataFrame[S])(using svp: Selection
(tdf.untyped.select(columns*)).withSchema[mc.MergedSchema]

trait MergeColumns[T]:
type MergedSchema <: FrameSchema
type MergedSchema <: StructType
def columns(t: T): List[UntypedColumn]

object MergeColumns:
Expand All @@ -34,18 +35,18 @@ object MergeColumns:

def mergeTupleColumnsImpl[T <: Tuple : Type](using Quotes): Expr[MergeColumns[T]] =
'{
type s = FrameSchema.FromLabelsAndTypes[ColumnsNames[T], ColumnsTypes[T]]
type s = StructType.FromLabelsAndTypes[ColumnsNames[T], ColumnsTypes[T]]
new MergeColumns[T] {
type MergedSchema = s
def columns(t: T): List[UntypedColumn] = t.toList.map(col => col.asInstanceOf[TypedColumn[Name, Any]].underlying) //List(t.underlying)
def columns(t: T): List[UntypedColumn] = t.toList.map(col => col.asInstanceOf[TypedColumn[Name, DataType]].underlying) //List(t.underlying)
}.asInstanceOf[MergeColumns[T] {type MergedSchema = s}]
}

transparent inline given mergeSingleColumn[N <: Name, A]: MergeColumns[TypedColumn[N, A]] = ${ mergeSingleColumnImpl[N, A] }
transparent inline given mergeSingleColumn[N <: Name, A <: DataType]: MergeColumns[TypedColumn[N, A]] = ${ mergeSingleColumnImpl[N, A] }

def mergeSingleColumnImpl[N <: Name : Type, A : Type](using Quotes): Expr[MergeColumns[TypedColumn[N, A]]] =
def mergeSingleColumnImpl[N <: Name : Type, A <: DataType : Type](using Quotes): Expr[MergeColumns[TypedColumn[N, A]]] =
'{
type s = FrameSchema.WithSingleColumn[N, A]
type s = StructType.WithSingleColumn[N, A]
new MergeColumns[TypedColumn[N, A]] {
type MergedSchema = s
def columns(t: TypedColumn[N, A]): List[UntypedColumn] = List(t.underlying)
Expand Down
13 changes: 7 additions & 6 deletions src/main/scala/org/virtuslab/typedframes/SelectionView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@ package org.virtuslab.typedframes

import scala.quoted._
import org.apache.spark.sql.functions.col
import types.{DataType, StructType}

class SelectionView extends Selectable:
def selectDynamic(name: String): TypedColumn[Nothing, Any] =
UnnamedTypedColumn[Any](col(name))
def selectDynamic(name: String): TypedColumn[Nothing, DataType] =
UnnamedTypedColumn[DataType](col(name))

object SelectionView:
trait Provider[A <: FrameSchema]:
trait Provider[A <: StructType]:
type View <: SelectionView
def view: View

object Provider:
transparent inline given selectionViewFor[A <: FrameSchema]: Provider[A] = ${selectionViewForImpl[A]}
transparent inline given selectionViewFor[A <: StructType]: Provider[A] = ${selectionViewForImpl[A]}

def selectionViewForImpl[A <: FrameSchema : Type](using Quotes): Expr[SelectionView.Provider[A]] =
def selectionViewForImpl[A <: StructType : Type](using Quotes): Expr[SelectionView.Provider[A]] =
import quotes.reflect.*

selectionView(TypeRepr.of[SelectionView], Type.of[A]).asType match
Expand All @@ -28,8 +29,8 @@ object SelectionView:
}

private def selectionView(using Quotes)(base: quotes.reflect.TypeRepr, frameType: Type[?]): quotes.reflect.TypeRepr =
import FrameSchema.*
import quotes.reflect.*
import StructType.{SNil, SCons}

frameType match
case '[SNil] => base
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/org/virtuslab/typedframes/ShowType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def showTypeImpl[A : Type](using quotes: Quotes): Expr[Unit] =
println("*********")
println(TypeRepr.of[A].show)
println(TypeRepr.of[A].widen.show)
println(TypeRepr.of[A].dealias.show)
'{()}


Expand Down
Loading

0 comments on commit 2517a1d

Please sign in to comment.