Skip to content

Commit

Permalink
RNN interface
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Feb 2, 2018
1 parent b1c5786 commit 10145c8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
50 changes: 50 additions & 0 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,53 @@ function forwardInference(rnn::RNNDesc{T}, x, h, c = nothing) where T
return y, hout, cout
end
end

# Interface

import ..Flux: Flux, relu
import ..Flux.Tracker: TrackedArray
using CUDAnative
using CuArrays: @cuindex, cudims

function copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end

CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}

function copyparams!(m::CuRNN, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end

function RNNDesc(m::Union{CuRNN{T},CuGRU{T},CuLSTM{T}}) where {T}
h, i = size(m.Wi)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end

const descs = WeakKeyDict()
desc(rnn) = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))

function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T
d = desc(m)
copyparams!(m, d)
y, h = forwardInference(d, Flux.data(x), Flux.data(h))
return h, y
end
6 changes: 3 additions & 3 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ flip(f, xs) = reverse(f.(reverse(xs)))

# Vanilla RNN

struct RNNCell{F,A,V}
mutable struct RNNCell{F,A,V}
σ::F
Wi::A
Wh::A
Expand Down Expand Up @@ -112,7 +112,7 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))

# LSTM

struct LSTMCell{A,V}
mutable struct LSTMCell{A,V}
Wi::A
Wh::A
b::V
Expand Down Expand Up @@ -161,7 +161,7 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))

# GRU

struct GRUCell{A,V}
mutable struct GRUCell{A,V}
Wi::A
Wh::A
b::V
Expand Down

0 comments on commit 10145c8

Please sign in to comment.