Skip to content

Commit

Permalink
Support empty features
Browse files Browse the repository at this point in the history
  • Loading branch information
rossant committed Aug 4, 2020
1 parent 947b486 commit 8c25e22
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
16 changes: 15 additions & 1 deletion phy/apps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,16 @@ def _get_feature_view_spike_ids(self, cluster_id=None, load_all=False):
"""Return some or all spikes belonging to a given cluster."""
if cluster_id is None:
spike_ids = self.get_background_spike_ids(self.n_spikes_features_background)
else:
# Compute features on the fly from spike waveforms.
elif self.model.features is None and self.model.spike_waveforms is not None:
spike_ids = self.get_spike_ids(cluster_id)
assert len(spike_ids)
spike_ids = np.intersect1d(spike_ids, self.model.spike_waveforms.spike_ids)
if len(spike_ids) == 0:
logger.debug("empty spikes for cluster %s", str(cluster_id))
return spike_ids
# Retrieve features from the self.model.features array.
elif self.model.features is not None:
# Load all spikes from the cluster if load_all is True.
n = self.n_spikes_features if not load_all else None
spike_ids = self.get_spike_ids(cluster_id, n=n)
Expand All @@ -341,13 +350,16 @@ def _get_feature_view_spike_ids(self, cluster_id=None, load_all=False):
def _get_feature_view_spike_times(self, cluster_id=None, load_all=False):
"""Return the times of some or all spikes belonging to a given cluster."""
spike_ids = self._get_feature_view_spike_ids(cluster_id, load_all=load_all)
if len(spike_ids) == 0:
return
spike_times = self._get_spike_times_reordered(spike_ids)
return Bunch(
data=spike_times,
spike_ids=spike_ids,
lim=(0., self.model.duration))

def _get_spike_features(self, spike_ids, channel_ids):
assert len(spike_ids)
data = self.model.get_features(spike_ids, channel_ids)
assert data.shape[:2] == (len(spike_ids), len(channel_ids))
# Replace NaN values by zeros.
Expand All @@ -361,6 +373,8 @@ def _get_spike_features(self, spike_ids, channel_ids):
def _get_features(self, cluster_id=None, channel_ids=None, load_all=False):
"""Return the features of a given cluster on specified channels."""
spike_ids = self._get_feature_view_spike_ids(cluster_id, load_all=load_all)
if len(spike_ids) == 0:
return Bunch()
# Use the best channels only if a cluster is specified and
# channels are not specified.
if cluster_id is not None and channel_ids is None:
Expand Down
20 changes: 16 additions & 4 deletions phy/cluster/views/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def _get_axis_bounds(self, dim, bunch):
return (-self._lim, +self._lim)

def _plot_points(self, bunch, clu_idx=None):
if not bunch:
return
cluster_id = self.cluster_ids[clu_idx] if clu_idx is not None else None
for i, j, dim_x, dim_y in self._iter_subplots():
px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id)
Expand Down Expand Up @@ -278,6 +280,8 @@ def _plot_axes(self):
self.canvas.update_visual(self.line_visual)

def _get_lim(self, bunchs):
if not bunchs: # pragma: no cover
return 1
m, M = min(bunch.data.min() for bunch in bunchs), max(bunch.data.max() for bunch in bunchs)
M = max(abs(m), abs(M))
return M
Expand All @@ -303,15 +307,18 @@ def get_clusters_data(self, fixed_channels=None, load_all=None):
# choose the first cluster's best channels.
c = self.channel_ids if fixed_channels else None
bunchs = [self.features(cluster_id, channel_ids=c) for cluster_id in self.cluster_ids]
bunchs = [b for b in bunchs if b]
if not bunchs: # pragma: no cover
return []
for cluster_id, bunch in zip(self.cluster_ids, bunchs):
bunch.cluster_id = cluster_id

# Choose the channels based on the first selected cluster.
channel_ids = list(bunchs[0].channel_ids) if bunchs else []
channel_ids = list(bunchs[0].get('channel_ids', [])) if bunchs else []
common_channels = list(channel_ids)
# Intersection (with order kept) of channels belonging to all clusters.
for bunch in bunchs:
common_channels = [c for c in bunch.channel_ids if c in common_channels]
common_channels = [c for c in bunch.get('channel_ids', []) if c in common_channels]
# The selected channels will be (1) the channels common to all clusters, followed
# by (2) remaining channels from the first cluster (excluding those already selected
# in (1)).
Expand All @@ -328,9 +335,9 @@ def get_clusters_data(self, fixed_channels=None, load_all=None):
# Channel labels.
self.channel_labels = {}
for d in bunchs:
chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
chl = d.get('channel_labels', ['%d' % ch for ch in d.get('channel_ids', [])])
self.channel_labels.update({
channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids)})
channel_id: chl[i] for i, channel_id in enumerate(d.get('channel_ids', []))})

return bunchs

Expand All @@ -346,6 +353,9 @@ def plot(self, **kwargs):

# Get the clusters data.
bunchs = self.get_clusters_data(fixed_channels=fixed_channels)
bunchs = [b for b in bunchs if b]
if not bunchs:
return
self._lim = self._get_lim(bunchs)

# Get the background data.
Expand Down Expand Up @@ -473,6 +483,8 @@ def on_request_split(self, sender=None):
for cluster_id in self.cluster_ids:
# Load all spikes.
bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True)
if not bunch:
continue
px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True)
py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True)
points = np.c_[px.data, py.data]
Expand Down

0 comments on commit 8c25e22

Please sign in to comment.