Skip to content

Commit

Permalink
Add reverse module for ComputeKJTToJTDict to combine jt_dict to kjt (p…
Browse files Browse the repository at this point in the history
…ytorch#1399)

Summary:

so that we don't need fx wrap KeyedJaggedTensor.from_jt_dict(jt_dict) manually everywhere. Also base on this we can do graph patten matching cancel (ComputeKJTToJTDict, ComputeKJTToJTDict) pairs during publish to save compute cycles. (see next diff in the stack)

Reviewed By: houseroad, YazhiGao

Differential Revision:
D49423522

Privacy Context Container: 314155190942957
  • Loading branch information
wpc authored and facebook-github-bot committed Sep 20, 2023
1 parent c1602e4 commit 1d44fad
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
30 changes: 30 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,36 @@ def forward(
)


class ComputeJTDictToKJT(torch.nn.Module):
"""Converts a dict of JaggedTensors to KeyedJaggedTensor.
Args:
Example:
passing in jt_dict
{
"Feature0": JaggedTensor([[V0,V1],None,V2]),
"Feature1": JaggedTensor([V3,V4,[V5,V6,V7]]),
}
Returns::
kjt with content:
# 0 1 2 <-- dim_1
# "Feature0" [V0,V1] None [V2]
# "Feature1" [V3] [V4] [V5,V6,V7]
# ^
# dim_0
"""

def forward(self, jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
"""
Args:
jt_dict: a dict of JaggedTensor
Returns:
KeyedJaggedTensor
"""
return KeyedJaggedTensor.from_jt_dict(jt_dict)


@torch.fx.wrap
def _maybe_compute_kjt_to_jt_dict(
stride: int,
Expand Down
32 changes: 32 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
ComputeJTDictToKJT,
ComputeKJTToJTDict,
JaggedTensor,
jt_is_equal,
Expand Down Expand Up @@ -707,6 +708,37 @@ def test_pytree(self) -> None:
self.assertTrue(torch.equal(j0.weights(), j1.weights()))
self.assertTrue(torch.equal(j0.values(), j1.values()))

def test_compute_jt_dict_to_kjt_module(self) -> None:
compute_jt_dict_to_kjt = ComputeJTDictToKJT()
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])
keys = ["index_0", "index_1"]
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])

jag_tensor = KeyedJaggedTensor(
values=values,
keys=keys,
offsets=offsets,
weights=weights,
)
jag_tensor_dict = jag_tensor.to_dict()
kjt = compute_jt_dict_to_kjt(jag_tensor_dict)
j0 = kjt["index_0"]
j1 = kjt["index_1"]

self.assertTrue(isinstance(j0, JaggedTensor))
self.assertTrue(isinstance(j0, JaggedTensor))
self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1])))
self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5])))
self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0])))
self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3])))
self.assertTrue(
torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5]))
)
self.assertTrue(
torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0]))
)

def test_from_jt_dict(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])
Expand Down

0 comments on commit 1d44fad

Please sign in to comment.