Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Bilinear layer #1009

Merged
merged 37 commits into from
Feb 11, 2021
Merged

Added Bilinear layer #1009

merged 37 commits into from
Feb 11, 2021

Conversation

bhvieira
Copy link
Contributor

A basic implementation inspired on https://pytorch.org/docs/stable/nn.html#bilinear

I haven't exported it, because I think this layer is a bit more esoteric compared with others.

It basically computes interactions between two sets of inputs.

I thought about augmenting it to also include the non-interaction terms (this can easily be done, eg. augmenting the data with a row of ones) but for now it simply mirrors PyTorch's one.

I had to use splatting vcat(x...) and hcat(x...) in the forward pass. I wanted to avoid it, but with reduce I couldn't get gradients. But I think this can be improved.

@mcabbott
Copy link
Member

Seems like a reason to want FluxML/NNlib.jl#100

But even without that, I think it's not hard to avoid the splats, and go almost 100x faster. Some scribbles here: https://gist.github.com/mcabbott/29cc74f287a95724d6f561f4ed285624

@bhvieira
Copy link
Contributor Author

bhvieira commented Jan 31, 2020

But even without that, I think it's not hard to avoid the splats, and go almost 100x faster. Some scribbles here: https://gist.github.com/mcabbott/29cc74f287a95724d6f561f4ed285624

Cool stuff, didn't know about @einsum (as a physicist, I'm specially pleased with it!), really concise. Feel free to push your code to my branch!

@bhvieira bhvieira changed the title Added Billinear layer Added Bilinear layer Jan 31, 2020
@mcabbott
Copy link
Member

mcabbott commented Feb 2, 2020

I updated the gist, now OMEinsum's @ein is the fastest way. (Curiously pytorch's layer is quite a bit slower, and using their einsum brings it down to only twice the time.) And I believe it will work on GPUs too, which the others may not? Haven't tried as mine is broken.

However I'm not sure that Flux wants to depend on that package, so I'm not so sure what the best answer is.

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 2, 2020

If we can't add that dependency we can just fallback to your previous implementation, thanks for the commit @mcabbott!

src/layers/basic.jl Outdated Show resolved Hide resolved
@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 2, 2020

@mcabbott where does eachcol come from?

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 2, 2020

Oh, is it DataFrames.jl? I think we could define our own here, DataFrames.jl would be a big dependency to include.

Edi: It's Base, but I'm on Julia 1.0 right now, that's why I can't see it.

src/layers/basic.jl Outdated Show resolved Hide resolved
@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 2, 2020

No, it's in Base, but perhaps only since 1.1? It is one line though.

Yeah, I tend to stick to the LTS versions, I'll check it

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 3, 2020

The only error remaining is about reducehcat, probably about it's adjoint as it appears in the gradient call in the last test.


I went ahead and removed its type annotations, but it'd be better to put other, suitable ones in place.

@DhairyaLGandhi
Copy link
Member

eachcol is think is Julia v1.1+ only, so will fail on earlier versions

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 10, 2020

I got a new implementation working now though, using Zygote.Buffer. It's faster (edit: I'm not that sure it's faster now, but I liked the idea of reusing Zygote machinery when possible nonetheless) than the previous one, and now the code uses it:

#current
@btime b($x...); #  48.201 μs (609 allocations: 103.31 KiB)
@btime gradient(() -> sum(abs2.(b($x...))), params(b)); #  11.179 ms (62819 allocations: 8.26 MiB)

#previous
@btime b($x...); #  53.300 μs (1022 allocations: 69.66 KiB)
@btime gradient(() -> sum(abs2.(b($x...))), params(b)); #  11.262 ms (62819 allocations: 8.26 MiB)

@arnavs
Copy link

arnavs commented Feb 26, 2020

Actually, I"m having trouble Chaining this to other layers.

I think you might need to add a multi-arg chain, e.g.

Flux.applychain(fs::Tuple, x, y) = Flux.applychain(Base.tail(fs), first(fs)(x, y))
(c::Chain)(x, y) = Flux.applychain(c.layers, x, y)

or something.

I see

MethodError: no method matching (::Chain{Tuple{Bilinear{Array{Float32,3},Array{Float64,1},typeof(tanh)},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}})(::Array{Float64,1}, ::Array{Float64,1})
Closest candidates are:
  Any(::Any) at /Users/arnavsood/.julia/packages/Flux/2i5P1/src/layers/basic.jl:32

Stacktrace:
 [1] top-level scope at In[99]:1

@bhvieira
Copy link
Contributor Author

@dhairyagandhi96 can this one be merged?

@arnavs
Copy link

arnavs commented Feb 27, 2020

@bhvieira Any reason to not add the chain methods? They were what I needed to get this to work.

@bhvieira
Copy link
Contributor Author

