diff --git a/examples/tutorial.jl b/examples/tutorial.jl index f6838ab..876ca1e 100644 --- a/examples/tutorial.jl +++ b/examples/tutorial.jl @@ -113,14 +113,15 @@ Thanks to this smoothing, we can now train our model with a standard gradient op encoder = deepcopy(initial_encoder) opt = Flux.Adam(); +opt_state = Flux.setup(opt, encoder) losses = Float64[] for epoch in 1:100 l = 0.0 for (x, y) in zip(X_train, Y_train) - grads = gradient(Flux.params(encoder)) do - l += loss(encoder(x), y; directions=queen_directions) + grads = Flux.gradient(encoder) do m + l += loss(m(x), y; directions=queen_directions) end - Flux.update!(opt, Flux.params(encoder), grads) + Flux.update!(opt_state, encoder, grads[1]) end push!(losses, l) end; diff --git a/test/argmax.jl b/test/argmax.jl index 4075c51..c3c2a57 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -116,7 +116,7 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, error_function=hamming_distance, @@ -198,7 +198,7 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), ), error_function=hamming_distance, @@ -263,7 +263,7 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), cost, ), diff --git a/test/paths.jl b/test/paths.jl index c38176b..0681e8a 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -101,7 +101,7 @@ end shortest_path_maximizer; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, error_function=mse_kw, @@ -177,7 +177,7 @@ end shortest_path_maximizer; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), ), error_function=mse_kw, @@ -247,7 +247,7 @@ end shortest_path_maximizer; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), cost, ), diff --git a/test/ranking.jl b/test/ranking.jl index 122ed99..63954a5 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -101,7 +101,7 @@ end ranking; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, error_function=hamming_distance, @@ -170,7 +170,7 @@ end ranking; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), ), error_function=hamming_distance, @@ -303,7 +303,7 @@ end ranking; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), cost, ),