Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compilation times #3295

Open
charleskawczynski opened this issue Sep 12, 2024 · 4 comments
Open

Improve compilation times #3295

charleskawczynski opened this issue Sep 12, 2024 · 4 comments
Assignees

Comments

@charleskawczynski
Copy link
Member

charleskawczynski commented Sep 12, 2024

Compilation times are pretty long, and we should probably see what low hanging fruit there is.

I used SnoopCompile to see what inference looks like when we run the driver (which took over 20 minutes to reach the end of the first call to step!):

# julia --project=examples

using SnoopCompileCore
tinf = @snoop_inference begin
	empty!(ARGS);
	push!(ARGS, "--config_file", "config/model_configs/diagnostic_edmfx_trmm_stretched_box.yml");
	push!(ARGS, "--job_id", "diagnostic_edmfx_trmm_stretched_box");
	include("examples/hybrid/driver.jl")
end;

using SnoopCompile
# staleinstances(tinf) # quite a few invalidation
# using AbstractTrees
# print_tree(tinf) # prints way too much info

using FlameGraphs
fg = flamegraph(tinf)
using ProfileView
ProfileView.view(fg)
Screenshot 2024-09-12 at 1 11 31 PM

The two large blocks are build_cache (left side), which mostly spends time in set_precomputed_quantities!, and step_u! (right side).

@charleskawczynski
Copy link
Member Author

Looks like there's still room for improvement:

Screenshot 2024-09-12 at 7 49 09 PM

@charleskawczynski
Copy link
Member Author

charleskawczynski commented Sep 13, 2024

Looking around a little bit, I could see the compiler being uncertain about the types in this function:

function edmfx_sgs_diffusive_flux_tendency!(
    Yₜ,
    Y,
    p,
    t,
    turbconv_model::DiagnosticEDMFX,
)

    FT = Spaces.undertype(axes(Y.c))
    (; dt, params) = p
    turbconv_params = CAP.turbconv_params(params)
    c_d = CAP.tke_diss_coeff(turbconv_params)
    (; sfc_conditions) = p.precomputed
    (; ᶜu, ᶜh_tot, ᶜspecific, ᶜtke⁰, ᶜmixing_length) = p.precomputed
    (; ᶜK_u, ᶜK_h, ρatke_flux) = p.precomputed
    ᶠgradᵥ = Operators.GradientC2F()

    if p.atmos.edmfx_sgs_diffusive_flux
        ᶠρaK_h = p.scratch.ᶠtemp_scalar
        @. ᶠρaK_h = ᶠinterp(Y.c.ρ) * ᶠinterp(ᶜK_h)
        ᶠρaK_u = p.scratch.ᶠtemp_scalar
        @. ᶠρaK_u = ᶠinterp(Y.c.ρ) * ᶠinterp(ᶜK_u)

        # energy
        ᶜdivᵥ_ρe_tot = Operators.DivergenceF2C(
            top = Operators.SetValue(C3(FT(0))),
            bottom = Operators.SetValue(sfc_conditions.ρ_flux_h_tot),
        )
        @. Yₜ.c.ρe_tot -= ᶜdivᵥ_ρe_tot(-(ᶠρaK_h * ᶠgradᵥ(ᶜh_tot)))

        if use_prognostic_tke(turbconv_model)
            # turbulent transport (diffusive flux)
            # boundary condition for the diffusive flux
            ᶜdivᵥ_ρatke = Operators.DivergenceF2C(
                top = Operators.SetValue(C3(FT(0))),
                bottom = Operators.SetValue(ρatke_flux),
            )
            @. Yₜ.c.sgs⁰.ρatke -=
                ᶜdivᵥ_ρatke(-(ᶠρaK_u * ᶠgradᵥ(ᶜtke⁰))) +
                tke_dissipation(Y.c.sgs⁰.ρatke, ᶜtke⁰, ᶜmixing_length, c_d, dt)
        end

        if !(p.atmos.moisture_model isa DryModel)
            # specific humidity
            ᶜρχₜ_diffusion = p.scratch.ᶜtemp_scalar
            ᶜdivᵥ_ρq_tot = Operators.DivergenceF2C(
                top = Operators.SetValue(C3(FT(0))),
                bottom = Operators.SetValue(sfc_conditions.ρ_flux_q_tot),
            )
            @. ᶜρχₜ_diffusion =
                ᶜdivᵥ_ρq_tot(-(ᶠρaK_h * ᶠgradᵥ(ᶜspecific.q_tot)))
            @. Yₜ.c.ρq_tot -= ᶜρχₜ_diffusion
            @. Yₜ.c.ρ -= ᶜρχₜ_diffusion
        end

        # momentum
        ᶠstrain_rate = p.scratch.ᶠtemp_UVWxUVW
        compute_strain_rate_face!(ᶠstrain_rate, ᶜu)
        @. Yₜ.c.uₕ -= C12(ᶜdivᵥ(-(2 * ᶠρaK_u * ᶠstrain_rate)) / Y.c.ρ)
        # apply boundary condition for momentum flux
        ᶜdivᵥ_uₕ = Operators.DivergenceF2C(
            top = Operators.SetValue(C3(FT(0))  C12(FT(0), FT(0))),
            bottom = Operators.SetValue(sfc_conditions.ρ_flux_uₕ),
        )
        @. Yₜ.c.uₕ -= ᶜdivᵥ_uₕ(-(FT(0) * ᶠgradᵥ(Y.c.uₕ))) / Y.c.ρ
    end

    # TODO: Add tracer flux

    return nothing
end

For example, what is the type of ᶠρaK_h in @. ᶜρχₜ_diffusion = ᶜdivᵥ_ρq_tot(-(ᶠρaK_h * ᶠgradᵥ(ᶜq_tot⁰)))? Well, that depends on p.atmos.edmfx_sgs_diffusive_flux, it's either undefined or the same type as p.scratch.ᶠtemp_scalar. Is edmfx_sgs_diffusive_flux a compile-time parameter? No, it's not, so the compiler does not know the type of ᶠρaK_h at compile-time, which means that this could badly impact inference.

I think one thing we can do is convert the runtime bools into types, so that this is better inferred.

@Sbozzolo
Copy link
Member

Wow, the big orange rectangle disappearing with a three line change!

@charleskawczynski
Copy link
Member Author

Wow, the big orange rectangle disappearing with a three line change!

Yeah, but it's a bit suspicious to me. I don't know why the other large platforms appeared. The results could be non-deterministic 🙁.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants