Skip to content

Commit

Permalink
Merge pull request #59 from darsnack/one-cycle
Browse files Browse the repository at this point in the history
Add `OneCycle` constructor
  • Loading branch information
darsnack authored Mar 5, 2024
2 parents 65b64b1 + 20495b4 commit 463aab2
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 50 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,27 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi

<tr><td>

[`OneCycle(nsteps, maxval)`](https://fluxml.ai/ParameterSchedulers.jl/api/complex.html#ParameterSchedulers.OneCycle)

</td>
<td>

[One cycle cosine](https://arxiv.org/abs/1708.07120)

</td>
<td> Complex </td>
<td style="text-align:center">

```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = OneCycle(10, 1.0) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Triangle(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle)

</td>
Expand Down
44 changes: 22 additions & 22 deletions docs/src/cheatsheet.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ If you are coming from PyTorch or Tensorflow, the following table should help yo
PyTorch typically wraps an optimizer as the first argument, but we ignore that functionality in the table. To wrap a Flux.jl optimizer with a schedule from the rightmost column, use [`ParameterSchedulers.Scheduler`](@ref).
The variable `lr` in the middle/rightmost column refers to the initial learning rate of the optimizer.

| PyTorch | Tensorflow | ParameterSchedulers.jl |
|:-------------------------------------------------------------------------------|:------------------------------------------------------|:------------------------------------------------------|
| `LambdaLR(_, lr_lambda)` | N/A | `lr_lambda` |
| `MultiplicativeLR(_, lr_lambda)` | N/A | N/A |
| `StepLR(_, step_size, gamma)` | `ExponentialDecay(lr, step_size, gamma, True)` | `Step(lr, gamma, step_size)` |
| `MultiStepLR(_, milestones, gamma)` | N/A | `Step(lr, gamma, milestones)` |
| `ConstantLR(_, factor, total_iters)` | N/A | `Sequence(lr * factor => total_iters, lr => nepochs)` |
| PyTorch | Tensorflow | ParameterSchedulers.jl |
|:-------------------------------------------------------------------------------|:------------------------------------------------------|:--------------------------------------------------------|
| `LambdaLR(_, lr_lambda)` | N/A | `lr_lambda` |
| `MultiplicativeLR(_, lr_lambda)` | N/A | N/A |
| `StepLR(_, step_size, gamma)` | `ExponentialDecay(lr, step_size, gamma, True)` | `Step(lr, gamma, step_size)` |
| `MultiStepLR(_, milestones, gamma)` | N/A | `Step(lr, gamma, milestones)` |
| `ConstantLR(_, factor, total_iters)` | N/A | `Sequence(lr * factor => total_iters, lr => nepochs)` |
| `LinearLR(_, start_factor, end_factor, total_iters)` | N/A | `Sequence(Triangle(lr * start_factor, lr * end_factor, 2 * total_iters) => total_iters, lr => nepochs)` |
| `ExponentialLR(_, gamma)` | `ExponentialDecay(lr, 1, gamma, False)` | `Exp(lr, gamma)` |
| N/A | `ExponentialDecay(lr, steps, gamma, False)` | `Interpolator(Exp(lr, gamma), steps)` |
| `CosineAnnealingLR(_, T_max, eta_min)` | `CosineDecay(lr, T_max, eta_min)` | `CosAnneal(lr, eta_min, T_0, false)` |
| `CosineAnnealingRestarts(_, T_0, 1, eta_min)` | `CosineDecayRestarts(lr, T_0, 1, 1, eta_min)` | `CosAnneal(lr, eta_min, T_0)` |
| `CosineAnnealingRestarts(_, T_0, T_mult, eta_min)` | `CosineDecayRestarts(lr, T_0, T_mult, 1, alpha)` | See [below](@ref "Cosine annealing variants") |
| N/A | `CosineDecayRestarts(lr, T_0, T_mult, m_mul, alpha)` | See [below](@ref "Cosine annealing variants") |
| `SequentialLR(_, schedulers, milestones)` | N/A | `Sequence(schedulers, milestones)` |
| `ReduceLROnPlateau(_, mode, factor, patience, threshold, 'abs', 0)` | N/A | See [below](@ref "`ReduceLROnPlateau` style schedules") |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular', _, None)` | N/A | `Triangle(base_lr, max_lr, step_size)` |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular2', _, None)` | N/A | `TriangleDecay2(base_lr, max_lr, step_size)` |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'exp_range', gamma, None)` | N/A | `TriangleExp(base_lr, max_lr, step_size, gamma)` |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, _, _, scale_fn)` | N/A | See [Arbitrary looping schedules](@ref) |
| N/A | `InverseTimeDecay(lr, 1, decay_rate, False)` | `Inv(lr, decay_rate, 1)` |
| N/A | `InverseTimeDecay(lr, decay_step, decay_rate, False)` | `Interpolator(Inv(lr, decay_rate, 1), decay_step)` |
| N/A | `PolynomialDecay(lr, decay_steps, 0, power, False)` | `Poly(lr, power, decay_steps)` |
| `ExponentialLR(_, gamma)` | `ExponentialDecay(lr, 1, gamma, False)` | `Exp(lr, gamma)` |
| N/A | `ExponentialDecay(lr, steps, gamma, False)` | `Interpolator(Exp(lr, gamma), steps)` |
| `CosineAnnealingLR(_, T_max, eta_min)` | `CosineDecay(lr, T_max, eta_min)` | `CosAnneal(lr, eta_min, T_0, false)` |
| `CosineAnnealingRestarts(_, T_0, 1, eta_min)` | `CosineDecayRestarts(lr, T_0, 1, 1, eta_min)` | `CosAnneal(lr, eta_min, T_0)` |
| `CosineAnnealingRestarts(_, T_0, T_mult, eta_min)` | `CosineDecayRestarts(lr, T_0, T_mult, 1, alpha)` | See [below](@ref "Cosine annealing variants") |
| N/A | `CosineDecayRestarts(lr, T_0, T_mult, m_mul, alpha)` | See [below](@ref "Cosine annealing variants") |
| `SequentialLR(_, schedulers, milestones)` | N/A | `Sequence(schedulers, milestones)` |
| `ReduceLROnPlateau(_, mode, factor, patience, threshold, 'abs', 0)` | N/A | See [below](@ref "`ReduceLROnPlateau` style schedules") |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular', _, None)` | N/A | `Triangle(base_lr, max_lr, step_size)` |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular2', _, None)` | N/A | `TriangleDecay2(base_lr, max_lr, step_size)` |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'exp_range', gamma, None)` | N/A | `TriangleExp(base_lr, max_lr, step_size, gamma)` |
| `CyclicLR(_, base_lr, max_lr, step_size, step_size, _, _, scale_fn)` | N/A | See [Arbitrary looping schedules](@ref) |
| N/A | `InverseTimeDecay(lr, 1, decay_rate, False)` | `Inv(lr, decay_rate, 1)` |
| N/A | `InverseTimeDecay(lr, decay_step, decay_rate, False)` | `Interpolator(Inv(lr, decay_rate, 1), decay_step)` |
| N/A | `PolynomialDecay(lr, decay_steps, 0, power, False)` | `Poly(lr, power, decay_steps)` |

## Cosine annealing variants

Expand Down
18 changes: 17 additions & 1 deletion docs/src/tutorials/warmup-schedules.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
# Warm-up Schedules

A popular technique for scheduling learning rates is "warming-up" the optimizer by ramping the learning rate up from zero to the "true" initial learning rate, then starting the "real" schedule. This is easily implementable with ParameterSchedulers.jl using [`Sequence`](@ref).
A popular technique for scheduling learning rates is "warming-up" the optimizer by ramping the learning rate up from near zero to the "true" initial learning rate, then starting the "real" schedule. This is easily implementable with ParameterSchedulers.jl using [`Sequence`](@ref).

## One cycle cosine schedule

A one-cycle cosine schedule is the most popular warm-up schedule. It ramps up once and ramps down once using a cosine waveform. Since this schedule is so common, we provide a convenience constructor in ParameterSchedulers.jl, [`OneCycle`](@ref).

```@example onecycle
using ParameterSchedulers
using UnicodePlots
nsteps = 10
maxval = 1f-1
onecycle = OneCycle(10, 1f-1; percent_start = 0.4)
t = 1:nsteps |> collect
lineplot(t, onecycle.(t); border = :none)
```

## Linear ramp

Expand Down
2 changes: 1 addition & 1 deletion src/ParameterSchedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export Triangle, TriangleDecay2, TriangleExp,
CosAnneal

include("complex.jl")
export Sequence, Loop, Interpolator, Shifted, ComposedSchedule
export Sequence, Loop, Interpolator, Shifted, ComposedSchedule, OneCycle

include("utils.jl")

Expand Down
52 changes: 52 additions & 0 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ Base.eltype(::Type{<:Constant{T}}) where T = T

(schedule::Constant)(t) = schedule.value

"""
Shortened{T}
Shortened(schedule, nsteps)
A schedule that mimics `schedule` but throws a `BoundsError` if accessed
beyond `nsteps`.
"""
struct Shortened{T} <: AbstractSchedule{true}
schedule::T
nsteps::Int
end

Base.eltype(::Type{Shortened{T}}) where T = eltype(T)
Base.length(schedule::Shortened) = schedule.nsteps

function (schedule::Shortened)(t)
(t <= length(schedule)) || throw(BoundsError(schedule, t))
return schedule.schedule(t)
end
Base.iterate(schedule::Shortened, state...) = iterate(schedule.schedule, state...)

"""
Sequence{T, S}
Sequence(schedules, step_sizes)
Expand Down Expand Up @@ -228,3 +249,34 @@ function (composition::ComposedSchedule)(t)

return s(t)
end

"""
OneCycle(nsteps, maxval;
startval = maxval / 25,
endval = maxval / 1f5,
percent_start = 0.25)
Creates a one-cycle cosine schedule over `nsteps` steps warming up from `startval`
up to `maxval` for `ceil(percent_start * nsteps)`, then back to `endval`
(see [Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120)).
"""
function OneCycle(nsteps, maxval;
startval = maxval / 25,
endval = maxval / 1f5,
percent_start = 0.25)
@assert 0 < percent_start < 1

warmup = ceil(Int, nsteps * percent_start)
warmdown = nsteps - warmup

return Sequence(
Shifted(CosAnneal(l0 = maxval,
l1 = startval,
period = warmup,
restart = false), warmup + 1) => warmup,
Shortened(CosAnneal(l0 = maxval,
l1 = endval,
period = warmdown,
restart = false), warmdown) => warmdown
)
end
12 changes: 11 additions & 1 deletion src/cyclic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Triangle(range::T, offset::T, period::S) where {T, S} = Triangle{T, S}(range, of
function Triangle(; kwargs...)
kwargs = depkwargs(:Triangle, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1

return Triangle(abs(l0 - l1), min(l0, l1), kwargs.period)
end

Expand Down Expand Up @@ -53,11 +54,13 @@ where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period))
"""
function TriangleDecay2(range::T, offset, period) where T
parameters = (Interpolator(Exp(range, T(1/2)), period), offset, period)

return ComposedSchedule(Triangle(range, offset, period), parameters)
end
function TriangleDecay2(; kwargs...)
kwargs = depkwargs(:TriangleDecay2, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1

return TriangleDecay2(abs(l0 - l1), min(l0, l1), kwargs.period)
end

Expand Down Expand Up @@ -85,6 +88,7 @@ TriangleExp(range, offset, period, decay) =
function TriangleExp(; kwargs...)
kwargs = depkwargs(:TriangleExp, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1

return TriangleExp(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.decay)
end

Expand Down Expand Up @@ -112,6 +116,7 @@ Sin(range::T, offset::T, period::S) where {T, S} = Sin{T, S}(range, offset, peri
function Sin(; kwargs...)
kwargs = depkwargs(:Sin, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1

return Sin(abs(l0 - l1), min(l0, l1), kwargs.period)
end

Expand All @@ -137,11 +142,13 @@ where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)).
"""
function SinDecay2(range::T, offset, period) where T
parameters = (Interpolator(Exp(range, T(1/2)), period), offset, period)

return ComposedSchedule(Sin(range, offset, period), parameters)
end
function SinDecay2(; kwargs...)
kwargs = depkwargs(:SinDecay2, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1

return SinDecay2(abs(l0 - l1), min(l0, l1), kwargs.period)
end

Expand All @@ -167,6 +174,7 @@ SinExp(range, offset, period, decay) =
function SinExp(; kwargs...)
kwargs = depkwargs(:SinExp, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1

return SinExp(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.decay)
end

Expand Down Expand Up @@ -200,7 +208,9 @@ CosAnneal(range, offset, period) = CosAnneal(range, offset, period, true)
function CosAnneal(; kwargs...)
kwargs = depkwargs(:CosAnneal, kwargs, :λ0 => :l0, :λ1 => :l1)
l0, l1 = kwargs.l0, kwargs.l1
return CosAnneal(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.restart)
restart = get(kwargs, :restart, true)

return CosAnneal(abs(l0 - l1), min(l0, l1), kwargs.period, restart)
end

Base.eltype(::Type{<:CosAnneal{T}}) where T = T
Expand Down
2 changes: 1 addition & 1 deletion src/decay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Base.eltype(::Type{<:Poly{T}}) where T = T
Base.length(schedule::Poly) = schedule.max_iter

function (schedule::Poly)(t)
(t <= length(schedule)) || throw(BoundsError("Cannot index Poly for t > max_iter"))
(t <= length(schedule)) || throw(BoundsError(schedule, t))
return schedule.start * (1 - (t - 1) / schedule.max_iter)^schedule.degree
end

Expand Down
46 changes: 46 additions & 0 deletions test/complex.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
@testset "Constant" begin
value = rand()
schedule = ParameterSchedulers.Constant(value)
@test all(value == schedule(t) for t in 1:1000)
end

@testset "Shortened" begin
base_schedule = Exp(0.9, 10.0)
nsteps = rand(1:100)
schedule = ParameterSchedulers.Shortened(base_schedule, nsteps)

@test Base.IteratorEltype(typeof(schedule)) == Base.IteratorEltype(typeof(base_schedule))
@test Base.IteratorSize(typeof(schedule)) == Base.HasLength()
@test length(schedule) == nsteps
@test all(schedule(t) == base_schedule(t) for t in 1:nsteps)
@test_throws BoundsError schedule(rand((nsteps + 1:100)))
end

@testset "Sequence" begin
schedules = (log, sqrt)
step_sizes = (rand(1:10), rand(1:10))
Expand Down Expand Up @@ -83,3 +101,31 @@ end
@test log(2) == next!(stateful_s)
@test log(2) == next!(stateful_s)
end

@testset "OneCycle" begin
function onecycle(t, nsteps, startval, maxval, endval, pct)
warmup = ceil(Int, pct * nsteps)
warmdown = nsteps - warmup

if t > nsteps
return endval
elseif t <= warmup
return _cycle(startval, maxval, _cos(t + warmup, warmup))
else
return _cycle(maxval, endval, _cos(t - warmup, warmdown))
end
end

nsteps = 50
maxval = 1f-1
s = OneCycle(nsteps, maxval)
@test all(s(t) == onecycle(t, nsteps, maxval / 25, maxval, maxval / 1f5, 0.25)
for t in 1:nsteps)
@test_throws BoundsError s(nsteps + 1)
startval = 1f-4
endval = 1f-2
pct = 0.3
s = OneCycle(nsteps, maxval; startval, endval, percent_start = pct)
@test all(s(t) == onecycle(t, nsteps, startval, maxval, endval, pct)
for t in 1:nsteps)
end
8 changes: 1 addition & 7 deletions test/cyclic.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
_cycle(l0, l1, x) = abs(l0 - l1) * x + min(l0, l1)
_tri(t, period) = (2 / π) * abs(asin(sin* (t - 1) / period)))
_sin(t, period) = abs(sin* (t - 1) / period))
_cos(t, period) = (1 + cos* (t - 1) / period)) / 2
_cosrestart(t, period) = (1 + cos* mod(t - 1, period) / period)) / 2

@testset "Triangle" begin
l0 = 0.5 * rand()
l1 = 0.5 * rand() + 1
Expand Down Expand Up @@ -103,4 +97,4 @@ end
@test Base.IteratorSize(typeof(s)) == Base.IsInfinite()
@test axes(s) == (OneToInf(),)
end
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ using Test

using InfiniteArrays: OneToInf

_cycle(λ0, λ1, x) = abs(λ0 - λ1) * x + min(λ0, λ1)
_tri(t, period) = (2 / π) * abs(asin(sin* (t - 1) / period)))
_sin(t, period) = abs(sin* (t - 1) / period))
_cos(t, period) = (1 + cos* (t - 1) / period)) / 2
_cosrestart(t, period) = (1 + cos* mod(t - 1, period) / period)) / 2

@testset "Decay" begin
include("decay.jl")
end
Expand Down
17 changes: 0 additions & 17 deletions toc.md

This file was deleted.

0 comments on commit 463aab2

Please sign in to comment.