Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replaced mathematical kwargs in scheduler constructors #60

Merged
merged 1 commit into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterSchedulers"
uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
authors = ["Kyle Daruwalla"]
version = "0.4.0"
version = "0.4.1"

[deps]
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
Expand Down
58 changes: 29 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ParameterSchedulers.jl provides common machine learning (ML) schedulers for hype
using Flux, ParameterSchedulers
using ParameterSchedulers: Scheduler

opt = Scheduler(Momentum, Exp(λ = 1e-2, γ = 0.8))
opt = Scheduler(Momentum, Exp(start = 1e-2, decay = 0.8))
```

## Available Schedules
Expand All @@ -30,12 +30,12 @@ You can read [this paper](https://arxiv.org/abs/1908.06477) for more information
<tbody>
<tr><td>

[`Step(;λ, γ, step_sizes)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Step)
[`Step(; start, decay, step_sizes)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Step)

</td>
<td>

Exponential decay by `γ` every step in `step_sizes`
Exponential decay by `decay` every step in `step_sizes`

</td>
<td> Decay </td>
Expand All @@ -44,19 +44,19 @@ Exponential decay by `γ` every step in `step_sizes`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Step(λ = 1.0, γ = 0.8, step_sizes = [2, 3, 2]) # hide
s = Step(start = 1.0, decay = 0.8, step_sizes = [2, 3, 2]) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Exp(;λ, γ)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Exp)
[`Exp(start, decay)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Exp)

</td>
<td>

Exponential decay by `γ` every iteration
Exponential decay by `decay` every iteration

</td>
<td> Decay </td>
Expand All @@ -65,14 +65,14 @@ Exponential decay by `γ` every iteration
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Exp(λ = 1.0, γ = 0.5) # hide
s = Exp(start = 1.0, decay = 0.5) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`CosAnneal(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.CosAnneal)
[`CosAnneal(;l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.CosAnneal)

