diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 2c1526b07..176c78452 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -16,12 +16,16 @@ NonlinearSolveBase = {path = "../NonlinearSolveBase"} [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" +BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "ForwardDiff"] [compat] Aqua = "0.8.9" +ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" ExplicitImports = "1.10.1" @@ -34,6 +38,7 @@ SciMLBase = "2.69" Test = "1.10" TestItemRunner = "1" julia = "1.10" +Zygote = "0.6.69, 0.7" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" @@ -42,6 +47,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner"] +test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner", "Zygote"] diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl new file mode 100644 index 000000000..388517367 --- /dev/null +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -0,0 +1,34 @@ +module BracketingNonlinearSolveChainRulesCoreExt + +using BracketingNonlinearSolve: bracketingnonlinear_solve_up, CommonSolve, SciMLBase +using CommonSolve: solve +using SciMLBase: IntervalNonlinearProblem +using ForwardDiff +using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk + +function ChainRulesCore.rrule( + ::typeof(bracketingnonlinear_solve_up), + prob::IntervalNonlinearProblem, + sensealg, p, alg, args...; kwargs... +) + out = solve(prob, alg) + u = out.u + f = SciMLBase.unwrapped_f(prob.f) + function ∇bracketingnonlinear_solve_up(Δ) + Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ + # Δ = dg/du + Δ isa Tangent ? delu = Δ.u : delu = Δ + λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ delu) + if p isa Number + dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), p) + else + dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p) + end + return (NoTangent(), NoTangent(), NoTangent(), + dgdp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + return out, ∇bracketingnonlinear_solve_up +end + +end \ No newline at end of file diff --git a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl index 64db84621..9337ac6fe 100644 --- a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl +++ b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl @@ -24,10 +24,22 @@ include("ridder.jl") function CommonSolve.solve(prob::IntervalNonlinearProblem; kwargs...) return CommonSolve.solve(prob, ITP(); kwargs...) end + function CommonSolve.solve(prob::IntervalNonlinearProblem, nothing, args...; kwargs...) return CommonSolve.solve(prob, ITP(), args...; kwargs...) end +function CommonSolve.solve(prob::IntervalNonlinearProblem, + alg::AbstractBracketingAlgorithm, args...; sensealg = nothing, kwargs...) + return bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, prob.p, alg, args...; kwargs...) +end + + +function bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, p, alg, args...; kwargs...) + return SciMLBase.__solve(prob, alg, args...; kwargs...) +end + + @setup_workload begin for T in (Float32, Float64) prob_brack = IntervalNonlinearProblem{false}( diff --git a/lib/BracketingNonlinearSolve/src/alefeld.jl b/lib/BracketingNonlinearSolve/src/alefeld.jl index 6880f8c95..86807986a 100644 --- a/lib/BracketingNonlinearSolve/src/alefeld.jl +++ b/lib/BracketingNonlinearSolve/src/alefeld.jl @@ -8,7 +8,7 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal """ struct Alefeld <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Alefeld, args...; maxiters = 1000, abstol = nothing, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/bisection.jl b/lib/BracketingNonlinearSolve/src/bisection.jl index 91c17a775..5f056abbc 100644 --- a/lib/BracketingNonlinearSolve/src/bisection.jl +++ b/lib/BracketingNonlinearSolve/src/bisection.jl @@ -19,7 +19,7 @@ A common bisection method. exact_right::Bool = false end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Bisection, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/brent.jl b/lib/BracketingNonlinearSolve/src/brent.jl index 7baebc90c..6199bf29a 100644 --- a/lib/BracketingNonlinearSolve/src/brent.jl +++ b/lib/BracketingNonlinearSolve/src/brent.jl @@ -5,7 +5,7 @@ Left non-allocating Brent method. """ struct Brent <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Brent, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/falsi.jl b/lib/BracketingNonlinearSolve/src/falsi.jl index 3074a5eb4..a2bdbde1f 100644 --- a/lib/BracketingNonlinearSolve/src/falsi.jl +++ b/lib/BracketingNonlinearSolve/src/falsi.jl @@ -5,7 +5,7 @@ A non-allocating regula falsi method. """ struct Falsi <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Falsi, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/itp.jl b/lib/BracketingNonlinearSolve/src/itp.jl index cbf5818bf..c733dc25f 100644 --- a/lib/BracketingNonlinearSolve/src/itp.jl +++ b/lib/BracketingNonlinearSolve/src/itp.jl @@ -56,7 +56,7 @@ function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10) return ITP(scaled_k1, k2, n0) end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::ITP, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/muller.jl b/lib/BracketingNonlinearSolve/src/muller.jl index 7b89236a0..1e321969b 100644 --- a/lib/BracketingNonlinearSolve/src/muller.jl +++ b/lib/BracketingNonlinearSolve/src/muller.jl @@ -27,7 +27,7 @@ end Muller() = Muller(nothing) -function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...; +function SciMLBase.__solve(prob::IntervalNonlinearProblem, alg::Muller, args...; abstol = nothing, maxiters = 1000, kwargs...) @assert !SciMLBase.isinplace(prob) "`Muller` only supports out-of-place problems." xᵢ₋₂, xᵢ = prob.tspan diff --git a/lib/BracketingNonlinearSolve/src/ridder.jl b/lib/BracketingNonlinearSolve/src/ridder.jl index 9192897c5..9e38d25b6 100644 --- a/lib/BracketingNonlinearSolve/src/ridder.jl +++ b/lib/BracketingNonlinearSolve/src/ridder.jl @@ -5,7 +5,7 @@ A non-allocating ridder method. """ struct Ridder <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Ridder, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl new file mode 100644 index 000000000..50c2dce2a --- /dev/null +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -0,0 +1,18 @@ +@testitem "Simple Adjoint Test" tags=[:adjoint] begin + using ForwardDiff, Zygote, BracketingNonlinearSolve + + ff(u, p) = u^2 .- p[1] + + function solve_nlprob(p) + prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p) + sol = solve(prob, Bisection()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) + end + + p = [2.0, 2.0] + + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) + @test ∂p_zygote ≈ ∂p_forwarddiff +end