blob: 3a625c1083dcc48998c46343a2f2e6d721a77889 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import sys
import unittest
from typing import Dict, List
import torch
import torch._dynamo.config
import torch._dynamo.test_case
from torch import nn
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import NoTest
try:
from torchrec.datasets.random import RandomRecDataset
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
HAS_TORCHREC = True
except ImportError:
HAS_TORCHREC = False
@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
class BucketizeMod(torch.nn.Module):
def __init__(self, feature_boundaries: Dict[str, List[float]]):
super().__init__()
self.bucket_w = torch.nn.ParameterDict()
self.boundaries_dict = {}
for key, boundaries in feature_boundaries.items():
self.bucket_w[key] = torch.nn.Parameter(
torch.empty([len(boundaries) + 1]).fill_(1.0),
requires_grad=True,
)
buf = torch.tensor(boundaries, requires_grad=False)
self.register_buffer(
f"{key}_boundaries",
buf,
persistent=False,
)
self.boundaries_dict[key] = buf
def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
weights_list = []
for key, boundaries in self.boundaries_dict.items():
jt = features[key]
bucketized = torch.bucketize(jt.weights(), boundaries)
# doesn't super matter I guess
# hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries))
hashed = bucketized
weights = torch.gather(self.bucket_w[key], dim=0, index=hashed)
weights_list.append(weights)
return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
weights=torch.cat(weights_list),
lengths=features.lengths(),
offsets=features.offsets(),
stride=features.stride(),
length_per_key=features.length_per_key(),
)
if not HAS_TORCHREC:
print("torchrec not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec")
class TorchRecTests(TestCase):
def test_pooled(self):
tables = [
(nn.EmbeddingBag(2000, 8), ["a0", "b0"]),
(nn.EmbeddingBag(2000, 8), ["a1", "b1"]),
(nn.EmbeddingBag(2000, 8), ["b2"]),
]
embedding_groups = {
"a": ["a0", "a1"],
"b": ["b0", "b1", "b2"],
}
counter = CompileCounter()
@torch.compile(backend=counter, fullgraph=True, dynamic=True)
def f(id_list_features: KeyedJaggedTensor):
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
pooled_embeddings = {}
# TODO: run feature processor
for emb_module, feature_names in tables:
features_dict = id_list_jt_dict
for feature_name in feature_names:
f = features_dict[feature_name]
pooled_embeddings[feature_name] = emb_module(
f.values(), f.offsets()
)
pooled_embeddings_by_group = {}
for group_name, group_embedding_names in embedding_groups.items():
group_embeddings = [
pooled_embeddings[name] for name in group_embedding_names
]
pooled_embeddings_by_group[group_name] = torch.cat(
group_embeddings, dim=1
)
return pooled_embeddings_by_group
dataset = RandomRecDataset(
keys=["a0", "a1", "b0", "b1", "b2"],
batch_size=4,
hash_size=2000,
ids_per_feature=3,
num_dense=0,
)
di = iter(dataset)
# unsync should work
d1 = next(di).sparse_features.unsync()
d2 = next(di).sparse_features.unsync()
d3 = next(di).sparse_features.unsync()
r1 = f(d1)
r2 = f(d2)
r3 = f(d3)
self.assertEqual(counter.frame_count, 1)
counter.frame_count = 0
# sync should work too
d1 = next(di).sparse_features.sync()
d2 = next(di).sparse_features.sync()
d3 = next(di).sparse_features.sync()
r1 = f(d1)
r2 = f(d2)
r3 = f(d3)
self.assertEqual(counter.frame_count, 1)
# export only works with unsync
gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module
gm.print_readable()
self.assertEqual(gm(d1), r1)
self.assertEqual(gm(d2), r2)
self.assertEqual(gm(d3), r3)
def test_bucketize(self):
mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]})
features = KeyedJaggedTensor.from_lengths_sync(
keys=["f1"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]),
).unsync()
def f(x):
# This is a trick to populate the computed cache and instruct
# ShapeEnv that they're all sizey
x.to_dict()
return mod(x)
torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable()
@unittest.expectedFailure
def test_simple(self):
jag_tensor1 = KeyedJaggedTensor(
values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.tensor([0, 0, 1, 1, 1, 3]),
).sync()
# ordinarily, this would trigger one specialization
self.assertEqual(jag_tensor1.length_per_key(), [1, 5])
counter = CompileCounter()
@torch._dynamo.optimize(counter, nopython=True)
def f(jag_tensor):
# The indexing here requires more symbolic reasoning
# and doesn't work right now
return jag_tensor["index_0"].values().sum()
f(jag_tensor1)
self.assertEqual(counter.frame_count, 1)
jag_tensor2 = KeyedJaggedTensor(
values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
).sync()
f(jag_tensor2)
self.assertEqual(counter.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()