Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivation for sum that consists only of singletons #19

Merged
merged 5 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions src/core/wisteria.SimpleSumDerivationMethods.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
Wisteria, version [unreleased]. Copyright 2024 Jon Pretty, Propensive OÜ.

The primary distribution site is: https://propensive.com/

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied. See the License for the specific language governing permissions
and limitations under the License.
*/

package wisteria

import anticipation.*
import rudiments.*
import vacuous.*
import contingency.*
import fulminate.*

import scala.deriving.*
import scala.compiletime.*

trait SimpleSumDerivationMethods[TypeclassType[_]]:
protected transparent inline def complement[DerivationType, VariantType](sum: DerivationType)
(using variantIndex: Int & VariantIndex[VariantType],
reflection: SumReflection[DerivationType])
: Optional[VariantType] =

type Labels = reflection.MirroredElemLabels
type Variants = reflection.MirroredElemTypes
val size: Int = valueOf[Tuple.Size[reflection.MirroredElemTypes]]

fold[DerivationType, Variants, Labels](sum, size, 0, false)(index == reflection.ordinal(sum)):
[VariantType2 <: DerivationType] => field =>
if index == variantIndex then field.asInstanceOf[VariantType] else Unset

protected inline def variantLabels[DerivationType]
(using reflection: SumReflection[DerivationType])
: List[Text] =

constValueTuple[reflection.MirroredElemLabels].toList.map(_.toString.tt)

protected transparent inline def delegateFromName[DerivationType](label: Text)
(using reflection: SumReflection[DerivationType], requirement: ContextRequirement)
: DerivationType =

type Labels = reflection.MirroredElemLabels
type Variants = reflection.MirroredElemTypes

val size: Int = valueOf[Tuple.Size[reflection.MirroredElemTypes]]
val variantLabel = label

// Here label comes from context of fold's predicate
fold[DerivationType, Variants, Labels](variantLabel, size, 0, true)(label == variantLabel)
.vouch(using Unsafe)

protected transparent inline def delegateFromIndex[DerivationType](index: Int)
(using reflection: SumReflection[DerivationType], requirement: ContextRequirement)
: DerivationType =

type Labels = reflection.MirroredElemLabels
type Variants = reflection.MirroredElemTypes

val size: Int = valueOf[Tuple.Size[reflection.MirroredElemTypes]]
val variantIndex = index

// Here label comes from context of fold's predicate
fold[DerivationType, Variants, Labels](variantIndex.toString().tt, size, 0, true)(index == variantIndex)
.vouch(using Unsafe)

protected transparent inline def variant[DerivationType](sum: DerivationType)
(using reflection: SumReflection[DerivationType], requirement: ContextRequirement)
[ResultType]
(inline lambda: [VariantType <: DerivationType] =>
VariantType =>
(label: Text,
index: Int & VariantIndex[VariantType]) ?=>
ResultType)
: ResultType =

type Labels = reflection.MirroredElemLabels
type Variants = reflection.MirroredElemTypes

val size: Int = valueOf[Tuple.Size[reflection.MirroredElemTypes]]

fold[DerivationType, Variants, Labels](sum, size, 0, false)(index == reflection.ordinal(sum)):
[VariantType <: DerivationType] => variant => lambda[VariantType](variant)
.vouch(using Unsafe)

private transparent inline def fold[DerivationType, VariantsType <: Tuple, LabelsType <: Tuple]
(inline inputLabel: Text, size: Int, index: Int, fallible: Boolean)
(using reflection: SumReflection[DerivationType], requirement: ContextRequirement)
(inline predicate: (label: Text, index: Int & VariantIndex[DerivationType]) ?=> Boolean)
: Optional[DerivationType] =

inline erasedValue[VariantsType] match
case _: (variantType *: variantsType) =>
inline erasedValue[LabelsType] match
case _: (labelType *: moreLabelsType) =>
type VariantType = variantType & DerivationType
if index >= size then Unset else
(valueOf[labelType].asMatchable: @unchecked) match
case label: String =>
val index2: Int & VariantIndex[DerivationType] = VariantIndex[DerivationType](index)

