Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework regularized layers #73

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions ext/InferOptFrankWolfeExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
module InferOptFrankWolfeExt

using DifferentiableFrankWolfe: DiffFW, LinearMaximizationOracleWithKwargs
using InferOpt: InferOpt, RegularizedGeneric, FixedAtomsProbabilityDistribution
using InferOpt:
InferOpt, Regularized, FixedAtomsProbabilityDistribution, FrankWolfeOptimizer
using InferOpt: compute_expectation, compute_probability_distribution
using LinearAlgebra: dot

## Forward pass

function InferOpt.compute_probability_distribution(
dfw::DiffFW, θ::AbstractArray; frank_wolfe_kwargs=NamedTuple()
)
Expand All @@ -23,26 +22,32 @@ Construct a `DifferentiableFrankWolfe.DiffFW` struct and call `compute_probabili
Keyword arguments are passed to the underlying linear maximizer.
"""
function InferOpt.compute_probability_distribution(
regularized::RegularizedGeneric, θ::AbstractArray; kwargs...
optimizer::FrankWolfeOptimizer, θ::AbstractArray; kwargs...
)
(; maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = optimizer
f(y, θ) = Ω(y) - dot(θ, y)
f_grad1(y, θ) = Ω_grad(y) - θ
lmo = LinearMaximizationOracleWithKwargs(maximizer, kwargs)
lmo = LinearMaximizationOracleWithKwargs(linear_maximizer, kwargs)
dfw = DiffFW(f, f_grad1, lmo)
probadist = compute_probability_distribution(dfw, θ; frank_wolfe_kwargs)
return probadist
end

function InferOpt.compute_probability_distribution(
regularized::Regularized{<:FrankWolfeOptimizer}, θ::AbstractArray; kwargs...
)
return compute_probability_distribution(regularized.optimizer, θ; kwargs...)
end

"""
(regularized::RegularizedGeneric)(θ; kwargs...)

Apply `compute_probability_distribution(regularized, θ)` and return the expectation.

Keyword arguments are passed to the underlying linear maximizer.
"""
function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...)
probadist = compute_probability_distribution(regularized, θ; kwargs...)
function (optimizer::FrankWolfeOptimizer)(θ::AbstractArray; kwargs...)
probadist = compute_probability_distribution(optimizer, θ; kwargs...)
return compute_expectation(probadist)
end

Expand Down
11 changes: 7 additions & 4 deletions src/InferOpt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ include("plus_identity/plus_identity.jl")

include("interpolation/interpolation.jl")

include("regularized/isregularized.jl")
include("regularized/regularized_utils.jl")
include("regularized/soft_argmax.jl")
include("regularized/sparse_argmax.jl")
include("regularized/regularized_generic.jl")
include("regularized/regularized.jl")
include("regularized/frank_wolfe_optimizer.jl")

include("perturbed/abstract_perturbed.jl")
include("perturbed/additive.jl")
Expand Down Expand Up @@ -54,9 +54,8 @@ export Interpolation
export half_square_norm
export shannon_entropy, negative_shannon_entropy
export one_hot_argmax, ranking
export IsRegularized
export soft_argmax, sparse_argmax
export RegularizedGeneric
export Regularized

export PerturbedAdditive
export PerturbedMultiplicative
Expand All @@ -71,4 +70,8 @@ export StructuredSVMLoss

export ImitationLoss, get_y_true

export SparseArgmax, SoftArgmax

export RegularizedFrankWolfe

end
4 changes: 2 additions & 2 deletions src/fenchel_young/fenchel_young.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwarg
return l
end

@traitfn function fenchel_young_loss_and_grad(
function fenchel_young_loss_and_grad(
fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs...
) where {P; IsRegularized{P}}
) where {P<:Regularized}
(; predictor) = fyl
ŷ = predictor(θ; kwargs...)
Ωy_true = compute_regularization(predictor, y_true)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
RegularizedGeneric{M,RF,RG}
FrankWolfeOptimizer{M,RF,RG,FWK}

Differentiable regularized prediction function `ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)}`.

Expand All @@ -9,7 +9,7 @@ Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytop
Since this is a conditional dependency, you need to run `import DifferentiableFrankWolfe` before using `RegularizedGeneric`.

# Fields
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
- `linear_maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should use linear_maximizer throughout InferOpt?

