Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
fix pdb LRU cache (#1545)
Browse files Browse the repository at this point in the history
  • Loading branch information
clonker committed Mar 1, 2022
1 parent 8c2bc84 commit f6a7a7d
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pyemma/coordinates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
'assign_to_centers',
]

_string_types = str
_string_types = (str, Path)

# ==============================================================================
#
Expand Down
16 changes: 10 additions & 6 deletions pyemma/coordinates/data/feature_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@


from copy import copy
from pathlib import Path

import mdtraj
import numpy as np

from pyemma._base.serialization.serialization import SerializableMixIn
from pyemma.coordinates.data._base.datasource import DataSource, EncapsulatedIterator
from pyemma.coordinates.data._base.random_accessible import RandomAccessStrategy
from pyemma.coordinates.data.featurization.featurizer import MDFeaturizer
from pyemma.coordinates.data.util.reader_utils import file_suffix
from pyemma.coordinates.data.util.traj_info_cache import TrajInfo
from pyemma.coordinates.util import patches
from pyemma.util.annotators import deprecated, fix_docs
Expand Down Expand Up @@ -91,15 +94,15 @@ def __init__(self, trajectories, topologyfile=None, chunksize=1000, featurizer=N
super(FeatureReader, self).__init__(chunksize=chunksize)
self._is_reader = True
self.topfile = topologyfile
self.filenames = copy(trajectories) # this is modified in-place in mdtraj.load
if not isinstance(trajectories, (list, tuple)):
trajectories = [trajectories]
self.filenames = copy([str(traj) for traj in trajectories]) # this is modified in-place in mdtraj.load
self._return_traj_obj = False

self._is_random_accessible = all(
(f.endswith(FeatureReader.SUPPORTED_RANDOM_ACCESS_FORMATS)
for f in self.filenames)
)
self._is_random_accessible = all(file_suffix(f) in FeatureReader.SUPPORTED_RANDOM_ACCESS_FORMATS
for f in self.filenames)
# check we have at least mdtraj-1.6.1 to efficiently seek xtc, trr formats
if any(f.endswith('.xtc') or f.endswith('.trr') for f in trajectories):
if any(file_suffix(f) == '.xtc' or file_suffix(f) == '.trr' for f in trajectories):
from distutils.version import LooseVersion
xtc_trr_random_accessible = True if LooseVersion(mdtraj.version.version) >= LooseVersion('1.6.1') else False
self._is_random_accessible &= xtc_trr_random_accessible
Expand Down Expand Up @@ -128,6 +131,7 @@ def trajfiles(self):
return self.filenames

def _get_traj_info(self, filename):
filename = str(filename) if isinstance(filename, Path) else filename
with mdtraj.open(filename, mode='r') as fh:
try:
length = len(fh)
Expand Down
5 changes: 3 additions & 2 deletions pyemma/coordinates/data/featurization/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import warnings
from pathlib import Path

from pyemma._base.loggable import Loggable
from pyemma._base.serialization.serialization import SerializableMixIn
Expand Down Expand Up @@ -68,8 +69,8 @@ def topologyfile(self):
@topologyfile.setter
def topologyfile(self, topfile):
self._topologyfile = topfile
if isinstance(topfile, str):
self.topology = load_topology_cached(topfile)
if isinstance(topfile, (Path, str)):
self.topology = load_topology_cached(str(topfile))
self._topologyfile = topfile
elif isinstance(topfile, mdtraj.Topology):
self.topology = topfile
Expand Down
4 changes: 3 additions & 1 deletion pyemma/coordinates/data/h5_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(self, filenames, selection='/*', chunk_size=5000, **kw):
# and the interface of the cache does not allow for such a mapping (1:1 relation filename:(dimension, len)).
from pyemma.util.contexts import settings
with settings(use_trajectory_lengths_cache=False):
self.filenames = filenames
if not isinstance(filenames, (list, tuple)):
filenames = [filenames]
self.filenames = [str(fname) for fname in filenames]

# we need to override the ntraj attribute to be equal with the itraj_counter to respect all data sets.
self._ntraj = self._itraj_counter
Expand Down
5 changes: 3 additions & 2 deletions pyemma/coordinates/data/numpy_filereader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


import functools
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -57,12 +58,12 @@ def __init__(self, filenames, chunksize=1000, mmap_mode='r'):
filenames = [filenames]

for f in filenames:
if not f.endswith('.npy'):
if Path(f).suffix != '.npy':
raise ValueError('given file "%s" is not supported by this'
' reader, since it does not end with .npy' % f)

self.mmap_mode = mmap_mode
self.filenames = filenames
self.filenames = [str(fname) for fname in filenames]

def _create_iterator(self, skip=0, chunk=0, stride=1, return_trajindex=False, cols=None):
return NPYIterator(self, skip=skip, chunk=chunk, stride=stride,
Expand Down
6 changes: 4 additions & 2 deletions pyemma/coordinates/data/py_csv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import csv
import os
from math import ceil
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -202,8 +203,9 @@ def __init__(self, filenames, chunksize=1000, delimiters=None, comments='#',

if isinstance(filenames, (tuple, list)):
n = len(filenames)
elif isinstance(filenames, str):
elif isinstance(filenames, (str, Path)):
n = 1
filenames = [filenames]
else:
raise TypeError("'filenames' argument has to be list, tuple or string")
self._comments = PyCSVReader.__parse_args(comments, '#', n)
Expand All @@ -216,7 +218,7 @@ def __init__(self, filenames, chunksize=1000, delimiters=None, comments='#',

self._skip = np.zeros(n, dtype=int)
# invoke filename setter
self.filenames = filenames
self.filenames = [str(fname) for fname in filenames]

@staticmethod
def __parse_args(arg, default, n):
Expand Down
28 changes: 17 additions & 11 deletions pyemma/coordinates/data/util/reader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@
from numpy import vstack
import mdtraj as md
import numpy as np
import os


def file_suffix(path):
r""" Returns the suffix of a path. The path may be any kind of object that can be converted into a pathlib Path
object. """
if not isinstance(path, Path):
path = Path(path)
return path.suffix


def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw):
r"""
Expand Down Expand Up @@ -51,15 +57,14 @@ def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw):
return FragmentedTrajectoryReader(input_files, topology, chunksize, featurizer)

# normal trajectories
if (isinstance(input_files, str)
if (isinstance(input_files, (Path, str))
or (isinstance(input_files, (list, tuple))
and (any(isinstance(item, str) for item in input_files)
and (any(isinstance(item, (Path, str)) for item in input_files)
or len(input_files) == 0))):
reader = None
# check: if single string create a one-element list
if isinstance(input_files, str):
if isinstance(input_files, (Path, str)):
input_list = [input_files]
elif len(input_files) > 0 and all(isinstance(item, str) for item in input_files):
elif len(input_files) > 0 and all(isinstance(item, (Path, str)) for item in input_files):
input_list = input_files
else:
if len(input_files) == 0:
Expand All @@ -68,20 +73,21 @@ def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw):
raise ValueError("The passed list did not exclusively contain strings or was a list of lists "
"(fragmented trajectory).")

# TODO: this does not handle suffixes like .xyz.gz (rare)
_, suffix = os.path.splitext(input_list[0])
# convert to list of paths
input_list = [Path(f) for f in input_list]

suffix = str(suffix)
# TODO: this does not handle suffixes like .xyz.gz (rare)
suffix = input_list[0].suffix

# check: do all files have the same file type? If not: raise ValueError.
if all(item.endswith(suffix) for item in input_list):
if all(item.suffix == suffix for item in input_list):

# do all the files exist? If not: Raise value error
all_exist = True
from six import StringIO
err_msg = StringIO()
for item in input_list:
if not os.path.isfile(item):
if not item.is_file():
err_msg.write('\n' if err_msg.tell() > 0 else "")
err_msg.write('File %s did not exist or was no file' % item)
all_exist = False
Expand Down
11 changes: 6 additions & 5 deletions pyemma/coordinates/data/util/traj_info_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import warnings
from io import BytesIO
from logging import getLogger
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -180,8 +181,10 @@ def _handle_csv(self, reader, filename, length):

def __getitem__(self, filename_reader_tuple):
filename, reader = filename_reader_tuple
if isinstance(filename, Path):
filename = str(filename)
abs_path = os.path.abspath(filename)
key = self._get_file_hash_v2(filename)
key = self.compute_file_hash(abs_path)
try:
info = self._database.get(key)
if not isinstance(info, TrajInfo):
Expand Down Expand Up @@ -226,15 +229,13 @@ def _get_file_hash(self, filename):
hash_value ^= hash(data)
return str(hash_value)

def _get_file_hash_v2(self, filename):
@staticmethod
def compute_file_hash(filename):
statinfo = os.stat(filename)
# now read the first megabyte and hash it
with open(filename, mode='rb') as fh:
data = fh.read(1024)

if sys.version_info > (3,):
long = int

hasher = hashlib.md5()
hasher.update(os.path.basename(filename).encode('utf-8'))
hasher.update(str(statinfo.st_mtime).encode('ascii'))
Expand Down
26 changes: 23 additions & 3 deletions pyemma/coordinates/tests/test_traj_info_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
@author: marscher
'''


import shutil
from pathlib import Path
from tempfile import NamedTemporaryFile

import os
Expand All @@ -35,7 +35,7 @@
from pyemma.coordinates.data.py_csv_reader import PyCSVReader
from pyemma.coordinates.data.util.traj_info_backends import SqliteDB
from pyemma.coordinates.data.util.traj_info_cache import TrajectoryInfoCache
from pyemma.coordinates.tests.util import create_traj
from pyemma.coordinates.tests.util import create_traj, get_top
from pyemma.datasets import get_bpti_test_data
from pyemma.util import config
from pyemma.util.contexts import settings
Expand Down Expand Up @@ -268,6 +268,26 @@ def test_max_n_entries(self):
self.assertLessEqual(self.db.num_entries, max_entries)
self.assertGreater(self.db.num_entries, 0)

def test_cache_miss_same_filename(self):
# reproduces issue #1541
tmpdir = None
try:
fname_pdb = os.path.basename(pdbfile)
fname_xtc = os.path.basename(xtcfiles[0])
tmpdir = Path(tempfile.mkdtemp())
shutil.copyfile(pdbfile, tmpdir / fname_pdb)
shutil.copyfile(xtcfiles[0], tmpdir / fname_xtc)
_ = pyemma.coordinates.source(tmpdir / fname_xtc, top=tmpdir / fname_pdb)
shutil.copyfile(get_top(), tmpdir / fname_pdb) # overwrite pdb

t = mdtraj.load(tmpdir / fname_pdb)
t.xyz = np.zeros(shape=(400, 3, 3))
t.time = np.arange(len(t.xyz))
t.save(tmpdir / fname_xtc, force_overwrite=True)
_ = pyemma.coordinates.source(tmpdir / fname_xtc, top=tmpdir / fname_pdb)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

def test_max_size(self):
data = [np.random.random((150, 10)) for _ in range(150)]
max_size = 1
Expand Down
6 changes: 4 additions & 2 deletions pyemma/coordinates/util/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,19 @@
from mdtraj.utils import in_units_of
from mdtraj.utils.validation import cast_indices

from pyemma.coordinates.data.util.traj_info_cache import TrajectoryInfoCache

TrajData = namedtuple("traj_data", ('xyz', 'unitcell_lengths', 'unitcell_angles', 'box'))


@lru_cache(maxsize=32)
def _load(top_file):
def _load(top_file, hash):
return load_topology(top_file)


def load_topology_cached(top_file):
if isinstance(top_file, str):
return _load(top_file)
return _load(top_file, TrajectoryInfoCache.compute_file_hash(top_file))
if isinstance(top_file, Topology):
return top_file
if isinstance(top_file, Trajectory):
Expand Down

0 comments on commit f6a7a7d

Please sign in to comment.