From 155e66c45d03697ecf20d67d4666855ebc8c29a9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 2 Feb 2024 20:08:04 -0500 Subject: [PATCH 1/7] Add OneCycle constructor --- docs/src/tutorials/warmup-schedules.md | 18 +++++++++++++++++- src/ParameterSchedulers.jl | 2 +- src/complex.jl | 25 +++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/docs/src/tutorials/warmup-schedules.md b/docs/src/tutorials/warmup-schedules.md index 89de866..df7c163 100644 --- a/docs/src/tutorials/warmup-schedules.md +++ b/docs/src/tutorials/warmup-schedules.md @@ -1,6 +1,22 @@ # Warm-up Schedules -A popular technique for scheduling learning rates is "warming-up" the optimizer by ramping the learning rate up from zero to the "true" initial learning rate, then starting the "real" schedule. This is easily implementable with ParameterSchedulers.jl using [`Sequence`](@ref). +A popular technique for scheduling learning rates is "warming-up" the optimizer by ramping the learning rate up from near zero to the "true" initial learning rate, then starting the "real" schedule. This is easily implementable with ParameterSchedulers.jl using [`Sequence`](@ref). + +## One cycle cosine schedule + +A one-cycle cosine schedule is the most popular warm-up schedule. It ramps up once and ramps down once using a cosine waveform. Since this schedule is so common, we provide a convenience constructor in ParameterSchedulers.jl, [`OneCycle`](@ref). + +```@example onecycle +using ParameterSchedulers +using UnicodePlots + +nsteps = 10 +maxval = 1f-1 +onecycle = OneCycle(10, 1f-1; percent_start = 0.4) + +t = 1:nsteps |> collect +lineplot(t, onecycle.(t); border = :none) +``` ## Linear ramp diff --git a/src/ParameterSchedulers.jl b/src/ParameterSchedulers.jl index 9f2094d..ecc81cf 100644 --- a/src/ParameterSchedulers.jl +++ b/src/ParameterSchedulers.jl @@ -16,7 +16,7 @@ export Triangle, TriangleDecay2, TriangleExp, CosAnneal include("complex.jl") -export Sequence, Loop, Interpolator, Shifted, ComposedSchedule +export Sequence, Loop, Interpolator, Shifted, ComposedSchedule, OneCycle include("utils.jl") diff --git a/src/complex.jl b/src/complex.jl index 4c04779..2a8b728 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -228,3 +228,28 @@ function (composition::ComposedSchedule)(t) return s(t) end + +""" + OneCycle(nsteps, maxval; + startval = maxval / 25, + endval = maxval / 1f5, + percent_start = 0.25) + +Creates a one-cycle cosine schedule over `nsteps` steps warming up from `startval` +up to `maxval` for `ceil(percent_start * nsteps)`, then back to `endval` +(see [Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120)). +""" +function OneCycle(nsteps, maxval; + startval = max / 25, + endval = max / 1f5, + percent_start = 0.25) + @assert 0 < percent_start < 1 + + warmup = ceil(Int, nsteps * percent_start) + warmdown = nsteps - warmup + + return Sequence( + Sin(λ0=maxval, λ1=startval, period=2*warmup) => warmup, + Shifted(Sin(λ0=maxval, λ1=endval, period=2*warmdown), warmdown + 1) => warmdown + ) +end From 9db2d657c2e12e643028f8a15f8433c9add567be Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 2 Feb 2024 20:21:01 -0500 Subject: [PATCH 2/7] Add OneCycle tests --- test/complex.jl | 27 ++++++++++++++ test/cyclic.jl | 96 +++++++++++++++++++++++------------------------- test/runtests.jl | 6 +++ 3 files changed, 78 insertions(+), 51 deletions(-) diff --git a/test/complex.jl b/test/complex.jl index 55d5b10..3714799 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -83,3 +83,30 @@ end @test log(2) == next!(stateful_s) @test log(2) == next!(stateful_s) end + +@testset "OneCycle" begin + function onecycle(t, nsteps, startval, maxval, endval, pct) + warmup = ceil(Int, pct * nsteps) + warmdown = nsteps - warmup + + if t > nsteps + return endval + elseif t <= warmup + return _cycle(startval, maxval, _sin(t, 2 * warmup)) + else + return _cycle(maxval, endval, _cos(t, 2 * warmdown)) + end + end + + nsteps = 50 + maxval = 1f-1 + s = OneCycle(nsteps, maxval) + @test all(s(t) == onecycle(t, nsteps, maxval / 25, maxval, maxval / 1f5, 0.25) + for t in 1:(2 * nsteps)) + startval = 1f-4 + endval = 1f-2 + pct = 0.3 + s = OneCycle(nsteps, maxval; startval, endval, percent_start = pct) + @test all(s(t) == onecycle(t, nsteps, startval, maxval, endval, pct) + for t in 1:(2 * nsteps)) +end diff --git a/test/cyclic.jl b/test/cyclic.jl index 7e57563..1c357b8 100644 --- a/test/cyclic.jl +++ b/test/cyclic.jl @@ -1,106 +1,100 @@ -_cycle(l0, l1, x) = abs(l0 - l1) * x + min(l0, l1) -_tri(t, period) = (2 / π) * abs(asin(sin(π * (t - 1) / period))) -_sin(t, period) = abs(sin(π * (t - 1) / period)) -_cos(t, period) = (1 + cos(π * (t - 1) / period)) / 2 -_cosrestart(t, period) = (1 + cos(π * mod(t - 1, period) / period)) / 2 - @testset "Triangle" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 period = rand(1:10) - s = Triangle(l0 = l0, l1 = l1, period = period) - @test s == Triangle(abs(l0 - l1), min(l0, l1), period) - @test [_cycle(l0, l1, _tri(t, period)) for t in 1:100] ≈ s.(1:100) + s = Triangle(λ0 = λ0, λ1 = λ1, period = period) + @test s == Triangle(abs(λ0 - λ1), min(λ0, λ1), period) + @test [_cycle(λ0, λ1, _tri(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "TriangleDecay2" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 period = rand(1:10) - s = TriangleDecay2(l0 = l0, l1 = l1, period = period) - @test s == TriangleDecay2(abs(l0 - l1), min(l0, l1), period) - @test [_cycle(l0, l1, _tri(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) + s = TriangleDecay2(λ0 = λ0, λ1 = λ1, period = period) + @test s == TriangleDecay2(abs(λ0 - λ1), min(λ0, λ1), period) + @test [_cycle(λ0, λ1, _tri(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "TriangleExp" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 - decay = rand() + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 + γ = rand() period = rand(1:10) - s = TriangleExp(l0 = l0, l1 = l1, period = period, decay = decay) - @test s == TriangleExp(abs(l0 - l1), min(l0, l1), period, decay) - @test [_cycle(l0, l1, _tri(t, period) * decay^(t - 1)) for t in 1:100] ≈ s.(1:100) + s = TriangleExp(λ0 = λ0, λ1 = λ1, period = period, γ = γ) + @test s == TriangleExp(abs(λ0 - λ1), min(λ0, λ1), period, γ) + @test [_cycle(λ0, λ1, _tri(t, period) * γ^(t - 1)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() end @testset "Sin" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 period = rand(1:10) - s = Sin(l0 = l0, l1 = l1, period = period) - @test s == Sin(abs(l0 - l1), min(l0, l1), period) - @test [_cycle(l0, l1, _sin(t, period)) for t in 1:100] ≈ s.(1:100) + s = Sin(λ0 = λ0, λ1 = λ1, period = period) + @test s == Sin(abs(λ0 - λ1), min(λ0, λ1), period) + @test [_cycle(λ0, λ1, _sin(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "SinDecay2" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 period = rand(1:10) - s = SinDecay2(l0 = l0, l1 = l1, period = period) - @test s == SinDecay2(abs(l0 - l1), min(l0, l1), period) - @test [_cycle(l0, l1, _sin(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) + s = SinDecay2(λ0 = λ0, λ1 = λ1, period = period) + @test s == SinDecay2(abs(λ0 - λ1), min(λ0, λ1), period) + @test [_cycle(λ0, λ1, _sin(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "SinExp" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 - decay = rand() + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 + γ = rand() period = rand(1:10) - s = SinExp(l0 = l0, l1 = l1, period = period, decay = decay) - @test s == SinExp(abs(l0 - l1), min(l0, l1), period, decay) - @test [_cycle(l0, l1, _sin(t, period) * decay^(t - 1)) for t in 1:100] ≈ s.(1:100) + s = SinExp(λ0 = λ0, λ1 = λ1, period = period, γ = γ) + @test s == SinExp(abs(λ0 - λ1), min(λ0, λ1), period, γ) + @test [_cycle(λ0, λ1, _sin(t, period) * γ^(t - 1)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "CosAnneal" begin - l0 = 0.5 * rand() - l1 = 0.5 * rand() + 1 + λ0 = 0.5 * rand() + λ1 = 0.5 * rand() + 1 period = rand(1:10) @testset for (restart, f) in ((true, _cosrestart), (false, _cos)) - s = CosAnneal(l0 = l0, l1 = l1, period = period, restart = restart) - @test s == CosAnneal(abs(l0 - l1), min(l0, l1), period, restart) - @test [_cycle(l0, l1, f(t, period)) for t in 1:100] ≈ s.(1:100) + s = CosAnneal(λ0 = λ0, λ1 = λ1, period = period, restart = restart) + @test s == CosAnneal(abs(λ0 - λ1), min(λ0, λ1), period, restart) + @test [_cycle(λ0, λ1, f(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(l0) + @test eltype(s) == eltype(λ0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end -end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e62bcea..4e4aac2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,12 @@ using Test using InfiniteArrays: OneToInf +_cycle(λ0, λ1, x) = abs(λ0 - λ1) * x + min(λ0, λ1) +_tri(t, period) = (2 / π) * abs(asin(sin(π * (t - 1) / period))) +_sin(t, period) = abs(sin(π * (t - 1) / period)) +_cos(t, period) = (1 + cos(π * (t - 1) / period)) / 2 +_cosrestart(t, period) = (1 + cos(π * mod(t - 1, period) / period)) / 2 + @testset "Decay" begin include("decay.jl") end From 778d6c5c49dc1b78c08c8ebb263551392593cfb4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 2 Feb 2024 20:25:05 -0500 Subject: [PATCH 3/7] Fix typo in OneCycle constructor --- src/complex.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index 2a8b728..cd6d20f 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -240,8 +240,8 @@ up to `maxval` for `ceil(percent_start * nsteps)`, then back to `endval` (see [Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120)). """ function OneCycle(nsteps, maxval; - startval = max / 25, - endval = max / 1f5, + startval = maxval / 25, + endval = maxval / 1f5, percent_start = 0.25) @assert 0 < percent_start < 1 From 8ae49b289b35a47d1d0c6cf35f3ca1c41892b304 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 08:48:08 -0500 Subject: [PATCH 4/7] Make OneCycle throw BoundsError for invalid access. - Add new Shortened schedule to cap the length of an arbitrary schedule. - Fix BoundsError for Poly --- src/complex.jl | 25 ++++++++++++++++++++++++- src/decay.jl | 2 +- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index cd6d20f..ba25a22 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -55,6 +55,27 @@ Base.eltype(::Type{<:Constant{T}}) where T = T (schedule::Constant)(t) = schedule.value +""" + Shortened{T} + Shortened(schedule, nsteps) + +A schedule that mimics `schedule` but throws a `BoundsError` if accessed +beyond `nsteps`. +""" +struct Shortened{T} <: AbstractSchedule{true} + schedule::T + nsteps::Int +end + +Base.eltype(::Type{Shortened{T}}) where T = eltype(T) +Base.length(schedule::Shortened) = schedule.nsteps + +function (schedule::Shortened)(t) + (t <= length(schedule)) || throw(BoundsError(schedule, t)) + return schedule.schedule(t) +end +Base.iterate(schedule::Shortened, state...) = iterate(schedule.schedule, state...) + """ Sequence{T, S} Sequence(schedules, step_sizes) @@ -250,6 +271,8 @@ function OneCycle(nsteps, maxval; return Sequence( Sin(λ0=maxval, λ1=startval, period=2*warmup) => warmup, - Shifted(Sin(λ0=maxval, λ1=endval, period=2*warmdown), warmdown + 1) => warmdown + Shortened(Shifted(Sin(λ0=maxval, λ1=endval, period=2*warmdown), + warmdown + 1), + warmdown) => warmdown ) end diff --git a/src/decay.jl b/src/decay.jl index 07d7abe..6109985 100644 --- a/src/decay.jl +++ b/src/decay.jl @@ -111,7 +111,7 @@ Base.eltype(::Type{<:Poly{T}}) where T = T Base.length(schedule::Poly) = schedule.max_iter function (schedule::Poly)(t) - (t <= length(schedule)) || throw(BoundsError("Cannot index Poly for t > max_iter")) + (t <= length(schedule)) || throw(BoundsError(schedule, t)) return schedule.start * (1 - (t - 1) / schedule.max_iter)^schedule.degree end From c383e0b1095960541423c2ddaa34f007824eab65 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 09:31:23 -0500 Subject: [PATCH 5/7] Working implementation based on CosAnneal --- src/complex.jl | 12 ++++++++---- test/complex.jl | 27 +++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index ba25a22..8cf3625 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -270,9 +270,13 @@ function OneCycle(nsteps, maxval; warmdown = nsteps - warmup return Sequence( - Sin(λ0=maxval, λ1=startval, period=2*warmup) => warmup, - Shortened(Shifted(Sin(λ0=maxval, λ1=endval, period=2*warmdown), - warmdown + 1), - warmdown) => warmdown + Shifted(CosAnneal(λ0 = maxval, + λ1 = startval, + period = warmup, + restart = false), warmup + 1) => warmup, + Shortened(CosAnneal(λ0 = maxval, + λ1 = endval, + period = warmdown, + restart = false), warmdown) => warmdown ) end diff --git a/test/complex.jl b/test/complex.jl index 3714799..b47165e 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,3 +1,21 @@ +@testset "Constant" begin + value = rand() + schedule = ParameterSchedulers.Constant(value) + @test all(value == schedule(t) for t in 1:1000) +end + +@testset "Shortened" begin + base_schedule = Exp(0.9, 10.0) + nsteps = rand(1:100) + schedule = ParameterSchedulers.Shortened(base_schedule, nsteps) + + @test Base.IteratorEltype(typeof(schedule)) == Base.IteratorEltype(typeof(base_schedule)) + @test Base.IteratorSize(typeof(schedule)) == Base.HasLength() + @test length(schedule) == nsteps + @test all(schedule(t) == base_schedule(t) for t in 1:nsteps) + @test_throws BoundsError schedule(rand((nsteps + 1:100))) +end + @testset "Sequence" begin schedules = (log, sqrt) step_sizes = (rand(1:10), rand(1:10)) @@ -92,9 +110,9 @@ end if t > nsteps return endval elseif t <= warmup - return _cycle(startval, maxval, _sin(t, 2 * warmup)) + return _cycle(startval, maxval, _cos(t + warmup, warmup)) else - return _cycle(maxval, endval, _cos(t, 2 * warmdown)) + return _cycle(maxval, endval, _cos(t - warmup, warmdown)) end end @@ -102,11 +120,12 @@ end maxval = 1f-1 s = OneCycle(nsteps, maxval) @test all(s(t) == onecycle(t, nsteps, maxval / 25, maxval, maxval / 1f5, 0.25) - for t in 1:(2 * nsteps)) + for t in 1:nsteps) + @test_throws BoundsError s(nsteps + 1) startval = 1f-4 endval = 1f-2 pct = 0.3 s = OneCycle(nsteps, maxval; startval, endval, percent_start = pct) @test all(s(t) == onecycle(t, nsteps, startval, maxval, endval, pct) - for t in 1:(2 * nsteps)) + for t in 1:nsteps) end From bd4f4b7e74a079a6dcc7ed038f6ef697f5d6d2d7 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 09:35:56 -0500 Subject: [PATCH 6/7] Rebase for kwarg name changes --- src/complex.jl | 8 ++--- src/cyclic.jl | 9 ++++++ test/cyclic.jl | 88 +++++++++++++++++++++++++------------------------- 3 files changed, 57 insertions(+), 48 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index 8cf3625..356d93b 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -270,12 +270,12 @@ function OneCycle(nsteps, maxval; warmdown = nsteps - warmup return Sequence( - Shifted(CosAnneal(λ0 = maxval, - λ1 = startval, + Shifted(CosAnneal(l0 = maxval, + l1 = startval, period = warmup, restart = false), warmup + 1) => warmup, - Shortened(CosAnneal(λ0 = maxval, - λ1 = endval, + Shortened(CosAnneal(l0 = maxval, + l1 = endval, period = warmdown, restart = false), warmdown) => warmdown ) diff --git a/src/cyclic.jl b/src/cyclic.jl index df73021..a82688f 100644 --- a/src/cyclic.jl +++ b/src/cyclic.jl @@ -26,6 +26,7 @@ Triangle(range::T, offset::T, period::S) where {T, S} = Triangle{T, S}(range, of function Triangle(; kwargs...) kwargs = depkwargs(:Triangle, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return Triangle(abs(l0 - l1), min(l0, l1), kwargs.period) end @@ -53,11 +54,13 @@ where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period)) """ function TriangleDecay2(range::T, offset, period) where T parameters = (Interpolator(Exp(range, T(1/2)), period), offset, period) + return ComposedSchedule(Triangle(range, offset, period), parameters) end function TriangleDecay2(; kwargs...) kwargs = depkwargs(:TriangleDecay2, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return TriangleDecay2(abs(l0 - l1), min(l0, l1), kwargs.period) end @@ -85,6 +88,7 @@ TriangleExp(range, offset, period, decay) = function TriangleExp(; kwargs...) kwargs = depkwargs(:TriangleExp, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return TriangleExp(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.decay) end @@ -112,6 +116,7 @@ Sin(range::T, offset::T, period::S) where {T, S} = Sin{T, S}(range, offset, peri function Sin(; kwargs...) kwargs = depkwargs(:Sin, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return Sin(abs(l0 - l1), min(l0, l1), kwargs.period) end @@ -137,11 +142,13 @@ where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)). """ function SinDecay2(range::T, offset, period) where T parameters = (Interpolator(Exp(range, T(1/2)), period), offset, period) + return ComposedSchedule(Sin(range, offset, period), parameters) end function SinDecay2(; kwargs...) kwargs = depkwargs(:SinDecay2, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return SinDecay2(abs(l0 - l1), min(l0, l1), kwargs.period) end @@ -167,6 +174,7 @@ SinExp(range, offset, period, decay) = function SinExp(; kwargs...) kwargs = depkwargs(:SinExp, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return SinExp(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.decay) end @@ -200,6 +208,7 @@ CosAnneal(range, offset, period) = CosAnneal(range, offset, period, true) function CosAnneal(; kwargs...) kwargs = depkwargs(:CosAnneal, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + return CosAnneal(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.restart) end diff --git a/test/cyclic.jl b/test/cyclic.jl index 1c357b8..e727bb7 100644 --- a/test/cyclic.jl +++ b/test/cyclic.jl @@ -1,99 +1,99 @@ @testset "Triangle" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = Triangle(λ0 = λ0, λ1 = λ1, period = period) - @test s == Triangle(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _tri(t, period)) for t in 1:100] ≈ s.(1:100) + s = Triangle(l0 = l0, l1 = l1, period = period) + @test s == Triangle(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _tri(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "TriangleDecay2" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = TriangleDecay2(λ0 = λ0, λ1 = λ1, period = period) - @test s == TriangleDecay2(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _tri(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) + s = TriangleDecay2(l0 = l0, l1 = l1, period = period) + @test s == TriangleDecay2(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _tri(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "TriangleExp" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 - γ = rand() + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 + decay = rand() period = rand(1:10) - s = TriangleExp(λ0 = λ0, λ1 = λ1, period = period, γ = γ) - @test s == TriangleExp(abs(λ0 - λ1), min(λ0, λ1), period, γ) - @test [_cycle(λ0, λ1, _tri(t, period) * γ^(t - 1)) for t in 1:100] ≈ s.(1:100) + s = TriangleExp(l0 = l0, l1 = l1, period = period, decay = decay) + @test s == TriangleExp(abs(l0 - l1), min(l0, l1), period, decay) + @test [_cycle(l0, l1, _tri(t, period) * decay^(t - 1)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() end @testset "Sin" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = Sin(λ0 = λ0, λ1 = λ1, period = period) - @test s == Sin(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _sin(t, period)) for t in 1:100] ≈ s.(1:100) + s = Sin(l0 = l0, l1 = l1, period = period) + @test s == Sin(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _sin(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "SinDecay2" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = SinDecay2(λ0 = λ0, λ1 = λ1, period = period) - @test s == SinDecay2(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _sin(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) + s = SinDecay2(l0 = l0, l1 = l1, period = period) + @test s == SinDecay2(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _sin(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "SinExp" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 - γ = rand() + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 + decay = rand() period = rand(1:10) - s = SinExp(λ0 = λ0, λ1 = λ1, period = period, γ = γ) - @test s == SinExp(abs(λ0 - λ1), min(λ0, λ1), period, γ) - @test [_cycle(λ0, λ1, _sin(t, period) * γ^(t - 1)) for t in 1:100] ≈ s.(1:100) + s = SinExp(l0 = l0, l1 = l1, period = period, decay = decay) + @test s == SinExp(abs(l0 - l1), min(l0, l1), period, decay) + @test [_cycle(l0, l1, _sin(t, period) * decay^(t - 1)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "CosAnneal" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) @testset for (restart, f) in ((true, _cosrestart), (false, _cos)) - s = CosAnneal(λ0 = λ0, λ1 = λ1, period = period, restart = restart) - @test s == CosAnneal(abs(λ0 - λ1), min(λ0, λ1), period, restart) - @test [_cycle(λ0, λ1, f(t, period)) for t in 1:100] ≈ s.(1:100) + s = CosAnneal(l0 = l0, l1 = l1, period = period, restart = restart) + @test s == CosAnneal(abs(l0 - l1), min(l0, l1), period, restart) + @test [_cycle(l0, l1, f(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end From 20495b49bd181285e94f79591817aa38c4e71b90 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 09:56:33 -0500 Subject: [PATCH 7/7] Update docs for OneCycle on main page --- README.md | 21 ++++++++++++++++++++ docs/src/cheatsheet.md | 44 +++++++++++++++++++++--------------------- src/cyclic.jl | 3 ++- toc.md | 17 ---------------- 4 files changed, 45 insertions(+), 40 deletions(-) delete mode 100644 toc.md diff --git a/README.md b/README.md index 5af53bf..866857b 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,27 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi +[`OneCycle(nsteps, maxval)`](https://fluxml.ai/ParameterSchedulers.jl/api/complex.html#ParameterSchedulers.OneCycle) + + + + +[One cycle cosine](https://arxiv.org/abs/1708.07120) + + + Complex + + +```@example +using UnicodePlots, ParameterSchedulers # hide +t = 1:10 |> collect # hide +s = OneCycle(10, 1.0) # hide +lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide +``` + + + + [`Triangle(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle) diff --git a/docs/src/cheatsheet.md b/docs/src/cheatsheet.md index c6eb5b4..7aa69f2 100644 --- a/docs/src/cheatsheet.md +++ b/docs/src/cheatsheet.md @@ -6,29 +6,29 @@ If you are coming from PyTorch or Tensorflow, the following table should help yo PyTorch typically wraps an optimizer as the first argument, but we ignore that functionality in the table. To wrap a Flux.jl optimizer with a schedule from the rightmost column, use [`ParameterSchedulers.Scheduler`](@ref). The variable `lr` in the middle/rightmost column refers to the initial learning rate of the optimizer. -| PyTorch | Tensorflow | ParameterSchedulers.jl | -|:-------------------------------------------------------------------------------|:------------------------------------------------------|:------------------------------------------------------| -| `LambdaLR(_, lr_lambda)` | N/A | `lr_lambda` | -| `MultiplicativeLR(_, lr_lambda)` | N/A | N/A | -| `StepLR(_, step_size, gamma)` | `ExponentialDecay(lr, step_size, gamma, True)` | `Step(lr, gamma, step_size)` | -| `MultiStepLR(_, milestones, gamma)` | N/A | `Step(lr, gamma, milestones)` | -| `ConstantLR(_, factor, total_iters)` | N/A | `Sequence(lr * factor => total_iters, lr => nepochs)` | +| PyTorch | Tensorflow | ParameterSchedulers.jl | +|:-------------------------------------------------------------------------------|:------------------------------------------------------|:--------------------------------------------------------| +| `LambdaLR(_, lr_lambda)` | N/A | `lr_lambda` | +| `MultiplicativeLR(_, lr_lambda)` | N/A | N/A | +| `StepLR(_, step_size, gamma)` | `ExponentialDecay(lr, step_size, gamma, True)` | `Step(lr, gamma, step_size)` | +| `MultiStepLR(_, milestones, gamma)` | N/A | `Step(lr, gamma, milestones)` | +| `ConstantLR(_, factor, total_iters)` | N/A | `Sequence(lr * factor => total_iters, lr => nepochs)` | | `LinearLR(_, start_factor, end_factor, total_iters)` | N/A | `Sequence(Triangle(lr * start_factor, lr * end_factor, 2 * total_iters) => total_iters, lr => nepochs)` | -| `ExponentialLR(_, gamma)` | `ExponentialDecay(lr, 1, gamma, False)` | `Exp(lr, gamma)` | -| N/A | `ExponentialDecay(lr, steps, gamma, False)` | `Interpolator(Exp(lr, gamma), steps)` | -| `CosineAnnealingLR(_, T_max, eta_min)` | `CosineDecay(lr, T_max, eta_min)` | `CosAnneal(lr, eta_min, T_0, false)` | -| `CosineAnnealingRestarts(_, T_0, 1, eta_min)` | `CosineDecayRestarts(lr, T_0, 1, 1, eta_min)` | `CosAnneal(lr, eta_min, T_0)` | -| `CosineAnnealingRestarts(_, T_0, T_mult, eta_min)` | `CosineDecayRestarts(lr, T_0, T_mult, 1, alpha)` | See [below](@ref "Cosine annealing variants") | -| N/A | `CosineDecayRestarts(lr, T_0, T_mult, m_mul, alpha)` | See [below](@ref "Cosine annealing variants") | -| `SequentialLR(_, schedulers, milestones)` | N/A | `Sequence(schedulers, milestones)` | -| `ReduceLROnPlateau(_, mode, factor, patience, threshold, 'abs', 0)` | N/A | See [below](@ref "`ReduceLROnPlateau` style schedules") | -| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular', _, None)` | N/A | `Triangle(base_lr, max_lr, step_size)` | -| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular2', _, None)` | N/A | `TriangleDecay2(base_lr, max_lr, step_size)` | -| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'exp_range', gamma, None)` | N/A | `TriangleExp(base_lr, max_lr, step_size, gamma)` | -| `CyclicLR(_, base_lr, max_lr, step_size, step_size, _, _, scale_fn)` | N/A | See [Arbitrary looping schedules](@ref) | -| N/A | `InverseTimeDecay(lr, 1, decay_rate, False)` | `Inv(lr, decay_rate, 1)` | -| N/A | `InverseTimeDecay(lr, decay_step, decay_rate, False)` | `Interpolator(Inv(lr, decay_rate, 1), decay_step)` | -| N/A | `PolynomialDecay(lr, decay_steps, 0, power, False)` | `Poly(lr, power, decay_steps)` | +| `ExponentialLR(_, gamma)` | `ExponentialDecay(lr, 1, gamma, False)` | `Exp(lr, gamma)` | +| N/A | `ExponentialDecay(lr, steps, gamma, False)` | `Interpolator(Exp(lr, gamma), steps)` | +| `CosineAnnealingLR(_, T_max, eta_min)` | `CosineDecay(lr, T_max, eta_min)` | `CosAnneal(lr, eta_min, T_0, false)` | +| `CosineAnnealingRestarts(_, T_0, 1, eta_min)` | `CosineDecayRestarts(lr, T_0, 1, 1, eta_min)` | `CosAnneal(lr, eta_min, T_0)` | +| `CosineAnnealingRestarts(_, T_0, T_mult, eta_min)` | `CosineDecayRestarts(lr, T_0, T_mult, 1, alpha)` | See [below](@ref "Cosine annealing variants") | +| N/A | `CosineDecayRestarts(lr, T_0, T_mult, m_mul, alpha)` | See [below](@ref "Cosine annealing variants") | +| `SequentialLR(_, schedulers, milestones)` | N/A | `Sequence(schedulers, milestones)` | +| `ReduceLROnPlateau(_, mode, factor, patience, threshold, 'abs', 0)` | N/A | See [below](@ref "`ReduceLROnPlateau` style schedules") | +| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular', _, None)` | N/A | `Triangle(base_lr, max_lr, step_size)` | +| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'triangular2', _, None)` | N/A | `TriangleDecay2(base_lr, max_lr, step_size)` | +| `CyclicLR(_, base_lr, max_lr, step_size, step_size, 'exp_range', gamma, None)` | N/A | `TriangleExp(base_lr, max_lr, step_size, gamma)` | +| `CyclicLR(_, base_lr, max_lr, step_size, step_size, _, _, scale_fn)` | N/A | See [Arbitrary looping schedules](@ref) | +| N/A | `InverseTimeDecay(lr, 1, decay_rate, False)` | `Inv(lr, decay_rate, 1)` | +| N/A | `InverseTimeDecay(lr, decay_step, decay_rate, False)` | `Interpolator(Inv(lr, decay_rate, 1), decay_step)` | +| N/A | `PolynomialDecay(lr, decay_steps, 0, power, False)` | `Poly(lr, power, decay_steps)` | ## Cosine annealing variants diff --git a/src/cyclic.jl b/src/cyclic.jl index a82688f..abce164 100644 --- a/src/cyclic.jl +++ b/src/cyclic.jl @@ -208,8 +208,9 @@ CosAnneal(range, offset, period) = CosAnneal(range, offset, period, true) function CosAnneal(; kwargs...) kwargs = depkwargs(:CosAnneal, kwargs, :λ0 => :l0, :λ1 => :l1) l0, l1 = kwargs.l0, kwargs.l1 + restart = get(kwargs, :restart, true) - return CosAnneal(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.restart) + return CosAnneal(abs(l0 - l1), min(l0, l1), kwargs.period, restart) end Base.eltype(::Type{<:CosAnneal{T}}) where T = T diff --git a/toc.md b/toc.md deleted file mode 100644 index c573714..0000000 --- a/toc.md +++ /dev/null @@ -1,17 +0,0 @@ -[Introduction](README.md) - -[Schedule cheatsheet](docs/cheatsheet.md) - -# Tutorials - -* [Getting started](docs/tutorials/getting-started.md) -* [Basic schedules](docs/tutorials/basic-schedules.md) -* [Optimizers](docs/tutorials/optimizers.md) -* [Complex schedules](docs/tutorials/complex-schedules.md) -* [Warmup schedules](docs/tutorials/warmup-schedules.md) - -[Interface](docs/interfaces/generic.md) - ----- - -[API Reference](docstrings.md)