You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.
A have carved out an MWE, which would look like this
using Zygote
x = [randn(Float32, 2) for i in1:3]
h =randn(Float32, 2)
functionf(α, h, x)
o =accumulate(x, init = h) do h, x
α * h + x
endendfunctiong(α, h, x)
o =accumulate(x, init = (h, x[1])) do (h,_),x
(α * h + x, x)
endfirst.(o)
endgradient(α ->sum(sum(g(α, h, x))), 1f0)[1]
gradient(α ->sum(sum(f(α, h, x))), 1f0)[1]
While computing gradient of f succeeds, computing gradient of g crashes with
Zygote is constructing tangents that enter the decumulate pullback via wrap_chainrules_output. in this case its hitting the method for Union{Tuple,NamedTuple} which is interesting, because I think it should be using the method for Tuple.
I think this could be fixed by making sure wrap_chainrules_output returns a StructuralTangent... or at least if in zygote I do:
@inlinefunctionwrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
xp =map(wrap_chainrules_input, dxs)
# This produces Tangent{Any} since it does not get to see the primal, `x`.# ChainRulesCore.Tangent{Any, typeof(xp)}(xp) -- comment this out and replace by line below
ChainRulesCore.StructuralTangent{typeof(xp)}(xp)
end
Not certain this is relevant, but notice the similarity to this:
julia>accumulate(=>, (1,2,3))
(1, 1=>2, (1=>2) =>3)
julia>accumulate(=>, [1,2,3])
ERROR: MethodError: Cannot `convert` an object of type Int64 to an object of type Pair{Int64, Int64}
and that this gradient works with x::Tuple:
julia>gradient(α ->sum(sum(g(α, h, Tuple(x)))), 1f0)[1]
15.059713f0
julia>gradient(α ->sum(sum(g(α, h, x))), 1f0)[1] # with x::Vector as above
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})
Hello,
I have been for educational purposes implementing RNN by hand and wanted to be fancy and use
accumulate
instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.A have carved out an MWE, which would look like this
While computing gradient of
f
succeeds, computing gradient ofg
crashes withJulia and environment
Thanks for help
The text was updated successfully, but these errors were encountered: