forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_util_test.py
378 lines (314 loc) · 13 KB
/
tree_util_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import re
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
from jax import tree_util
from jax._src.tree_util import _process_pytree
from jax import flatten_util
import jax.numpy as jnp
def _dummy_func(*args, **kwargs):
return
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
class ANamedTupleSubclass(ATuple):
pass
class AnObject(object):
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
def __eq__(self, other):
return self.x == other.x and self.y == other.y and self.z == other.z
def __hash__(self):
return hash((self.x, self.y, self.z))
def __repr__(self):
return "AnObject({},{},{})".format(self.x, self.y, self.z)
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
lambda z, xy: AnObject(xy[0], xy[1], z))
@tree_util.register_pytree_node_class
class Special:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "Special(x={}, y={})".format(self.x, self.y)
def tree_flatten(self):
return ((self.x, self.y), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __eq__(self, other):
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
@tree_util.register_pytree_node_class
class FlatCache:
def __init__(self, structured, *, leaves=None, treedef=None):
if treedef is None:
leaves, treedef = tree_util.tree_flatten(structured)
self._structured = structured
self.treedef = treedef
self.leaves = leaves
def __hash__(self):
return hash(self.structured)
def __eq__(self, other):
return self.structured == other.structured
def __repr__(self):
return f"FlatCache({self.structured!r})"
@property
def structured(self):
if self._structured is None:
self._structured = tree_util.tree_unflatten(self.treedef, self.leaves)
return self._structured
def tree_flatten(self):
return self.leaves, self.treedef
@classmethod
def tree_unflatten(cls, meta, data):
if not tree_util.all_leaves(data):
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
return FlatCache(None, leaves=data, treedef=meta)
TREES = (
(None,),
((None,),),
((),),
(([()]),),
((1, 2),),
(((1, "foo"), ["bar", (3, None, 7)]),),
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject(3, None, [4, "foo"])],),
(Special(2, 3.),),
({"a": 1, "b": 2},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
(ANamedTupleSubclass(foo="hello", bar=3.5),),
(FlatCache(None),),
(FlatCache(1),),
(FlatCache({"a": [1, 2]}),),
)
TREE_STRINGS = (
"PyTreeDef(None)",
"PyTreeDef((None,))",
"PyTreeDef(())",
"PyTreeDef([()])",
"PyTreeDef((*, *))",
"PyTreeDef(((*, *), [*, (*, None, *)]))",
"PyTreeDef([*])",
"PyTreeDef([*, CustomNode(namedtuple[<class '__main__.ATuple'>], [(*, "
"CustomNode(namedtuple[<class '__main__.ATuple'>], [*, None])), {'baz': "
"*}])])",
"PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(<class '__main__.Special'>[None], [*, *]))",
"PyTreeDef({'a': *, 'b': *})",
)
# pytest expects "tree_util_test.ATuple"
STRS = []
for tree_str in TREE_STRINGS:
tree_str = re.escape(tree_str)
tree_str = tree_str.replace("__main__", ".*")
STRS.append(tree_str)
TREE_STRINGS = STRS
LEAVES = (
("foo",),
(0.1,),
(1,),
(object(),),
)
class TreeTest(jtu.JaxTestCase):
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtrip(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
self.assertEqual(actual, inputs)
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtripWithFlattenUpTo(self, inputs):
_, tree = tree_util.tree_flatten(inputs)
xs = tree.flatten_up_to(inputs)
actual = tree_util.tree_unflatten(tree, xs)
self.assertEqual(actual, inputs)
@parameterized.parameters(
(tree_util.Partial(_dummy_func),),
(tree_util.Partial(_dummy_func, 1, 2),),
(tree_util.Partial(_dummy_func, x="a"),),
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
)
def testRoundtripPartial(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
# functools.partial does not support equality comparisons:
# https://stackoverflow.com/a/32786109/809705
self.assertEqual(actual.func, inputs.func)
self.assertEqual(actual.args, inputs.args)
self.assertEqual(actual.keywords, inputs.keywords)
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtripViaBuild(self, inputs):
xs, tree = _process_pytree(tuple, inputs)
actual = tree_util.build_tree(tree, xs)
self.assertEqual(actual, inputs)
def testChildren(self):
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
_, c0 = tree_util.tree_flatten((0, 0, 0))
_, c1 = tree_util.tree_flatten((7,))
self.assertEqual([c0, c1], tree.children())
def testTreedefTupleFromChildren(self):
# https://github.com/google/jax/issues/7377
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
if jax.lib._xla_extension_version < 29:
self.skipTest("fixed in future jaxlib")
tree = ((1, 2, (3, 4)), (5,))
leaves, treedef1 = tree_util.tree_flatten(tree)
treedef2 = tree_util.treedef_tuple(treedef1.children())
self.assertEqual(treedef1.num_leaves, len(leaves))
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
def testFlattenUpTo(self):
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
out = tree.flatten_up_to([({
"foo": 7
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
def testTreeMultimap(self):
x = ((1, 2), (3, 4, 5))
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
self.assertEqual(out, (((1, [3]), (2, None)),
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
def testTreeMultimapWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y,
is_leaf=lambda n: isinstance(n, list))
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
def testFlattenIsLeaf(self):
x = [(1, 2), (3, 4), (5, 6)]
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: False)
self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
leaves, _ = tree_util.tree_flatten(
x, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, x)
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: isinstance(t, list))
self.assertEqual(leaves, [x])
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: True)
self.assertEqual(leaves, [x])
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
leaves, _ = tree_util.tree_flatten(
y, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, [(1,), (2,), (3,)])
@parameterized.parameters(*TREES)
def testRoundtripIsLeaf(self, tree):
xs, treedef = tree_util.tree_flatten(
tree, is_leaf=lambda t: isinstance(t, tuple))
recon_tree = tree_util.tree_unflatten(treedef, xs)
self.assertEqual(recon_tree, tree)
@parameterized.parameters(*TREES)
def testAllLeavesWithTrees(self, tree):
leaves = tree_util.tree_leaves(tree)
self.assertTrue(tree_util.all_leaves(leaves))
self.assertFalse(tree_util.all_leaves([tree]))
@parameterized.parameters(*LEAVES)
def testAllLeavesWithLeaves(self, leaf):
self.assertTrue(tree_util.all_leaves([leaf]))
@parameterized.parameters(*TREES)
def testCompose(self, tree):
treedef = tree_util.tree_structure(tree)
inner_treedef = tree_util.tree_structure(["*", "*", "*"])
composed_treedef = treedef.compose(inner_treedef)
expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
self.assertEqual(composed_treedef.num_leaves, expected_leaves)
expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
(inner_treedef.num_nodes * treedef.num_leaves))
self.assertEqual(composed_treedef.num_nodes, expected_nodes)
leaves = [1] * expected_leaves
composed = tree_util.tree_unflatten(composed_treedef, leaves)
self.assertEqual(leaves, tree_util.tree_leaves(composed))
@parameterized.parameters(*TREES)
def testTranspose(self, tree):
outer_treedef = tree_util.tree_structure(tree)
if not outer_treedef.num_leaves:
self.skipTest("Skipping empty tree")
inner_treedef = tree_util.tree_structure([1, 1, 1])
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
self.assertEqual(actual, [tree, tree, tree])
def testTransposeMismatchOuter(self):
tree = {"a": [1, 2], "b": [3, 4]}
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
inner_treedef = tree_util.tree_structure([1, 2])
with self.assertRaisesRegex(TypeError, "Mismatch"):
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
def testTransposeMismatchInner(self):
tree = {"a": [1, 2], "b": [3, 4]}
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
inner_treedef = tree_util.tree_structure([1, 2, 3])
with self.assertRaisesRegex(TypeError, "Mismatch"):
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
def testTransposeWithCustomObject(self):
outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
inner_treedef = tree_util.tree_structure([1, 2])
expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
FlatCache({"a": [3, 4], "b": [5, 6]}))
self.assertEqual(expected, actual)
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
def testStringRepresentation(self, tree, correct_string):
"""Checks that the string representation of a tree works."""
treedef = tree_util.tree_structure(tree)
self.assertRegex(str(treedef), correct_string)
class RavelUtilTest(jtu.JaxTestCase):
def testFloats(self):
tree = [jnp.array([3.], jnp.float32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32)
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testInts(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.int32)
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedFloatInt(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedIntBool(self):
tree = [jnp.array([0], jnp.bool_),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedFloatComplex(self):
tree = [jnp.array([1.], jnp.float32),
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testEmpty(self):
tree = []
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32) # convention
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())