Skip to content

Commit

Permalink
Merge pull request #2476 from IvyZX:perturb
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 476490270
  • Loading branch information
Flax Authors committed Sep 23, 2022
2 parents b344022 + 8c0a60a commit cc88a73
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
44 changes: 44 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,50 @@ def __call__(self, x):
self.scope.put_variable(col, name, xs)
return True

def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T:
"""Add an zero-value variable ('perturbation') to the intermediate value.
The gradient of `value` would be the same as the gradient of this
perturbation variable. Therefore, if you define your loss function with
both params and perturbations as standalone arguments, you can get the
intermediate gradients of `value` by running `jax.grad` on the perturbation
argument.
Note: this is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
Example::
import jax
import jax.numpy as jnp
import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = self.perturb('dense3', x)
return nn.Dense(2)(x)
def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = model.apply(variables, inputs)
return jnp.square(preds - targets).mean()
x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847]
# [-1.456924 -0.44332537 0.02422847]]
"""
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
return value

def tabulate(
self,
rngs: Union[PRNGKey, RNGSequences],
Expand Down
23 changes: 23 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,29 @@ def __call__(self, x):
_, state = Foo().apply({}, 1, capture_intermediates=fn)
self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}})

def test_perturb(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(10)(x)
x = self.perturb('before_multiply', x)
x = 4 * x
x = self.perturb('after_multiply', x)
return x

def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = Foo().apply(variables, inputs)
return jnp.square(preds - targets).mean()

x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10, ))
y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10, ))
variables = Foo().init(jax.random.PRNGKey(0), x)
pred = Foo().apply(variables, x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
# activation * 4 so reverse gradient also * 4
self.assertTrue(all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']))

def test_functional_apply(self):

class Foo(nn.Module):
Expand Down

0 comments on commit cc88a73

Please sign in to comment.