diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index 388517367..0470d0dde 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -2,8 +2,8 @@ module BracketingNonlinearSolveChainRulesCoreExt using BracketingNonlinearSolve: bracketingnonlinear_solve_up, CommonSolve, SciMLBase using CommonSolve: solve -using SciMLBase: IntervalNonlinearProblem -using ForwardDiff +using SciMLBase: IntervalNonlinearProblem, unwrapped_f +using ForwardDiff: derivative, gradient using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk function ChainRulesCore.rrule( @@ -13,16 +13,16 @@ function ChainRulesCore.rrule( ) out = solve(prob, alg) u = out.u - f = SciMLBase.unwrapped_f(prob.f) + f = unwrapped_f(prob.f) function ∇bracketingnonlinear_solve_up(Δ) Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ # Δ = dg/du Δ isa Tangent ? delu = Δ.u : delu = Δ - λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ delu) + λ = only(derivative(u -> f(u, p), only(u)) \ delu) if p isa Number - dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), p) + dgdp = -λ * derivative(p -> f(u, p), p) else - dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p) + dgdp = -λ * gradient(p -> f(u, p), p) end return (NoTangent(), NoTangent(), NoTangent(), dgdp, NoTangent(), @@ -31,4 +31,4 @@ function ChainRulesCore.rrule( return out, ∇bracketingnonlinear_solve_up end -end \ No newline at end of file +end diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveForwardDiffExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveForwardDiffExt.jl index 09616b5a2..f531df777 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveForwardDiffExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveForwardDiffExt.jl @@ -1,7 +1,7 @@ module BracketingNonlinearSolveForwardDiffExt using CommonSolve: CommonSolve -using ForwardDiff: ForwardDiff, Dual +using ForwardDiff: Dual using NonlinearSolveBase: nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution using SciMLBase: SciMLBase, IntervalNonlinearProblem