Skip to content

Commit

Permalink
fixing batchnorm 1d input (NVIDIA#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored and mcarilli committed Nov 6, 2019
1 parent 606c3dc commit 37cdaf4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
4 changes: 2 additions & 2 deletions apex/parallel/optimized_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True

if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
Expand All @@ -82,4 +82,4 @@ def forward(self, input, z = None):
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu)
18 changes: 18 additions & 0 deletions tests/distributed/synced_batchnorm/test_batchnorm1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import apex

model = apex.parallel.SyncBatchNorm(4).cuda()
model.weight.data.uniform_()
model.bias.data.uniform_()
data = torch.rand((8,4)).cuda()

model_ref = torch.nn.BatchNorm1d(4).cuda()
model_ref.load_state_dict(model.state_dict())
data_ref = data.clone()

output = model(data)
output_ref = model_ref(data_ref)

assert(output.allclose(output_ref))
assert(model.running_mean.allclose(model_ref.running_mean))
assert(model.running_var.allclose(model_ref.running_var))
1 change: 1 addition & 0 deletions tests/distributed/synced_batchnorm/unit_test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
python python_single_gpu_unit_test.py
python single_gpu_unit_test.py
python test_batchnorm1d.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16
#beware, you need a system with at least 4 gpus to test group_size<world_size
Expand Down

0 comments on commit 37cdaf4

Please sign in to comment.