-
Notifications
You must be signed in to change notification settings - Fork 744
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
39f99ac
commit 2089b02
Showing
4 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .repeat_aug import RepeatAugSampler | ||
|
||
__all__ = ['RepeatAugSampler'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import math | ||
from typing import Iterator, Optional, Sized | ||
|
||
import torch | ||
from mmcls.registry import DATA_SAMPLERS | ||
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed | ||
from torch.utils.data import Sampler | ||
|
||
|
||
@DATA_SAMPLERS.register_module() | ||
class RepeatAugSampler(Sampler): | ||
"""Sampler that restricts data loading to a subset of the dataset for | ||
distributed, with repeated augmentation. It ensures that different each | ||
augmented version of a sample will be visible to a different process (GPU). | ||
Heavily based on torch.utils.data.DistributedSampler. | ||
This sampler was taken from | ||
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py | ||
Used in | ||
Copyright (c) 2015-present, Facebook, Inc. | ||
Args: | ||
dataset (Sized): The dataset. | ||
shuffle (bool): Whether shuffle the dataset or not. Defaults to True. | ||
num_repeats (int): The repeat times of every sample. Defaults to 3. | ||
seed (int, optional): Random seed used to shuffle the sampler if | ||
:attr:`shuffle=True`. This number should be identical across all | ||
processes in the distributed group. Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
dataset: Sized, | ||
shuffle: bool = True, | ||
num_repeats: int = 3, | ||
seed: Optional[int] = None): | ||
rank, world_size = get_dist_info() | ||
self.rank = rank | ||
self.world_size = world_size | ||
|
||
self.dataset = dataset | ||
self.shuffle = shuffle | ||
if not self.shuffle and is_main_process(): | ||
from mmengine.logging import MMLogger | ||
logger = MMLogger.get_current_instance() | ||
logger.warning('The RepeatAugSampler always picks a ' | ||
'fixed part of data if `shuffle=False`.') | ||
|
||
if seed is None: | ||
seed = sync_random_seed() | ||
self.seed = seed | ||
self.epoch = 0 | ||
self.num_repeats = num_repeats | ||
|
||
# The number of repeated samples in the rank | ||
self.num_samples = math.ceil( | ||
len(self.dataset) * num_repeats / world_size) | ||
# The total number of repeated samples in all ranks. | ||
self.total_size = self.num_samples * world_size | ||
# The number of selected samples in the rank | ||
self.num_selected_samples = math.ceil(len(self.dataset) / world_size) | ||
|
||
def __iter__(self) -> Iterator[int]: | ||
"""Iterate the indices.""" | ||
# deterministically shuffle based on epoch and seed | ||
if self.shuffle: | ||
g = torch.Generator() | ||
g.manual_seed(self.seed + self.epoch) | ||
indices = torch.randperm(len(self.dataset), generator=g).tolist() | ||
else: | ||
indices = list(range(len(self.dataset))) | ||
|
||
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] | ||
indices = [x for x in indices for _ in range(self.num_repeats)] | ||
# add extra samples to make it evenly divisible | ||
padding_size = self.total_size - len(indices) | ||
indices += indices[:padding_size] | ||
assert len(indices) == self.total_size | ||
|
||
# subsample per rank | ||
indices = indices[self.rank:self.total_size:self.world_size] | ||
assert len(indices) == self.num_samples | ||
|
||
# return up to num selected samples | ||
return iter(indices[:self.num_selected_samples]) | ||
|
||
def __len__(self) -> int: | ||
"""The number of samples in this rank.""" | ||
return self.num_selected_samples | ||
|
||
def set_epoch(self, epoch: int) -> None: | ||
"""Sets the epoch for this sampler. | ||
When :attr:`shuffle=True`, this ensures all replicas use a different | ||
random ordering for each epoch. Otherwise, the next iteration of this | ||
sampler will yield the same ordering. | ||
Args: | ||
epoch (int): Epoch number. | ||
""" | ||
self.epoch = epoch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import math | ||
from unittest import TestCase | ||
from unittest.mock import patch | ||
|
||
import torch | ||
from mmengine.logging import MMLogger | ||
|
||
from mmocr.datasets import RepeatAugSampler | ||
|
||
file = 'mmocr.datasets.samplers.repeat_aug.' | ||
|
||
|
||
class MockDist: | ||
|
||
def __init__(self, dist_info=(0, 1), seed=7): | ||
self.dist_info = dist_info | ||
self.seed = seed | ||
|
||
def get_dist_info(self): | ||
return self.dist_info | ||
|
||
def sync_random_seed(self): | ||
return self.seed | ||
|
||
def is_main_process(self): | ||
return self.dist_info[0] == 0 | ||
|
||
|
||
class TestRepeatAugSampler(TestCase): | ||
|
||
def setUp(self): | ||
self.data_length = 100 | ||
self.dataset = list(range(self.data_length)) | ||
|
||
@patch(file + 'get_dist_info', return_value=(0, 1)) | ||
def test_non_dist(self, mock): | ||
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) | ||
self.assertEqual(sampler.world_size, 1) | ||
self.assertEqual(sampler.rank, 0) | ||
self.assertEqual(sampler.total_size, self.data_length * 3) | ||
self.assertEqual(sampler.num_samples, self.data_length * 3) | ||
self.assertEqual(sampler.num_selected_samples, self.data_length) | ||
self.assertEqual(len(sampler), sampler.num_selected_samples) | ||
indices = [x for x in range(self.data_length) for _ in range(3)] | ||
self.assertEqual(list(sampler), indices[:self.data_length]) | ||
|
||
logger = MMLogger.get_current_instance() | ||
with self.assertLogs(logger, 'WARN') as log: | ||
sampler = RepeatAugSampler(self.dataset, shuffle=False) | ||
self.assertIn('always picks a fixed part', log.output[0]) | ||
|
||
@patch(file + 'get_dist_info', return_value=(2, 3)) | ||
@patch(file + 'is_main_process', return_value=False) | ||
def test_dist(self, mock1, mock2): | ||
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) | ||
self.assertEqual(sampler.world_size, 3) | ||
self.assertEqual(sampler.rank, 2) | ||
self.assertEqual(sampler.num_samples, self.data_length) | ||
self.assertEqual(sampler.total_size, self.data_length * 3) | ||
self.assertEqual(sampler.num_selected_samples, | ||
math.ceil(self.data_length / 3)) | ||
self.assertEqual(len(sampler), sampler.num_selected_samples) | ||
indices = [x for x in range(self.data_length) for _ in range(3)] | ||
self.assertEqual( | ||
list(sampler), indices[2::3][:sampler.num_selected_samples]) | ||
|
||
logger = MMLogger.get_current_instance() | ||
with patch.object(logger, 'warning') as mock_log: | ||
sampler = RepeatAugSampler(self.dataset, shuffle=False) | ||
mock_log.assert_not_called() | ||
|
||
@patch(file + 'get_dist_info', return_value=(0, 1)) | ||
@patch(file + 'sync_random_seed', return_value=7) | ||
def test_shuffle(self, mock1, mock2): | ||
# test seed=None | ||
sampler = RepeatAugSampler(self.dataset, seed=None) | ||
self.assertEqual(sampler.seed, 7) | ||
|
||
# test random seed | ||
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0) | ||
sampler.set_epoch(10) | ||
g = torch.Generator() | ||
g.manual_seed(10) | ||
indices = torch.randperm(len(self.dataset), generator=g).tolist() | ||
indices = [x for x in indices | ||
for _ in range(3)][:sampler.num_selected_samples] | ||
self.assertEqual(list(sampler), indices) | ||
|
||
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42) | ||
sampler.set_epoch(10) | ||
g = torch.Generator() | ||
g.manual_seed(42 + 10) | ||
indices = torch.randperm(len(self.dataset), generator=g).tolist() | ||
indices = [x for x in indices | ||
for _ in range(3)][:sampler.num_selected_samples] | ||
self.assertEqual(list(sampler), indices) |