- `Ω::RF`: regularization function `Ω(y)`
- `Ω_grad::RG`: gradient of the regularization function `∇Ω(y)`
- `frank_wolfe_kwargs::FWK`: keyword arguments passed to the Frank-Wolfe algorithm
Expand All @@ -32,30 +32,14 @@ Some values you can tune:

See the documentation of FrankWolfe.jl for details.
"""
struct RegularizedGeneric{M,RF,RG,FWK}
maximizer::M
struct FrankWolfeOptimizer{M,RF,RG,FWK}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather call this FrankWolfeConcaveMaximizer

linear_maximizer::M
Ω::RF
Ω_grad::RG
frank_wolfe_kwargs::FWK
end

function Base.show(io::IO, regularized::RegularizedGeneric)
(; maximizer, Ω, Ω_grad) = regularized
return print(io, "RegularizedGeneric($maximizer, $Ω, $Ω_grad)")
end

@traitimpl IsRegularized{RegularizedGeneric}

function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray)
return regularized.Ω(y)
end

"""
(regularized::RegularizedGeneric)(θ; kwargs...)

Apply `compute_probability_distribution(regularized, θ, kwargs...)` and return the expectation.
"""
function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...)
probadist = compute_probability_distribution(regularized, θ; kwargs...)
return compute_expectation(probadist)
function Base.show(io::IO, optimizer::FrankWolfeOptimizer)
(; linear_maximizer, Ω, Ω_grad) = optimizer
return print(io, "RegularizedGeneric($linear_maximizer, $Ω, $Ω_grad)")
end
23 changes: 0 additions & 23 deletions src/regularized/isregularized.jl

This file was deleted.

46 changes: 46 additions & 0 deletions src/regularized/regularized.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
optimizer: θ ⟼ argmax θᵀy - Ω(y)
"""
struct Regularized{O,R}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need the linear maximizer as a field for when the layer is called outside of training?
It would make sense to me to modify the behavior of Perturbed as well so that the standard forward pass just calls the naked linear maximizer

Ω::R
optimizer::O
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather call this concave_maximizer to differentiate from (linear_)maximizer used elsewhere

end

function Base.show(io::IO, regularized::Regularized)
(; optimizer, Ω) = regularized
return print(io, "Regularized($optimizer, $Ω)")
end

function (regularized::Regularized)(θ::AbstractArray; kwargs...)
return regularized.optimizer(θ; kwargs...)
end

function compute_regularization(regularized::Regularized, y::AbstractArray)
return regularized.Ω(y)
end

# Specific constructors

"""
TODO
"""
function SparseArgmax()
return Regularized(sparse_argmax_regularization, sparse_argmax)
end

"""
TODO
"""
function SoftArgmax()
return Regularized(soft_argmax_regularization, soft_argmax)
end

"""
TODO
"""
function RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple())
# TODO : add a warning if DifferentiableFrankWolfe is not imported ?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

return Regularized(
Ω, FrankWolfeOptimizer(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs)
)
end
8 changes: 6 additions & 2 deletions src/regularized/soft_argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ function soft_argmax(z::AbstractVector; kwargs...)
return s
end

@traitimpl IsRegularized{typeof(soft_argmax)}
# @traitimpl IsRegularized{typeof(soft_argmax)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the trash


function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real}
# function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real}
# return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R)
# end

function soft_argmax_regularization(y::AbstractVector)
return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R)
end
12 changes: 8 additions & 4 deletions src/regularized/sparse_argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ function sparse_argmax(z::AbstractVector; kwargs...)
return p
end

@traitimpl IsRegularized{typeof(sparse_argmax)}
# @traitimpl IsRegularized{typeof(sparse_argmax)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the trash


function compute_regularization(
::typeof(sparse_argmax), y::AbstractVector{R}
) where {R<:Real}
# function compute_regularization(
# ::typeof(sparse_argmax), y::AbstractVector{R}
# ) where {R<:Real}
# return isprobadist(y) ? half_square_norm(y) : typemax(R)
# end

function sparse_argmax_regularization(y::AbstractVector)
return isprobadist(y) ? half_square_norm(y) : typemax(R)
end

