Skip to content

Commit

Permalink
Type fixing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622933698
  • Loading branch information
Sonnet Contributor authored and copybara-github committed Apr 8, 2024
1 parent 6cc140e commit 6d59725
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sonnet/src/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ============================================================================
"""Exponential moving average for Sonnet."""

from typing import Optional
from typing import Optional, cast

from sonnet.src import metrics
from sonnet.src import once
Expand Down Expand Up @@ -61,8 +61,8 @@ def __init__(self, decay: types.FloatLike, name: Optional[str] = None):
self._counter = tf.Variable(
0, trainable=False, dtype=tf.int64, name="counter")

self._hidden = None
self.average = None
self._hidden: tf.Variable = cast(tf.Variable, None)
self.average: tf.Variable = cast(tf.Variable, None)

def update(self, value: tf.Tensor):
"""Applies EMA to the value given."""
Expand All @@ -82,8 +82,10 @@ def value(self) -> tf.Tensor:
def reset(self):
"""Resets the EMA."""
self._counter.assign(tf.zeros_like(self._counter))
self._hidden.assign(tf.zeros_like(self._hidden))
self.average.assign(tf.zeros_like(self.average))
if self._hidden is not None:
self._hidden.assign(tf.zeros_like(self._hidden))
if self.average is not None:
self.average.assign(tf.zeros_like(self.average))

@once.once
def initialize(self, value: tf.Tensor):
Expand Down

0 comments on commit 6d59725

Please sign in to comment.