Skip to content

Commit

Permalink
transforms: augmentation level from cmdline
Browse files Browse the repository at this point in the history
  • Loading branch information
akors committed May 3, 2023
1 parent 7d2eb36 commit 3e7eed0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
7 changes: 4 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def make_datasets(datadir: str="./data/", years=["2012"], augment_level: int=0):
ds_train_list = list()
ds_val_list = list()


tr_train = transforms.make_transforms(
transforms.PASCAL_VOC_2012_MEAN, transforms.PASCAL_VOC_2012_STD, augment=(augment_level > 0))
transforms.PASCAL_VOC_2012_MEAN, transforms.PASCAL_VOC_2012_STD, augment_level=augment_level)

# validation transforms dont get data augmentation
tr_val = transforms.make_transforms(
transforms.PASCAL_VOC_2012_MEAN, transforms.PASCAL_VOC_2012_STD, augment=False)
transforms.PASCAL_VOC_2012_MEAN, transforms.PASCAL_VOC_2012_STD, augment_level=0)

for year in years:
# create datasets with our transforms. assume they're already downloaded
Expand Down
9 changes: 6 additions & 3 deletions transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def do(self, x):
x = torchvision.datapoints.Mask(torch.where(x <= 20, x, 0))
return x

def make_transforms(mean, std, augment=False):
def make_transforms(mean, std, augment_level=0):
# apply anti-aliasing for resize operations, this will be skipped automagically for masks of type
# torchvision.datapoints.Mask
antialias = True
Expand All @@ -82,11 +82,14 @@ def make_transforms(mean, std, augment=False):

oplist.append(T.ToImageTensor())

if not augment:
if not augment_level:
oplist.append(T.Resize(size=256, antialias=antialias))
oplist.append(T.CenterCrop(256))
else:

if augment_level >= 1:
oplist.append(T.RandomResizedCrop(size=256, scale=(0.3, 1.0), ratio=(1,1), antialias=antialias))

assert augment_level <= 1, "Augmentation level "+str(augment_level)+"?? What is this, the future??"

oplist.append(T.ConvertImageDtype(torch.float32))
oplist.append(T.Normalize(mean=mean, std=std))
Expand Down

0 comments on commit 3e7eed0

Please sign in to comment.