Order nested Dual tags by provable containment (supersedes #807)#820
Draft
ChrisRackauckas-Claude wants to merge 2 commits into
Draft
Order nested Dual tags by provable containment (supersedes #807)#820ChrisRackauckas-Claude wants to merge 2 commits into
ChrisRackauckas-Claude wants to merge 2 commits into
Conversation
Fixes SciML/OrdinaryDiffEq.jl#3381 and superscedes SciML/OrdinaryDiffEq.jl#3587 . Also fixes NonlinearSolve.jl master and superscedes SciML/NonlinearSolve.jl#932 Superscedes JuliaDiff#724 and is a better solution to JuliaDiff#714. The crux of the issue is that ForwardDiff.jl's tagging system is somewhat designed around the tag only being used once, i.e. the function is created, the derivative function is called, the tag is set for that derivative as a type of the function being differentiated, and therefore it's unique. Then this ends up working with nested differentiation because you call the inner function first, usually, before the outer function, or only do the combination, and so the tag ordering is set correctly. Mixing tagging with precompilation then leads to this issue where it's possible for the outer tag to be precompiled before the inner tag. This makes the tag ordering the opposite, and what happens is then that the type promotion mechanism gets confused because it is tied to the tag ordering. This seems pretty fundamental because it's a useful property, it's the core property used to prevent perturbation confusion, but it means that this interaction between nested differentiation and precompilation ends up having odd bugs. I tried working around this downstream (SciML/OrdinaryDiffEq.jl#3587) but it was very nasty. Basically, you had to make sure you didn't have dual numbers automatically converting Float64s, as then sometimes it could convert to the inner type instead of the outer type, and it wouldn't do the normal conversion of first to the inner to then the wrapped outer type because doing so required the outer type to postdate the inner type. But, this really then showcases that the bug truly only manifests with nested types. And if you have nested types, you know you don't have perturbation confusion if one tag is nested deeper than the other tag, because there are not the same number of partials. So in the case where the tag depths are not the same, you can do an alternative tag ordering since you will have already proven perturbations aren't confusing. And in that case, you can choose the deeper nested tags to just always be `<` the less deeper tags. So added that and poof, tag nesting worked out in these cases with precompilation. So I think this captures the true crux of the problem and solves it at its core.
Replace the depth-only fast path with a containment proof: a tag T1 is ordered outside a tag T2 only when T2 provably appears inside T1's type structure — in the seeded value type V, or captured by the function type F (closure/struct fields and type parameters). When neither tag nests the other, ordering falls back to tagcount exactly as before. This fixes both counterexamples found in review: the depth-only path regressed 3-level nesting with a constant-seeded innermost derivative (depth understated when nesting lives in F), and missed the same-depth case where both tags have V === Float64 and the outer perturbation enters through a closure capture. The containment walk runs entirely at compile time via a @generated function, so the steady-state cost of ≺ is unchanged. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #820 +/- ##
==========================================
- Coverage 90.74% 90.65% -0.09%
==========================================
Files 11 11
Lines 1070 1092 +22
==========================================
+ Hits 971 990 +19
- Misses 99 102 +3 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
5 tasks
Member
|
I'm on vacation and will think this through more carefully when I'm back end of July. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is the updated version of #807, superseding it: same goal (fix #714's precompilation-inverted tag ordering, and with it SciML/OrdinaryDiffEq.jl#3381 and SciML/NonlinearSolve.jl#932), but with the depth-only fast path replaced by the containment proof discussed in #807's review thread. Opened as a separate PR only because #807's head branch can't be pushed to from this account; @ChrisRackauckas can instead pull this into #807's branch to keep the discussion in one place:
Note: please treat this as pending until reviewed by @ChrisRackauckas.
The problem
tagcountis a@generatedfunction whose literal is baked at first compile, so precompilation can bake an outer and an inner tag's counts in inverted order (#714). Tag ordering (≺) is a puretagcountcomparison, so nested differentiation then composes Duals in the wrong nesting order — the stochasticDualMismatchError/MethodError: no method matching Float64(::Dual{...})class of failures seen downstream.The fix: order by provable containment, fall back to tagcount
#807's original depth-only fast path had two confirmed counterexamples (thanks @devmotion, confirmed by @riftsim-jarod):
tagdepthread nesting offVonly, so it both regressed a 3-level case where nesting lives in the closure typeF, and missed the same-depth case where both tags haveV === Float64.This version replaces the depth heuristic with a containment proof.
containstag(T1, T2)walksT1's type structure at compile time — the seeded value typeV, plus the function typeF's type parameters and concrete field types (closure captures) — and checks whetherT2genuinely appears inside it. If it does,T1was necessarily created whileT2's derivative was already in progress, soDual{T1}must compose outsideDual{T2}, regardless of whattagcountsays:When neither tag provably nests the other, behavior is exactly master's. The walk is a
@generatedfunction (cycle-guarded, depth-bounded at 32), so≺still folds to a constant at compile time. This is the "prove that one tag genuinely nests the other" safety check from the #807 thread — extended throughFas well asV, which is what also fixes the same-Vcase (the nesting there is only visible through the callable's fields).Verification (all run locally)
12.0— #807's regression gonecos(0.5) - 0.5·sin(0.5)— fixedAutoFiniteDiffreference (9 digits); FD Jacobian vs analytic:3e-8Regression tests added to
test/ConfusionTest.jl, covering containment ordering under explicitly inverted tagcounts in both shapes (V-nesting andF-capture) plus the 3-level constant-seed case. TheF-capture test throwsDualMismatchErroron master and passes here.Two caveats for review:
containstagonly unwrapsUnion/UnionAlland walks concrete field types; pathological/truncated cases fall back to tagcount (master behavior) rather than guessing.The "Julia pre" CI failures are pre-existing — all six pre jobs fail identically on the latest master push run (JET
@test_optintest/QATest.jlflagging Julia 1.13-rc1 Base broadcast-inference regressions; reproduced locally on unmodified master).🤖 Generated with Claude Code