Skip to content

Commit

Permalink
correct casing
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Mar 18, 2019
1 parent e23c8dd commit ca68bf9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
20 changes: 10 additions & 10 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ end


"""
MaxOut(over)
Maxout(over)
`MaxOut` is a neural network layer, which has a number of internal layers,
`Maxout` is a neural network layer, which has a number of internal layers,
which all have the same input, and the max out returns the elementwise maximium
of the internal layers' outputs.
Expand All @@ -142,30 +142,30 @@ In Proceedings of the 30th International Conference on International Conference
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct MaxOut{FS<:Tuple}
struct Maxout{FS<:Tuple}
over::FS
end

"""
MaxOut(f, n_alts, args...; kwargs...)
Maxout(f, n_alts, args...; kwargs...)
Constructs a MaxOut layer over `n_alts` instances of the layer given by `f`.
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
All other arguements (`args` & `kwargs`) are passed to the constructor `f`.
For example the following example which
will construct a `MaxOut` layer over 4 dense linear layers,
will construct a `Maxout` layer over 4 dense linear layers,
each identical in structure (784 inputs, 128 outputs).
```julia
insize = 784
outsie = 128
MaxOut(Dense, 4, insize, outsize)
Maxout(Dense, 4, insize, outsize)
```
"""
function MaxOut(f, n_alts, args...; kwargs...)
function Maxout(f, n_alts, args...; kwargs...)
over = Tuple(f(args...; kwargs...) for _ in 1:n_alts)
return MaxOut(over)
return Maxout(over)
end

function (mo::MaxOut)(input::AbstractArray)
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end
10 changes: 5 additions & 5 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ using Test, Random
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end

@testset "MaxOut" begin
# Note that the normal common usage of MaxOut is as per the docstring
@testset "Maxout" begin
# Note that the normal common usage of Maxout is as per the docstring
# These are abnormal constructors used for testing purposes

@testset "Constructor" begin
mo = MaxOut(() -> identity, 4)
mo = Maxout(() -> identity, 4)
input = rand(40)
@test mo(input) == input
end

@testset "simple alternatives" begin
mo = MaxOut((x -> x, x -> 2x, x -> 0.5x))
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
input = rand(40)
@test mo(input) == 2*input
end

@testset "complex alternatives" begin
mo = MaxOut((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
input = [3.0 2.0]
target = [0.5, 0.7].*input
@test mo(input) == target
Expand Down

0 comments on commit ca68bf9

Please sign in to comment.