Skip to content

Commit

Permalink
lazy calc offset_per_key, length_per_key (#1395)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1395

Differential Revision:
D49338062

Privacy Context Container: L1138451
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Sep 15, 2023
1 parent cdd9f20 commit 206737a
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,9 @@ def _key_indices(self) -> Dict[str, int]:
return _index_per_key

def length_per_key(self) -> List[int]:
if self._length_per_key is not None:
return torch.jit._unwrap_optional(self._length_per_key)

_length_per_key = _maybe_compute_length_per_key(
keys=self._keys,
stride=self.stride(),
Expand All @@ -1403,6 +1406,9 @@ def length_per_key_or_none(self) -> Optional[List[int]]:
return self._length_per_key

def offset_per_key(self) -> List[int]:
if self._offset_per_key is not None:
return torch.jit._unwrap_optional(self._offset_per_key)

_length_per_key, _offset_per_key = _maybe_compute_offset_per_key(
keys=self._keys,
stride=self.stride(),
Expand Down

0 comments on commit 206737a

Please sign in to comment.