Skip to content

Commit

Permalink
4th commit
Browse files Browse the repository at this point in the history
  • Loading branch information
NJUyued committed Sep 11, 2022
1 parent 99c9305 commit 675cc08
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions datasets/ssl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,13 @@ def get_data(self):
if self.name=='stl10':
if self.train:
dset = getattr(torchvision.datasets, self.name.upper())
dset_lb = dset(self.data_dir, split='train', folds=self.fold, download=True)
if self.fold in [0,1,2,3,4]:
dset_lb = dset(self.data_dir, split='train', folds=self.fold, download=True)
dset_lb_ulb = dset(self.data_dir, split='train+unlabeled', folds=self.fold, download=True)
else:
dset_lb = dset(self.data_dir, split='train', download=True)
dset_lb_ulb = dset(self.data_dir, split='unlabeled', download=True)
data_lb, targets_lb = dset_lb.data, dset_lb.labels
dset_lb_ulb = dset(self.data_dir, split='train+unlabeled', folds=self.fold, download=True)
data_lb_ulb, targets_lb_ulb = dset_lb_ulb.data, dset_lb_ulb.labels
return data_lb, targets_lb, data_lb_ulb, targets_lb_ulb
else:
Expand Down Expand Up @@ -185,7 +189,7 @@ def get_ssl_dset(self, num_labels, index=None, include_lb_to_ulb=True,
lbs = []
for c in range(num_classes):
idx = np.where(targets_lb == c)[0]
idx = np.random.choice(idx, len(idx), False) if num_labels==1000 else np.random.choice(idx, samples_per_class, False)
idx = np.random.choice(idx, len(idx), False) if num_labels==1000 or num_labels==5000 else np.random.choice(idx, samples_per_class, False)
temp_data = data_lb[idx]
temp_lb = targets_lb[idx]
lb_data.extend(temp_data)
Expand Down

0 comments on commit 675cc08

Please sign in to comment.