-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid device-to-host copy in ∇getindex!
#801
base: main
Are you sure you want to change the base?
Conversation
Can this package extension live in GPUArrays or one of the other packages this is a extension on? It;s out general policy not to have rules that are applicable to specific packages in ChainRules.jl but rather just have the rules for Base and StdLibs. This case is a little less clear since it overloading |
Would it be fine to remove |
Is the problem that it uses other packages and adds them as weak dependencies? |
That is a fair point, we did apparently already make an exception for |
Could this Gordion Knot be sliced by creating an API in CRC which both ChainRules and GPU array packages could use and overload respectively? Maybe something adjacent to the current accumulation functionality ( |
Since this is basically BTW, I'm not entirely sure if it is OK that an extension for |
This seems to be a major performance bump. Could it be resolved merged soon? |
368e211
to
e23e8d1
Compare
Implement kernel for accumulation in
∇getindex!
for generic index types.Fallback to old implementation if KernelAbstractions backend does not support atomics (currently this is Metal.jl).
This is implemented as an extension which is triggered by
Atomix
,KernelAbstractions
,GPUArrays
(which every GPU backend has as dependencies). This way we won't need to manually specify every GPU backend.Closes #800.
Benchmarking:
Before:
Now: