From 81fb59acbe7b043a06cbb80ad1ef5f72a2d04338 Mon Sep 17 00:00:00 2001 From: Pepijn de Vos Date: Thu, 2 Nov 2023 18:30:41 +0100 Subject: [PATCH 1/9] Change typeof(x) <: y to x isa y --- gen/generate.jl | 6 +++--- src/common_interface/function_types.jl | 2 +- src/common_interface/integrator_types.jl | 4 ++-- src/common_interface/integrator_utils.jl | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gen/generate.jl b/gen/generate.jl index ba291bc1..9611a747 100644 --- a/gen/generate.jl +++ b/gen/generate.jl @@ -86,7 +86,7 @@ function wrap_sundials_api(expr::Expr) if occursin(r"UserDataB?$", func_name) # replace Ptr{Void} with Any to allow passing Julia objects through user data for (i, arg_expr) in enumerate(expr.args[2].args[1].args) - if !(typeof(arg_expr) <: Symbol) && + if !(arg_expr isa Symbol) && arg_expr.args[1] in values(ctor_return_type) if arg_expr.args[2] == :(Ptr{Cvoid}) arg_expr.args[2] = Any @@ -96,7 +96,7 @@ function wrap_sundials_api(expr::Expr) end end end - if !(typeof(expr) <: Symbol) && length(expr.args) > 1 && + if !(expr isa Symbol) && length(expr.args) > 1 && (expr.args[2].args[1].args[2].args[2] == :libsundials_sunlinsol || expr.args[2].args[1].args[2].args[2] == :libsundials_sunmatrix || expr.args[2].args[1].args[2].args[2] == :libsundials_sunnonlinsol) @@ -124,7 +124,7 @@ function wrap_sundials_api(expr::Expr) # 2) expr for local var definition, nothing if not required # 3) expr for low-level wrapper call # if 1)==3), then no wrapping is required - if typeof(expr) <: Symbol + if expr isa Symbol arg_name_expr = expr arg_type_expr = Any else diff --git a/src/common_interface/function_types.jl b/src/common_interface/function_types.jl index c6e8e353..d824b90f 100644 --- a/src/common_interface/function_types.jl +++ b/src/common_interface/function_types.jl @@ -146,7 +146,7 @@ function massmat(t::Float64, tmp1::N_Vector, tmp2::N_Vector, tmp3::N_Vector) - if typeof(mmf.mass_matrix) <: Array + if mmf.mass_matrix isa Array M = convert(Matrix, _M) M .= mmf.mass_matrix else diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index af43121d..6be8ca72 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -220,7 +220,7 @@ DiffEqBase.postamble!(integrator::AbstractSundialsIntegrator) = nothing tstop = first(integrator.opts.tstops) set_stop_time(integrator, tstop) integrator.tprev = integrator.t - if !(typeof(integrator.opts.callback.continuous_callbacks) <: Tuple{}) + if !(integrator.opts.callback.continuous_callbacks isa Tuple{}) integrator.uprev .= integrator.u end solver_step(integrator, tstop) @@ -231,7 +231,7 @@ DiffEqBase.postamble!(integrator::AbstractSundialsIntegrator) = nothing end else integrator.tprev = integrator.t - if !(typeof(integrator.opts.callback.continuous_callbacks) <: Tuple{}) + if !(integrator.opts.callback.continuous_callbacks isa Tuple{}) integrator.uprev .= integrator.u end if !isempty(integrator.opts.tstops) diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index 505255c8..1f2d077b 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -6,7 +6,7 @@ function handle_callbacks!(integrator) continuous_modified = false discrete_modified = false saved_in_cb = false - if !(typeof(continuous_callbacks) <: Tuple{}) + if !(continuous_callbacks isa Tuple{}) time, upcrossing, event_occured, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback(integrator, continuous_callbacks...) if event_occured @@ -22,7 +22,7 @@ function handle_callbacks!(integrator) integrator.vector_event_last_time = 1 end end - if !(typeof(discrete_callbacks) <: Tuple{}) + if !(discrete_callbacks isa Tuple{}) discrete_modified, saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator, discrete_callbacks...) end From 7eb8a120aed3b18d62208437026892c0964e69ae Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 15:11:56 -0500 Subject: [PATCH 2/9] `finalize` callbacks see also https://github.com/SciML/OrdinaryDiffEq.jl/pull/2061 --- src/common_interface/solve.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index c1786400..6f51b8e6 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -1441,6 +1441,7 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = handle_tstop!(integrator) end + DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator) tend = integrator.t if integrator.opts.save_end && (isempty(integrator.sol.t) || integrator.sol.t[end] != tend) From 242dccec3541590535f8ed14ef670f66ecfc12d6 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 15:13:26 -0500 Subject: [PATCH 3/9] add test --- test/common_interface/callbacks.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/common_interface/callbacks.jl b/test/common_interface/callbacks.jl index 4f080cab..4f5c9939 100644 --- a/test/common_interface/callbacks.jl +++ b/test/common_interface/callbacks.jl @@ -54,9 +54,11 @@ function condition2(u, t, integrator) get_du(integrator)[1] > 0 end affect2!(integrator) = terminate!(integrator) -cb = DiscreteCallback(condition2, affect2!) +times_finalize_called = 0 +cb = DiscreteCallback(condition2, affect2!, finalize=(args...)->times_finalize_called+=1) sol = solve(prob, CVODE_BDF(); callback = cb) @test sol.t[end] < 3.5 +@test times_finalize_called = 1 condition3(u, t, integrator) = u[2] affect3!(integrator) = terminate!(integrator) From bdf3f7f3ac690bee3ae7bf6ac855e42f9479f5d9 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 15:43:38 -0500 Subject: [PATCH 4/9] typo --- test/common_interface/callbacks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/common_interface/callbacks.jl b/test/common_interface/callbacks.jl index 4f5c9939..ff540e8c 100644 --- a/test/common_interface/callbacks.jl +++ b/test/common_interface/callbacks.jl @@ -55,7 +55,7 @@ function condition2(u, t, integrator) end affect2!(integrator) = terminate!(integrator) times_finalize_called = 0 -cb = DiscreteCallback(condition2, affect2!, finalize=(args...)->times_finalize_called+=1) +cb = DiscreteCallback(condition2, affect2!, finalize=(args...)->global times_finalize_called+=1) sol = solve(prob, CVODE_BDF(); callback = cb) @test sol.t[end] < 3.5 @test times_finalize_called = 1 From c385febde7df4cc8a9ec0519cb8c6d1e343543da Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 15:53:51 -0500 Subject: [PATCH 5/9] typo 2 --- test/common_interface/callbacks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/common_interface/callbacks.jl b/test/common_interface/callbacks.jl index ff540e8c..a2e6a378 100644 --- a/test/common_interface/callbacks.jl +++ b/test/common_interface/callbacks.jl @@ -58,7 +58,7 @@ times_finalize_called = 0 cb = DiscreteCallback(condition2, affect2!, finalize=(args...)->global times_finalize_called+=1) sol = solve(prob, CVODE_BDF(); callback = cb) @test sol.t[end] < 3.5 -@test times_finalize_called = 1 +@test times_finalize_called == 1 condition3(u, t, integrator) = u[2] affect3!(integrator) = terminate!(integrator) From 70654b6d9fa482559ce2b7cf2497e4029215d292 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 20 Nov 2023 17:45:08 -0500 Subject: [PATCH 6/9] make the saveat and tstop handling match OrdinaryDiffEq more closely and remove bug where we would pretend to hit tstops when not actually doing so --- src/common_interface/integrator_utils.jl | 6 +-- src/common_interface/solve.jl | 53 ++++++++---------------- 2 files changed, 20 insertions(+), 39 deletions(-) diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index 1f2d077b..91454034 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -46,7 +46,7 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, uType = typeof(integrator.sol.prob.u0) # The call to first is an overload of Base.first implemented in DataStructures while !isempty(integrator.opts.saveat) && - integrator.tdir * first(integrator.opts.saveat) < integrator.tdir * integrator.t + first(integrator.opts.saveat) <= integrator.tdir * integrator.t saved = true curt = pop!(integrator.opts.saveat) @@ -123,13 +123,13 @@ end function DiffEqBase.add_tstop!(integrator::AbstractSundialsIntegrator, t) integrator.tdir * (t - integrator.t) < 0 && error("Tried to add a tstop that is behind the current time. This is strictly forbidden") - push!(integrator.opts.tstops, t) + push!(integrator.opts.tstops, integrator.tdir * t) end function DiffEqBase.add_saveat!(integrator::AbstractSundialsIntegrator, t) integrator.tdir * (t - integrator.t) < 0 && error("Tried to add a saveat that is behind the current time. This is strictly forbidden") - push!(integrator.opts.saveat, t) + push!(integrator.opts.saveat, integrator.tdir * t) end DiffEqBase.get_tmp_cache(integrator::AbstractSundialsIntegrator) = (integrator.tmp,) diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 6f51b8e6..4e8580f5 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -936,42 +936,25 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i end # function solve function tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) - if isempty(tstops) # TODO: Specialize more - tstops_vec = [tspan[2]] - else - tstops_vec = vec(collect(tType, - Iterators.filter(x -> tdir * tspan[1] < tdir * x ≤ - tdir * tspan[end], - Iterators.flatten((tstops, tspan[end]))))) - end + tstops_internal = DataStructures.BinaryHeap{tType}(DataStructures.FasterForward()) + saveat_internal = DataStructures.BinaryHeap{tType}(DataStructures.FasterForward()) - if tdir > 0 - tstops_internal = DataStructures.BinaryMinHeap(tstops_vec) - else - tstops_internal = DataStructures.BinaryMaxHeap(tstops_vec) + t0, tf = tspan + tdir_t0 = tdir * t0 + tdir_tf = tdir * tf + + for t in tstops + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) end + push!(tstops_internal, tdir_tf) if saveat isa Number - if (tspan[1]:saveat:tspan[end])[end] == tspan[end] - saveat_vec = convert(Vector{tType}, - collect(tType, (tspan[1] + saveat):saveat:tspan[end])) - else - saveat_vec = convert(Vector{tType}, - collect(tType, - (tspan[1] + saveat):saveat:(tspan[end] - saveat))) - end - elseif isempty(saveat) - saveat_vec = saveat - else - saveat_vec = vec(collect(tType, - Iterators.filter(x -> tdir * tspan[1] < tdir * x < - tdir * tspan[end], saveat))) + saveat = (t0:tdir*abs(saveat):tf)[2:end] end - - if tdir > 0 - saveat_internal = DataStructures.BinaryMinHeap(saveat_vec) - else - saveat_internal = DataStructures.BinaryMaxHeap(saveat_vec) + for t in saveat + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t) end tstops_internal, saveat_internal @@ -1409,11 +1392,9 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = uType = eltype(integrator.sol.u) iters = Ref(Clong(-1)) while !isempty(integrator.opts.tstops) - # Sundials can have floating point issues approaching a tstop if - # there is a modifying event each # The call to first is an overload of Base.first implemented in DataStructures - while integrator.tdir * (integrator.t - first(integrator.opts.tstops)) < -1e6eps() - tstop = first(integrator.opts.tstops) + while integrator.tdir * integrator.t < first(integrator.opts.tstops) + tstop = integrator.tdir * first(integrator.opts.tstops) set_stop_time(integrator, tstop) integrator.tprev = integrator.t if !(integrator.opts.callback.continuous_callbacks isa Tuple{}) @@ -1490,7 +1471,7 @@ end function handle_tstop!(integrator::AbstractSundialsIntegrator) tstops = integrator.opts.tstops if !isempty(tstops) - if integrator.tdir * (integrator.t - first(integrator.opts.tstops)) > -1e6eps() + if integrator.tdir * integrator.t < first(integrator.opts.tstops) pop!(tstops) t = integrator.t integrator.just_hit_tstop = true From 7dac95dd96eedb9610c06ba8c53b1d5994f156b2 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 21 Nov 2023 10:39:03 -0500 Subject: [PATCH 7/9] fix --- src/common_interface/integrator_utils.jl | 3 +-- src/common_interface/solve.jl | 9 +++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index 91454034..87642bfe 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -48,8 +48,7 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, while !isempty(integrator.opts.saveat) && first(integrator.opts.saveat) <= integrator.tdir * integrator.t saved = true - curt = pop!(integrator.opts.saveat) - + curt = integrator.tdir * pop!(integrator.opts.saveat) tmp = integrator(curt) save_value!(integrator.sol.u, tmp, uType, integrator.opts.save_idxs, false) diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 4e8580f5..e1adc7d8 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -1470,12 +1470,13 @@ end function handle_tstop!(integrator::AbstractSundialsIntegrator) tstops = integrator.opts.tstops - if !isempty(tstops) - if integrator.tdir * integrator.t < first(integrator.opts.tstops) + if !isempty(tstops) && integrator.tdir * integrator.t >= first(tstops) + pop!(tstops) + # If we passed multiple tstops at once (possible if Sundials ignores us or we had redundant tstops) + while !isempty(tstops) && integrator.tdir * integrator.t >= first(tstops) pop!(tstops) - t = integrator.t - integrator.just_hit_tstop = true end + integrator.just_hit_tstop = true end end From dad5580d55f1a30f21c84483a1a40120902e8cfc Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 10 Dec 2023 09:26:51 -0500 Subject: [PATCH 8/9] Try bumping lower bounds and see if it's all working --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 92855280..c245a498 100644 --- a/Project.toml +++ b/Project.toml @@ -17,12 +17,12 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Sundials_jll = "fb77eaff-e24c-56d4-86b1-d163f2edb164" [compat] -CEnum = "0.2, 0.3, 0.4, 0.5" +CEnum = "0.5" DataStructures = "0.18" DiffEqBase = "6.122" PrecompileTools = "1" -Reexport = "0.2, 1.0" -SciMLBase = "1.92, 2" +Reexport = "1.0" +SciMLBase = "2" Sundials_jll = "5.2" julia = "1.6" From fc4326937183f62c794385126b19f9c184428429 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 10 Dec 2023 09:38:49 -0500 Subject: [PATCH 9/9] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c245a498..2aa0b977 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Sundials" uuid = "c3572dad-4567-51f8-b174-8c6c989267f4" authors = ["Chris Rackauckas "] -version = "4.20.1" +version = "4.21.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"