Skip to content

Commit

Permalink
support registering IterableDataset and build it from config.
Browse files Browse the repository at this point in the history
Summary:
Fix facebookresearch#4105

Pull Request resolved: facebookresearch#4180

Reviewed By: zhanghang1989

Differential Revision: D36045130

fbshipit-source-id: f2aa3bfdebf476737b6deec8eac93ef3043964b8
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed May 1, 2022
1 parent 0ad20f1 commit a5f2845
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
3 changes: 2 additions & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ C++ compilation errors from NVCC / NVRTC, or "Unsupported gpu architecture"
<br/>
A few possibilities:

1. Local CUDA/NVCC version has to match the CUDA version of your PyTorch. Both can be found in `python collect_env.py`.
1. Local CUDA/NVCC version has to match the CUDA version of your PyTorch. Both can be found in `python collect_env.py`
(download from [here](./detectron2/utils/collect_env.py)).
When they are inconsistent, you need to either install a different build of PyTorch (or build by yourself)
to match your local CUDA installation, or install a different version of CUDA to match PyTorch.

Expand Down
44 changes: 29 additions & 15 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ def get_detection_dataset_dicts(
names = [names]
assert len(names), names
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]

if isinstance(dataset_dicts[0], torchdata.Dataset):
if len(dataset_dicts) > 1:
# ConcatDataset does not work for iterable style dataset.
# We could support concat for iterable as well, but it's often
# not a good idea to concat iterables anyway.
return torchdata.ConcatDataset(dataset_dicts)
return dataset_dicts[0]

for dataset_name, dicts in zip(names, dataset_dicts):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)

Expand All @@ -250,9 +259,6 @@ def get_detection_dataset_dicts(
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
]

if isinstance(dataset_dicts[0], torchdata.Dataset):
return torchdata.ConcatDataset(dataset_dicts)

dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))

has_instances = "annotations" in dataset_dicts[0]
Expand Down Expand Up @@ -351,18 +357,24 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
if sampler is None:
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
logger = logging.getLogger(__name__)
logger.info("Using training sampler {}".format(sampler_name))
if sampler_name == "TrainingSampler":
sampler = TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
elif sampler_name == "RandomSubsetTrainingSampler":
sampler = RandomSubsetTrainingSampler(len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO)
if isinstance(dataset, torchdata.IterableDataset):
logger.info("Not using any sampler since the dataset is IterableDataset.")
sampler = None
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))
logger.info("Using training sampler {}".format(sampler_name))
if sampler_name == "TrainingSampler":
sampler = TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
elif sampler_name == "RandomSubsetTrainingSampler":
sampler = RandomSubsetTrainingSampler(
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))

return {
"dataset": dataset,
Expand Down Expand Up @@ -461,7 +473,9 @@ def _test_loader_from_config(cfg, dataset_name, mapper=None):
"dataset": dataset,
"mapper": mapper,
"num_workers": cfg.DATALOADER.NUM_WORKERS,
"sampler": InferenceSampler(len(dataset)),
"sampler": InferenceSampler(len(dataset))
if not isinstance(dataset, torchdata.IterableDataset)
else None,
}


Expand Down
19 changes: 18 additions & 1 deletion tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from iopath.common.file_io import LazyPath

from detectron2 import model_zoo
from detectron2.config import instantiate
from detectron2.config import get_cfg, instantiate
from detectron2.data import (
DatasetCatalog,
DatasetFromList,
MapDataset,
ToIterableDataset,
Expand Down Expand Up @@ -112,6 +113,22 @@ def test_build_iterable_dataloader_train(self):
dl = build_detection_train_loader(dataset=ds, **kwargs)
next(iter(dl))

def test_build_iterable_dataloader_from_cfg(self):
cfg = get_cfg()

class MyData(torch.utils.data.IterableDataset):
def __iter__(self):
while True:
yield 1

cfg.DATASETS.TRAIN = ["iter_data"]
DatasetCatalog.register("iter_data", lambda: MyData())
dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False)
next(iter(dl))

dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x)
next(iter(dl))

def _check_is_range(self, data_loader, N):
# check that data_loader produces range(N)
data = list(iter(data_loader))
Expand Down

0 comments on commit a5f2845

Please sign in to comment.