Skip to content

Commit

Permalink
Fix dataclass pytree registration bug (#34)
Browse files Browse the repository at this point in the history
* Fix dataclass pytree registration

* Simplify
  • Loading branch information
StannisZhou authored Jul 30, 2021
1 parent c379dbc commit 40efc27
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 40 deletions.
23 changes: 10 additions & 13 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,30 +117,29 @@ def run_bp(
msgs = init_msgs
else:
msgs = self.get_init_msgs(msgs_context)

wiring = jax.device_put(self.wiring)
evidence = self.get_evidence(evidence_data, evidence_context)
edges_num_states = jax.device_put(self.wiring.edges_num_states)
var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges)
factor_configs_edge_states = jax.device_put(
self.wiring.factor_configs_edge_states
)
max_msg_size = int(jnp.max(edges_num_states))
max_msg_size = int(jnp.max(wiring.edges_num_states))

# Normalize the messages to ensure the maximum value is 0.
msgs = infer.normalize_and_clip_msgs(msgs, edges_num_states, max_msg_size)
num_val_configs = int(factor_configs_edge_states[-1, 0])
msgs = infer.normalize_and_clip_msgs(
msgs, wiring.edges_num_states, max_msg_size
)
num_val_configs = int(wiring.factor_configs_edge_states[-1, 0])

@jax.jit
def message_passing_step(msgs, _):
# Compute new variable to factor messages by message passing
vtof_msgs = infer.pass_var_to_fac_messages(
msgs,
evidence,
var_states_for_edges,
wiring.var_states_for_edges,
)
# Compute new factor to variable messages by message passing
ftov_msgs = infer.pass_fac_to_var_messages(
vtof_msgs,
factor_configs_edge_states,
wiring.factor_configs_edge_states,
num_val_configs,
)
# Use the results of message passing to perform damping and
Expand All @@ -151,7 +150,7 @@ def message_passing_step(msgs, _):
# them.
msgs = infer.normalize_and_clip_msgs(
msgs,
edges_num_states,
wiring.edges_num_states,
max_msg_size,
)
return msgs, None
Expand Down Expand Up @@ -180,8 +179,6 @@ def decode_map_states(
# NOTE: Having to regenerate the evidence here is annoying - there must be a better way to handle evidence and
# message initialization.
evidence = self.get_evidence(evidence_data, evidence_context)
# TODO: Once issue #20 is resolved, just grab this from self.wiring instead of
# casting it to a jnp.array
var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges)
final_var_states = evidence.at[var_states_for_edges].add(msgs)
var_to_map_dict = {}
Expand Down
15 changes: 12 additions & 3 deletions pgmax/fg/nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""A module containing classes that specify the components of a Factor Graph."""

from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Mapping, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -18,7 +19,7 @@ class Variable:
num_states: int


@utils.register_pytree_node_dataclass
@jax.tree_util.register_pytree_node_class
@dataclass(frozen=True, eq=False)
class EnumerationWiring:
"""Wiring for enumeration factors.
Expand All @@ -40,7 +41,15 @@ class EnumerationWiring:

def __post_init__(self):
for field in self.__dataclass_fields__:
getattr(self, field).flags.writeable = False
if isinstance(getattr(self, field), np.ndarray):
getattr(self, field).flags.writeable = False

def tree_flatten(self):
return jax.tree_util.tree_flatten(asdict(self))

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**aux_data.unflatten(children))


@dataclass(frozen=True, eq=False)
Expand Down
25 changes: 1 addition & 24 deletions pgmax/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""A module containing helper functions useful while constructing Factor Graphs."""

import dataclasses
import functools
from typing import Any, Callable

import jax
from typing import Callable


def cached_property(func: Callable) -> property:
Expand All @@ -17,23 +14,3 @@ def cached_property(func: Callable) -> property:
Decorated cached property
"""
return property(functools.lru_cache(None)(func))


def register_pytree_node_dataclass(cls: Any) -> Any:
"""Decorator to register a dataclass as a pytree
Args:
cls: A dataclass to be registered as a pytree
Returns:
The registered dataclass
"""

def _flatten(obj):
jax.tree_flatten(dataclasses.asdict(obj))

def _unflatten(d, children):
cls(**d.unflatten(children))

jax.tree_util.register_pytree_node(cls, _flatten, _unflatten)
return cls

0 comments on commit 40efc27

Please sign in to comment.