Skip to content

Commit

Permalink
Avoid type promotions due to hard-coded tolerances (#438)
Browse files Browse the repository at this point in the history
* Avoid type promotions due to hard-coded tolerances

...in the `AlefieldPotraShi` solver.

* Fix a unit handling problem
  • Loading branch information
jmert authored Aug 26, 2024
1 parent 5794200 commit 53670e8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/Bracketing/alefeld_potra_shi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ function update_state(
options,
l=NullTracks(),
) where {T,S}
μ, λ = 0.5, 0.7
atol, rtol = options.xabstol, options.xreltol
μ, λ = oftype(rtol, 0.5), oftype(rtol, 0.7)
tols = (; λ=λ, atol=atol, rtol=rtol)

a::T, b::T, d::T, ee::T = o.xn0, o.xn1, o.d, o.ee
Expand Down Expand Up @@ -187,7 +187,7 @@ struct A2425{K} <: AbstractAlefeldPotraShi end
function calculateΔ(::A2425{K}, F::Callable_Function, c₀::T, ps) where {K,T}
a, b, d, ee = ps.a, ps.b, ps.d, ps.ee
fa, fb, fd, fee = ps.fa, ps.fb, ps.fd, ps.fee
tols ==0.7, atol=ps.atol, rtol=ps.rtol)
tols ==oftype(ps.rtol, 0.7), atol=ps.atol, rtol=ps.rtol)

c = a
for k in 1:K
Expand Down Expand Up @@ -236,7 +236,7 @@ fncalls_per_step(::A57{K}) where {K} = K - 1
function calculateΔ(::A57{K}, F::Callable_Function, c₀::T, ps) where {K,T}
a, b, d, ee = ps.a, ps.b, ps.d, ps.ee
fa, fb, fd, fee = ps.fa, ps.fb, ps.fd, ps.fee
tols ==0.7, atol=ps.atol, rtol=ps.rtol)
tols ==oftype(ps.rtol, 0.7), atol=ps.atol, rtol=ps.rtol)
c, fc = a, fa

for k in 1:K
Expand Down
2 changes: 1 addition & 1 deletion src/convergence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ function is_small_Δx(
state::AbstractUnivariateZeroState,
options,
)
δ = abs(state.xn1 - state.xn0)
δ = _unitless(abs(state.xn1 - state.xn0))
δₐ, δᵣ = options.xabstol, options.xreltol
Δₓ = max(_unitless(δₐ), _unitless(abs(state.xn1)) * δᵣ)
Δₓ = sqrt(sqrt(sqrt((abs(_unitless(Δₓ)))))) # faster than x^(1/8)
Expand Down
8 changes: 7 additions & 1 deletion test/test_composable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ using ForwardDiff
@testset "find zero(s) with Unitful" begin
s = u"s"
m = u"m"
g = 9.8 * m / s^2
g = (9 + 8//10) * m / s^2
v0 = 10m / s
y0 = 16m
y(t) = -g * t^2 + v0 * t + y0

for order in orders
@test find_zero(y, 1.8s, order) 1.886053370668014s
@test find_zero(y, 1.8f0s, order) isa typeof(1.88f0s)
end

for M in [Roots.Bisection(), Roots.A42(), Roots.AlefeldPotraShi()]
@test find_zero(y, (1.8s, 1.9s), M) 1.886053370668014s
@test find_zero(y, (1.8f0s, 1.9f0s), M) isa typeof(1.88f0s)
end

xrts = find_zeros(y, 0s, 10s)
Expand All @@ -44,6 +46,10 @@ using ForwardDiff
# issue #434
xzs1 = find_zeros(x -> cos(x / 1u"m"), -1.6u"m", 2u"m")
@test length(xzs1) == 2 && maximum(xzs1) 1.5707963267948966 * u"m"

FX = ZeroProblem(y, (0f0s, 2f0s))
prob = Roots.init(FX, Roots.AlefeldPotraShi())
@test Roots.is_small_Δx(prob.M, prob.state, prob.options) isa Bool # does not throw
end

# Polynomials
Expand Down

0 comments on commit 53670e8

Please sign in to comment.