Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNCell, LSTMCell and GRUCell are implemented as mutable structs, but never do mutation #1089

Closed
AzamatB opened this issue Mar 18, 2020 · 11 comments
Labels

Comments

@AzamatB
Copy link
Contributor

AzamatB commented Mar 18, 2020

It seems like RNNCell, LSTMCell and GRUCell that are defined as mutable structs in src/layers/recurrent.jl never actually use any mutation in the forward pass. Instead, it looks like state mutation is handled in the Recur struct, which copies over the hidden state of its cell during construction and then never reads or writes the cell's hidden state onwards.

Given this, what is the reason these recurrent cells are defined as mutable structs?
Should we change them to be immutable to reap the performance benefits that come with this?

@DrChainsaw
Copy link
Contributor

You might also want to take a look at the h member of said structs. It seems to only be used to initialize the hidden and init state of Recur when creating it. If this is the case it seems a bit wasteful and confusing to keep them as members of the cells.

@bhvieira
Copy link
Contributor

The hidden state dimensions change depending on the size of the batch

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 16, 2020

@bhvieira yes, but that change is handled as part of Recur struct, specifically as it state field. The cells themselves are never mutated, which what this issue is about.

@bhvieira
Copy link
Contributor

Hmm, I see. The cells themselves are never updated, only their values just like any other layer. I'll make some tests. I remember asking something similar in the Slack channel though.

@bhvieira
Copy link
Contributor

Alright, it works and I have no clue why it's mutable, really. It could me immutable. See the spoilers below:

using Flux
import Flux: hidden

struct RNNCell2{F,A,V}
  σ::F
  Wi::A
  Wh::A
  b::V
  h::V
end


RNNCell2(in::Integer, out::Integer, σ = tanh;
        init = glorot_uniform) =
  RNNCell2(σ, init(out, in), init(out, out),
          init(out), Flux.zeros(out))


function (m::RNNCell2)(h, x)
  σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
  h = σ.(Wi*x .+ Wh*h .+ b)
  return h, h
end

hidden(m::RNNCell2) = m.h

Flux.@functor RNNCell2

RNN2(a...; ka...) = Flux.Recur(RNNCell2(a...; ka...))

using Flux: glorot_uniform

x = randn(Float32, 10, 30);
l1 = RNN(10, 10);
l2 = RNN2(10, 10);

gs1 = gradient(() -> sum(abs2, l2.([x for _ in 1:10])[end]), params(l2));
gs2 = gradient(() -> sum(abs2, l2.([x for _ in 1:10])[end]), params(l2));

@bhvieira
Copy link
Contributor

They were turned mutable only in #161

@bhvieira
Copy link
Contributor

You might also want to take a look at the h member of said structs. It seems to only be used to initialize the hidden and init state of Recur when creating it. If this is the case it seems a bit wasteful and confusing to keep them as members of the cells.

You can optimize the initial state though. But the way they are set right now doesn't allow that by default. And then you get a repetition: the initial state appears in the cell and in Recur.init.

@mkschleg
Copy link
Contributor

Further refining to the commit: 9a6fcf0#diff-d486393fe3ae37696de565e0fbd70386

Looks like it had to do with connecting the CUDA and Flux APIs for RNNs. I'll try and test later if making them non-mutable breaks this interface. Nothing is jumping out at me here, but I've only glanced at it.

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 16, 2020

You can optimize the initial state though. But the way they are set right now doesn't allow that by default.

All you need to do to train the initial state right now is to call reset! as part of your forward pass. This will set Recur.state to Recur.init, which is marked as trainable, so will get updated by the optimizer.

And then you get a repetition: the initial state appears in the cell and in Recur.init.

Yes, having both Cell.h and Recur.init is redundant

jeremiedb pushed a commit to jeremiedb/Flux.jl that referenced this issue Oct 22, 2020
check that CUDNN drop solves for too many wrappers - FluxML#1259
@CarloLucibello
Copy link
Member

this is fixed on master

@AzamatB
Copy link
Contributor Author

AzamatB commented Dec 27, 2020

Fixed in #1367

@mcabbott mcabbott added the RNN label Oct 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants