-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relax linear indexing requirement _slightly_ #216
base: master
Are you sure you want to change the base?
Relax linear indexing requirement _slightly_ #216
Conversation
Isn't it much faster to use a |
Well, currently you end up with every Personally I'd rather take slightly slower AD with ReverseDiff than AD with ReverseDiff that completely breaks the expectation of the user and functionality. |
And, in the |
Codecov ReportBase: 84.48% // Head: 84.51% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #216 +/- ##
==========================================
+ Coverage 84.48% 84.51% +0.03%
==========================================
Files 18 18
Lines 1921 1925 +4
==========================================
+ Hits 1623 1627 +4
Misses 298 298
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
I wonder if one could just define function TrackedArray(sol::ODESolution)
ODESolution(TrackedArray(sol.u), sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg, sol.interp, sol.dense, sol.tslocation, sol.destats, sol.alg_choice, sol.retcode)
end (possibly one has to handle |
Something like that might be possible? EDIT: Just function ReverseDiff.track(::DiffEqBase.ODESolution, tp::Vector{ReverseDiff.AbstractInstruction}=ReverseDiff.InstructionTape())
DiffEqBase.ODESolution(
ReverseDiff.track(sol.u, tp), # But this won't work because `sol.u` is a `Vector{<:Vector}`.
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.destats,
sol.alg_choice,
sol.retcode
)
end |
But regardless of the discussion related to |
I'm still not sure if disabling the check should be called a feature. I think it would be great though if ReverseDiff would suppport From a practical perspective, if you would want to implement BTW could implementing the 3-arg outer constructor (with the derivative information and tape) fix the ODESolution issue? |
Currently
TrackedArray
requires the input to satisfyIndexStyle(x) === IndexLinear()
since ReverseDiff currently only has the capability of tracking, well, arrays supporting linear indexing.But supporting linear indexing and having
IndexStyle(x) === IndexLinear()
are, IIUC, two different things: you can support linear indexing while still havingIndexStyle(x) === IndexCartesian()
, i.e. linear indexing is not the most efficient indexing.For example,
DifferentialEquations.DESolution
supports linear indexing but hasIndexStyle(x) === IndexCartesian()
.Currently, this means that DiffEq has to hack around this constraint by converting into a
Matrix
, completely losing all the information related to theDESolution
.This PR adds a method
supports_linear_indexing
which gives arrays such asDESolution
a way to tell ReverseDiff that it supports linear indexing even though it's not maybe the most efficient way to index in the array.I honestly don't know 100% if this is the way to go, but it seems to do the trick locally (and seem to compute the correct gradients) so figured I'd make a PR to maybe at least get a discussion going.