</td>
<td>
Expand All @@ -86,14 +86,14 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = CosAnneal(λ0 = 0.0, λ1 = 1.0, period = 4) # hide
s = CosAnneal(l0 = 0.0, l1 = 1.0, period = 4) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Triangle(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle)
[`Triangle(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle)

</td>
<td>
Expand All @@ -107,14 +107,14 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Triangle(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = Triangle(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`TriangleDecay2(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleDecay2)
[`TriangleDecay2(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleDecay2)

</td>
<td>
Expand All @@ -128,19 +128,19 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = TriangleDecay2(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = TriangleDecay2(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`TriangleExp(;λ0, λ1, period, γ)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleExp)
[`TriangleExp(l0, l1, period, decay)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleExp)

</td>
<td>

[Triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) function with exponential amplitude decay at rate `γ`
[Triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) function with exponential amplitude decay at rate `decay`

</td>
<td> Cyclic </td>
Expand All @@ -149,19 +149,19 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = TriangleExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8) # hide
s = TriangleExp(l0 = 0.0, l1 = 1.0, period = 2, decay = 0.8) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Poly(;λ, p, max_iter)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Poly)
[`Poly(start, degree, max_iter)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Poly)

</td>
<td>

Polynomial decay at degree `p`
Polynomial decay at degree `degree`.

</td>
<td> Decay </td>
Expand All @@ -170,19 +170,19 @@ Polynomial decay at degree `p`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Poly(λ = 1.0, p = 2, max_iter = t[end]) # hide
s = Poly(start = 1.0, degree = 2, max_iter = t[end]) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Inv(;λ, γ, p)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Inv)
[`Inv(start, decay, degree)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Inv)

</td>
<td>

Inverse decay at rate `(1 + tγ)^p`
Inverse decay at rate `(1 + t * decay)^degree`

</td>
<td> Decay </td>
Expand All @@ -191,14 +191,14 @@ Inverse decay at rate `(1 + tγ)^p`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Inv(λ = 1.0, p = 2, γ = 0.8) # hide
s = Inv(start = 1.0, degree = 2, decay = 0.8) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Sin(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Sin)
[`Sin(;l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Sin)

</td>
<td>
Expand All @@ -212,14 +212,14 @@ Sine function
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Sin(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = Sin(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`SinDecay2(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinDecay2)
[`SinDecay2(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinDecay2)

</td>
<td>
Expand All @@ -233,19 +233,19 @@ Sine function with half the amplitude every `period`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = SinDecay2(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = SinDecay2(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`SinExp(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinExp)
[`SinExp(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinExp)

</td>
<td>

Sine function with exponential amplitude decay at rate `γ`
Sine function with exponential amplitude decay at rate `decay`

</td>
<td> Cyclic </td>
Expand All @@ -254,7 +254,7 @@ Sine function with exponential amplitude decay at rate `γ`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = SinExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8) # hide
s = SinExp(l0 = 0.0, l1 = 1.0, period = 2, decay = 0.8) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>
Expand Down
47 changes: 24 additions & 23 deletions docs/src/tutorials/basic-schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@ using ParameterSchedulers # hide

A decay schedule is defined by the following formula:
```math
s(t) = \lambda g(t)
s(t) = l \times g(t)
```
where ``s(t)`` is the schedule output, ``\lambda`` is the base (initial) value, and ``g(t)`` is the decay function. Typically, the decay function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.
where ``s(t)`` is the schedule output, ``l`` is the base (initial) value, and ``g(t)`` is the decay function. Typically, the decay function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.

For example, here is an exponential decay schedule:
```@example decay-schedules
expdecay(γ, t) = γ^(t - 1)
s = Exp(λ = 0.1, γ = 0.8)
println("λ g(1) == s(1): ",
0.1 * expdecay(0.8, 1) == s(1))
expdecay(decay, t) = decay^(t - 1)
s = Exp(start = 0.1, decay = 0.8)
println("l g(1) == s(1): ", 0.1 * expdecay(0.8, 1) == s(1))
```

As you can see above, [`Exp`](@ref) is a type of decay schedule. Below is a list of all the decay schedules implemented, and the parameters and decay functions for each one.

| Schedule | Parameters | Decay Function |
|:---------------|:-----------------------|:---------------|
| [`Step`](@ref) | `λ`, `γ`, `step_sizes` | ``g(t) = \gamma^{i - 1}`` where ``\sum_{j = 1}^{i - 1} \text{step\_sizes}_j < t \leq \sum_{j = 1}^i \text{step\_sizes}_j`` |
| [`Exp`](@ref) | `λ`, `γ` | ``g(t) = \gamma^{t - 1}`` |
| [`Poly`](@ref) | `λ`, `p`, `max_iter` | ``g(t) = \frac{1}{\left(1 - (t - 1) / \text{max\_iter}\right)^p}`` |
| [`Inv`](@ref) | `λ`, `γ`, `p` | ``g(t) = \frac{1}{(1 + (t - 1) \gamma)^p}`` |
| [`Step`](@ref) | `start`, `decay`, `step_sizes` | ``g(t) = \texttt{decay}^{i - 1}`` where ``\sum_{j = 1}^{i - 1} \texttt{step\_sizes}_j < t \leq \sum_{j = 1}^i \texttt{step\_sizes}_j`` |
| [`Exp`](@ref) | `start`, `decay` | ``g(t) = \texttt{decay}^{t - 1}`` |
| [`Poly`](@ref) | `start`, `degree`, `max_iter` | ``g(t) = \dfrac{1}{\left(\dfrac{1 - (t - 1)}{\texttt{max\_iter}}\right)^\texttt{degree}}`` |
| [`Inv`](@ref) | `start`, `decay`, `degree` | ``g(t) = \dfrac{1}{\left(1 + \texttt{decay} \times (t - 1) \right)^\texttt{degree}}`` |

## Cyclic schedules

Expand All @@ -39,27 +38,29 @@ using ParameterSchedulers #hide

A cyclic schedule exhibits periodic behavior, and it is described by the following formula:
```math
s(t) = |\lambda_0 - \lambda_1| g(t) + \min (\lambda_0, \lambda_1)
s(t) = |l_0 - l_1| g(t) + \min (l_0, l_1)
```
where ``s(t)`` is the schedule output, ``\lambda_0`` and ``\lambda_1`` are the range endpoints, and ``g(t)`` is the cycle function. Similar to the decay function, the cycle function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.
where ``s(t)`` is the schedule output, ``l_0`` and ``l_1`` are the range endpoints, and ``g(t)`` is the cycle function. Similar to the decay function, the cycle function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.

For example, here is triangular wave schedule:
```@example cyclic-schedules
tricycle(period, t) = (2 / π) * abs(asin(sin(π * (t - 1) / period)))
s = Triangle(λ0 = 0.1, λ1 = 0.4, period = 2)
println("abs(λ0 - λ1) g(1) + min(λ0, λ1) == s(1): ",
abs(0.1 - 0.4) * tricycle(2, 1) + min(0.1, 0.4) == s(1))
s = Triangle(l0 = 0.1, l1 = 0.4, period = 2)
println(
"abs(l0 - l1) * g(1) + min(l0, l1) == s(1): ",
abs(0.1 - 0.4) * tricycle(2, 1) + min(0.1, 0.4) == s(1)
)
```

[`Triangle`](@ref) (used in the above example) is a type of cyclic schedule. Below is a list of all the cyclic schedules implemented, and the parameters and cycle functions for each one.

| Schedule | Parameters | Cycle Function |
|:-------------------------|:-----------------------------------------|:---------------|
| [`Triangle`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{2}{\pi} \left| \arcsin \left( \sin \left(\frac{\pi (t - 1)}{\text{period}} \right) \right) \right|`` |
| [`TriangleDecay2`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{1}{2^{\lfloor (t - 1) / \text{period} \rfloor}} g_{\mathrm{Triangle}}(t)`` |
| [`TriangleExp`](@ref) | `λ0`, `λ1`, `period`, `γ` | ``g(t) = \gamma^{t - 1} g_{\mathrm{Triangle}}(t)`` |
| [`Sin`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \left| \sin \left(\frac{\pi (t - 1)}{\text{period}} \right) \right|`` |
| [`SinDecay2`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{1}{2^{\lfloor (t - 1) / \text{period} \rfloor}} g_{\mathrm{Sin}}(t)`` |
| [`SinExp`](@ref) | `λ0`, `λ1`, `period`, `γ` | ``g(t) = \gamma^{t - 1} g_{\mathrm{Sin}}(t)`` |
| [`CosAnneal`](@ref) | `λ0`, `λ1`, `period`, `restart == true` | ``g(t) = \frac{1}{2} \left(1 + \cos \left(\frac{\pi \: \mathrm{mod}(t - 1, \text{period})}{\text{period}}\right) \right)`` |
| [`CosAnneal`](@ref) | `λ0`, `λ1`, `period`, `restart == false` | ``g(t) = \frac{1}{2} \left(1 + \cos \left(\frac{\pi \: (t - 1)}{\text{period}}\right) \right)`` |
| [`Triangle`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{2}{\pi} \left\| \arcsin (\sin (\frac{\pi (t - 1)}{\text{period}})) \right\| `` |
| [`TriangleDecay2`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{1}{2^{\lfloor (t - 1) / \texttt{period} \rfloor}} g_{\texttt{Triangle}}(t)`` |
| [`TriangleExp`](@ref) | `l0`, `l1`, `period`, `decay` | ``g(t) = \texttt{decay}^{t - 1} g_{\texttt{Triangle}}(t)`` |
| [`Sin`](@ref) | `l0`, `l1`, `period` | ``g(t) = \left\| \sin \left(\frac{\pi (t - 1)}{\texttt{period}} \right) \right\|`` |
| [`SinDecay2`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{1}{2^{\lfloor (t - 1) / \texttt{period} \rfloor}} g_{\texttt{Sin}}(t)`` |
| [`SinExp`](@ref) | `l0`, `l1`, `period`, `decay` | ``g(t) = \texttt{decay}^{t - 1} g_{\texttt{Sin}}(t)`` |
| [`CosAnneal`](@ref) | `l0`, `l1`, `period`, with `restart = true` | ``g(t) = \dfrac{1}{2} \left(1 + \cos \left(\frac{\pi \: \mathrm{mod}(t - 1, \texttt{period})}{\texttt{period}}\right) \right)`` |
| [`CosAnneal`](@ref) | `l0`, `l1`, `period`, with `restart = false` | ``g(t) = \dfrac{1}{2} \left(1 + \cos \left(\frac{\pi \: (t - 1)}{\texttt{period}}\right) \right)`` |
4 changes: 2 additions & 2 deletions docs/src/tutorials/complex-schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Let's take the notion of arbitrary schedules one step further, and instead defin
```@example complex-schedules
using UnicodePlots

s = Loop(Exp(λ = 0.1, γ = 0.4), 10)
s = Loop(Exp(start = 0.1, decay = 0.4), 10)
t = 1:25 |> collect
lineplot(t, s.(t); border = :none)
```
Expand All @@ -32,7 +32,7 @@ lineplot(t, s.(t); border = :none)
Finally, we might concatenate sequences of schedules, applying each one for a given length, then switch to the next schedule in the order. A [`Sequence`](@ref) schedule lets us do this. For example, we can start with a triangular schedule, then switch to a more conservative exponential schedule half way through training.
```@example complex-schedules
nepochs = 50
s = Sequence([Triangle(λ0 = 0.0, λ1 = 0.5, period = 5), Exp(λ = 0.5, γ = 0.5)],
s = Sequence([Triangle(l0 = 0.0, l1 = 0.5, period = 5), Exp(start = 0.5, decay = 0.5)],
[nepochs ÷ 2, nepochs ÷ 2])

t = 1:nepochs |> collect
Expand Down
7 changes: 4 additions & 3 deletions docs/src/tutorials/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ParameterSchedulers # hide

All schedules types in ParameterSchedulers.jl behave as callable iterators. For example, we can call the simple exponential decay schedule ([`Exp`](@ref)) below at a specific iteration:
```@example getting-started
s = Exp(λ = 0.1, γ = 0.8)
s = Exp(start = 0.1, decay = 0.8)
println("s(1): ", s(1))
println("s(5): ", s(5))
```
Expand Down Expand Up @@ -42,8 +42,9 @@ println("s: ", next!(stateful_s))

Also note that `Stateful` cannot be called (or iterated with `Base.iterate`):
```@example getting-started
try stateful_s(1)
try
stateful_s(1)
catch e
println(e)
end
```
```
Loading
Loading