Skip to content

Commit

Permalink
Merge pull request FluxML#1009 from bhvieira/billinear
Browse files Browse the repository at this point in the history
Added Bilinear layer
  • Loading branch information
DhairyaLGandhi committed Feb 11, 2021
2 parents ddb5e9c + f4a60c7 commit 3bc42f2
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ But in contrast to the layers described in the other sections are not readily gr
Maxout
SkipConnection
Parallel
Bilinear
```

## Normalisation & Regularisation
Expand Down
78 changes: 77 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,83 @@ function Base.show(io::IO, b::SkipConnection)
end

"""
Parallel(connection, layers...)
Bilinear(in1, in2, out)
Creates a Bilinear layer, which operates on two inputs at the same time.
It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form
z[i] = σ.(x' * W[i,:,:] * y .+ b[i])
If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
given that `B` is a Bilinear layer of appropriate size.
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
which is accepted as the input to a `Chain`.
```julia
# using Bilinear to generate interactions, on one input
x = randn(Float32, 11, 7)
B = Bilinear(11, 11, 3)
size(B(x)) == (3, 7)
# using Bilinear on two data streams at once, as a tuple
x = randn(Float32, 10, 9)
y = randn(Float32, 2, 9)
m = Chain(Bilinear(10, 2, 3), Dense(3, 1))
size(m((x, y))) == (1, 9)
# using Bilinear as the recombinator in a SkipConnection
x = randn(Float32, 10, 9)
sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5))
size(sc(x)) == (5, 9)
```
"""
struct Bilinear{A,B,S}
W::A
b::B
σ::S
end

@functor Bilinear

Bilinear(W, b) = Bilinear(W, b, identity)

function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Bilinear(initW(out, in1, in2), initb(out), σ)
end

function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
W, b, σ = a.W, a.b, a.σ

d_z, d_x, d_y = size(W)
d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
size(x,2) == size(y,2) || throw(DimensionMismatch("Data inputs must agree on number of columns, got $(size(x,2)) and $(size(y,2))"))

# @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]
Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :))

# @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :)))
Z = reshape(Wyx, (d_z, :))

# @einsum out[o,s] := σ(Z[o,i] + b[o])
σ.(Z .+ b)
end

(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)
(a::Bilinear)(x::AbstractVector, y::AbstractVector) = vec(a(reshape(x, :,1), reshape(y, :,1)))
(a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2])

function Base.show(io::IO, l::Bilinear)
print(io, "Bilinear(", size(l.W, 2), ", ", size(l.W, 3), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

"""
Parallel(connection, layers...)
Create a 'Parallel' layer that passes an input array to each path in
`layers`, reducing the output with `connection`.
Expand Down
14 changes: 14 additions & 0 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,17 @@ end
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
end

@testset "Two-streams Bilinear" begin
x = zeros(Float32,10,9) |> gpu
y = zeros(Float32,2,9) |> gpu
b = Flux.Bilinear(10, 2, 3) |> gpu
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) 0f0
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
end
end
25 changes: 25 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,31 @@ import Flux: activations
end
end

@testset "Bilinear" begin
@testset "SkipConnection recombinator" begin
d = Dense(10, 10)
b = Flux.Bilinear(10, 10, 5)
x = randn(Float32,10,9)
sc = SkipConnection(d, b)
@test size(sc(x)) == (5,9)
end

@testset "Two-streams zero sum" begin
x = zeros(Float32,10,9)
y = zeros(Float32,2,9)
b = Flux.Bilinear(10, 2, 3)
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) == 0f0
end

@testset "Inner interactions" begin
x = randn(Float32,11,7)
b = Flux.Bilinear(11, 11, 3)
@test size(b(x)) == (3,7)
@test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b))
end
end

@testset "Parallel" begin
@testset "zero sum" begin
input = randn(10, 10, 10, 10)
Expand Down

0 comments on commit 3bc42f2

Please sign in to comment.