Skip to content

Commit

Permalink
Enable generalized sharing capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 15, 2019
1 parent 7ae38bc commit 1acbfdd
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
3 changes: 2 additions & 1 deletion lenskit/algorithms/item_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numba import njit, prange, objmode

from lenskit import util, matrix, DataWarning
from lenskit.sharing import in_share_context
from lenskit.util.accum import kvp_minheap_insert, kvp_minheap_sort
from . import Predictor

Expand Down Expand Up @@ -602,7 +603,7 @@ def _count_viable_targets(self, targets, rated):

def __getstate__(self):
state = dict(self.__dict__)
if '_sim_inv_' in state:
if '_sim_inv_' in state and not in_share_context():
del state['_sim_inv_']
return state

Expand Down
4 changes: 3 additions & 1 deletion lenskit/batch/_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd

from .. import util
from ..sharing import sharing_mode

_logger = logging.getLogger(__name__)
_rec_context = None
Expand Down Expand Up @@ -97,7 +98,8 @@ def predict(algo, pairs, *, n_jobs=None, **kwargs):
path = pathlib.Path(path)
os.close(fd)
_logger.debug('pre-serializing algorithm %s to %s', algo, path)
dump(algo, path)
with sharing_mode():
dump(algo, path)
algo = _AlgoKey('file', path)

nusers = pairs['user'].nunique()
Expand Down
4 changes: 3 additions & 1 deletion lenskit/batch/_recommend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..algorithms import Recommender
from .. import util
from ..sharing import sharing_mode

_logger = logging.getLogger(__name__)
_AlgoKey = namedtuple('AlgoKey', ['type', 'data'])
Expand Down Expand Up @@ -128,7 +129,8 @@ def recommend(algo, users, n, candidates=None, *, n_jobs=None, dask_result=False
path = pathlib.Path(path)
os.close(fd)
_logger.debug('pre-serializing algorithm %s to %s', rec_algo, path)
dump(rec_algo, path)
with sharing_mode():
dump(rec_algo, path)
rec_algo = _AlgoKey('file', path)

_logger.info('recommending with %s for %d users (n_jobs=%s)', astr, len(users), n_jobs)
Expand Down
31 changes: 31 additions & 0 deletions lenskit/sharing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Support for sharing and saving models and data structures.
"""

from contextlib import contextmanager

__save_mode = 'save'


@contextmanager
def sharing_mode():
"""
Context manager to tell models that pickling will be used for cross-process
sharing, not model persistence.
"""
global __save_mode
old = __save_mode
__save_mode = 'share'
try:
yield
finally:
__save_mode = old


def in_share_context():
"""
Query whether sharing mode is active. If ``True``, we are currently in a
:fun:`sharing_mode` context, which means model pickling will be used for
cross-process sharing.
"""
return __save_mode == 'share'

0 comments on commit 1acbfdd

Please sign in to comment.