-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework regularized layers #73
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
RegularizedGeneric{M,RF,RG} | ||
FrankWolfeOptimizer{M,RF,RG,FWK} | ||
|
||
Differentiable regularized prediction function `ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)}`. | ||
|
||
|
@@ -9,7 +9,7 @@ Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytop | |
Since this is a conditional dependency, you need to run `import DifferentiableFrankWolfe` before using `RegularizedGeneric`. | ||
|
||
# Fields | ||
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` | ||
- `linear_maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` | ||
- `Ω::RF`: regularization function `Ω(y)` | ||
- `Ω_grad::RG`: gradient of the regularization function `∇Ω(y)` | ||
- `frank_wolfe_kwargs::FWK`: keyword arguments passed to the Frank-Wolfe algorithm | ||
|
@@ -32,30 +32,14 @@ Some values you can tune: | |
|
||
See the documentation of FrankWolfe.jl for details. | ||
""" | ||
struct RegularizedGeneric{M,RF,RG,FWK} | ||
maximizer::M | ||
struct FrankWolfeOptimizer{M,RF,RG,FWK} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather call this |
||
linear_maximizer::M | ||
Ω::RF | ||
Ω_grad::RG | ||
frank_wolfe_kwargs::FWK | ||
end | ||
|
||
function Base.show(io::IO, regularized::RegularizedGeneric) | ||
(; maximizer, Ω, Ω_grad) = regularized | ||
return print(io, "RegularizedGeneric($maximizer, $Ω, $Ω_grad)") | ||
end | ||
|
||
@traitimpl IsRegularized{RegularizedGeneric} | ||
|
||
function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray) | ||
return regularized.Ω(y) | ||
end | ||
|
||
""" | ||
(regularized::RegularizedGeneric)(θ; kwargs...) | ||
|
||
Apply `compute_probability_distribution(regularized, θ, kwargs...)` and return the expectation. | ||
""" | ||
function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...) | ||
probadist = compute_probability_distribution(regularized, θ; kwargs...) | ||
return compute_expectation(probadist) | ||
function Base.show(io::IO, optimizer::FrankWolfeOptimizer) | ||
(; linear_maximizer, Ω, Ω_grad) = optimizer | ||
return print(io, "RegularizedGeneric($linear_maximizer, $Ω, $Ω_grad)") | ||
end |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
optimizer: θ ⟼ argmax θᵀy - Ω(y) | ||
""" | ||
struct Regularized{O,R} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need the linear maximizer as a field for when the layer is called outside of training? |
||
Ω::R | ||
optimizer::O | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather call this |
||
end | ||
|
||
function Base.show(io::IO, regularized::Regularized) | ||
(; optimizer, Ω) = regularized | ||
return print(io, "Regularized($optimizer, $Ω)") | ||
end | ||
|
||
function (regularized::Regularized)(θ::AbstractArray; kwargs...) | ||
return regularized.optimizer(θ; kwargs...) | ||
end | ||
|
||
function compute_regularization(regularized::Regularized, y::AbstractArray) | ||
return regularized.Ω(y) | ||
end | ||
|
||
# Specific constructors | ||
|
||
""" | ||
TODO | ||
""" | ||
function SparseArgmax() | ||
return Regularized(sparse_argmax_regularization, sparse_argmax) | ||
end | ||
|
||
""" | ||
TODO | ||
""" | ||
function SoftArgmax() | ||
return Regularized(soft_argmax_regularization, soft_argmax) | ||
end | ||
|
||
""" | ||
TODO | ||
""" | ||
function RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple()) | ||
# TODO : add a warning if DifferentiableFrankWolfe is not imported ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea |
||
return Regularized( | ||
Ω, FrankWolfeOptimizer(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) | ||
) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,12 @@ function soft_argmax(z::AbstractVector; kwargs...) | |
return s | ||
end | ||
|
||
@traitimpl IsRegularized{typeof(soft_argmax)} | ||
# @traitimpl IsRegularized{typeof(soft_argmax)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the trash |
||
|
||
function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real} | ||
# function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real} | ||
# return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R) | ||
# end | ||
|
||
function soft_argmax_regularization(y::AbstractVector) | ||
return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,15 @@ function sparse_argmax(z::AbstractVector; kwargs...) | |
return p | ||
end | ||
|
||
@traitimpl IsRegularized{typeof(sparse_argmax)} | ||
# @traitimpl IsRegularized{typeof(sparse_argmax)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the trash |
||
|
||
function compute_regularization( | ||
::typeof(sparse_argmax), y::AbstractVector{R} | ||
) where {R<:Real} | ||
# function compute_regularization( | ||
# ::typeof(sparse_argmax), y::AbstractVector{R} | ||
# ) where {R<:Real} | ||
# return isprobadist(y) ? half_square_norm(y) : typemax(R) | ||
# end | ||
|
||
function sparse_argmax_regularization(y::AbstractVector) | ||
return isprobadist(y) ? half_square_norm(y) : typemax(R) | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use
linear_maximizer
throughout InferOpt?