Skip to content

Commit

Permalink
Fix immutability and caching; Speed up wiring compilation (#25)
Browse files Browse the repository at this point in the history
* Makes immutability and caching play well

* Docstrings updated

* Uses regular np.concatenate which is faster
  • Loading branch information
StannisZhou authored Jul 22, 2021
1 parent 0534235 commit 19b8fd4
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 70 deletions.
7 changes: 3 additions & 4 deletions pgmax/fg/fg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np

from pgmax import utils
from pgmax.fg.nodes import EnumerationWiring


Expand Down Expand Up @@ -39,11 +38,11 @@ def concatenate_enumeration_wirings(
)

return EnumerationWiring(
edges_num_states=utils.concatenate_arrays(
edges_num_states=np.concatenate(
[wiring.edges_num_states for wiring in wirings]
),
var_states_for_edges=utils.concatenate_arrays(
var_states_for_edges=np.concatenate(
[wiring.var_states_for_edges for wiring in wirings]
),
factor_configs_edge_states=utils.concatenate_arrays(factor_configs_edge_states),
factor_configs_edge_states=np.concatenate(factor_configs_edge_states, axis=0),
)
6 changes: 3 additions & 3 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import numpy as np

import pgmax.bp.infer as infer
from pgmax import utils
from pgmax.fg import fg_utils, nodes
from pgmax.utils import cached_property


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class FactorGraph:
"""Base class to represent a factor graph.
Expand Down Expand Up @@ -42,7 +42,7 @@ def __post_init__(self):
)
self.num_var_states = vars_num_states_cumsum[-1]

@cached_property
@utils.cached_property
def wiring(self) -> nodes.EnumerationWiring:
"""Function to compile wiring for belief propagation..
Expand Down
68 changes: 43 additions & 25 deletions pgmax/fg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Variable:


@utils.register_pytree_node_dataclass
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class EnumerationWiring:
"""Wiring for enumeration factors.
Expand All @@ -36,8 +36,12 @@ class EnumerationWiring:
var_states_for_edges: Union[np.ndarray, jnp.ndarray]
factor_configs_edge_states: Union[np.ndarray, jnp.ndarray]

def __post_init__(self):
for field in self.__dataclass_fields__:
getattr(self, field).flags.writeable = False


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class EnumerationFactor:
"""An enumeration factor
Expand All @@ -51,9 +55,7 @@ class EnumerationFactor:
configs: np.ndarray

def __post_init__(self):
if self.configs.flags.writeable:
raise ValueError("Configurations need to be immutable.")

self.configs.flags.writeable = False
if not np.issubdtype(self.configs.dtype, np.integer):
raise ValueError(
f"Configurations should be integers. Got {self.configs.dtype}."
Expand All @@ -70,6 +72,39 @@ def __post_init__(self):
).all():
raise ValueError("Invalid configurations for given variables")

@utils.cached_property
def edges_num_states(self) -> np.ndarray:
"""Number of states for the variables connected to each edge
Returns:
Array of shape (num_edges,)
Number of states for the variables connected to each edge
"""
edge_num_states = np.array(
[variable.num_states for variable in self.variables], dtype=int
)
return edge_num_states

@utils.cached_property
def factor_configs_edge_states(self) -> np.ndarray:
"""Array containing factor configs and edge states pairs
Returns:
Array of shape (num_factor_configs, 2)
factor_configs_edge_states[ii] contains a pair of global factor_config and edge_state indices
factor_configs_edge_states[ii, 0] contains the global factor config index
factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index
"""
edges_starts = np.insert(self.edges_num_states.cumsum(), 0, 0)[:-1]
factor_configs_edge_states = np.stack(
[
np.repeat(np.arange(self.configs.shape[0]), self.configs.shape[1]),
(self.configs + edges_starts[None]).flatten(),
],
axis=1,
)
return factor_configs_edge_states

def compile_wiring(
self, vars_to_starts: Mapping[Variable, int]
) -> EnumerationWiring:
Expand All @@ -83,31 +118,14 @@ def compile_wiring(
Returns:
Enumeration wiring for the enumeration factor
"""
if not hasattr(self, "_edges_num_states"):
self._edges_num_states = np.array(
[variable.num_states for variable in self.variables], dtype=int
)

var_states_for_edges = utils.concatenate_arrays(
var_states_for_edges = np.concatenate(
[
np.arange(variable.num_states) + vars_to_starts[variable]
for variable in self.variables
]
)
if not hasattr(self, "_factor_configs_edge_states"):
configs = self.configs.copy()
configs.flags.writeable = True
edges_starts = np.insert(self._edges_num_states.cumsum(), 0, 0)[:-1]
self._factor_configs_edge_states = np.stack(
[
np.repeat(np.arange(configs.shape[0]), configs.shape[1]),
(configs + edges_starts[None]).flatten(),
],
axis=1,
)

return EnumerationWiring(
edges_num_states=self._edges_num_states,
edges_num_states=self.edges_num_states,
var_states_for_edges=var_states_for_edges,
factor_configs_edge_states=self._factor_configs_edge_states,
factor_configs_edge_states=self.factor_configs_edge_states,
)
52 changes: 14 additions & 38 deletions pgmax/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import dataclasses
from typing import Any, Sequence
import functools
from typing import Any, Callable

import jax
import numpy as np


def cached_property(func: Callable) -> property:
"""Customized cached property decorator
Args:
func: Member function to be decorated
Returns:
Decorated cached property
"""
return property(functools.lru_cache(None)(func))


def register_pytree_node_dataclass(cls: Any) -> Any:
Expand All @@ -23,39 +35,3 @@ def _unflatten(d, children):

jax.tree_util.register_pytree_node(cls, _flatten, _unflatten)
return cls


def concatenate_arrays(arrays: Sequence[np.ndarray]) -> np.ndarray:
"""Convenience function to concatenate a list of arrays along the 0th axis
Args:
arrays: A list of numpy arrays to be concatenated
Returns:
The concatenated array
"""
lengths = np.array([array.shape[0] for array in arrays], dtype=int)
lengths_cumsum = np.insert(lengths.cumsum(), 0, 0)
starts, total_length = lengths_cumsum[:-1], lengths_cumsum[-1]
concatenated_array = np.zeros(
(total_length,) + arrays[0].shape[1:], dtype=arrays[0].dtype
)
for start, length, array in zip(starts, lengths, arrays):
concatenated_array[start : start + length] = array

return concatenated_array


class cached_property(object):
"""Descriptor (non-data) for building an attribute on-demand on first use."""

def __init__(self, factory):
self._attr_name = factory.__name__
self._factory = factory

def __get__(self, instance, owner):
# Build the attribute.
attr = self._factory(instance)
# Cache the value; hide ourselves.
setattr(instance, self._attr_name, attr)
return attr

0 comments on commit 19b8fd4

Please sign in to comment.