Skip to content

Commit

Permalink
add shared storage in windows (apache#8967)
Browse files Browse the repository at this point in the history
* add shared storage in windows

* fix

* lint

* fix

* fix

* fix

* fix process.h
  • Loading branch information
yajiedesign authored and piiswrong committed Dec 18, 2017
1 parent 5858d62 commit df25378
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 13 deletions.
1 change: 1 addition & 0 deletions amalgamation/amalgamation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

if platform.system() != 'Windows':
blacklist.append('windows.h')
blacklist.append('process.h')

def pprint(lst):
for item in lst:
Expand Down
10 changes: 3 additions & 7 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from multiprocessing.reduction import ForkingPickler
import pickle
import io
import os
import sys
import warnings
import numpy as np

from . import sampler as _sampler
Expand All @@ -52,7 +50,7 @@ class ConnectionWrapper(object):
NDArray via shared memory."""

def __init__(self, conn):
self.conn = conn
self._conn = conn

def send(self, obj):
"""Send object"""
Expand All @@ -67,7 +65,8 @@ def recv(self):

def __getattr__(self, name):
"""Emmulate conn"""
return getattr(self.conn, name)
attr = self.__dict__.get('_conn', None)
return getattr(attr, name)


class Queue(multiprocessing.queues.Queue):
Expand Down Expand Up @@ -188,9 +187,6 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
"not be specified if batch_sampler is specified.")

self._batch_sampler = batch_sampler
if num_workers > 0 and os.name == 'nt':
warnings.warn("DataLoader does not support num_workers > 0 on Windows yet.")
num_workers = 0
self._num_workers = num_workers
if batchify_fn is None:
if num_workers > 0:
Expand Down
75 changes: 69 additions & 6 deletions src/storage/cpu_shared_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#else
#include <Windows.h>
#include <process.h>
#endif // _WIN32

#include <unordered_map>
Expand Down Expand Up @@ -64,6 +67,9 @@ class CPUSharedStorageManager final : public StorageManager {
for (const auto& kv : pool_) {
FreeImpl(kv.second);
}
#ifdef _WIN32
CheckAndRealFree();
#endif
}

void Alloc(Storage::Handle* handle) override;
Expand Down Expand Up @@ -91,11 +97,18 @@ class CPUSharedStorageManager final : public StorageManager {
private:
static constexpr size_t alignment_ = 16;

std::mutex mutex_;
std::recursive_mutex mutex_;
std::mt19937 rand_gen_;
std::unordered_map<void*, Storage::Handle> pool_;
#ifdef _WIN32
std::unordered_map<void*, Storage::Handle> is_free_;
std::unordered_map<void*, HANDLE> map_handle_map_;
#endif

void FreeImpl(const Storage::Handle& handle);
#ifdef _WIN32
void CheckAndRealFree();
#endif

std::string SharedHandleToString(int shared_pid, int shared_id) {
std::stringstream name;
Expand All @@ -106,14 +119,44 @@ class CPUSharedStorageManager final : public StorageManager {
}; // class CPUSharedStorageManager

void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max());
int fid = -1;
bool is_new = false;
size_t size = handle->size + alignment_;
void* ptr = nullptr;
#ifdef _WIN32
LOG(FATAL) << "Shared memory is not supported on Windows yet.";
void *ptr = nullptr;
#ifdef _WIN32
CheckAndRealFree();
HANDLE map_handle = nullptr;
uint32_t error = 0;
if (handle->shared_id == -1 && handle->shared_pid == -1) {
is_new = true;
handle->shared_pid = _getpid();
for (int i = 0; i < 10; ++i) {
handle->shared_id = dis(rand_gen_);
auto filename = SharedHandleToString(handle->shared_pid, handle->shared_id);
map_handle = CreateFileMapping(INVALID_HANDLE_VALUE,
NULL, PAGE_READWRITE, 0, size, filename.c_str());
if ((error = GetLastError()) == ERROR_SUCCESS) {
break;;
}
}
} else {
auto filename = SharedHandleToString(handle->shared_pid, handle->shared_id);
map_handle = OpenFileMapping(FILE_MAP_READ | FILE_MAP_WRITE,
FALSE, filename.c_str());
error = GetLastError();
}

if (error != ERROR_SUCCESS && map_handle == nullptr) {
LOG(FATAL) << "Failed to open shared memory. CreateFileMapping failed with error "
<< error;
}

ptr = MapViewOfFile(map_handle, FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
CHECK_NE(ptr, (void *)0)
<< "Failed to map shared memory. MapViewOfFile failed with error " << GetLastError();
map_handle_map_[ptr] = map_handle;
#else
if (handle->shared_id == -1 && handle->shared_pid == -1) {
is_new = true;
Expand Down Expand Up @@ -153,7 +196,7 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
int count = DecrementRefCount(handle);
CHECK_GE(count, 0);
#ifdef _WIN32
LOG(FATAL) << "Shared memory is not supported on Windows yet.";
is_free_[handle.dptr] = handle;
#else
CHECK_EQ(munmap(static_cast<char*>(handle.dptr) - alignment_,
handle.size + alignment_), 0)
Expand All @@ -169,6 +212,26 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
#endif // _WIN32
}

#ifdef _WIN32
inline void CPUSharedStorageManager::CheckAndRealFree() {
std::lock_guard<std::recursive_mutex> lock(mutex_);
for (auto it = std::begin(is_free_); it != std::end(is_free_);) {
void* ptr = static_cast<char*>(it->second.dptr) - alignment_;
std::atomic<int>* counter = reinterpret_cast<std::atomic<int>*>(
static_cast<char*>(it->second.dptr) - alignment_);
if ((*counter) == 0) {
CHECK_NE(UnmapViewOfFile(ptr), 0)
<< "Failed to UnmapViewOfFile shared memory ";
CHECK_NE(CloseHandle(map_handle_map_[ptr]), 0)
<< "Failed to CloseHandle shared memory ";
map_handle_map_.erase(ptr);
it = is_free_.erase(it);
} else {
++it;
}
}
}
#endif // _WIN32
} // namespace storage
} // namespace mxnet

Expand Down

0 comments on commit df25378

Please sign in to comment.