diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index ed6027f5ca..f0a2f08fee 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -57,6 +57,7 @@ But in contrast to the layers described in the other sections are not readily gr Maxout SkipConnection Parallel +Bilinear ``` ## Normalisation & Regularisation diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 141a4f4c3d..38a5a1eef9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -249,7 +249,83 @@ function Base.show(io::IO, b::SkipConnection) end """ - Parallel(connection, layers...) + Bilinear(in1, in2, out) + +Creates a Bilinear layer, which operates on two inputs at the same time. +It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form + + z[i] = σ.(x' * W[i,:,:] * y .+ b[i]) + +If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form, +given that `B` is a Bilinear layer of appropriate size. + +If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)` +The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`, +which is accepted as the input to a `Chain`. + +```julia +# using Bilinear to generate interactions, on one input +x = randn(Float32, 11, 7) +B = Bilinear(11, 11, 3) +size(B(x)) == (3, 7) + +# using Bilinear on two data streams at once, as a tuple +x = randn(Float32, 10, 9) +y = randn(Float32, 2, 9) +m = Chain(Bilinear(10, 2, 3), Dense(3, 1)) +size(m((x, y))) == (1, 9) + +# using Bilinear as the recombinator in a SkipConnection +x = randn(Float32, 10, 9) +sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5)) +size(sc(x)) == (5, 9) +``` +""" +struct Bilinear{A,B,S} + W::A + b::B + σ::S +end + +@functor Bilinear + +Bilinear(W, b) = Bilinear(W, b, identity) + +function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; + initW = glorot_uniform, initb = zeros) + return Bilinear(initW(out, in1, in2), initb(out), σ) +end + +function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) + W, b, σ = a.W, a.b, a.σ + + d_z, d_x, d_y = size(W) + d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W")) + size(x,2) == size(y,2) || throw(DimensionMismatch("Data inputs must agree on number of columns, got $(size(x,2)) and $(size(y,2))")) + + # @einsum Wy[o,i,s] := W[o,i,j] * y[j,s] + Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :)) + + # @einsum Z[o,s] := Wy[o,i,s] * x[i,s] + Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :))) + Z = reshape(Wyx, (d_z, :)) + + # @einsum out[o,s] := σ(Z[o,i] + b[o]) + σ.(Z .+ b) +end + +(a::Bilinear)(x::AbstractVecOrMat) = a(x, x) +(a::Bilinear)(x::AbstractVector, y::AbstractVector) = vec(a(reshape(x, :,1), reshape(y, :,1))) +(a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2]) + +function Base.show(io::IO, l::Bilinear) + print(io, "Bilinear(", size(l.W, 2), ", ", size(l.W, 3), ", ", size(l.W, 1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + +""" +Parallel(connection, layers...) Create a 'Parallel' layer that passes an input array to each path in `layers`, reducing the output with `connection`. diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 6a7dcd1b05..185a9e8f72 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -178,3 +178,17 @@ end gs = gradient(() -> sum(l(ip)), Flux.params(l)) @test l.b ∉ gs.params end + +@testset "Two-streams Bilinear" begin + x = zeros(Float32,10,9) |> gpu + y = zeros(Float32,2,9) |> gpu + b = Flux.Bilinear(10, 2, 3) |> gpu + @test size(b(x,y)) == (3,9) + @test sum(abs2, b(x,y)) ≈ 0f0 + gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b)) + b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu + gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu)) + for (pgpu, pcpu) in zip(params(b), params(b_cpu)) + @test gs_cpu[pcpu] ≈ Array(gs_gpu[pgpu]) + end +end \ No newline at end of file diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c04d1f97d5..073182c03c 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -112,6 +112,31 @@ import Flux: activations end end + @testset "Bilinear" begin + @testset "SkipConnection recombinator" begin + d = Dense(10, 10) + b = Flux.Bilinear(10, 10, 5) + x = randn(Float32,10,9) + sc = SkipConnection(d, b) + @test size(sc(x)) == (5,9) + end + + @testset "Two-streams zero sum" begin + x = zeros(Float32,10,9) + y = zeros(Float32,2,9) + b = Flux.Bilinear(10, 2, 3) + @test size(b(x,y)) == (3,9) + @test sum(abs2, b(x,y)) == 0f0 + end + + @testset "Inner interactions" begin + x = randn(Float32,11,7) + b = Flux.Bilinear(11, 11, 3) + @test size(b(x)) == (3,7) + @test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b)) + end + end + @testset "Parallel" begin @testset "zero sum" begin input = randn(10, 10, 10, 10)