Skip to content

Commit

Permalink
Provide deterministic builds (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Dec 21, 2022
1 parent 4ffbe77 commit c57e37f
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 194 deletions.
181 changes: 65 additions & 116 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ defmodule Axon do
to inference function except:
* `:name` - layer name.
* `:op_name` - layer operation for inspection and building parameter map.
* `:mode` - if the layer should run only on `:inference` or `:train`. Defaults to `:both`
* `:op_name` - layer operation for inspection and building parameter
map.
Note this means your layer should not use these as input options,
as they will always be dropped during inference compilation.
Expand All @@ -268,22 +268,23 @@ defmodule Axon do
params = Enum.reverse(params)
args = Enum.reverse(args)

{mode, opts} = Keyword.pop(opts, :mode, :both)
{name, opts} = Keyword.pop(opts, :name)
{op_name, layer_opts} = Keyword.pop(opts, :op_name, :custom)

{id, name} = unique_identifiers(op_name, name)

axon_node = make_node(id, op, name, op_name, inputs, params, args, layer_opts)
{op_name, opts} = Keyword.pop(opts, :op_name, :custom)
name = name(op_name, name)

id = System.unique_integer([:positive, :monotonic])
axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts)
%Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)}
end

defp make_node(id, op, name, op_name, inputs, params, args, layer_opts) do
defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts) do
{:current_stacktrace, [_process_info, _axon_layer | stacktrace]} =
Process.info(self(), :current_stacktrace)

%Axon.Node{
id: id,
mode: mode,
name: name,
parent: inputs,
parameters: params,
Expand Down Expand Up @@ -340,10 +341,7 @@ defmodule Axon do
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}

id = System.unique_integer([:positive, :monotonic])

%Axon.Parameter{
id: id,
name: name,
shape: shape,
type: type,
Expand Down Expand Up @@ -1399,7 +1397,8 @@ defmodule Axon do
layer(dropout, [x, key_state],
name: opts[:name],
rate: opts[:rate],
op_name: dropout
op_name: dropout,
mode: :train
)
end

Expand Down Expand Up @@ -2174,8 +2173,9 @@ defmodule Axon do
def lstm(%Axon{} = x, units, opts)
when is_integer(units) and units > 0 and is_list(opts) do
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
c = rnn_state(x, units, :lstm, opts[:name], "c", recurrent_initializer)
h = rnn_state(x, units, :lstm, opts[:name], "h", recurrent_initializer)
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
c = rnn_state(x, units, :lstm, opts[:name], "c", recurrent_initializer, seed)
h = rnn_state(x, units, :lstm, opts[:name], "h", recurrent_initializer, seed)
lstm(x, {c, h}, units, opts)
end

Expand Down Expand Up @@ -2372,7 +2372,8 @@ defmodule Axon do
when is_integer(units) and units > 0
when is_list(opts) do
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
h = rnn_state(x, units, :gru, opts[:name], "h", recurrent_initializer)
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
h = rnn_state(x, units, :gru, opts[:name], "h", recurrent_initializer, seed)
gru(x, {h}, units, opts)
end

Expand Down Expand Up @@ -2549,8 +2550,9 @@ defmodule Axon do
def conv_lstm(%Axon{} = x, units, opts)
when is_integer(units) and units > 0 and is_list(opts) do
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
c = rnn_state(x, units, :conv_lstm, opts[:name], "c", recurrent_initializer)
h = rnn_state(x, units, :conv_lstm, opts[:name], "h", recurrent_initializer)
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
c = rnn_state(x, units, :conv_lstm, opts[:name], "c", recurrent_initializer, seed)
h = rnn_state(x, units, :conv_lstm, opts[:name], "h", recurrent_initializer, seed)
conv_lstm(x, {c, h}, units, opts)
end

Expand Down Expand Up @@ -2727,9 +2729,12 @@ defmodule Axon do
{output_sequence, {new_c, new_h}}
end

defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer) do
defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer, seed) do
initializer = initializer || :glorot_uniform
key = Nx.Random.key(:erlang.system_time()) |> Nx.backend_copy(Nx.Defn.Expr)
key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.BinaryBackend)

key_state =
param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end)

name =
case parent_name do
Expand All @@ -2742,7 +2747,7 @@ defmodule Axon do
"#{parent_name}_#{state_name}_hidden_state"
end

fun = fn inputs, opts ->
fun = fn inputs, key, _opts ->
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)

case initializer do
Expand All @@ -2758,15 +2763,15 @@ defmodule Axon do
fun.(shape, {:f, 32})

arity == 3 ->
fun.(shape, {:f, 32}, opts[:key])
fun.(shape, {:f, 32}, key)

true ->
raise ArgumentError, "bad arity for initializer"
end
end
end

layer(fun, [x], name: name, op_name: :recurrent_state, key: key)
layer(fun, [x, key_state], name: name, op_name: :recurrent_state)
end

@doc """
Expand Down Expand Up @@ -2881,51 +2886,44 @@ defmodule Axon do
the update process.
"""
@doc type: :model
def freeze(%Axon{} = model, fun_or_predicate \\ :all) do
parameters_per_layer =
reduce_nodes(model, [], fn %Axon.Node{parameters: params}, acc ->
layer_params =
Enum.reduce(params, [], fn param, inner_acc ->
[param | inner_acc]
end)
def freeze(model, fun_or_predicate \\ :all) do
freeze(model, fun_or_predicate, true)
end

[layer_params | acc]
end)
defp freeze(%Axon{output: id, nodes: nodes} = axon, fun_or_predicate, flag) do
{nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())

