Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Implement Adjoints for solution of IntervalNonlinearProblems #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b5e2ce2
make solve algorithms use __solve
jClugstor May 21, 2025
5205bdf
add bracketingnonlinear_solve_up
jClugstor May 21, 2025
ec41473
add extensions
jClugstor May 21, 2025
ef32112
add Bracketing ChainRulesCoreExt
jClugstor May 21, 2025
1610a51
better error message to make sure problem constructor adjoints exist
jClugstor May 21, 2025
845ec7f
add weakdeps
jClugstor May 21, 2025
b8eb03f
add test
jClugstor May 21, 2025
a878e8e
fix test
jClugstor May 21, 2025
db469c1
use SciMLBase instead
jClugstor May 22, 2025
d52842f
use gradient, p might not be scalar
jClugstor May 22, 2025
ff43257
add zygote as trigger for chainrulescore extension
jClugstor May 22, 2025
634b1c3
account for both derivative and gradient
jClugstor May 22, 2025
f593cc4
old docstring
jClugstor May 22, 2025
53cd09a
add ForwardDiff trigger, more using
jClugstor May 22, 2025
bac45ad
get rid of unnecessary Zygote
jClugstor May 22, 2025
77b9e5a
fix adjoint test
jClugstor May 22, 2025
0ed1191
don't need diffeqbase ext stuff
jClugstor May 22, 2025
7d48d7d
load bracketing nonlinear solve in test
jClugstor May 22, 2025
f458c96
fix project.toml
jClugstor May 22, 2025
50ce860
add Zygote to test deps
jClugstor May 22, 2025
d70f9b3
test should use Bisection
jClugstor May 22, 2025
7f3db45
account for Thunks, non tangent types
jClugstor May 22, 2025
671d23a
fix test
jClugstor May 22, 2025
789e04b
make imports explicit, add ompat bounds
jClugstor May 22, 2025
d8b82af
Update lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChain…
ChrisRackauckas May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion 8 lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ NonlinearSolveBase = {path = "../NonlinearSolveBase"}

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"


[extensions]
BracketingNonlinearSolveForwardDiffExt = "ForwardDiff"
BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "ForwardDiff"]

[compat]
Aqua = "0.8.9"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
ExplicitImports = "1.10.1"
Expand All @@ -34,6 +38,7 @@ SciMLBase = "2.69"
Test = "1.10"
TestItemRunner = "1"
julia = "1.10"
Zygote = "0.6.69, 0.7"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand All @@ -42,6 +47,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner"]
test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner", "Zygote"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module BracketingNonlinearSolveChainRulesCoreExt

using BracketingNonlinearSolve: bracketingnonlinear_solve_up, CommonSolve, SciMLBase
using CommonSolve: solve
using SciMLBase: IntervalNonlinearProblem
using ForwardDiff
using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk

function ChainRulesCore.rrule(
::typeof(bracketingnonlinear_solve_up),
prob::IntervalNonlinearProblem,
sensealg, p, alg, args...; kwargs...
)
out = solve(prob, alg)
u = out.u
f = SciMLBase.unwrapped_f(prob.f)
function ∇bracketingnonlinear_solve_up(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
# Δ = dg/du
Δ isa Tangent ? delu = Δ.u : delu = Δ
λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ delu)
if p isa Number
dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), p)
else
dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p)
end
return (NoTangent(), NoTangent(), NoTangent(),
dgdp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
return out, ∇bracketingnonlinear_solve_up
end

end
12 changes: 12 additions & 0 deletions 12 lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ include("ridder.jl")
function CommonSolve.solve(prob::IntervalNonlinearProblem; kwargs...)
return CommonSolve.solve(prob, ITP(); kwargs...)
end

function CommonSolve.solve(prob::IntervalNonlinearProblem, nothing, args...; kwargs...)
return CommonSolve.solve(prob, ITP(), args...; kwargs...)
end

function CommonSolve.solve(prob::IntervalNonlinearProblem,
alg::AbstractBracketingAlgorithm, args...; sensealg = nothing, kwargs...)
return bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, prob.p, alg, args...; kwargs...)
end


function bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, p, alg, args...; kwargs...)
return SciMLBase.__solve(prob, alg, args...; kwargs...)
end


@setup_workload begin
for T in (Float32, Float64)
prob_brack = IntervalNonlinearProblem{false}(
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/alefeld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal
"""
struct Alefeld <: AbstractBracketingAlgorithm end

function CommonSolve.solve(
function SciMLBase.__solve(
prob::IntervalNonlinearProblem, alg::Alefeld, args...;
maxiters = 1000, abstol = nothing, kwargs...
)
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ A common bisection method.
exact_right::Bool = false
end

function CommonSolve.solve(
function SciMLBase.__solve(
prob::IntervalNonlinearProblem, alg::Bisection, args...;
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
)
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Left non-allocating Brent method.
"""
struct Brent <: AbstractBracketingAlgorithm end

function CommonSolve.solve(
function SciMLBase.__solve(
prob::IntervalNonlinearProblem, alg::Brent, args...;
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
)
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ A non-allocating regula falsi method.
"""
struct Falsi <: AbstractBracketingAlgorithm end

function CommonSolve.solve(
function SciMLBase.__solve(
prob::IntervalNonlinearProblem, alg::Falsi, args...;
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
)
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10)
return ITP(scaled_k1, k2, n0)
end

function CommonSolve.solve(
function SciMLBase.__solve(
prob::IntervalNonlinearProblem, alg::ITP, args...;
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
)
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/muller.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

Muller() = Muller(nothing)

function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...;
function SciMLBase.__solve(prob::IntervalNonlinearProblem, alg::Muller, args...;
abstol = nothing, maxiters = 1000, kwargs...)
@assert !SciMLBase.isinplace(prob) "`Muller` only supports out-of-place problems."
xᵢ₋₂, xᵢ = prob.tspan
Expand Down
2 changes: 1 addition & 1 deletion 2 lib/BracketingNonlinearSolve/src/ridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ A non-allocating ridder method.
"""
struct Ridder <: AbstractBracketingAlgorithm end

function CommonSolve.solve(
function SciMLBase.__solve(
prob::IntervalNonlinearProblem, alg::Ridder, args...;
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
)
Expand Down
18 changes: 18 additions & 0 deletions 18 lib/BracketingNonlinearSolve/test/adjoint_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
using ForwardDiff, Zygote, BracketingNonlinearSolve

ff(u, p) = u^2 .- p[1]

function solve_nlprob(p)
prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p)
sol = solve(prob, Bisection())
res = sol isa AbstractArray ? sol : sol.u
return sum(abs2, res)
end

p = [2.0, 2.0]

∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
@test ∂p_zygote ≈ ∂p_forwarddiff
end
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.