if predicate(using label.tt, index2)
then
summonInline[SingletonFactory[VariantType]].create
else
fold
[DerivationType, variantsType, moreLabelsType]
(inputLabel, size, index + 1, fallible)
(predicate)

case _ =>
inline if fallible
then raise(VariantError[DerivationType](inputLabel))(Unset)(using summonInline[Errant[VariantError]])
else throw Panic(msg"Should be unreachable")

private transparent inline def fold[DerivationType, VariantsType <: Tuple, LabelsType <: Tuple]
(inline sum: DerivationType, size: Int, index: Int, fallible: Boolean)
(using reflection: SumReflection[DerivationType], requirement: ContextRequirement)
(inline predicate: (label: Text, index: Int & VariantIndex[DerivationType]) ?=> Boolean)
[ResultType]
(inline lambda: [VariantType <: DerivationType] =>
VariantType =>
(label: Text,
index: Int & VariantIndex[VariantType]) ?=>
ResultType)
: Optional[ResultType] =

inline erasedValue[VariantsType] match
case _: (variantType *: variantsType) => inline erasedValue[LabelsType] match
case _: (labelType *: moreLabelsType) =>
type VariantType = variantType & DerivationType
if index >= size then Unset else
(valueOf[labelType].asMatchable: @unchecked) match
case label: String =>
val index2: Int & VariantIndex[DerivationType] = VariantIndex[DerivationType](index)

if predicate(using label.tt, index2)
then
val index3: Int & VariantIndex[VariantType] = VariantIndex[VariantType](index)
val variant: VariantType = sum.asInstanceOf[VariantType]
lambda[VariantType](variant)(using label.tt, index3)
else
fold
[DerivationType, variantsType, moreLabelsType]
(sum, size, index + 1, fallible)
(predicate)
(lambda)

case _ =>
inline if fallible
then raise(VariantError[DerivationType]("".tt))(Unset)(using summonInline[Errant[VariantError]])
else throw Panic(msg"Should be unreachable")

inline def split[DerivationType: SumReflection]: TypeclassType[DerivationType]
17 changes: 17 additions & 0 deletions src/core/wisteria.SingletonFactory.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package wisteria

import scala.deriving.*
import scala.compiletime.*

trait SingletonFactory[T] {
def create: T
}

inline given singletonFactory[T <: Product](using reflection: ProductReflection[T]): SingletonFactory[T] =
compiletime.summonFrom:
case given (reflection.MirroredMonoType <:< Singleton) =>
new SingletonFactory[T]:
override def create: T = reflection.fromProduct(EmptyTuple)
case _ =>
inline val typeName: String = erasedValue[reflection.MirroredLabel]
error("Cannot derive sumSimpleVariant for '"+typeName+"'")
65 changes: 65 additions & 0 deletions src/test/tests2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
Wisteria, version [unreleased]. Copyright 2024 Jon Pretty, Propensive OÜ.

The primary distribution site is: https://propensive.com/

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied. See the License for the specific language governing permissions
and limitations under the License.
*/

package wisteria

import anticipation.*
import rudiments.*
import vacuous.*
import contingency.*
import fulminate.*
import contingency.errorHandlers.throwUnsafely

import scala.deriving.*
import scala.compiletime.*

trait SimpleSumDerivation[TypeclassType[_]] extends SimpleSumDerivationMethods[TypeclassType]:
inline given derived[DerivationType](using reflection: SumReflection[DerivationType])
: TypeclassType[DerivationType] =
split[DerivationType](using reflection)

trait Show[T] {
def show(value: T): String
}

object Show extends SimpleSumDerivation[Show] {
inline def split[DerivationType: SumReflection]: Show[DerivationType] = value =>
variant(value):
[VariantType <: DerivationType] => variant =>
s"label($label), index($index)"
}

trait Read[T] {
def read(value: String): T
}

object Read extends SimpleSumDerivation[Read] {
inline def split[DerivationType: SumReflection]: Read[DerivationType] = input =>
delegateFromName(input)
}

enum Test:
case First
case Second
case Third

// @main
def prototypeTest(): Unit =
val showImpl = summon[Show[Test]]
val readImpl = summon[Read[Test]]

val value = readImpl.read("Second")
println(showImpl.show(value))
Loading