Skip to content

Commit

Permalink
Changes to Docs (#142)
Browse files Browse the repository at this point in the history
* Docs

* Comments

* Typo
  • Loading branch information
antoine-dedieu authored May 4, 2022
1 parent c62b624 commit 107dd8e
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 57 deletions.
19 changes: 10 additions & 9 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from pgmax import fgraph, fgroup, infer, vgroup

# %% [markdown]
# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module containing core functions to perform LBP.
# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module contains functions to perform LBP.
#
# We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) on MNIST digits.
# We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) which has been trained on MNIST digits.

# %%
# Load parameters
Expand All @@ -50,9 +50,9 @@
fg = fgraph.FactorGraph(variable_groups=[hidden_variables, visible_variables])

# %% [markdown]
# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.varray.NDVarArray.html#pgmax.vgroup.varray.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with possibly different number of states: this class shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.fgraph.FactorGraph.html#pgmax.fgraph.fgraph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.vgroup.VarGroup.html#pgmax.vgroup.vgroup.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.varray.NDVarArray.html#pgmax.vgroup.varray.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.vgroup.VarGroup.html#pgmax.vgroup.vgroup.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
#
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.fgraph.FactorGraph.html#pgmax.fgraph.fgraph.FactorGraph). We efficiently add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup)

# %%
# Create unary factors
Expand Down Expand Up @@ -86,13 +86,14 @@


# %% [markdown]
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumFactorGroup.html#pgmax.fg.groups.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.enum.EnumFactorGroup.html#pgmax.fgroup.enum.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.enum.PairwiseFactorGroup.html#pgmax.fgroup.enum.PairwiseFactorGroup).
#
# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup) (e.g. `factor_configs` and `log_potential_matrix` here).
#
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.varray.NDVarArray.html#pgmax.vgroup.varray.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
#
# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup)s as below. This approach is not recommended as it can be much slower than using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup)s.
#
# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
# ~~~python
# from pgmax import factor
# import itertools
Expand Down Expand Up @@ -176,7 +177,7 @@
# ~~~python
# bp = infer.BP(fg.bp_state, temperature=T)
# ~~~
# where the arguments of the `this_bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
# where the arguments of `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
#
# As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:

Expand Down
2 changes: 1 addition & 1 deletion pgmax/factor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A sub-package defining factors containing different types of factors."""
"""A sub-package defining different types of factors."""

import collections
from typing import Callable, OrderedDict, Type
Expand Down
2 changes: 1 addition & 1 deletion pgmax/factor/factor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A module containing classes that specify the basic components of a factor."""
"""A module containing the base classes for factors in a factor graph."""

from dataclasses import asdict, dataclass
from typing import List, Sequence, Tuple, Union
Expand Down
2 changes: 1 addition & 1 deletion pgmax/factor/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class ORFactor(LogicalFactor):
@dataclass(frozen=True, eq=False)
class ANDFactor(LogicalFactor):
"""An AND factor of the form (p1,...,pn, c)
where p1,...,pn are the parents variables and c is the child variable.
where p1,...,pn are the parents variables and c is the child variable.
An AND factor is defined as:
F(p1, p2, ..., pn, c) = 0 <=> c = AND(p1, p2, ..., pn)
Expand Down
78 changes: 38 additions & 40 deletions pgmax/fgraph/fgraph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from __future__ import annotations

"""A module containing the core class to specify a Factor Graph."""
"""A module containing the core class to build a factor graph."""

import collections
import copy
Expand Down Expand Up @@ -28,6 +26,43 @@
from pgmax.utils import cached_property


@dataclass(frozen=True, eq=False)
class FactorGraphState:
"""FactorGraphState.
Args:
variable_groups: VarGroups in the FactorGraph.
vars_to_starts: Maps variables to their starting indices in the flat evidence array.
flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
contains evidence to the variable.
num_var_states: Total number of variable states.
total_factor_num_states: Size of the flat ftov messages array.
factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
log_potentials: Flat log potentials array concatenated for each factor type.
wiring: Wiring derived for each factor type.
"""

variable_groups: Sequence[vgroup.VarGroup]
vars_to_starts: Mapping[Tuple[int, int], int]
num_var_states: int
total_factor_num_states: int
factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int]
log_potentials: OrderedDict[type, Union[None, np.ndarray]]
wiring: OrderedDict[type, factor.Wiring]

def __post_init__(self):
for field in self.__dataclass_fields__:
if isinstance(getattr(self, field), np.ndarray):
getattr(self, field).flags.writeable = False

if isinstance(getattr(self, field), Mapping):
object.__setattr__(self, field, MappingProxyType(getattr(self, field)))


@dataclass
class FactorGraph:
"""Class for representing a factor graph.
Expand Down Expand Up @@ -294,40 +329,3 @@ def bp_state(self) -> Any:
ftov_msgs=bp_state.FToVMessages(fg_state=self.fg_state),
evidence=bp_state.Evidence(fg_state=self.fg_state),
)


@dataclass(frozen=True, eq=False)
class FactorGraphState:
"""FactorGraphState.
Args:
variable_groups: VarGroups in the FactorGraph.
vars_to_starts: Maps variables to their starting indices in the flat evidence array.
flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
contains evidence to the variable.
num_var_states: Total number of variable states.
total_factor_num_states: Size of the flat ftov messages array.
factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
log_potentials: Flat log potentials array concatenated for each factor type.
wiring: Wiring derived for each factor type.
"""

variable_groups: Sequence[vgroup.VarGroup]
vars_to_starts: Mapping[Tuple[int, int], int]
num_var_states: int
total_factor_num_states: int
factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int]
log_potentials: OrderedDict[type, None | np.ndarray]
wiring: OrderedDict[type, factor.Wiring]

def __post_init__(self):
for field in self.__dataclass_fields__:
if isinstance(getattr(self, field), np.ndarray):
getattr(self, field).flags.writeable = False

if isinstance(getattr(self, field), Mapping):
object.__setattr__(self, field, MappingProxyType(getattr(self, field)))
2 changes: 1 addition & 1 deletion pgmax/fgroup/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A sub-package defining factor groups and containing different types of factor groups."""
"""A sub-package defining different types of groups of factors."""

from .enum import EnumFactorGroup, PairwiseFactorGroup
from .fgroup import FactorGroup, SingleFactorGroup
Expand Down
2 changes: 1 addition & 1 deletion pgmax/fgroup/fgroup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A module containing the base classes for factor groups in a Factor Graph."""
"""A module containing the base classes for factor groups in a factor graph."""

import inspect
from dataclasses import dataclass, field
Expand Down
2 changes: 1 addition & 1 deletion pgmax/infer/bp_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Defines container classes for belief propagation states, and for the relevant flat arrays used in belief propagation."
"A module defining container classes for belief propagation states."

import functools
from dataclasses import asdict, dataclass
Expand Down
2 changes: 1 addition & 1 deletion pgmax/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A module containing helper functions useful while constructing Factor Graphs."""
"""A module containing helper functions."""

import functools
from typing import Callable
Expand Down
2 changes: 1 addition & 1 deletion pgmax/vgroup/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A sub-package defining variable groups and containing different types of variable groups."""
"""A sub-package defining different types of groups of variables."""

from .varray import NDVarArray
from .vdict import VarDict
Expand Down

0 comments on commit 107dd8e

Please sign in to comment.