Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

indirect solve prototype #121

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ version = "0.4.1"
[deps]
AMD = "14f7f29c-3bd6-536c-9a0b-7339e30b5a3e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Expand Down
1 change: 1 addition & 0 deletions src/Clarabel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ module Clarabel
#KKT solvers and solver level kktsystem
include("./kktsolvers/kktsolver_defaults.jl")
include("./kktsolvers/kktsolver_directldl.jl")
include("./kktsolvers/kktsolver_indirect.jl")
include("./kktsystem.jl")

# printing and top level solver
Expand Down
10 changes: 8 additions & 2 deletions src/info_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,14 @@ function print_settings(settings::Settings, T::DataType)
set = settings
@printf("\nsettings:\n")

if(set.direct_kkt_solver)
@printf(" linear algebra: direct / %s, precision: %s\n", set.direct_solve_method, get_precision_string(T))
if(set.kkt_solver_method == :directldl)
@printf(" KKT solve method: direct / %s, precision: %s\n",
set.direct_solve_method,
get_precision_string(T))
else
@printf(" KKT solve method: %s, precision: %s\n",
set.kkt_solver_method,
get_precision_string(T))
end

@printf(" max iter = %i, time limit = %f, max step = %.3f\n",
Expand Down
10 changes: 9 additions & 1 deletion src/kktsolvers/direct-ldl/directldl_defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@ abstract type AbstractDirectLDLSolver{T <: AbstractFloat} end

const DirectLDLSolversDict = Dict{Symbol, UnionAll}()

function _get_ldlsolver_type(s::Symbol)
try
return DirectLDLSolversDict[s]
catch
throw(error("Unsupported direct LDL linear solver :", s))
end
end

# Any new LDL solver type should provide implementations of all
# of the following and add itself to the DirectLDLSolversDict

# register type, .e.g
# register type, e.g.
# DirectLDLSolversDict[:qdldl] = QDLDLDirectLDLSolver

# return either :triu or :tril
Expand Down
16 changes: 16 additions & 0 deletions src/kktsolvers/kktsolver_defaults.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
const KKTSolversDict = Dict{Symbol, UnionAll}()

function _get_kktsolver_type(s::Symbol)
try
return KKTSolversDict[s]
catch
throw(error("Unsupported kkt solver method:", s))
end
end

# Any new AbstractKKTSolver sub type should provide implementations of
# of the following and add itself to the KKTSolversDict

# register type, e.g.
# KKTSolversDict[:directldl] = DirectLDLKKTSolver

#update matrix data and factor
function kktsolver_update!(linsys::AbstractKKTSolver{T},cones::CompositeCone{T}) where{T}
Expand All @@ -17,6 +32,7 @@ end
#solve and assign LHS
function kktsolver_solve!(
kktsolver::AbstractKKTSolver{T},
cones::CompositeCone{T},
x::Union{Nothing,AbstractVector{T}},
z::Union{Nothing,AbstractVector{T}}
) where{T}
Expand Down
12 changes: 4 additions & 8 deletions src/kktsolvers/kktsolver_directldl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,10 @@ end

DirectLDLKKTSolver(args...) = DirectLDLKKTSolver{DefaultFloat}(args...)

function _get_ldlsolver_type(s::Symbol)
try
return DirectLDLSolversDict[s]
catch
throw(error("Unsupported direct LDL linear solver :", s))
end
end
KKTSolversDict[:directldl] = DirectLDLKKTSolver


function _fill_Dsigns!(Dsigns,m,n,p)
function _fill_Dsigns!(Dsigns::Vector{Int64}, m::Int64, n::Int64, p::Int64)

Dsigns .= 1

Expand Down Expand Up @@ -336,6 +331,7 @@ end

function kktsolver_solve!(
kktsolver::DirectLDLKKTSolver{T},
cones::CompositeCone{T},
lhsx::Union{Nothing,AbstractVector{T}},
lhsz::Union{Nothing,AbstractVector{T}}
) where {T}
Expand Down
181 changes: 181 additions & 0 deletions src/kktsolvers/kktsolver_indirect.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
using LinearOperators, IterativeSolvers
# -------------------------------------
# Generic Indirect KKTSolver
# -------------------------------------

mutable struct IndirectKKTSolver{T} <: AbstractKKTSolver{T}

# problem dimensions
m::Int; n::Int;

# Left and right hand sides for solves
x1::Vector{T}
x2::ConicVector{T}
b1::Vector{T}
b2::ConicVector{T}

# two work vectors are required for multiplying
# through W and for intermediate products A*x
work1::ConicVector{T}
work2::ConicVector{T}

# internal (shallow) copies of problem data.
# could be mapped here to some other format
P::Symmetric{T,SparseMatrixCSC{T,Int}}
A::AbstractMatrix{T}

# block diagonal data for the lower RHS
H::Vector{Vector{T}}

#settings just points back to the main solver settings.
#Required since there is no separate KKT settings container
settings::Settings{T}


function IndirectKKTSolver{T}(P,A,cones,m,n,settings) where {T}

#LHS/RHS/work
x1 = Vector{T}(undef,n)
b1 = Vector{T}(undef,n)
x2 = ConicVector{T}(cones)
b2 = ConicVector{T}(cones)
work1 = ConicVector{T}(cones)
work2 = ConicVector{T}(cones)

#lower RHS block elements
#PJG: wiill only work for diagonal blocks
nblocks = numel.(cones.cones)
H = map(n -> zeros(T,n), nblocks)

return new(m,n,x1,x2,b1,b2,
work1,work2,Symmetric(P),A,H,settings)
end

end

IndirectKKTSolver(args...) = IndirectKKTSolver{DefaultFloat}(args...)

KKTSolversDict[:indirect] = IndirectKKTSolver


function kktsolver_update!(
kktsolver::IndirectKKTSolver{T},
cones::CompositeCone{T}
) where {T}

get_Hs!(cones,kktsolver.H)

#PJG: development optimism
is_success = true
return is_success
end


function kktsolver_setrhs!(
kktsolver::IndirectKKTSolver{T},
rhsx::AbstractVector{T},
rhsz::AbstractVector{T}
) where {T}

kktsolver.b1 .= rhsx
kktsolver.b2.vec .= rhsz

return nothing
end


function kktsolver_getlhs!(
kktsolver::IndirectKKTSolver{T},
lhsx::Union{Nothing,AbstractVector{T}},
lhsz::Union{Nothing,AbstractVector{T}}
) where {T}

isnothing(lhsx) || (@views lhsx .= kktsolver.x1)
isnothing(lhsz) || (@views lhsz .= kktsolver.x2.vec)

return nothing
end


function kktsolver_solve!(
kktsolver::IndirectKKTSolver{T},
cones::CompositeCone{T},
lhsx::Union{Nothing,AbstractVector{T}},
lhsz::Union{Nothing,AbstractVector{T}}
) where {T}

(P,A) = (kktsolver.P, kktsolver.A)
work1 = kktsolver.work1
work2 = kktsolver.work2
(x1,x2) = (kktsolver.x1,kktsolver.x2)
(b1,b2) = (kktsolver.b1,kktsolver.b2)
H = kktsolver.H

_indirect_solve_kkt(cones,x1,x2,P,A,H,b1,b2,work1,work2)

#PJG: development optimism
is_success = true

if is_success
kktsolver_getlhs!(kktsolver,lhsx,lhsz)
end

return is_success
end


function _indirect_solve_kkt(cones,x1,x2,P,A,H,b1,b2,work1,work2)

# Here we should put our solver for the system
# [P A'][x1] = [b1]
# [A -H ][x2] [b2]
#
# I will assume here that :
#
# 1) we want to solve by condensing to a PSD form
# and doing an indirect solve of (P + A'*H^{-1}*A)x = r
#
# 2) The matrix H = W^TW is symmetric, sign definite,
# and diagonal (i.e. nonnegative cones only).

# Diagonal H is not really necessary since I only need
# to compute products H^-1*b for an indirect solver,
# but in this prototype I have formed H and its inverse
# directly just for testing.
#
# for an actual indirect method based on condensing we only need to
# compute products y = H^{-1}b = (W^TW)^{-1}b. For symmetric cones
# we should be able to do something like:
#
# mul_Winv!(cones,:T,work1,b2,one(T),zero(T)) #work1 = (W^T)^{-1}*b2
# mul_Winv!(cones,:N,work2,work1,one(T),zero(T)) #work2 = W^{-1}*work1
#
# At present the above won't compile because the CompositeCone
# container type only implements the general nonsymmetric cone
# interface, i.e. no mul_W or mul_Winv is available without
# some hackery.

# Here for simplicity I will instead assume that H is diagonal and only
# has one block (e.g. a single nonnegative cone constraint)

@assert(length(H) == 1)

H = Diagonal(H[1])
Hinv = inv(H)

# Solve (P + A'*H^{-1}*A)*b1 = x1. Indirect method goes here
# Should produce same as ...
# x1 .= (P + A'*Hinv*A)\(b1 + A'*Hinv*b2);

A = LinearOperator(A);
Hinv = LinearOperator(Hinv);
P = LinearOperator(P);
M = P + (A'*Hinv*A);
x1 .= cg(M,b1 + A'*(Hinv*b2.vec));

# backsolve for x2.
x2 .= Hinv*(A*x1 - b2);

end


20 changes: 12 additions & 8 deletions src/kktsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ mutable struct DefaultKKTSystem{T} <: AbstractKKTSystem{T}
#basic problem dimensions
(m, n) = (data.m, data.n)

#create the linear solver. Always LDL for now
kktsolver = DirectLDLKKTSolver{T}(data.P,data.A,cones,m,n,settings)

#which KKT Solver should I use?
kktsolverT = _get_kktsolver_type(settings.kkt_solver_method)
kktsolver = kktsolverT{T}(data.P,data.A,cones,m,n,settings)

#the LHS constant part of the reduced solve
x1 = Vector{T}(undef,n)
Expand Down Expand Up @@ -68,20 +70,21 @@ function kkt_update!(
is_success || return is_success

#calculate KKT solution for constant terms
is_success = _kkt_solve_constant_rhs!(kktsystem,data)
is_success = _kkt_solve_constant_rhs!(kktsystem,cones,data)

return is_success
end

function _kkt_solve_constant_rhs!(
kktsystem::DefaultKKTSystem{T},
cones::CompositeCone{T},
data::DefaultProblemData{T}
) where {T}

@. kktsystem.workx = -data.q;

kktsolver_setrhs!(kktsystem.kktsolver, kktsystem.workx, data.b)
is_success = kktsolver_solve!(kktsystem.kktsolver, kktsystem.x2, kktsystem.z2)
is_success = kktsolver_solve!(kktsystem.kktsolver, cones, kktsystem.x2, kktsystem.z2)

return is_success

Expand All @@ -90,6 +93,7 @@ end

function kkt_solve_initial_point!(
kktsystem::DefaultKKTSystem{T},
cones::CompositeCone{T},
variables::DefaultVariables{T},
data::DefaultProblemData{T}
) where{T}
Expand All @@ -101,7 +105,7 @@ function kkt_solve_initial_point!(
kktsystem.workx .= zero(T)
kktsystem.workz .= data.b
kktsolver_setrhs!(kktsystem.kktsolver, kktsystem.workx, kktsystem.workz)
is_success = kktsolver_solve!(kktsystem.kktsolver, variables.x, variables.s)
is_success = kktsolver_solve!(kktsystem.kktsolver, cones, variables.x, variables.s)

if !is_success return is_success end

Expand All @@ -111,13 +115,13 @@ function kkt_solve_initial_point!(
kktsystem.workz .= zero(T)

kktsolver_setrhs!(kktsystem.kktsolver, kktsystem.workx, kktsystem.workz)
is_success = kktsolver_solve!(kktsystem.kktsolver, nothing, variables.z)
is_success = kktsolver_solve!(kktsystem.kktsolver, cones, nothing, variables.z)
else
# QP initialization
@. kktsystem.workx = -data.q
@. kktsystem.workz = data.b
kktsolver_setrhs!(kktsystem.kktsolver, kktsystem.workx, kktsystem.workz)
is_success = kktsolver_solve!(kktsystem.kktsolver, variables.x, variables.z)
is_success = kktsolver_solve!(kktsystem.kktsolver, cones, variables.x, variables.z)
@. variables.s = -variables.z
end

Expand Down Expand Up @@ -162,7 +166,7 @@ function kkt_solve!(
#---------------------------------------------------
#this solves the variable part of reduced KKT system
kktsolver_setrhs!(kktsystem.kktsolver, workx, workz)
is_success = kktsolver_solve!(kktsystem.kktsolver,x1,z1)
is_success = kktsolver_solve!(kktsystem.kktsolver,cones,x1,z1)

if !is_success return false end

Expand Down
Loading