Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add perturb() to allow capturing intermediate gradients #2476

Merged
merged 1 commit into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,50 @@ def __call__(self, x):
self.scope.put_variable(col, name, xs)
return True

def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T:
"""Add an zero-value variable ('perturbation') to the intermediate value.
The gradient of `value` would be the same as the gradient of this
perturbation variable. Therefore, if you define your loss function with
both params and perturbations as standalone arguments, you can get the
intermediate gradients of `value` by running `jax.grad` on the perturbation
argument.
Note: this is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
Example::
import jax
import jax.numpy as jnp
import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = self.perturb('dense3', x)
return nn.Dense(2)(x)
def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = model.apply(variables, inputs)
return jnp.square(preds - targets).mean()
x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847]
# [-1.456924 -0.44332537 0.02422847]]
"""
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
return value

def tabulate(
self,
rngs: Union[PRNGKey, RNGSequences],
Expand Down
23 changes: 23 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,29 @@ def __call__(self, x):
_, state = Foo().apply({}, 1, capture_intermediates=fn)
self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}})

def test_perturb(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(10)(x)
x = self.perturb('before_multiply', x)
x = 4 * x
x = self.perturb('after_multiply', x)
return x

def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = Foo().apply(variables, inputs)
return jnp.square(preds - targets).mean()

x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10, ))
y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10, ))
variables = Foo().init(jax.random.PRNGKey(0), x)
pred = Foo().apply(variables, x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
# activation * 4 so reverse gradient also * 4
self.assertTrue(all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']))

def test_functional_apply(self):

class Foo(nn.Module):
Expand Down