Skip to content

Commit

Permalink
Fix tf tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Feb 21, 2022
1 parent 632c0f7 commit d9a0b73
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_activations_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
class TestTFActivations(unittest.TestCase):

def test_gelu_10(self):
x = tf.constant([-100, -1, -0.1, 0, 0.1, 1.0, 100])
x = tf.constant([-100, -1., -0.1, 0, 0.1, 1.0, 100.])
gelu = get_tf_activation("gelu")
gelu10 = get_tf_activation("gelu_10")

y_gelu = gelu(x)
y_gelu_10 = gelu10(x)

clipped_mask = tf.where(y_gelu_10 < 10.0, 1, 0)
clipped_mask = tf.where(y_gelu_10 < 10.0, 1., 0.)

self.assertEqual(tf.math.max(y_gelu_10).numpy().item(), 10.0)
self.assertEqual(tf.math.reduce_max(y_gelu_10).numpy().item(), 10.0)
self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))

def test_get_activation(self):
Expand Down

0 comments on commit d9a0b73

Please sign in to comment.