Expand Down
14 changes: 7 additions & 7 deletions test/argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
PipelineLossImitation;
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=sparse_argmax,
maximizer=SparseArgmax(),
loss=mse,
error_function=hamming_distance,
)
Expand All @@ -67,7 +67,7 @@ end
PipelineLossImitation;
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=soft_argmax,
maximizer=SoftArgmax(),
loss=mse,
error_function=hamming_distance,
)
Expand Down Expand Up @@ -112,7 +112,7 @@ end
PipelineLossImitation;
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=RegularizedGeneric(
maximizer=RegularizedFrankWolfe(
one_hot_argmax,
half_square_norm,
identity,
Expand All @@ -133,7 +133,7 @@ end
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=identity,
loss=FenchelYoungLoss(sparse_argmax),
loss=FenchelYoungLoss(SparseArgmax()),
error_function=hamming_distance,
)
end
Expand All @@ -148,7 +148,7 @@ end
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=identity,
loss=FenchelYoungLoss(soft_argmax),
loss=FenchelYoungLoss(SoftArgmax()),
error_function=hamming_distance,
)
end
Expand Down Expand Up @@ -194,7 +194,7 @@ end
true_maximizer=one_hot_argmax,
maximizer=identity,
loss=FenchelYoungLoss(
RegularizedGeneric(
RegularizedFrankWolfe(
one_hot_argmax,
half_square_norm,
identity,
Expand Down Expand Up @@ -259,7 +259,7 @@ end
true_maximizer=one_hot_argmax,
maximizer=identity,
loss=Pushforward(
RegularizedGeneric(
RegularizedFrankWolfe(
one_hot_argmax,
half_square_norm,
identity,
Expand Down
4 changes: 2 additions & 2 deletions test/imitation_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=identity,
loss=FenchelYoungLoss(sparse_argmax),
loss=FenchelYoungLoss(SparseArgmax()),
error_function=hamming_distance,
true_encoder,
verbose=false,
Expand Down Expand Up @@ -98,7 +98,7 @@ end
instance_dim=5,
true_maximizer=one_hot_argmax,
maximizer=identity,
loss=FenchelYoungLoss(soft_argmax),
loss=FenchelYoungLoss(SoftArgmax()),
error_function=hamming_distance,
true_encoder,
verbose=false,
Expand Down
12 changes: 6 additions & 6 deletions test/paths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end
)
end

@testitem "Paths - imit - MSE RegularizedGeneric" default_imports = false begin
@testitem "Paths - imit - MSE RegularizedFrankWolfe" default_imports = false begin
include("InferOptTestUtils/InferOptTestUtils.jl")
using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random
Random.seed!(63)
Expand All @@ -97,7 +97,7 @@ end
PipelineLossImitation;
instance_dim=(5, 5),
true_maximizer=shortest_path_maximizer,
maximizer=RegularizedGeneric(
maximizer=RegularizedFrankWolfe(
shortest_path_maximizer,
half_square_norm,
identity,
Expand Down Expand Up @@ -143,7 +143,7 @@ end
)
end

@testitem "Paths - imit - FYL RegularizedGeneric" default_imports = false begin
@testitem "Paths - imit - FYL RegularizedFrankWolfe" default_imports = false begin
include("InferOptTestUtils/InferOptTestUtils.jl")
using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random
Random.seed!(63)
Expand All @@ -154,7 +154,7 @@ end
true_maximizer=shortest_path_maximizer,
maximizer=identity,
loss=FenchelYoungLoss(
RegularizedGeneric(
RegularizedFrankWolfe(
shortest_path_maximizer,
half_square_norm,
identity,
Expand Down Expand Up @@ -210,7 +210,7 @@ end
)
end

@testitem "Paths - exp - Pushforward RegularizedGeneric" default_imports = false begin
@testitem "Paths - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin
include("InferOptTestUtils/InferOptTestUtils.jl")
using DifferentiableFrankWolfe,
FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random
Expand All @@ -224,7 +224,7 @@ end
true_maximizer=shortest_path_maximizer,
maximizer=identity,
loss=Pushforward(
RegularizedGeneric(
RegularizedFrankWolfe(
shortest_path_maximizer,
half_square_norm,
identity,
Expand Down
Loading