Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
keynmol committed Sep 13, 2024
0 parents commit 6feefd3
Show file tree
Hide file tree
Showing 32 changed files with 1,652 additions and 0 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: CI
on:
push:
branches: ["main"]
tags: ["v*"]
pull_request:
branches: ["*"]

jobs:
build:
strategy:
fail-fast: false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: coursier/cache-action@v6
- uses: VirtusLab/scala-cli-setup@main
with:
power: true

- name: Check formatting
run: make code-check || echo "Run `make pre-ci`"

- name: Test
run: make test

# - name: Check documentation compiles and runs
# run: make check-docs && make run-example

- name: Publish snapshot
if: github.ref == 'refs/heads/main'
run: make publish-snapshot
env:
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}

- name: Publish
if: startsWith(github.ref, 'refs/tags/v')
run: make publish
env:
PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
PGP_SECRET: ${{ secrets.PGP_SECRET }}
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}

33 changes: 33 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*

node_modules
dist
dist-ssr
*.local

# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

.class
.scala-build
.metals
.bsp
scalajs-frontend.js
*.semanticdb
db.json
*.map
3 changes: 3 additions & 0 deletions .scalafix.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
OrganizeImports.groupedImports = Merge
OrganizeImports.targetDialect = Scala3
OrganizeImports.removeUnused = true
9 changes: 9 additions & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

version = "3.8.3"

runner.dialect = scala3
rewrite.scala3.insertEndMarkerMinLines = 10
rewrite.scala3.removeOptionalBraces = true
rewrite.scala3.convertToNewSyntax = true
align.preset = more

28 changes: 28 additions & 0 deletions EnumFeatureful.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package genovese

import scala.reflect.ClassTag

final case class EnumFeatureful[T](values: IArray[T]) extends Featureful[T]:
private val lookup = values.zipWithIndex.toMap
override lazy val features: IArray[Feature[?]] = IArray(
Feature.IntCategory(List.tabulate(values.length)(identity))
)
override def toFeatures(value: T)(using
RuntimeChecks
): IArray[NormalisedFloat] =
features.map:
case f @ Feature.IntCategory(_) =>
f.toNormalisedFloat(lookup(value))

override def fromFeatures(
fv: IArray[NormalisedFloat]
): T =
features(0): @unchecked match
case f @ Feature.IntCategory(_) =>
val idx =
f.fromNormalisedFloat(NormalisedFloat.applyUnsafe(fv(0)))
values(idx)

override def toString(): String = s"EnumFeatureful(${values.mkString(",")})"
end EnumFeatureful

10 changes: 10 additions & 0 deletions Evaluator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package genovese

trait Evaluator:
def evaluate(population: Population, fitness: Vec => NormalisedFloat): Evaluated

object Evaluator extends EvaluatorCompanion

object SequentialEvaluator extends Evaluator:
override def evaluate(population: Population, fitness: Vec => NormalisedFloat): Evaluated =
Evaluated(population.map(v => v -> fitness(v)))
13 changes: 13 additions & 0 deletions EventHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package genovese

trait EventHandler:
def handle[T](te: TrainingEvent[T], data: T | Null): TrainingInstruction
val allowed: Set[TrainingEvent[?]] = TrainingEvent.all

object EventHandler:
val None = new EventHandler:
override def handle[T](
te: TrainingEvent[T],
data: T | Null
): TrainingInstruction = TrainingInstruction.Continue
override val allowed: Set[TrainingEvent[?]] = Set.empty
116 changes: 116 additions & 0 deletions Feature.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package genovese

private class CategoryCalc(
val steps: Map[Int, Float],
val stepSize: Float
)
private object CategoryCalc:
def apply(alts: Seq[Int]) =
val cnt = alts.size
val step = 1.0f / cnt

new CategoryCalc(
List.tabulate(alts.size)(idx => idx -> step * idx).toMap,
step
)

enum Feature[T]:
case FloatRange(from: Float, to: Float) extends Feature[Float]
case IntRange(from: Int, to: Int) extends Feature[Int]
case NormalisedFloatRange extends Feature[NormalisedFloat]
case Bool extends Feature[Boolean]
case StringCategory(alts: List[String]) extends Feature[String]
case IntCategory(alts: List[Int]) extends Feature[Int]
case Optional(feature: Feature[T]) extends Feature[Option[T]]

private lazy val categorySteps: CategoryCalc | Null =
this match
case StringCategory(alts) =>
CategoryCalc(alts.indices)
case IntCategory(alts) =>
CategoryCalc(alts)
case _ => null

def toNormalisedFloat(value: this.type#T)(using
RuntimeChecks
): NormalisedFloat =
import NormalisedFloat.*
this match
case FloatRange(from, to) =>
runtimeChecks.floatBoundsCheck(from, value, to)
NormalisedFloat((value - from) / (to - from))

case IntRange(from, to) =>
runtimeChecks.intBoundsCheck(from, value, to)
NormalisedFloat((value - from).toFloat / (to - from).toFloat)

case Bool =>
if value then ONE else ZERO

case NormalisedFloatRange =>
value

case StringCategory(alts) =>
val idxof = alts.indexOf(value)

runtimeChecks.check(
Option.when(idxof == -1)(
s"Value of [$value] not found among alternatives [${alts.mkString(", ")}]"
)
)

