-
Notifications
You must be signed in to change notification settings - Fork 1
/
network_models.py
46 lines (36 loc) · 1.19 KB
/
network_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import tensorflow as tf
from baselines.a2c.utils import ortho_init
from baselines.common.models import register
def mlp_norm(norm, num_layers=2, num_hidden=64, activation=tf.tanh):
def network_fn(input_shape):
# print('input shape is {}'.format(input_shape))
x_input = tf.keras.Input(shape=input_shape)
h = x_input
for i in range(num_layers):
h = tf.keras.layers.Dense(
units=num_hidden,
kernel_initializer=ortho_init(np.sqrt(2)),
name='mlp_fc{}'.format(i),
activation=activation
)(h)
h = norm()(h)
network = tf.keras.Model(inputs=[x_input], outputs=[h])
return network
return network_fn
@register("mlp_batchnorm")
def mlp_batchnorm(num_layers=2, num_hidden=64, activation=tf.tanh):
return mlp_norm(
tf.keras.layers.BatchNormalization,
num_layers,
num_hidden,
activation
)
@register("mlp_layernorm")
def mlp_layernorm(num_layers=2, num_hidden=64, activation=tf.tanh):
return mlp_norm(
tf.keras.layers.LayerNormalization,
num_layers,
num_hidden,
activation
)