-
-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: CUDNN RNNs #161
WIP: CUDNN RNNs #161
Conversation
10145c8
to
c3589c2
Compare
Eesh, this isn't even done yet (we still need to expose sequences properly). The basic RNN API is working though, so it's ready to go for basic stuff. |
I cannot get tests passed. INFO: Testing Flux/CUDNN
R = Flux.RNN: Error During Test
Got an exception of type CuArrays.CUDNN.CUDNNError outside of a @test
CUDNNError(code 3, CUDNN_STATUS_BAD_PARAM)
Stacktrace:
[1] macro expansion at /home/iblis/.julia/v0.6/CuArrays/src/dnn/error.jl:19 [inlined]
[2] cudnnRNNForward(::Flux.CUDA.RNNDesc{Float32}, ::Int64, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,2}, ::CuArrays.CUDNN.TensorDe
sc, ::CuArray{Float32,2}, ::Ptr{Void}, ::Ptr{Void}, ::CuArrays.CUDNN.FilterDesc, ::CuArray{Float32,1}, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuAr
ray{Float32,2}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,2}, ::Ptr{Void}, ::Ptr{Void}, ::CuArray{UInt8,1}, ::CuArray{UInt8,1}) at /home/iblis
/.julia/v0.6/Flux/src/cuda/cudnn.jl:131 Test Summary: | Pass Error Total
Flux | 109 1 110
Throttle | 11 11
Jacobian | 1 1
Initialization | 14 14
Tracker | 29 29
Dropout | 8 8
BatchNorm | 10 10
losses | 5 5
Optimise | 8 8
Training Loop | 1 1
CuArrays | 4 4
RNN | 15 1 16
R = Flux.RNN | 1 1
R = Flux.GRU | 7 7
R = Flux.LSTM | 8 8
|
I'm on 6 so probably just a version issue. Will debug when I can get 7 handy. |
@iblis17 Not sure what's happening here, but this passes on my CUDNN 7 machine. julia> Pkg.test("Flux")
INFO: Testing Flux
INFO: Testing Flux/GPU
INFO: Testing Flux/CUDNN
Test Summary: | Pass Total
Flux | 116 116
INFO: Flux tests passed
julia> CuArrays.CUDNN.CUDNN_VERSION
7005 |
😱 julia> using CuArrays
julia> CuArrays.CUDNN.CUDNN_VERSION
7005 |
Hmmm. If you want to debug it a bit, it might be worth trying out Knet's implementation and seeing if that works. If it does, it should be straightforward (if tedious) to figure out what we're calling differently. |
Mostly there now. Once the back passes are wrapped, the main item will be to rewrite our existing RNNs as one big matmul. This should have some CPU perf benefits anyway.