NormalisedFloat(
categorySteps.nn.steps(
idxof
) + categorySteps.nn.stepSize / 2
)

case IntCategory(alts) =>
val idxof = alts.indexOf(value)

runtimeChecks.check(
Option.when(idxof == -1)(
s"Value of [$value] not found among alternatives [${alts.mkString(", ")}]"
)
)

NormalisedFloat(
categorySteps.nn.steps(
idxof
) + categorySteps.nn.stepSize / 2
)
case Optional(what) =>
value match
case None => NormalisedFloat(0.25f)
case Some(value) =>
NormalisedFloat(what.toNormalisedFloat(value) / 2 + 0.5f)

end match
end toNormalisedFloat

def fromNormalisedFloat(value: NormalisedFloat): this.type#T =
this match
case FloatRange(from, to) =>
from + value * (to - from)
case IntRange(from, to) =>
(from + value * (to - from)).toInt
case Bool => value >= 0.5f
case NormalisedFloatRange => value

case StringCategory(alts) =>
val modulo = (value / categorySteps.nn.stepSize).toInt
alts(modulo.min(alts.length - 1))

case IntCategory(alts) =>
val modulo = (value / categorySteps.nn.stepSize).toInt
alts(modulo.min(alts.length - 1))
case Optional(feature) =>
if value < 0.5f then None
else
Some(
feature.fromNormalisedFloat(
NormalisedFloat.applyUnsafe(2 * (value - 0.5f))
)
)

end Feature
102 changes: 102 additions & 0 deletions Feature.test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package genovese

import munit.*
import org.scalacheck.Prop.*
import org.scalacheck.Gen

class FeatureTest extends FunSuite, ScalaCheckSuite:
import NormalisedFloat.*
given RuntimeChecks = RuntimeChecks.Full

val genNormalisedFloat = Gen.chooseNum(0.0f, 1.0f).map(NormalisedFloat(_))

object gen:
val bool = Gen.const(Feature.Bool)
val normalisedFloat = Gen.const(Feature.NormalisedFloatRange)

val stringCategory: Gen[Feature.StringCategory] =
for
n <- Gen.chooseNum(1, 10)
alts = List.tabulate(n)(i => s"alt$i")
yield Feature.StringCategory(alts)

val intCategory: Gen[Feature.IntCategory] =
for
n <- Gen.chooseNum(1, 10)
alts <- Gen.listOfN(n, Gen.posNum[Int])
yield Feature.IntCategory(alts)

val feature = Gen.oneOf(bool, stringCategory, intCategory, normalisedFloat)
end gen

test("Bool: toNormalisedFloat basics"):
val bool = Feature.Bool

assertEquals(bool.toNormalisedFloat(true), 1.0f)
assertEquals(bool.toNormalisedFloat(false), 0.0f)

property("Bool: fromNormalisedFloat base cases"):
val bool = Feature.Bool

assertEquals(bool.fromNormalisedFloat(ZERO), false)
assertEquals(bool.fromNormalisedFloat(ONE), true)

property("Bool: fromNormalisedFloat range"):
val bool = Feature.Bool

val g = Gen.chooseNum(0.0f, 1.0f).map(NormalisedFloat.apply)

forAll(g): x =>
bool.fromNormalisedFloat(x) == (x >= 0.5f)

property("IntCategory"):
val g =
for
n <- Gen.chooseNum(1, 10)
flt <- genNormalisedFloat
alts <- Gen.listOfN(n, Gen.posNum[Int])
chosen <- Gen.chooseNum(0, n - 1).map(alts.apply)
yield (Feature.IntCategory(alts), alts, chosen, flt)

forAll(g): (feature, alts, chosen, float) =>
feature.fromNormalisedFloat(
feature.toNormalisedFloat(chosen)
) == chosen
&&
alts.contains(feature.fromNormalisedFloat(float))

property("Optional"):
forAll(gen.feature, genNormalisedFloat): (feat, float) =>
val o = Feature.Optional(feat)
((float <= 0.5f) && o.fromNormalisedFloat(float) == None) ||
(float > 0.5f) && o.fromNormalisedFloat(float).isDefined

property("StringCategory"):
val g =
for
n <- Gen.chooseNum(1, 10)
flt <- genNormalisedFloat
alts = List.tabulate(n)(i => s"alt$i")
chosen <- Gen.chooseNum(0, n - 1).map(alts.apply)
yield (Feature.StringCategory(alts), alts, chosen, flt)

forAll(g): (feature, alts, chosen, float) =>
feature.fromNormalisedFloat(
feature.toNormalisedFloat(chosen)
) == chosen
&&
alts.contains(feature.fromNormalisedFloat(float))

property("FloatRange: normalisation roundtrip"):
val gen = for
min <- Gen.double.map(_.toFloat)
len <- Gen.posNum[Float]
num <- Gen.chooseNum(min, min + len)
yield (Feature.FloatRange(min, min + len), num)

forAll(gen): (float, num) =>
Math.abs(
float.fromNormalisedFloat(float.toNormalisedFloat(num)) - num
) <= 0.0001

end FeatureTest
Loading

0 comments on commit 6feefd3

Please sign in to comment.