Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve variable groups; Make variable groups and factor groups hashable; Fix mypy errors #51

Merged
merged 11 commits into from
Aug 11, 2021
36 changes: 20 additions & 16 deletions examples/sanity_check_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Custom Imports
import pgmax.fg.nodes as nodes # isort:skip
import pgmax.interface.datatypes as interface_datatypes # isort:skip
import pgmax.fg.groups as groups # isort:skip

# Standard Package Imports
import matplotlib.pyplot as plt # isort:skip
Expand Down Expand Up @@ -165,26 +165,26 @@ def create_valid_suppression_config_arr(suppression_diameter):
# We create a NDVariableArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one
# attached horizontally to the factor) that's at that location in the image, and the [1,i,j] entry corresponds to
# the horizontal cut variable (i.e, the one attached vertically to the factor) that's at that location
grid_vars_group = interface_datatypes.NDVariableArray(3, (2, M - 1, N - 1))
grid_vars_group = groups.NDVariableArray(3, (2, M - 1, N - 1))

# Make a group of additional variables for the edges of the grid
extra_row_keys: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)]
extra_col_keys: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)]
additional_keys = tuple(extra_row_keys + extra_col_keys)
additional_keys_group = interface_datatypes.GenericVariableGroup(3, additional_keys)
additional_keys_group = groups.GenericVariableGroup(3, additional_keys)

# Combine these two VariableGroups into one CompositeVariableGroup
composite_grid_group = interface_datatypes.CompositeVariableGroup(
(("grid_vars", grid_vars_group), ("additional_vars", additional_keys_group))
composite_grid_group = groups.CompositeVariableGroup(
dict(grid_vars=grid_vars_group, additional_vars=additional_keys_group)
)


# %%
# Subclass FactorGroup into the 3 different groups that appear in this problem


@dataclass
class FourFactorGroup(interface_datatypes.EnumerationFactorGroup):
@dataclass(frozen=True, eq=False)
class FourFactorGroup(groups.EnumerationFactorGroup):
num_rows: int
num_cols: int
factor_configs_log_potentials: Optional[np.ndarray] = None
Expand Down Expand Up @@ -235,8 +235,8 @@ def connected_variables(
return ret_list


@dataclass
class VertSuppressionFactorGroup(interface_datatypes.EnumerationFactorGroup):
@dataclass(frozen=True, eq=False)
class VertSuppressionFactorGroup(groups.EnumerationFactorGroup):
num_rows: int
num_cols: int
suppression_diameter: int
Expand Down Expand Up @@ -270,8 +270,8 @@ def connected_variables(
return ret_list


@dataclass
class HorzSuppressionFactorGroup(interface_datatypes.EnumerationFactorGroup):
@dataclass(frozen=True, eq=False)
class HorzSuppressionFactorGroup(groups.EnumerationFactorGroup):
num_rows: int
num_cols: int
suppression_diameter: int
Expand Down Expand Up @@ -307,22 +307,22 @@ def connected_variables(
# %%
# Now, we instantiate the four factors
four_factors_group = FourFactorGroup(
var_group=composite_grid_group,
variable_group=composite_grid_group,
factor_configs=valid_configs_non_supp,
num_rows=M,
num_cols=N,
)
# Next, we instantiate all the vertical suppression variables
vert_suppression_group = VertSuppressionFactorGroup(
var_group=composite_grid_group,
variable_group=composite_grid_group,
factor_configs=valid_configs_supp,
num_rows=M,
num_cols=N,
suppression_diameter=SUPPRESSION_DIAMETER,
)
# Next, we instantiate all the horizontal suppression variables
horz_suppression_group = HorzSuppressionFactorGroup(
var_group=composite_grid_group,
variable_group=composite_grid_group,
factor_configs=valid_configs_supp,
num_rows=M,
num_cols=N,
Expand Down Expand Up @@ -424,13 +424,17 @@ def get_evidence(
for row in range(M):
for col in range(N):
try:
bp_values[i, row, col] = map_message_dict[composite_grid_group["grid_vars", i, row, col]] # type: ignore
bp_values[i, row, col] = map_message_dict[
composite_grid_group["grid_vars", i, row, col]
]
bu_evidence[i, row, col, :] = var_evidence_dict[
grid_vars_group[i, row, col]
]
except ValueError:
try:
bp_values[i, row, col] = map_message_dict[composite_grid_group["additional_vars", i, row, col]] # type: ignore
bp_values[i, row, col] = map_message_dict[
composite_grid_group["additional_vars", i, row, col]
]
bu_evidence[i, row, col, :] = var_evidence_dict[
additional_keys_group[i, row, col]
]
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ warn_redundant_casts = True
warn_return_any = True
warn_unused_configs = True
warn_unused_ignores = True
allow_redefinition = True
StannisZhou marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 4 additions & 6 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import jax.numpy as jnp
import numpy as np

import pgmax.bp.infer as infer
import pgmax.fg.fg_utils as fg_utils
import pgmax.fg.nodes as nodes
import pgmax.interface.datatypes as interface_datatypes
from pgmax import utils
from pgmax.bp import infer
from pgmax.fg import fg_utils, groups, nodes


@dataclass(frozen=True, eq=False)
Expand All @@ -35,15 +33,15 @@ class FactorGraph:
for that particular variable should be placed.
"""

factor_groups: Tuple[interface_datatypes.FactorGroup, ...]
factor_groups: Tuple[groups.FactorGroup, ...]

def __post_init__(self):
self.factors = sum(
[factor_group.factors for factor_group in self.factor_groups], ()
)
self.variables = sum(
[
factor_group.var_group.get_all_vars()
factor_group.variable_group.variables
for factor_group in self.factor_groups
],
(),
Expand Down
Loading