forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scipy_signal_test.py
110 lines (94 loc) · 4.61 KB
/
scipy_signal_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest, parameterized
import numpy as np
from jax import lax
from jax import test_util as jtu
import jax.scipy.signal as jsp_signal
import scipy.signal as osp_signal
from jax.config import config
config.parse_flags_with_absl()
onedim_shapes = [(1,), (2,), (5,), (10,)]
twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)]
threedim_shapes = [(2, 2, 2), (3, 3, 2), (4, 4, 2), (5, 5, 2)]
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipySignalTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
mode),
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
"jsp_op": getattr(jsp_signal, op),
"osp_op": getattr(osp_signal, op)}
for mode in ['full', 'same', 'valid']
for op in ['convolve', 'correlate']
for dtype in default_dtypes
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
for xshape in shapeset
for yshape in shapeset))
def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
mode),
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
"jsp_op": getattr(jsp_signal, op),
"osp_op": getattr(osp_signal, op)}
for mode in ['full', 'same', 'valid']
for op in ['convolve2d', 'correlate2d']
for dtype in default_dtypes
for xshape in twodim_shapes
for yshape in twodim_shapes))
def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_type={}_bp={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
"shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp}
for shape in [(5,), (4, 5), (3, 4, 5)]
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for axis in [0, -1]
for type in ['constant', 'linear']
for bp in [0, [0, 2]]))
def testDetrend(self, shape, dtype, axis, type, bp):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp)
jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp)
tol = {np.float32: 1e-5, np.float64: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())