Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Bilinear layer #1009

Merged
merged 37 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c2d3bf7
Added the Billinear layer
bhvieira Jan 28, 2020
a114435
Added some tests
bhvieira Jan 28, 2020
92a99d9
Forgot the activation function
bhvieira Jan 28, 2020
cbef3a4
typo
bhvieira Jan 28, 2020
10f8e24
Change order, error out fast
bhvieira Jan 28, 2020
c6172d7
Splatting instead of reduce, solves grads
bhvieira Jan 29, 2020
b4d587e
Fix tests try 1
bhvieira Jan 29, 2020
19d520a
Fixed the error
bhvieira Jan 29, 2020
3d3665a
faster Bilinear implementation
mcabbott Feb 2, 2020
ee09f08
Update src/layers/basic.jl
bhvieira Feb 2, 2020
17450f3
Non-essential explicit call to Zygote
bhvieira Feb 2, 2020
7830a8d
Fixed forward pass of bilinear
bhvieira Feb 2, 2020
e670646
Removed type annotations of reducehcat
bhvieira Feb 3, 2020
7f891db
New implementation based on Zygote.Buffer
bhvieira Feb 10, 2020
dc78890
Added expanding args for two data streams
bhvieira Mar 1, 2020
5ea755f
Restricted Tuples
bhvieira Mar 1, 2020
3306e65
Missing end in test
bhvieira Mar 2, 2020
27e3f52
use batched_mul
mcabbott Mar 2, 2020
d99f2f3
tidy up docstring
mcabbott Mar 2, 2020
5545011
indent two spaces
mcabbott Mar 2, 2020
5c2496e
Specified julia in docstring
bhvieira Mar 7, 2020
997a3a6
Apply suggestions from code review
bhvieira Apr 2, 2020
d11efb6
Update basic.jl
bhvieira Apr 2, 2020
bf064bf
Added Bilinear to docs
bhvieira Dec 9, 2020
4ba3fb7
Stop splatting NTuple2 in Bilinear
bhvieira Jan 5, 2021
e0f6549
Upper case exception text
bhvieira Jan 5, 2021
4aab530
Update src/layers/basic.jl
bhvieira Jan 10, 2021
b849d60
Wrong testset added by mistake
bhvieira Jan 10, 2021
5aa8c82
Merge branch 'master' into billinear
mcabbott Jan 28, 2021
58da113
First bilinear gpu test
bhvieira Feb 7, 2021
67d016b
Update test/cuda/layers.jl
CarloLucibello Feb 7, 2021
82d61a0
Update test/cuda/layers.jl
CarloLucibello Feb 7, 2021
cde19ed
Update test/cuda/layers.jl
CarloLucibello Feb 7, 2021
9eac5b4
Missing parentheses in cuda test
bhvieira Feb 7, 2021
ca283f7
Update test/cuda/layers.jl
CarloLucibello Feb 7, 2021
793b92c
Update test/cuda/layers.jl
CarloLucibello Feb 8, 2021
f4a60c7
Merge branch 'master' into billinear
bhvieira Feb 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ But in contrast to the layers described in the other sections are not readily gr
Maxout
SkipConnection
Parallel
Bilinear
```

## Normalisation & Regularisation
Expand Down
78 changes: 77 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,83 @@ function Base.show(io::IO, b::SkipConnection)
end

"""
Parallel(connection, layers...)
Bilinear(in1, in2, out)

Creates a Bilinear layer, which operates on two inputs at the same time.
It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form

z[i] = σ.(x' * W[i,:,:] * y .+ b[i])

If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
given that `B` is a Bilinear layer of appropriate size.

If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
which is accepted as the input to a `Chain`.

```julia
# using Bilinear to generate interactions, on one input
x = randn(Float32, 11, 7)
B = Bilinear(11, 11, 3)
size(B(x)) == (3, 7)

# using Bilinear on two data streams at once, as a tuple
x = randn(Float32, 10, 9)
y = randn(Float32, 2, 9)
m = Chain(Bilinear(10, 2, 3), Dense(3, 1))
size(m((x, y))) == (1, 9)

# using Bilinear as the recombinator in a SkipConnection
x = randn(Float32, 10, 9)
sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5))
size(sc(x)) == (5, 9)
```
"""
struct Bilinear{A,B,S}
W::A
b::B
σ::S
end

@functor Bilinear

Bilinear(W, b) = Bilinear(W, b, identity)

function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Bilinear(initW(out, in1, in2), initb(out), σ)
end

function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
W, b, σ = a.W, a.b, a.σ

d_z, d_x, d_y = size(W)
d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
size(x,2) == size(y,2) || throw(DimensionMismatch("Data inputs must agree on number of columns, got $(size(x,2)) and $(size(y,2))"))

# @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]
Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :))

# @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :)))
Z = reshape(Wyx, (d_z, :))

# @einsum out[o,s] := σ(Z[o,i] + b[o])
σ.(Z .+ b)
end

(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)
(a::Bilinear)(x::AbstractVector, y::AbstractVector) = vec(a(reshape(x, :,1), reshape(y, :,1)))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
(a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2])

function Base.show(io::IO, l::Bilinear)
print(io, "Bilinear(", size(l.W, 2), ", ", size(l.W, 3), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

"""
Parallel(connection, layers...)

Create a 'Parallel' layer that passes an input array to each path in
`layers`, reducing the output with `connection`.
Expand Down
14 changes: 13 additions & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,16 @@ end
@test sum(l(ip)) ≈ 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b ∉ gs.params
end
end

@testset "Two-streams Bilinear" begin
x = zeros(Float32,10,9) |> gpu
y = zeros(Float32,2,9) |> gpu
b = Flux.Bilinear(10, 2, 3) |> gpu
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) ≈ 0f0
@test_nowarn gs_gpu = gradient(() -> sum(abs2.(b(x,y))), params(b))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
@test gs_cpu ≈ Array(gs_gpu)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end
25 changes: 25 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,31 @@ import Flux: activations
end
end

@testset "Bilinear" begin
@testset "SkipConnection recombinator" begin
d = Dense(10, 10)
b = Flux.Bilinear(10, 10, 5)
x = randn(Float32,10,9)
sc = SkipConnection(d, b)
@test size(sc(x)) == (5,9)
end

@testset "Two-streams zero sum" begin
x = zeros(Float32,10,9)
y = zeros(Float32,2,9)
b = Flux.Bilinear(10, 2, 3)
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) == 0f0
end

@testset "Inner interactions" begin
x = randn(Float32,11,7)
b = Flux.Bilinear(11, 11, 3)
@test size(b(x)) == (3,7)
@test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b))
end
end

@testset "Parallel" begin
@testset "zero sum" begin
input = randn(10, 10, 10, 10)
Expand Down