Skip to content

Commit

Permalink
lazy calc offset_per_key, length_per_key (#1395)
Browse files Browse the repository at this point in the history
Summary:

Copy of D49308714 + jit unwrap optional

Reviewed By: jiaqizhai

Differential Revision:
D49338062

Privacy Context Container: L1138451
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Sep 16, 2023
1 parent cdd9f20 commit a82727d
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,7 @@ def __init__(
self._offsets: Optional[torch.Tensor] = offsets

self._stride_per_key_per_rank: List[List[int]] = []
self._stride_per_key: List[int] = []
self._variable_stride_per_key: bool = False
self._stride: int = -1

Expand All @@ -1092,6 +1093,7 @@ def __init__(
self._stride = 0
elif all(s == self.stride_per_key()[0] for s in self.stride_per_key()):
self._stride = self.stride_per_key()[0]
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
else:
if torch.jit.is_tracing():
stride = _maybe_compute_stride_kjt_scripted(
Expand All @@ -1101,6 +1103,7 @@ def __init__(
stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets)
self._stride = stride
self._stride_per_key_per_rank = [[stride]] * len(self._keys)
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]

# lazy fields
self._length_per_key: Optional[List[int]] = length_per_key
Expand Down Expand Up @@ -1370,7 +1373,7 @@ def stride(self) -> int:
return self._stride

def stride_per_key(self) -> List[int]:
return [sum(stride) for stride in self._stride_per_key_per_rank]
return self._stride_per_key

def stride_per_key_per_rank(self) -> List[List[int]]:
return self._stride_per_key_per_rank
Expand All @@ -1387,6 +1390,9 @@ def _key_indices(self) -> Dict[str, int]:
return _index_per_key

def length_per_key(self) -> List[int]:
if self._length_per_key is not None:
return torch.jit._unwrap_optional(self._length_per_key)

_length_per_key = _maybe_compute_length_per_key(
keys=self._keys,
stride=self.stride(),
Expand All @@ -1403,6 +1409,9 @@ def length_per_key_or_none(self) -> Optional[List[int]]:
return self._length_per_key

def offset_per_key(self) -> List[int]:
if self._offset_per_key is not None:
return torch.jit._unwrap_optional(self._offset_per_key)

_length_per_key, _offset_per_key = _maybe_compute_offset_per_key(
keys=self._keys,
stride=self.stride(),
Expand Down

0 comments on commit a82727d

Please sign in to comment.