Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BatyLeo committed Dec 23, 2024
1 parent d214802 commit f084090
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 231 deletions.
484 changes: 299 additions & 185 deletions docs/Manifest.toml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/InferOpt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl
export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl
export RegularizedFrankWolfe

export Perturbed
export PerturbedOracle
export PerturbedAdditive
export PerturbedMultiplicative
export LinearPerturbed
Expand Down
64 changes: 35 additions & 29 deletions src/layers/perturbed/perturbed.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Perturbed{D,F} <: AbstractOptimizationLayer
PerturbedOracle{D,F} <: AbstractOptimizationLayer
Differentiable perturbation of a black box optimizer of type `F`, with perturbation of type `D`.
Expand All @@ -10,23 +10,25 @@ There are three different available constructors that behave differently in the
- [`PerturbedAdditive`](@ref)
- [`PerturbedMultiplicative`](@ref)
"""
struct Perturbed{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer
struct PerturbedOracle{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer
reinforce::Reinforce{t,variance_reduction,F,D,G,R,S}
end

function (perturbed::Perturbed)(θ::AbstractArray; kwargs...)
function (perturbed::PerturbedOracle)(θ::AbstractArray; kwargs...)
return perturbed.reinforce(θ; kwargs...)
end

function get_maximizer(perturbed::Perturbed)
function get_maximizer(perturbed::PerturbedOracle)
return perturbed.reinforce.f
end

function compute_probability_distribution(perturbed::Perturbed, θ::AbstractArray; kwargs...)
function compute_probability_distribution(
perturbed::PerturbedOracle, θ::AbstractArray; kwargs...
)
return empirical_distribution(perturbed.reinforce, θ; kwargs...)
end

function Base.show(io::IO, perturbed::Perturbed{<:AbstractPerturbation})
function Base.show(io::IO, perturbed::PerturbedOracle{<:AbstractPerturbation})
(; reinforce) = perturbed
nb_samples = reinforce.nb_samples
ε = reinforce.dist_constructor.ε
Expand All @@ -36,24 +38,36 @@ function Base.show(io::IO, perturbed::Perturbed{<:AbstractPerturbation})
f = reinforce.f
return print(
io,
"Perturbed($f, ε=, nb_samples=$nb_samples, perturbation=$perturbation, rng=$(typeof(rng)), seed=$seed)",
"PerturbedOracle($f, ε=, nb_samples=$nb_samples, perturbation=$perturbation, rng=$(typeof(rng)), seed=$seed)",
)
end

"""
doc
"""
function LinearPerturbed(
function PerturbedOracle(
maximizer,
dist_constructor,
dist_logdensity_grad=nothing;
g=nothing,
h=nothing,
dist_constructor;
dist_logdensity_grad=nothing,
nb_samples=1,
variance_reduction=true,
threaded=false,
seed=nothing,
rng=Random.default_rng(),
kwargs...,
)
linear_maximizer = LinearMaximizer(; maximizer, g, h)
return Perturbed(
Reinforce(linear_maximizer, dist_constructor, dist_logdensity_grad; kwargs...)
return PerturbedOracle(
Reinforce(
maximizer,
dist_constructor,
dist_logdensity_grad;
nb_samples,
variance_reduction,
threaded,
seed,
rng,
kwargs...,
),
)
end

Expand All @@ -69,26 +83,22 @@ function PerturbedAdditive(
seed=nothing,
threaded=false,
rng=Random.default_rng(),
g=identity_kw,
h=zero eltype_kw,
dist_logdensity_grad=if (perturbation_dist == Normal(0, 1))
(η, θ) -> ((η .- θ) ./ ε^2,)
else
nothing
end,
)
dist_constructor = AdditivePerturbation(perturbation_dist, float(ε))
return LinearPerturbed(
return PerturbedOracle(
maximizer,
dist_constructor,
dist_logdensity_grad;
dist_constructor;
dist_logdensity_grad,
nb_samples,
variance_reduction,
seed,
threaded,
rng,
g,
h,
)
end

Expand All @@ -104,25 +114,21 @@ function PerturbedMultiplicative(
seed=nothing,
threaded=false,
rng=Random.default_rng(),
g=identity_kw,
h=zero eltype_kw,
dist_logdensity_grad=if (perturbation_dist == Normal(0, 1))
(η, θ) -> (inv.(ε^2 .* θ) .*.- θ),)
else
nothing
end,
)
dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε))
return LinearPerturbed(
return PerturbedOracle(
maximizer,
dist_constructor,
dist_logdensity_grad;
dist_constructor;
dist_logdensity_grad,
nb_samples,
variance_reduction,
seed,
threaded,
rng,
g,
h,
)
end
9 changes: 6 additions & 3 deletions src/losses/fenchel_young_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ function fenchel_young_loss_and_grad(
end

function fenchel_young_loss_and_grad(
fyl::FenchelYoungLoss{<:Perturbed}, θ::AbstractArray, y_true::AbstractArray; kwargs...
fyl::FenchelYoungLoss{<:PerturbedOracle},
θ::AbstractArray,
y_true::AbstractArray;
kwargs...,
)
(; optimization_layer) = fyl
maximizer = get_maximizer(optimization_layer)
Expand All @@ -82,7 +85,7 @@ end
## Specific overrides for perturbed layers

function fenchel_young_F_and_first_part_of_grad(
perturbed::Perturbed{<:AdditivePerturbation}, θ::AbstractArray; kwargs...
perturbed::PerturbedOracle{<:AdditivePerturbation}, θ::AbstractArray; kwargs...
)
(; reinforce) = perturbed
maximizer = get_maximizer(perturbed)
Expand All @@ -98,7 +101,7 @@ function fenchel_young_F_and_first_part_of_grad(
end

function fenchel_young_F_and_first_part_of_grad(
perturbed::Perturbed{<:MultiplicativePerturbation}, θ::AbstractArray; kwargs...
perturbed::PerturbedOracle{<:MultiplicativePerturbation}, θ::AbstractArray; kwargs...
)
(; reinforce) = perturbed
maximizer = get_maximizer(perturbed)
Expand Down
4 changes: 4 additions & 0 deletions src/utils/linear_maximizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ function Base.show(io::IO, f::LinearMaximizer)
return print(io, "LinearMaximizer($maximizer, $g, $h)")
end

function LinearMaximizer(maximizer; g=identity_kw, h=zero eltype_kw)
return LinearMaximizer(maximizer, g, h)
end

# Callable calls the wrapped maximizer
function (f::LinearMaximizer)(θ::AbstractArray; kwargs...)
return f.maximizer(θ; kwargs...)
Expand Down
24 changes: 13 additions & 11 deletions test/learning_generalized_maximizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@test y == [1 0 1; 0 1 0; 1 1 1]

generalized_maximizer = LinearMaximizer(; maximizer=max_pricing, g, h)
generalized_maximizer = LinearMaximizer(max_pricing; g, h)

@test generalized_maximizer(θ; instance) == y

Expand All @@ -29,8 +29,8 @@ end

true_encoder = encoder_factory()

perturbed = PerturbedAdditive(max_pricing; ε=1.0, nb_samples=10, g, h)
maximizer = InferOpt.get_maximizer(perturbed)
maximizer = LinearMaximizer(max_pricing; g, h)
perturbed = PerturbedAdditive(maximizer; ε=1.0, nb_samples=10)
function cost(y; instance)
return -objective_value(maximizer, true_encoder(instance), y; instance)
end
Expand All @@ -55,8 +55,8 @@ end

true_encoder = encoder_factory()

perturbed = PerturbedMultiplicative(max_pricing; ε=1.0, nb_samples=10, g, h)
maximizer = InferOpt.get_maximizer(perturbed)
maximizer = LinearMaximizer(max_pricing; g, h)
perturbed = PerturbedMultiplicative(maximizer; ε=1.0, nb_samples=10)
function cost(y; instance)
return -objective_value(maximizer, true_encoder(instance), y; instance)
end
Expand All @@ -80,8 +80,10 @@ end

true_encoder = encoder_factory()

perturbed = PerturbedAdditive(max_pricing; ε=1.0, nb_samples=10, g, h)
maximizer = InferOpt.get_maximizer(perturbed)
maximizer = LinearMaximizer(max_pricing; g, h)
@info maximizer g h
perturbed = PerturbedAdditive(maximizer; ε=1.0, nb_samples=10)
@info perturbed
function cost(y; instance)
return -objective_value(maximizer, true_encoder(instance), y; instance)
end
Expand All @@ -106,8 +108,8 @@ end

true_encoder = encoder_factory()

perturbed = PerturbedMultiplicative(max_pricing; ε=0.1, nb_samples=10, g, h)
maximizer = InferOpt.get_maximizer(perturbed)
maximizer = LinearMaximizer(max_pricing; g, h)
perturbed = PerturbedMultiplicative(maximizer; ε=0.1, nb_samples=10)
function cost(y; instance)
return -objective_value(maximizer, true_encoder(instance), y; instance)
end
Expand Down Expand Up @@ -180,7 +182,7 @@ end

true_encoder = encoder_factory()

generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
generalized_maximizer = LinearMaximizer(max_pricing; g, h)
function cost(y; instance)
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
end
Expand All @@ -207,7 +209,7 @@ end

true_encoder = encoder_factory()

generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
generalized_maximizer = LinearMaximizer(max_pricing; g, h)
function cost(y; instance)
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
end
Expand Down
5 changes: 4 additions & 1 deletion test/paths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ end
maximizer=identity_kw,
loss=FenchelYoungLoss(
PerturbedAdditive(
shortest_path_maximizer; ε=1.0, nb_samples=5, perturbation=LogNormal(0, 1)
shortest_path_maximizer;
ε=1.0,
nb_samples=5,
perturbation_dist=LogNormal(0, 1),
),
),
error_function=mse_kw,
Expand Down
2 changes: 1 addition & 1 deletion test/perturbed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
p(θ) = MvNormal(θ, ε^2 * I)
oracle(η) = η

po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) # TODO: fix this
po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0)
pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0)

θ = randn(10)
Expand Down

0 comments on commit f084090

Please sign in to comment.