Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default type for initializer? #564

Closed
chengchingwen opened this issue Jan 22, 2019 · 5 comments
Closed

Default type for initializer? #564

chengchingwen opened this issue Jan 22, 2019 · 5 comments

Comments

@chengchingwen
Copy link
Member

I notice that in the newest release make the default initializer initialize with Float32 instead of Float64. Does that mean from now on we are going to assume that all float should be 32 bit?

@MikeInnes
Copy link
Member

Layers should still work with 64 bit input (feel free to open an issue if you see errors with that) but yeah, it can be a good idea to convert data and such to 32 bit.

@KristofferC
Copy link
Contributor

Ref #501

@chengchingwen
Copy link
Member Author

@MikeInnes I can't find an example for it now, but there are some situations that we accidentally mixed 64 bit and 32 bit float together which cause the type conversion without noticed. I noticed it because my model throw an error after I update Flux. Maybe we could have another method for passing a dtype argument into the initializer or the constructor?

@MikeInnes
Copy link
Member

I'd like to have some tooling for changing the dtype of models, yeah. cc @dhairyagandhi96

@chengchingwen
Copy link
Member Author

Currently I overwrite the methods like this

glorot_uniform(T::Type, dims...) = (rand(T, dims...) .- T(0.5)) .* sqrt(T(24.0)/sum(dims))
glorot_normal(T::Type, dims...) = randn(T, dims...) .* sqrt(T(2.0)/sum(dims))

function Dense(dtype::Type, in::Integer, out::Integer, σ = identity;
               initW = glorot_uniform, initb = zeros)
  return Dense(param(initW(dtype, out, in)), param(initb(dtype, out)), σ)
end

but I'm not sure what will be the best strategy for it. Float32 is kind of a trouble since any Float64 appeared will cause the output converted to 64 bit even if it is on gpu. It will also cause a huge performance issue if the type is not stable. I almost do type conversion on every single float appeared in my code manually.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants