diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index aac83258..5015cff9 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -117,17 +117,16 @@ def run_bp( msgs = init_msgs else: msgs = self.get_init_msgs(msgs_context) + + wiring = jax.device_put(self.wiring) evidence = self.get_evidence(evidence_data, evidence_context) - edges_num_states = jax.device_put(self.wiring.edges_num_states) - var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges) - factor_configs_edge_states = jax.device_put( - self.wiring.factor_configs_edge_states - ) - max_msg_size = int(jnp.max(edges_num_states)) + max_msg_size = int(jnp.max(wiring.edges_num_states)) # Normalize the messages to ensure the maximum value is 0. - msgs = infer.normalize_and_clip_msgs(msgs, edges_num_states, max_msg_size) - num_val_configs = int(factor_configs_edge_states[-1, 0]) + msgs = infer.normalize_and_clip_msgs( + msgs, wiring.edges_num_states, max_msg_size + ) + num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) @jax.jit def message_passing_step(msgs, _): @@ -135,12 +134,12 @@ def message_passing_step(msgs, _): vtof_msgs = infer.pass_var_to_fac_messages( msgs, evidence, - var_states_for_edges, + wiring.var_states_for_edges, ) # Compute new factor to variable messages by message passing ftov_msgs = infer.pass_fac_to_var_messages( vtof_msgs, - factor_configs_edge_states, + wiring.factor_configs_edge_states, num_val_configs, ) # Use the results of message passing to perform damping and @@ -151,7 +150,7 @@ def message_passing_step(msgs, _): # them. msgs = infer.normalize_and_clip_msgs( msgs, - edges_num_states, + wiring.edges_num_states, max_msg_size, ) return msgs, None @@ -180,8 +179,6 @@ def decode_map_states( # NOTE: Having to regenerate the evidence here is annoying - there must be a better way to handle evidence and # message initialization. evidence = self.get_evidence(evidence_data, evidence_context) - # TODO: Once issue #20 is resolved, just grab this from self.wiring instead of - # casting it to a jnp.array var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges) final_var_states = evidence.at[var_states_for_edges].add(msgs) var_to_map_dict = {} diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 7680f16e..8f538676 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -1,8 +1,9 @@ """A module containing classes that specify the components of a Factor Graph.""" -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Mapping, Tuple, Union +import jax import jax.numpy as jnp import numpy as np @@ -18,7 +19,7 @@ class Variable: num_states: int -@utils.register_pytree_node_dataclass +@jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) class EnumerationWiring: """Wiring for enumeration factors. @@ -40,7 +41,15 @@ class EnumerationWiring: def __post_init__(self): for field in self.__dataclass_fields__: - getattr(self, field).flags.writeable = False + if isinstance(getattr(self, field), np.ndarray): + getattr(self, field).flags.writeable = False + + def tree_flatten(self): + return jax.tree_util.tree_flatten(asdict(self)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(**aux_data.unflatten(children)) @dataclass(frozen=True, eq=False) diff --git a/pgmax/utils.py b/pgmax/utils.py index 74c448f4..cba3f375 100644 --- a/pgmax/utils.py +++ b/pgmax/utils.py @@ -1,10 +1,7 @@ """A module containing helper functions useful while constructing Factor Graphs.""" -import dataclasses import functools -from typing import Any, Callable - -import jax +from typing import Callable def cached_property(func: Callable) -> property: @@ -17,23 +14,3 @@ def cached_property(func: Callable) -> property: Decorated cached property """ return property(functools.lru_cache(None)(func)) - - -def register_pytree_node_dataclass(cls: Any) -> Any: - """Decorator to register a dataclass as a pytree - - Args: - cls: A dataclass to be registered as a pytree - - Returns: - The registered dataclass - """ - - def _flatten(obj): - jax.tree_flatten(dataclasses.asdict(obj)) - - def _unflatten(d, children): - cls(**d.unflatten(children)) - - jax.tree_util.register_pytree_node(cls, _flatten, _unflatten) - return cls