parameters_to_freeze =
nodes =
case fun_or_predicate do
:all ->
List.flatten(parameters_per_layer)
freeze_nodes(nodes, flag)

[{:up, n}] ->
parameters_per_layer
|> Enum.reverse()
|> Enum.take(n)
|> List.flatten()
{pre, post} = Enum.split(nodes, n)
freeze_nodes(pre, flag) ++ post

[{:down, n}] ->
parameters_per_layer
|> Enum.reverse()
|> Enum.drop(n)
|> List.flatten()
{pre, post} = Enum.split(nodes, -n)
pre ++ freeze_nodes(post, flag)

fun ->
parameters_per_layer
|> List.flatten()
|> Enum.filter(fun)
Enum.map(nodes, fn %Axon.Node{parameters: params} = axon_node ->
%{
axon_node
| parameters:
Enum.map(params, fn p ->
if fun.(p), do: %{p | frozen: flag}, else: p
end)
}
end)
end

map_nodes(model, fn %Axon.Node{parameters: params} = axon_node ->
frozen_params =
Enum.map(params, fn %{id: param_id} = v ->
if Enum.any?(parameters_to_freeze, fn %{id: id} -> param_id == id end) do
%{v | frozen: true}
else
v
end
end)
%{axon | nodes: Map.new(nodes, fn %{id: id} = node -> {id, node} end)}
end

%{axon_node | parameters: frozen_params}
defp freeze_nodes(nodes, flag) do
Enum.map(nodes, fn %Axon.Node{parameters: params} = axon_node ->
%{axon_node | parameters: Enum.map(params, fn p -> %{p | frozen: flag} end)}
end)
end

Expand Down Expand Up @@ -2960,52 +2958,8 @@ defmodule Axon do
the update process.
"""
@doc type: :model
def unfreeze(%Axon{} = model, fun_or_predicate \\ :all) do
parameters_per_layer =
reduce_nodes(model, [], fn %Axon.Node{parameters: params}, acc ->
layer_params =
Enum.reduce(params, [], fn param, inner_acc ->
[param | inner_acc]
end)

[layer_params | acc]
end)

parameters_to_freeze =
case fun_or_predicate do
:all ->
List.flatten(parameters_per_layer)

[{:up, n}] ->
parameters_per_layer
|> Enum.reverse()
|> Enum.take(n)
|> List.flatten()

[{:down, n}] ->
parameters_per_layer
|> Enum.reverse()
|> Enum.drop(n)
|> List.flatten()

fun ->
parameters_per_layer
|> List.flatten()
|> Enum.filter(fun)
end

map_nodes(model, fn %Axon.Node{parameters: params} = axon_node ->
frozen_params =
Enum.map(params, fn %{id: param_id} = v ->
if Enum.any?(parameters_to_freeze, fn %{id: id} -> param_id == id end) do
%{v | frozen: false}
else
v
end
end)

%{axon_node | parameters: frozen_params}
end)
def unfreeze(model, fun_or_predicate \\ :all) do
freeze(model, fun_or_predicate, false)
end

@doc """
Expand Down Expand Up @@ -3401,7 +3355,7 @@ defmodule Axon do
## Options
* `:mode` - one of `:inference` or `:training`. Forwarded to layers
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
to control differences in compilation at training or inference time.
Defaults to `:inference`
Expand Down Expand Up @@ -3480,7 +3434,7 @@ defmodule Axon do
## Options
* `:mode` - one of `:inference` or `:training`. Forwarded to layers
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
to control differences in compilation at training or inference time.
Defaults to `:inference`
Expand Down Expand Up @@ -3544,7 +3498,7 @@ defmodule Axon do
## Options
* `:mode` - one of `:inference` or `:training`. Forwarded to layers
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
to control differences in compilation at training or inference time.
Defaults to `:inference`
Expand Down Expand Up @@ -3636,7 +3590,7 @@ defmodule Axon do
@doc type: :model
def serialize(%Axon{output: id, nodes: nodes}, params, opts \\ []) do
Logger.warning(
"Attempting to serialize an Axon model. Serialiation is discouraged" <>
"Attempting to serialize an Axon model. Serialization is discouraged" <>
" and will be deprecated, then removed in future releases. You should" <>
" keep your model definitions as code and serialize your parameters using" <>
" `Nx.serialize/2`."
Expand Down Expand Up @@ -3782,27 +3736,22 @@ defmodule Axon do
# Names are generated lazily at inspect, initialization, and compile
# time, so for name we return a function which takes `op` and `op_count`
# and returns a unique name for the given model.
defp unique_identifiers(type, nil) do
id = System.unique_integer([:positive, :monotonic])

name = fn op, op_counts ->
defp name(type, nil) do
fn op, op_counts ->
count = op_counts[op] || 0
Atom.to_string(type) <> "_#{count}"
end

{id, name}
end

defp unique_identifiers(_type, name_fn) when is_function(name_fn, 2) do
id = System.unique_integer([:positive, :monotonic])
{id, name_fn}
defp name(_type, name_fn) when is_function(name_fn, 2) do
name_fn
end

defp unique_identifiers(_type, name) when is_binary(name) do
{System.unique_integer([:positive, :monotonic]), fn _, _ -> name end}
defp name(_type, name) when is_binary(name) do
fn _, _ -> name end
end

defp unique_identifiers(_, name) do
defp name(_type, name) do
raise ArgumentError,
"expected layer name to be a binary, a function or nil, " <>
"got: #{inspect(name)}"
Expand Down
Loading

0 comments on commit c57e37f

Please sign in to comment.