Skip to content

Commit

Permalink
safer gpu free
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 13, 2021
1 parent 9e1182e commit 77c9350
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqFlux"
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
authors = ["Chris Rackauckas <[email protected]>"]
version = "1.36.0"
version = "1.36.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
16 changes: 6 additions & 10 deletions src/fast_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)

y = f.σ.(r)

if typeof(f.σ) <: typeof(tanh) || typeof(f.σ) <: typeof(identity)
ifgpufree(r)
end

function FastDense_adjoint(ȳ)
if typeof(f.σ) <: typeof(tanh)
zbar =.* (1 .- y.^2)
Expand All @@ -96,7 +92,7 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)
pbar = typeof(bbar) <: AbstractVector ?
vec(vcat(vec(Wbar),bbar)) :
vec(vcat(vec(Wbar),sum(bbar,dims=2)))
ifgpufree(Wbar); ifgpufree(bbar)
ifgpufree(Wbar); ifgpufree(bbar); ifgpufree(r)
nothing,xbar,pbar
end
y,FastDense_adjoint
Expand All @@ -123,7 +119,7 @@ struct StaticDense{out,in,bias,F,F2} <: FastLayer
initial_params::F2
function StaticDense(in::Integer, out::Integer, σ = identity;
bias::Bool = true, initW = Flux.glorot_uniform, initb = Flux.zeros)
temp = ((bias == true ) ? vcat(vec(initW(out, in)),initb(out)) : vcat(vec(initW(out, in))) )
temp = ((bias == true ) ? vcat(vec(initW(out, in)),initb(out)) : vcat(vec(initW(out, in))) )
initial_params() = temp
new{out,in,bias,typeof(σ),typeof(initial_params)}(σ,initial_params)
end
Expand All @@ -135,17 +131,17 @@ function param2Wb(f::StaticDense{out,in,bias}, p) where {out,in,bias}
W = @inbounds convert(SMatrix{out,in},_W)
b = @inbounds SVector{out}(_b)
return W, b
else
else
_W = @view p[1:(out*in)]
W = @inbounds convert(SMatrix{out,in},_W)
return W
end
end
function (f::StaticDense{out,in,bias})(x,p) where {out,in,bias}
if bias == true
if bias == true
W, b = param2Wb(f, p)
return f.σ.(W*x .+ b)
else
else
W = param2Wb(f,p)
return f.σ.(W*x)
end
Expand All @@ -154,7 +150,7 @@ ZygoteRules.@adjoint function (f::StaticDense{out,in,bias})(x,p) where {out,in,b
if bias == true
W, b = param2Wb(f, p)
r = W*x .+ b
else
else
W = param2Wb(f,p)
r = W*x
end
Expand Down

0 comments on commit 77c9350

Please sign in to comment.