Skip to content

Commit

Permalink
Add jax CNN model snippet files
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Nov 4, 2023
1 parent b6b7011 commit 30f67c4
Show file tree
Hide file tree
Showing 4 changed files with 409 additions and 0 deletions.
97 changes: 97 additions & 0 deletions docs/source/snippets/categorical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,103 @@ def compute(self, inputs, role):
unnormalized_log_prob=True)
# [end-cnn-functional-torch]

# [start-cnn-setup-jax]
import flax.linen as nn

from skrl.models.jax import Model, CategoricalMixin


# define the model
class CNN(CategoricalMixin, Model):
def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
CategoricalMixin.__init__(self, unnormalized_log_prob)

def setup(self):
self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
self.fc1 = nn.Dense(512)
self.fc2 = nn.Dense(16)
self.fc3 = nn.Dense(64)
self.fc4 = nn.Dense(32)
self.fc5 = nn.Dense(self.num_actions)

def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = self.conv1(x)
x = nn.relu(x)
x = self.conv2(x)
x = nn.relu(x)
x = self.conv3(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.tanh(x)
x = self.fc3(x)
x = nn.tanh(x)
x = self.fc4(x)
x = nn.tanh(x)
x = self.fc5(x)
return x, {}


# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True)

# initialize model's state dict
policy.init_state_dict("policy")
# [end-cnn-setup-jax]

# [start-cnn-compact-jax]
import flax.linen as nn

from skrl.models.jax import Model, CategoricalMixin


# define the model
class CNN(CategoricalMixin, Model):
def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
CategoricalMixin.__init__(self, unnormalized_log_prob)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.tanh(x)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(self.num_actions)(x)
return x, {}


# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True)

# initialize model's state dict
policy.init_state_dict("policy")
# [end-cnn-compact-jax]

# =============================================================================

# [start-rnn-sequential-torch]
Expand Down
101 changes: 101 additions & 0 deletions docs/source/snippets/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,107 @@ def compute(self, inputs, role):
clip_actions=False)
# [end-cnn-functional-torch]

# [start-cnn-setup-jax]
import jax.numpy as jnp
import flax.linen as nn

from skrl.models.jax import Model, DeterministicMixin


# define the model
class CNN(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def setup(self):
self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
self.fc1 = nn.Dense(512)
self.fc2 = nn.Dense(16)
self.fc3 = nn.Dense(64)
self.fc4 = nn.Dense(32)
self.fc5 = nn.Dense(1)

def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = self.conv1(x)
x = nn.relu(x)
x = self.conv2(x)
x = nn.relu(x)
x = self.conv3(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.tanh(x)
x = jnp.concatenate([x, inputs["taken_actions"]], axis=-1)
x = self.fc3(x)
x = nn.tanh(x)
x = self.fc4(x)
x = nn.tanh(x)
x = self.fc5(x)
return x, {}


# instantiate the model (assumes there is a wrapped environment: env)
critic = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=False)

# initialize model's state dict
critic.init_state_dict("critic")
# [end-cnn-setup-jax]

# [start-cnn-compact-jax]
import jax.numpy as jnp
import flax.linen as nn

from skrl.models.jax import Model, DeterministicMixin


# define the model
class CNN(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.tanh(x)
x = jnp.concatenate([x, inputs["taken_actions"]], axis=-1)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(1)(x)
return x, {}


# instantiate the model (assumes there is a wrapped environment: env)
critic = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=False)

# initialize model's state dict
critic.init_state_dict("critic")
# [end-cnn-compact-jax]

# =============================================================================

# [start-rnn-sequential-torch]
Expand Down
112 changes: 112 additions & 0 deletions docs/source/snippets/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,118 @@ def compute(self, inputs, role):
reduction="sum")
# [end-cnn-functional-torch]

# [start-cnn-setup-jax]
import jax.numpy as jnp
import flax.linen as nn

from skrl.models.jax import Model, GaussianMixin


# define the model
class CNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def setup(self):
self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
self.fc1 = nn.Dense(512)
self.fc2 = nn.Dense(16)
self.fc3 = nn.Dense(64)
self.fc4 = nn.Dense(32)
self.fc5 = nn.Dense(self.num_actions)

self.log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))

def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = self.conv1(x)
x = nn.relu(x)
x = self.conv2(x)
x = nn.relu(x)
x = self.conv3(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.tanh(x)
x = self.fc3(x)
x = nn.tanh(x)
x = self.fc4(x)
x = nn.tanh(x)
x = self.fc5(x)
return nn.tanh(x), self.log_std_parameter, {}


# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")

# initialize model's state dict
policy.init_state_dict("policy")
# [end-cnn-setup-jax]

# [start-cnn-compact-jax]
import jax.numpy as jnp
import flax.linen as nn

from skrl.models.jax import Model, GaussianMixin


# define the model
class CNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.tanh(x)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(self.num_actions)(x)
log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
return nn.tanh(x), log_std_parameter, {}


# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")

# initialize model's state dict
policy.init_state_dict("policy")
# [end-cnn-compact-jax]

# =============================================================================

# [start-rnn-sequential-torch]
Expand Down
Loading

0 comments on commit 30f67c4

Please sign in to comment.