Skip to content

Commit

Permalink
style.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Feb 21, 2022
1 parent d9a0b73 commit 2bdfc0c
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/test_activations_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@

@require_tf
class TestTFActivations(unittest.TestCase):

def test_gelu_10(self):
x = tf.constant([-100, -1., -0.1, 0, 0.1, 1.0, 100.])
x = tf.constant([-100, -1.0, -0.1, 0, 0.1, 1.0, 100.0])
gelu = get_tf_activation("gelu")
gelu10 = get_tf_activation("gelu_10")

y_gelu = gelu(x)
y_gelu_10 = gelu10(x)

clipped_mask = tf.where(y_gelu_10 < 10.0, 1., 0.)
clipped_mask = tf.where(y_gelu_10 < 10.0, 1.0, 0.0)

self.assertEqual(tf.math.reduce_max(y_gelu_10).numpy().item(), 10.0)
self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))
Expand Down

0 comments on commit 2bdfc0c

Please sign in to comment.