Skip to content

Commit

Permalink
Merge pull request scipy#4473 from WarrenWeckesser/sosfilt-zi-shape
Browse files Browse the repository at this point in the history
BUG: signal: Fix validation of the zi shape in sosfilt.
  • Loading branch information
rgommers committed Jan 28, 2015
2 parents 8e3a683 + 08515aa commit 0f50512
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
35 changes: 18 additions & 17 deletions scipy/signal/signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,11 +2386,12 @@ def sosfilt(sos, x, axis=-1, zi=None):
linear filter. The filter is applied to each subarray along
this axis. Default is -1.
zi : array_like, optional
Initial conditions for the cascaded filter delays. It is a (at least
2D) vector of shape ``(n_sections, ..., 2)``, with middle dimensions
equal to those of the input shape (without the filtered axis).
If `zi` is None or is not given then initial rest is assumed. Note
that these initial conditions are *not* the same as the initial
Initial conditions for the cascaded filter delays. It is a (at
least 2D) vector of shape ``(n_sections, ..., 2, ...)``, where
``..., 2, ...`` denotes the shape of `x`, but with ``x.shape[axis]``
replaced by 2. If `zi` is None or is not given then initial rest
(i.e. all zeros) is assumed.
Note that these initial conditions are *not* the same as the initial
conditions given by `lfiltic` or `lfilter_zi`.
Returns
Expand Down Expand Up @@ -2434,6 +2435,7 @@ def sosfilt(sos, x, axis=-1, zi=None):
>>> plt.show()
"""
x = np.asarray(x)

sos = atleast_2d(sos)
if sos.ndim != 2:
Expand All @@ -2443,19 +2445,18 @@ def sosfilt(sos, x, axis=-1, zi=None):
if m != 6:
raise ValueError('sos array must be shape (n_sections, 6)')

if zi is not None:
use_zi = True
zi = np.array(zi)
x_zi_shape = np.delete(np.array(x.shape), axis)
proper_shape = (zi.ndim >= 2 and
zi.shape[0] == n_sections and zi.shape[-1] == 2 and
np.array_equal(zi.shape[1:-1], x_zi_shape))
if not proper_shape:
raise ValueError('sos initial states must be shape '
'(n_sections, ..., 2)')
use_zi = zi is not None
if use_zi:
zi = np.asarray(zi)
x_zi_shape = list(x.shape)
x_zi_shape[axis] = 2
x_zi_shape = tuple([n_sections] + x_zi_shape)
if zi.shape != x_zi_shape:
raise ValueError('Invalid zi shape. With axis=%r, an input with '
'shape %r, and an sos array with %d sections, zi '
'must have shape %r.' %
(axis, x.shape, n_sections, x_zi_shape))
zf = zeros_like(zi)
else:
use_zi = False

for section in range(n_sections):
if use_zi:
Expand Down
40 changes: 40 additions & 0 deletions scipy/signal/tests/test_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,46 @@ def test_initial_conditions(self):
assert_allclose(y[0, 0], np.ones(8))
assert_allclose(zf[:, 0, 0, :], zi)

def test_initial_conditions_3d_axis1(self):
# Test the use of zi when sosfilt is applied to axis 1 of a 3-d input.

# Input array is x.
np.random.seed(159)
x = np.random.randint(0, 5, size=(2, 15, 3))

# Design a filter in SOS format.
sos = signal.butter(6, 0.35, output='sos')
nsections = sos.shape[0]

# Filter along this axis.
axis = 1

# Initial conditions, all zeros.
shp = list(x.shape)
shp[axis] = 2
shp = [nsections] + shp
z0 = np.zeros(shp)

# Apply the filter to x.
yf, zf = sosfilt(sos, x, axis=axis, zi=z0)

# Apply the filter to x in two stages.
y1, z1 = sosfilt(sos, x[:, :5, :], axis=axis, zi=z0)
y2, z2 = sosfilt(sos, x[:, 5:, :], axis=axis, zi=z1)

# y should equal yf, and z2 should equal zf.
y = np.concatenate((y1, y2), axis=axis)
assert_allclose(y, yf, rtol=1e-10, atol=1e-13)
assert_allclose(z2, zf, rtol=1e-10, atol=1e-13)

def test_bad_zi_shape(self):
# The shape of zi is checked before using any values in the
# arguments, so np.empty is fine for creating the arguments.
x = np.empty((3, 15, 3))
sos = np.empty((4, 6))
zi = np.empty((4, 3, 3, 2)) # Correct shape is (4, 3, 2, 3)
assert_raises(ValueError, sosfilt, sos, x, zi=zi, axis=1)

def test_sosfilt_zi(self):
sos = signal.butter(6, 0.2, output='sos')
zi = sosfilt_zi(sos)
Expand Down

0 comments on commit 0f50512

Please sign in to comment.