@arnavs I think you deleted a comment or something? Didn't see your suggestion for some reason, I might look into it, but could do it in another PR as well

@arnavs
Copy link

arnavs commented Feb 27, 2020

Yeah, I'd made an earlier comment just asking if this was ready to merge. And then a follow up with the bug report.

Thanks for looking into it. Basically we just need chains to act on two arguments, otherwise you can't use a bilinear layer as the first in a chain. So those two lines work for me, but perhaps there are better ways.

@bhvieira
Copy link
Contributor Author

bhvieira commented Mar 1, 2020

I fixed that issue @arnavs without touching Chain, see if it works for you now 🙂

@mcabbott
Copy link
Member

mcabbott commented Mar 2, 2020

Note that batched_mul is now merged, FluxML/NNlib.jl#100, and has a gradient FluxML/Zygote.jl#531. Not yet hooked up for cuarrays, but will surely be. I think this PR ought to use that, instead of hacking its own version. Not sure how you are timing things, but if my updated gist is correct, then using this is about 1000 times faster.

@bhvieira
Copy link
Contributor Author

bhvieira commented Mar 2, 2020

@mcabbott Gosh this never stops haha. It's cool that we can rely on batched_mul, but I wouldn't call using Buffer 'hacking' by any means. Anyways, can you open a PR against my branch again? Checks are failing because I probably did something wrong, so I'll look into that when I have the time later today.

@mcabbott
Copy link
Member

mcabbott commented Mar 2, 2020

Sorry, no insult intended, if that came off wrong, I'm guilty of earlier hacks. But this is a common operation which, like *, should ideally be outsourced to the professionals. And now at last we can easily do so.

function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
    W, b, σ = a.W, a.b, a.σ

    d_z, d_x, d_y = size(W)
    d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
    size(x,2) == size(y,2) || throw(DimensionMismatch("data inputs must agree on number of columns"))

    # @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]
    Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :))

    # @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
    Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :)))
    Z = reshape(Wyx, (d_z, :))

    # @einsum out[o,s] := σ(Z[o,i] + b[o])
    σ.(Z .+ b)
end

src/layers/basic.jl Outdated Show resolved Hide resolved
@bhvieira
Copy link
Contributor Author

bhvieira commented Mar 4, 2020

With the timely PRs by @mcabbott, I think we are set here and the functionality is better than ever. Is there anything else you think we should do here @dhairyagandhi96?

@bhvieira
Copy link
Contributor Author

bhvieira commented Mar 4, 2020

Btw, should it be exported? Similarly "uncommon" functionalities aren't exported, so I did not include it, but I can add it you deem it useful.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

looks good! I would leave it unexported

test/cuda/layers.jl Outdated Show resolved Hide resolved
@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 7, 2021

Would the @test_nowarn not return gs_gpu to the local scope @CarloLucibello?

@CarloLucibello
Copy link
Member

could be. I didn't even know it existed though. I'll just remove the test

test/cuda/layers.jl Outdated Show resolved Hide resolved
test/cuda/layers.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

I really hope this goes green, this commit suggestion thing is becoming painful 😅

CarloLucibello
CarloLucibello previously approved these changes Feb 9, 2021
@CarloLucibello
Copy link
Member

victory!

bors r+

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 9, 2021

@CarloLucibello thanks for the efforts haha. I had no idea a simple equality test between gpu and cpu would take so much. Are gpu gradients stored as gpu arrays? Perhaps if we moved it back to the cpu it would've worked.

@CarloLucibello
Copy link
Member

bors r+

@CarloLucibello
Copy link
Member

bors r-

@CarloLucibello
Copy link
Member

bors r+

@CarloLucibello
Copy link
Member

@DhairyaLGandhi maybe you should just merge manually here

@CarloLucibello
Copy link
Member

bors r+

bors bot added a commit that referenced this pull request Feb 11, 2021
1009: Added Bilinear layer r=CarloLucibello a=bhvieira

A basic implementation inspired on https://pytorch.org/docs/stable/nn.html#bilinear

I haven't exported it, because I think this layer is a bit more esoteric compared with others.

It basically computes interactions between two sets of inputs.

I thought about augmenting it to also include the non-interaction terms (this can easily be done, eg. augmenting the data with a row of ones) but for now it simply mirrors PyTorch's one.

I had to use splatting `vcat(x...)` and `hcat(x...)` in the forward pass. I wanted to avoid it, but with `reduce` I couldn't get gradients. But I think this can be improved.

Co-authored-by: Bruno Hebling Vieira <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>
@DhairyaLGandhi DhairyaLGandhi merged commit 3bc42f2 into FluxML:master Feb 11, 2021
@bors
Copy link
Contributor

bors bot commented Feb 11, 2021

This PR was included in a batch that successfully built, but then failed to merge into master (it was a non-fast-forward update). It will be automatically retried.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants