From 37cdaf4ad57ab4e7dd9ef13dbed7b29aa939d061 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 6 Nov 2019 15:53:25 -0800 Subject: [PATCH] fixing batchnorm 1d input (#590) --- apex/parallel/optimized_sync_batchnorm.py | 4 ++-- .../synced_batchnorm/test_batchnorm1d.py | 18 ++++++++++++++++++ .../distributed/synced_batchnorm/unit_test.sh | 1 + 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 tests/distributed/synced_batchnorm/test_batchnorm1d.py diff --git a/apex/parallel/optimized_sync_batchnorm.py b/apex/parallel/optimized_sync_batchnorm.py index 56fd9784c..65cf5eabf 100644 --- a/apex/parallel/optimized_sync_batchnorm.py +++ b/apex/parallel/optimized_sync_batchnorm.py @@ -71,7 +71,7 @@ def forward(self, input, z = None): # if input.dim() == 2, we switch to channel_last for efficient memory accessing channel_last = self.channel_last if input.dim() != 2 else True - if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None: + if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None: # fall back to pytorch implementation for inference return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) else: @@ -82,4 +82,4 @@ def forward(self, input, z = None): exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: exponential_average_factor = self.momentum - return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu) + return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu) diff --git a/tests/distributed/synced_batchnorm/test_batchnorm1d.py b/tests/distributed/synced_batchnorm/test_batchnorm1d.py new file mode 100644 index 000000000..f35ac4739 --- /dev/null +++ b/tests/distributed/synced_batchnorm/test_batchnorm1d.py @@ -0,0 +1,18 @@ +import torch +import apex + +model = apex.parallel.SyncBatchNorm(4).cuda() +model.weight.data.uniform_() +model.bias.data.uniform_() +data = torch.rand((8,4)).cuda() + +model_ref = torch.nn.BatchNorm1d(4).cuda() +model_ref.load_state_dict(model.state_dict()) +data_ref = data.clone() + +output = model(data) +output_ref = model_ref(data_ref) + +assert(output.allclose(output_ref)) +assert(model.running_mean.allclose(model_ref.running_mean)) +assert(model.running_var.allclose(model_ref.running_var)) diff --git a/tests/distributed/synced_batchnorm/unit_test.sh b/tests/distributed/synced_batchnorm/unit_test.sh index 281c84d47..fdf45dc60 100755 --- a/tests/distributed/synced_batchnorm/unit_test.sh +++ b/tests/distributed/synced_batchnorm/unit_test.sh @@ -1,5 +1,6 @@ python python_single_gpu_unit_test.py python single_gpu_unit_test.py +python test_batchnorm1d.py python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 #beware, you need a system with at least 4 gpus to test group_size