Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging in fused adam optimizer, additional DDP features tested in 18.10 #60

Merged
merged 17 commits into from
Oct 29, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Optional cpp extension build
  • Loading branch information
definitelynotmcarilli committed Oct 29, 2018
commit e8651fd047e41720195383f3b7c270f12d3cc286
16 changes: 11 additions & 5 deletions apex/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch
# from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import apex_C
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
print("Apex was built without --cpp_ext; falling back to Python flatten and unflatten")
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
Expand All @@ -11,7 +17,7 @@
# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):

coalesced = apex_C.flatten(bucket)
coalesced = flatten(bucket)

if extra_args is not None:
call(coalesced, *extra_args)
Expand All @@ -21,7 +27,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce:
coalesced /= dist.get_world_size()

for buf, synced in zip(bucket, apex_C.unflatten(coalesced, bucket)):
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)

def split_half_float_double(tensors):
Expand Down Expand Up @@ -331,7 +337,7 @@ def allreduce_hook(*unused):
wrapper(param)

def allreduce_bucket(self, bucket):
tensor = apex_C.flatten(bucket)
tensor = flatten(bucket)

tensor_to_allreduce = tensor

Expand Down Expand Up @@ -359,7 +365,7 @@ def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
"allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced
else:
for buf, synced in zip(bucket, apex_C.unflatten(allreduced, bucket)):
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)


Expand Down
18 changes: 10 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
cmdclass = {}
ext_modules = []

# if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CppExtension, BuildExtension
# sys.argv.remove("--cuda_ext")
cmdclass['build_ext'] = BuildExtension
ext_modules.append(
CppExtension('apex_C',
['csrc/flatten_unflatten.cpp',]))

if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension

if "--cpp_ext" in sys.argv:
from torch.utils.cpp_extension import CppExtension
sys.argv.remove("--cpp_ext")
ext_modules.append(
CppExtension('apex_C',
['csrc/flatten_unflatten.cpp',]))

setup(
name='apex',
Expand Down