-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 6feefd3
Showing
32 changed files
with
1,652 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
name: CI | ||
on: | ||
push: | ||
branches: ["main"] | ||
tags: ["v*"] | ||
pull_request: | ||
branches: ["*"] | ||
|
||
jobs: | ||
build: | ||
strategy: | ||
fail-fast: false | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- uses: coursier/cache-action@v6 | ||
- uses: VirtusLab/scala-cli-setup@main | ||
with: | ||
power: true | ||
|
||
- name: Check formatting | ||
run: make code-check || echo "Run `make pre-ci`" | ||
|
||
- name: Test | ||
run: make test | ||
|
||
# - name: Check documentation compiles and runs | ||
# run: make check-docs && make run-example | ||
|
||
- name: Publish snapshot | ||
if: github.ref == 'refs/heads/main' | ||
run: make publish-snapshot | ||
env: | ||
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} | ||
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} | ||
|
||
- name: Publish | ||
if: startsWith(github.ref, 'refs/tags/v') | ||
run: make publish | ||
env: | ||
PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} | ||
PGP_SECRET: ${{ secrets.PGP_SECRET }} | ||
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} | ||
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Logs | ||
logs | ||
*.log | ||
npm-debug.log* | ||
yarn-debug.log* | ||
yarn-error.log* | ||
pnpm-debug.log* | ||
lerna-debug.log* | ||
|
||
node_modules | ||
dist | ||
dist-ssr | ||
*.local | ||
|
||
# Editor directories and files | ||
.vscode/* | ||
!.vscode/extensions.json | ||
.idea | ||
.DS_Store | ||
*.suo | ||
*.ntvs* | ||
*.njsproj | ||
*.sln | ||
*.sw? | ||
|
||
.class | ||
.scala-build | ||
.metals | ||
.bsp | ||
scalajs-frontend.js | ||
*.semanticdb | ||
db.json | ||
*.map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
OrganizeImports.groupedImports = Merge | ||
OrganizeImports.targetDialect = Scala3 | ||
OrganizeImports.removeUnused = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
version = "3.8.3" | ||
|
||
runner.dialect = scala3 | ||
rewrite.scala3.insertEndMarkerMinLines = 10 | ||
rewrite.scala3.removeOptionalBraces = true | ||
rewrite.scala3.convertToNewSyntax = true | ||
align.preset = more | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package genovese | ||
|
||
import scala.reflect.ClassTag | ||
|
||
final case class EnumFeatureful[T](values: IArray[T]) extends Featureful[T]: | ||
private val lookup = values.zipWithIndex.toMap | ||
override lazy val features: IArray[Feature[?]] = IArray( | ||
Feature.IntCategory(List.tabulate(values.length)(identity)) | ||
) | ||
override def toFeatures(value: T)(using | ||
RuntimeChecks | ||
): IArray[NormalisedFloat] = | ||
features.map: | ||
case f @ Feature.IntCategory(_) => | ||
f.toNormalisedFloat(lookup(value)) | ||
|
||
override def fromFeatures( | ||
fv: IArray[NormalisedFloat] | ||
): T = | ||
features(0): @unchecked match | ||
case f @ Feature.IntCategory(_) => | ||
val idx = | ||
f.fromNormalisedFloat(NormalisedFloat.applyUnsafe(fv(0))) | ||
values(idx) | ||
|
||
override def toString(): String = s"EnumFeatureful(${values.mkString(",")})" | ||
end EnumFeatureful | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package genovese | ||
|
||
trait Evaluator: | ||
def evaluate(population: Population, fitness: Vec => NormalisedFloat): Evaluated | ||
|
||
object Evaluator extends EvaluatorCompanion | ||
|
||
object SequentialEvaluator extends Evaluator: | ||
override def evaluate(population: Population, fitness: Vec => NormalisedFloat): Evaluated = | ||
Evaluated(population.map(v => v -> fitness(v))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package genovese | ||
|
||
trait EventHandler: | ||
def handle[T](te: TrainingEvent[T], data: T | Null): TrainingInstruction | ||
val allowed: Set[TrainingEvent[?]] = TrainingEvent.all | ||
|
||
object EventHandler: | ||
val None = new EventHandler: | ||
override def handle[T]( | ||
te: TrainingEvent[T], | ||
data: T | Null | ||
): TrainingInstruction = TrainingInstruction.Continue | ||
override val allowed: Set[TrainingEvent[?]] = Set.empty |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
package genovese | ||
|
||
private class CategoryCalc( | ||
val steps: Map[Int, Float], | ||
val stepSize: Float | ||
) | ||
private object CategoryCalc: | ||
def apply(alts: Seq[Int]) = | ||
val cnt = alts.size | ||
val step = 1.0f / cnt | ||
|
||
new CategoryCalc( | ||
List.tabulate(alts.size)(idx => idx -> step * idx).toMap, | ||
step | ||
) | ||
|
||
enum Feature[T]: | ||
case FloatRange(from: Float, to: Float) extends Feature[Float] | ||
case IntRange(from: Int, to: Int) extends Feature[Int] | ||
case NormalisedFloatRange extends Feature[NormalisedFloat] | ||
case Bool extends Feature[Boolean] | ||
case StringCategory(alts: List[String]) extends Feature[String] | ||
case IntCategory(alts: List[Int]) extends Feature[Int] | ||
case Optional(feature: Feature[T]) extends Feature[Option[T]] | ||
|
||
private lazy val categorySteps: CategoryCalc | Null = | ||
this match | ||
case StringCategory(alts) => | ||
CategoryCalc(alts.indices) | ||
case IntCategory(alts) => | ||
CategoryCalc(alts) | ||
case _ => null | ||
|
||
def toNormalisedFloat(value: this.type#T)(using | ||
RuntimeChecks | ||
): NormalisedFloat = | ||
import NormalisedFloat.* | ||
this match | ||
case FloatRange(from, to) => | ||
runtimeChecks.floatBoundsCheck(from, value, to) | ||
NormalisedFloat((value - from) / (to - from)) | ||
|
||
case IntRange(from, to) => | ||
runtimeChecks.intBoundsCheck(from, value, to) | ||
NormalisedFloat((value - from).toFloat / (to - from).toFloat) | ||
|
||
case Bool => | ||
if value then ONE else ZERO | ||
|
||
case NormalisedFloatRange => | ||
value | ||
|
||
case StringCategory(alts) => | ||
val idxof = alts.indexOf(value) | ||
|
||
runtimeChecks.check( | ||
Option.when(idxof == -1)( | ||
s"Value of [$value] not found among alternatives [${alts.mkString(", ")}]" | ||
) | ||
) | ||
|
||
NormalisedFloat( | ||
categorySteps.nn.steps( | ||
idxof | ||
) + categorySteps.nn.stepSize / 2 | ||
) | ||
|
||
case IntCategory(alts) => | ||
val idxof = alts.indexOf(value) | ||
|
||
runtimeChecks.check( | ||
Option.when(idxof == -1)( | ||
s"Value of [$value] not found among alternatives [${alts.mkString(", ")}]" | ||
) | ||
) | ||
|
||
NormalisedFloat( | ||
categorySteps.nn.steps( | ||
idxof | ||
) + categorySteps.nn.stepSize / 2 | ||
) | ||
case Optional(what) => | ||
value match | ||
case None => NormalisedFloat(0.25f) | ||
case Some(value) => | ||
NormalisedFloat(what.toNormalisedFloat(value) / 2 + 0.5f) | ||
|
||
end match | ||
end toNormalisedFloat | ||
|
||
def fromNormalisedFloat(value: NormalisedFloat): this.type#T = | ||
this match | ||
case FloatRange(from, to) => | ||
from + value * (to - from) | ||
case IntRange(from, to) => | ||
(from + value * (to - from)).toInt | ||
case Bool => value >= 0.5f | ||
case NormalisedFloatRange => value | ||
|
||
case StringCategory(alts) => | ||
val modulo = (value / categorySteps.nn.stepSize).toInt | ||
alts(modulo.min(alts.length - 1)) | ||
|
||
case IntCategory(alts) => | ||
val modulo = (value / categorySteps.nn.stepSize).toInt | ||
alts(modulo.min(alts.length - 1)) | ||
case Optional(feature) => | ||
if value < 0.5f then None | ||
else | ||
Some( | ||
feature.fromNormalisedFloat( | ||
NormalisedFloat.applyUnsafe(2 * (value - 0.5f)) | ||
) | ||
) | ||
|
||
end Feature |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package genovese | ||
|
||
import munit.* | ||
import org.scalacheck.Prop.* | ||
import org.scalacheck.Gen | ||
|
||
class FeatureTest extends FunSuite, ScalaCheckSuite: | ||
import NormalisedFloat.* | ||
given RuntimeChecks = RuntimeChecks.Full | ||
|
||
val genNormalisedFloat = Gen.chooseNum(0.0f, 1.0f).map(NormalisedFloat(_)) | ||
|
||
object gen: | ||
val bool = Gen.const(Feature.Bool) | ||
val normalisedFloat = Gen.const(Feature.NormalisedFloatRange) | ||
|
||
val stringCategory: Gen[Feature.StringCategory] = | ||
for | ||
n <- Gen.chooseNum(1, 10) | ||
alts = List.tabulate(n)(i => s"alt$i") | ||
yield Feature.StringCategory(alts) | ||
|
||
val intCategory: Gen[Feature.IntCategory] = | ||
for | ||
n <- Gen.chooseNum(1, 10) | ||
alts <- Gen.listOfN(n, Gen.posNum[Int]) | ||
yield Feature.IntCategory(alts) | ||
|
||
val feature = Gen.oneOf(bool, stringCategory, intCategory, normalisedFloat) | ||
end gen | ||
|
||
test("Bool: toNormalisedFloat basics"): | ||
val bool = Feature.Bool | ||
|
||
assertEquals(bool.toNormalisedFloat(true), 1.0f) | ||
assertEquals(bool.toNormalisedFloat(false), 0.0f) | ||
|
||
property("Bool: fromNormalisedFloat base cases"): | ||
val bool = Feature.Bool | ||
|
||
assertEquals(bool.fromNormalisedFloat(ZERO), false) | ||
assertEquals(bool.fromNormalisedFloat(ONE), true) | ||
|
||
property("Bool: fromNormalisedFloat range"): | ||
val bool = Feature.Bool | ||
|
||
val g = Gen.chooseNum(0.0f, 1.0f).map(NormalisedFloat.apply) | ||
|
||
forAll(g): x => | ||
bool.fromNormalisedFloat(x) == (x >= 0.5f) | ||
|
||
property("IntCategory"): | ||
val g = | ||
for | ||
n <- Gen.chooseNum(1, 10) | ||
flt <- genNormalisedFloat | ||
alts <- Gen.listOfN(n, Gen.posNum[Int]) | ||
chosen <- Gen.chooseNum(0, n - 1).map(alts.apply) | ||
yield (Feature.IntCategory(alts), alts, chosen, flt) | ||
|
||
forAll(g): (feature, alts, chosen, float) => | ||
feature.fromNormalisedFloat( | ||
feature.toNormalisedFloat(chosen) | ||
) == chosen | ||
&& | ||
alts.contains(feature.fromNormalisedFloat(float)) | ||
|
||
property("Optional"): | ||
forAll(gen.feature, genNormalisedFloat): (feat, float) => | ||
val o = Feature.Optional(feat) | ||
((float <= 0.5f) && o.fromNormalisedFloat(float) == None) || | ||
(float > 0.5f) && o.fromNormalisedFloat(float).isDefined | ||
|
||
property("StringCategory"): | ||
val g = | ||
for | ||
n <- Gen.chooseNum(1, 10) | ||
flt <- genNormalisedFloat | ||
alts = List.tabulate(n)(i => s"alt$i") | ||
chosen <- Gen.chooseNum(0, n - 1).map(alts.apply) | ||
yield (Feature.StringCategory(alts), alts, chosen, flt) | ||
|
||
forAll(g): (feature, alts, chosen, float) => | ||
feature.fromNormalisedFloat( | ||
feature.toNormalisedFloat(chosen) | ||
) == chosen | ||
&& | ||
alts.contains(feature.fromNormalisedFloat(float)) | ||
|
||
property("FloatRange: normalisation roundtrip"): | ||
val gen = for | ||
min <- Gen.double.map(_.toFloat) | ||
len <- Gen.posNum[Float] | ||
num <- Gen.chooseNum(min, min + len) | ||
yield (Feature.FloatRange(min, min + len), num) | ||
|
||
forAll(gen): (float, num) => | ||
Math.abs( | ||
float.fromNormalisedFloat(float.toNormalisedFloat(num)) - num | ||
) <= 0.0001 | ||
|
||
end FeatureTest |
Oops, something went wrong.