Skip to content

Commit

Permalink
revert to 0.2 for now
Browse files Browse the repository at this point in the history
  • Loading branch information
BatyLeo committed Dec 24, 2024
1 parent c5c1bfd commit c2615cd
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe"
ChainRulesCore = "1"
DensityInterface = "0.4.0"
DifferentiableExpectations = "0.2"
DifferentiableFrankWolfe = "0.3"
DifferentiableFrankWolfe = "0.2"
Distributions = "0.25"
DocStringExtensions = "0.9.3"
LinearAlgebra = "<0.0.1,1"
Expand Down
10 changes: 0 additions & 10 deletions src/utils/linear_maximizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ end
# default is oracles of the form argmax_y θᵀy
objective_value(::Any, θ, y; kwargs...) = dot(θ, y)
apply_g(::Any, y; kwargs...) = y
# apply_h(::Any, y; kwargs...) = zero(eltype(y)) is not needed

"""
$TYPEDSIGNATURES
Expand All @@ -65,12 +64,3 @@ Applies the function `g` of the LinearMaximizer `f` to `y`.
function apply_g(f::LinearMaximizer, y; kwargs...)
return f.g(y; kwargs...)
end

# """
# $TYPEDSIGNATURES

# Applies the function `h` of the LinearMaximizer `f` to `y`.
# """
# function apply_h(f::LinearMaximizer, y; kwargs...)
# return f.h(y; kwargs...)
# end
13 changes: 5 additions & 8 deletions src/utils/pushforward.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""
Pushforward <: AbstractLayer
$TYPEDEF
Differentiable pushforward of a probabilistic optimization layer with an arbitrary function post-processing function.
`Pushforward` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost.
# Fields
- `optimization_layer::AbstractOptimizationLayer`: probabilistic optimization layer
- `post_processing`: callable
See also: `FixedAtomsProbabilityDistribution`.
$TYPEDFIELDS
"""
struct Pushforward{O<:AbstractOptimizationLayer,P} <: AbstractLayer
"probabilistic optimization layer"
optimization_layer::O
"callable"
post_processing::P
end

Expand All @@ -22,13 +21,11 @@ function Base.show(io::IO, pushforward::Pushforward)
end

"""
(pushforward::Pushforward)(θ; kwargs...)
$TYPEDSIGNATURES
Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`.
This function is differentiable, even if `pushforward.post_processing` isn't.
See also: `compute_expectation`.
"""
function (pushforward::Pushforward)(θ::AbstractArray; kwargs...)
(; optimization_layer, post_processing) = pushforward
Expand Down
16 changes: 8 additions & 8 deletions src/utils/some_functions.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""
positive_part(x)
$TYPEDSIGNATURES
Compute `max(x,0)`.
Compute `max(x, 0)`.
"""
positive_part(x) = x >= zero(x) ? x : zero(x)

"""
isproba(x)
$TYPEDSIGNATURES
Check whether `x ∈ [0,1]`.
"""
isproba(x::Real) = zero(x) <= x <= one(x)

"""
isprobadist(p)
$TYPEDSIGNATURES
Check whether the elements of `p` are nonnegative and sum to 1.
"""
isprobadist(p::AbstractVector{R}) where {R<:Real} = all(isproba, p) && sum(p) one(R)

"""
half_square_norm(x)
$TYPEDSIGNATURES
Compute the squared Euclidean norm of `x` and divide it by 2.
"""
Expand All @@ -29,7 +29,7 @@ function half_square_norm(x::AbstractArray)
end

"""
shannon_entropy(p)
$TYPEDSIGNATURES
Compute the Shannon entropy of a probability distribution: `H(p) = -∑ pᵢlog(pᵢ)`.
"""
Expand All @@ -46,7 +46,7 @@ end
negative_shannon_entropy(p::AbstractVector) = -shannon_entropy(p)

"""
one_hot_argmax(z)
$TYPEDSIGNATURES
One-hot encoding of the argmax function.
"""
Expand All @@ -57,7 +57,7 @@ function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real}
end

"""
ranking(θ[; rev])
$TYPEDSIGNATURES
Compute the vector `r` such that `rᵢ` is the rank of `θᵢ` in `θ`.
"""
Expand Down

0 comments on commit c2615cd

Please sign in to comment.