Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for criteo validation and test sets #670

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions torchrec/datasets/test_utils/criteo_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def _create_dataset_npys(
generate_dense: bool = True,
generate_sparse: bool = True,
generate_labels: bool = True,
dense: Optional[np.ndarray] = None,
sparse: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
) -> Generator[Tuple[str, ...], None, None]:
with tempfile.TemporaryDirectory() as tmpdir:

Expand All @@ -108,31 +111,34 @@ def _create_dataset_npys(

if generate_dense:
dense_path = os.path.join(tmpdir, filename + "_dense.npy")
dense = np.random.random((num_rows, INT_FEATURE_COUNT)).astype(
np.float32
)
if dense is None:
dense = np.random.random((num_rows, INT_FEATURE_COUNT)).astype(
np.float32
)
np.save(dense_path, dense)
paths.append(dense_path)

if generate_sparse:
sparse_path = os.path.join(tmpdir, filename + "_sparse.npy")
sparse = np.random.randint(
cls.CAT_VAL_RANGE[0],
cls.CAT_VAL_RANGE[1] + 1,
size=(num_rows, CAT_FEATURE_COUNT),
dtype=np.int32,
)
if sparse is None:
sparse = np.random.randint(
cls.CAT_VAL_RANGE[0],
cls.CAT_VAL_RANGE[1] + 1,
size=(num_rows, CAT_FEATURE_COUNT),
dtype=np.int32,
)
np.save(sparse_path, sparse)
paths.append(sparse_path)

if generate_labels:
labels_path = os.path.join(tmpdir, filename + "_labels.npy")
labels = np.random.randint(
cls.LABEL_VAL_RANGE[0],
cls.LABEL_VAL_RANGE[1] + 1,
size=(num_rows, 1),
dtype=np.int32,
)
if labels is None:
labels = np.random.randint(
cls.LABEL_VAL_RANGE[0],
cls.LABEL_VAL_RANGE[1] + 1,
size=(num_rows, 1),
dtype=np.int32,
)
np.save(labels_path, labels)
paths.append(labels_path)

Expand Down
70 changes: 60 additions & 10 deletions torchrec/datasets/tests/test_criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,25 +346,50 @@ def _validate_batch(
)

def _test_dataset(
self, rows_per_file: List[int], batch_size: int, world_size: int
self,
rows_per_file: List[int],
batch_size: int,
world_size: int,
stage: str = "train",
) -> None:
with contextlib.ExitStack() as stack:
num_rows = sum(rows_per_file)
if stage == "train":
dense, sparse, labels = None, None, None
else:
dense = np.mgrid[0:num_rows, 0:INT_FEATURE_COUNT][0]
sparse = np.mgrid[0:num_rows, 0:CAT_FEATURE_COUNT][0]
labels = np.ones((num_rows, 1))
files = [
stack.enter_context(self._create_dataset_npys(num_rows=num_rows))
stack.enter_context(
self._create_dataset_npys(
num_rows=num_rows, dense=dense, sparse=sparse, labels=labels
)
)
for num_rows in rows_per_file
]
hashes = [i + 1 for i in range(CAT_FEATURE_COUNT)]

total_rows = sum(rows_per_file)

incomplete_last_batch_size = total_rows // world_size % batch_size
num_batches = total_rows // world_size // batch_size + (
if stage == "train":
dataset_start = 0
dataset_len = num_rows
elif stage == "val":
dataset_start = 0
dataset_len = num_rows // 2 + num_rows % 2
else:
dataset_start = num_rows // 2 + num_rows % 2
dataset_len = num_rows // 2

incomplete_last_batch_size = dataset_len // world_size % batch_size
num_batches = dataset_len // world_size // batch_size + (
incomplete_last_batch_size != 0
)

lens = []
samples_counts = []
for rank in range(world_size):
datapipe = InMemoryBinaryCriteoIterDataPipe(
stage="train",
stage=stage,
dense_paths=[f[0] for f in files],
sparse_paths=[f[1] for f in files],
labels_paths=[f[2] for f in files],
Expand All @@ -374,25 +399,50 @@ def _test_dataset(
hashes=hashes,
)
datapipe_len = len(datapipe)
self.assertEqual(datapipe_len, num_batches)

len_ = 0
for x in datapipe:
samples_count = 0
for batch in datapipe:
if stage in ["val", "test"] and len_ == 0 and rank == 0:
self.assertEqual(
batch.dense_features[0, 0].item(),
dataset_start,
)
if len_ < num_batches - 1 or incomplete_last_batch_size == 0:
self._validate_batch(x, batch_size=batch_size)
self._validate_batch(batch, batch_size=batch_size)
else:
self._validate_batch(x, batch_size=incomplete_last_batch_size)
self._validate_batch(
batch, batch_size=incomplete_last_batch_size
)
len_ += 1
samples_count += batch.dense_features.shape[0]

# Check that dataset __len__ matches true length.
self.assertEqual(datapipe_len, len_)
lens.append(len_)
self.assertEqual(samples_count, dataset_len // world_size)
samples_counts.append(samples_count)

# Ensure all ranks' datapipes return the same number of batches.
self.assertEqual(len(set(lens)), 1)
self.assertEqual(len(set(samples_counts)), 1)

def test_dataset_small_files(self) -> None:
self._test_dataset([1] * 20, 4, 2)

def test_dataset_random_sized_files(self) -> None:
random.seed(0)
self._test_dataset([random.randint(1, 100) for _ in range(100)], 16, 3)

def test_dataset_val_and_test_sets(self) -> None:
for stage in ["train", "val", "test"]:
# Test cases where batch_size evenly divides dataset_len.
self._test_dataset([100], 1, 2, stage=stage)
self._test_dataset([101], 1, 2, stage=stage)
# Test cases where the first and only batch is an incomplete batch.
self._test_dataset([100], 32, 8, stage=stage)
self._test_dataset([101], 32, 8, stage=stage)
# Test cases where batches are full size followed by a last batch that is incomplete.
self._test_dataset([10000], 128, 8, stage=stage)
self._test_dataset([10001], 128, 8, stage=stage)