Skip to content

Commit

Permalink
MAINT: cluster: fixes to _vq
Browse files Browse the repository at this point in the history
* intc not int32 for C int
* _vq would not propagate exceptions due to void
* don't rely on C99 function sqrtf
  • Loading branch information
larsmans committed Oct 7, 2015
1 parent b31514e commit 7898c14
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions scipy/cluster/_vq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ Translated to Cython by David Warde-Farley, October 2009.
cimport cython
import numpy as np
cimport numpy as np
from cluster_blas cimport *
from cluster_blas cimport f_dgemm, f_sgemm

cdef extern from "math.h":
float sqrtf(float num)
double sqrt(double num)
from libc.math cimport sqrt

ctypedef np.float64_t float64_t
ctypedef np.float32_t float32_t
Expand All @@ -34,13 +32,6 @@ DEF NFEATURES_CUTOFF=5
np.import_array()


cdef inline vq_type _sqrt(vq_type x):
if vq_type is float32_t:
return sqrtf(x)
else:
return sqrt(x)


cdef inline vq_type vec_sqr(int n, vq_type *p):
cdef vq_type result = 0.0
cdef int i
Expand All @@ -66,9 +57,9 @@ cdef inline void cal_M(int nobs, int ncodes, int nfeat, vq_type *obs,
&alpha, code_book, &nfeat, obs, &nfeat, &beta, M, &ncodes)


cdef void _vq(vq_type *obs, vq_type *code_book,
int ncodes, int nfeat, int nobs,
int32_t *codes, vq_type *low_dist):
cdef int _vq(vq_type *obs, vq_type *code_book,
int ncodes, int nfeat, int nobs,
int32_t *codes, vq_type *low_dist) except -1:
"""
The underlying function (template) of _vq.vq.
Expand All @@ -93,7 +84,7 @@ cdef void _vq(vq_type *obs, vq_type *code_book,
# Naive algorithm is prefered when nfeat is small
if nfeat < NFEATURES_CUTOFF:
_vq_small_nf(obs, code_book, ncodes, nfeat, nobs, codes, low_dist)
return
return 0

cdef np.npy_intp i, j
cdef vq_type *p_obs
Expand Down Expand Up @@ -137,10 +128,12 @@ cdef void _vq(vq_type *obs, vq_type *code_book,

# dist_sqr may be negative due to float point errors
if low_dist[i] > 0:
low_dist[i] = _sqrt(low_dist[i])
low_dist[i] = sqrt(low_dist[i])
else:
low_dist[i] = 0

return 0


cdef void _vq_small_nf(vq_type *obs, vq_type *code_book,
int ncodes, int nfeat, int nobs,
Expand Down Expand Up @@ -178,7 +171,7 @@ cdef void _vq_small_nf(vq_type *obs, vq_type *code_book,
codes[i] = j
low_dist[i] = dist_sqr

low_dist[i] = _sqrt(low_dist[i])
low_dist[i] = sqrt(low_dist[i])

# Update the offset of the current observation
obs_offset += nfeat
Expand Down Expand Up @@ -280,7 +273,7 @@ cdef np.ndarray _update_cluster_means(vq_type *obs, int32_t *labels,
cdef np.ndarray[int, ndim=1] obs_count

# Calculate the sums the numbers of obs in each cluster
obs_count = np.zeros(nc, np.int32)
obs_count = np.zeros(nc, np.intc)
obs_p = obs
for i in range(nobs):
label = labels[i]
Expand Down Expand Up @@ -371,4 +364,3 @@ def update_cluster_means(np.ndarray obs, np.ndarray labels, int nc):
obs.shape[0], nc, nfeat)

return cb, has_members

0 comments on commit 7898c14

Please sign in to comment.