From bb461130b39c87c67b7bc3a3f375a53ec486b49b Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 12:00:17 -0700 Subject: [PATCH 01/11] Composite within composite --- pgmax/interface/datatypes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pgmax/interface/datatypes.py b/pgmax/interface/datatypes.py index db219f61..3521f74e 100644 --- a/pgmax/interface/datatypes.py +++ b/pgmax/interface/datatypes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools from dataclasses import dataclass from types import MappingProxyType @@ -93,7 +95,9 @@ class CompositeVariableGroup: """ - key_vargroup_pairs: Tuple[Tuple[Any, VariableGroup], ...] + key_vargroup_pairs: Tuple[ + Tuple[Any, Union[VariableGroup, CompositeVariableGroup]], ... + ] def __post_init__(self): """Initialize a private, immuable mapping from keys to VariableGroups.""" From 8356d712bb3016d22aba4899781ef9351e8f77a5 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 21:12:02 -0700 Subject: [PATCH 02/11] Update variable groups --- examples/sanity_check_example.py | 8 +- pgmax/fg/graph.py | 5 +- pgmax/interface/datatypes.py | 220 +++++++++++++++---------------- 3 files changed, 114 insertions(+), 119 deletions(-) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index 4ab60095..778731b3 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -175,7 +175,7 @@ def create_valid_suppression_config_arr(suppression_diameter): # Combine these two VariableGroups into one CompositeVariableGroup composite_grid_group = interface_datatypes.CompositeVariableGroup( - (("grid_vars", grid_vars_group), ("additional_vars", additional_keys_group)) + dict(grid_vars=grid_vars_group, additional_vars=additional_keys_group) ) @@ -307,14 +307,14 @@ def connected_variables( # %% # Now, we instantiate the four factors four_factors_group = FourFactorGroup( - var_group=composite_grid_group, + variable_group=composite_grid_group, factor_configs=valid_configs_non_supp, num_rows=M, num_cols=N, ) # Next, we instantiate all the vertical suppression variables vert_suppression_group = VertSuppressionFactorGroup( - var_group=composite_grid_group, + variable_group=composite_grid_group, factor_configs=valid_configs_supp, num_rows=M, num_cols=N, @@ -322,7 +322,7 @@ def connected_variables( ) # Next, we instantiate all the horizontal suppression variables horz_suppression_group = HorzSuppressionFactorGroup( - var_group=composite_grid_group, + variable_group=composite_grid_group, factor_configs=valid_configs_supp, num_rows=M, num_cols=N, diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 304a881b..209615a5 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -42,10 +42,7 @@ def __post_init__(self): [factor_group.factors for factor_group in self.factor_groups], () ) self.variables = sum( - [ - factor_group.var_group.get_all_vars() - for factor_group in self.factor_groups - ], + [factor_group.var_group.variables for factor_group in self.factor_groups], (), ) diff --git a/pgmax/interface/datatypes.py b/pgmax/interface/datatypes.py index 3521f74e..3a97c35e 100644 --- a/pgmax/interface/datatypes.py +++ b/pgmax/interface/datatypes.py @@ -1,50 +1,45 @@ -from __future__ import annotations - import itertools -from dataclasses import dataclass +from dataclasses import dataclass, field from types import MappingProxyType -from typing import Any, Dict, List, Mapping, Tuple, Union +from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union import numpy as np import pgmax.fg.nodes as nodes -@dataclass +@dataclass(frozen=True, eq=False) class VariableGroup: """Base class to represent a group of variables. All variables in the group are assumed to have the same size. Additionally, the variables are indexed by a "key", and can be retrieved by direct indexing (even indexing - a list of keys) of the VariableGroup. - - Args: - variable_size: the number of states that the variable can be in. + a sequence of keys) of the VariableGroup. """ - variable_size: int + _keys_to_vars: Mapping[Any, nodes.Variable] = field(init=False) def __post_init__(self) -> None: """Initialize a private, immuable mapping from keys to Variables.""" - self._key_to_var: Mapping[Any, nodes.Variable] = MappingProxyType( - self._generate_vars() + object.__setattr__( + self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars()) ) - def __getitem__(self, key) -> Union[nodes.Variable, List[nodes.Variable]]: + def __getitem__(self, key: Any) -> Union[nodes.Variable, List[nodes.Variable]]: """Given a key, retrieve the associated Variable. Args: - key: a single key corresponding to a single variable, or a list of such keys + key: a single key corresponding to a single variable, or a sequence of such keys Returns: a single variable if the "key" argument is a single key. Otherwise, returns a list of variables corresponding to each key in the "key" argument. """ - if type(key) is list: + if isinstance(key, Sequence): vars_list: List[nodes.Variable] = [] for k in key: - var = self._key_to_var.get(k) + var = self._keys_to_vars.get(k) if var is None: raise ValueError( f"The key {k} is not present in the VariableGroup {type(self)}; please ensure " @@ -53,7 +48,7 @@ def __getitem__(self, key) -> Union[nodes.Variable, List[nodes.Variable]]: vars_list.append(var) return vars_list else: - var = self._key_to_var.get(key) + var = self._keys_to_vars.get(key) if var is None: raise ValueError( f"The key {key} is not present in in the VariableGroup {type(self)}; please ensure " @@ -61,7 +56,7 @@ def __getitem__(self, key) -> Union[nodes.Variable, List[nodes.Variable]]: ) return var - def _generate_vars(self) -> Dict[Any, nodes.Variable]: + def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: """Function that generates a dictionary mapping keys to variables. Returns: @@ -71,17 +66,18 @@ def _generate_vars(self) -> Dict[Any, nodes.Variable]: "Please subclass the VariableGroup class and override this method" ) - def get_all_vars(self) -> Tuple[nodes.Variable, ...]: + @property + def variables(self) -> Tuple[nodes.Variable, ...]: """Function to return a tuple of all variables in the group. Returns: tuple of all variable that are part of this VariableGroup """ - return tuple(self._key_to_var.values()) + return tuple(self._keys_to_vars.values()) -@dataclass -class CompositeVariableGroup: +@dataclass(frozen=True, eq=False) +class CompositeVariableGroup(VariableGroup): """A class to encapsulate a collection of instantiated VariableGroups. This class enables users to wrap various different VariableGroups and then index @@ -90,74 +86,123 @@ class CompositeVariableGroup: by the key to be indexed within the VariableGroup. Args: - key_vargroup_pairs: a tuple of tuples where each inner tuple is a (key, VariableGroup) - pair + variable_group_container: A container containing multiple variable groups. + Supported containers include mapping and sequence. + For a mapping, the keys of the mapping are used to index the variable groups. + For a sequence, the indices of the sequence are used to index the variable groups. """ - key_vargroup_pairs: Tuple[ - Tuple[Any, Union[VariableGroup, CompositeVariableGroup]], ... + variable_group_container: Union[ + Mapping[Any, VariableGroup], Sequence[VariableGroup] ] - def __post_init__(self): - """Initialize a private, immuable mapping from keys to VariableGroups.""" - key_vargroup_dict: Dict[Any, VariableGroup] = {} - for key_vargroup_tuple in self.key_vargroup_pairs: - key, vargroup = key_vargroup_tuple - key_vargroup_dict[key] = vargroup - self._key_to_vargroup: Mapping[Any, VariableGroup] = MappingProxyType( - key_vargroup_dict - ) - - def __getitem__(self, key) -> Union[nodes.Variable, List[nodes.Variable]]: + def __getitem__(self, key: Any) -> Union[nodes.Variable, List[nodes.Variable]]: """Given a key, retrieve the associated Variable from the associated VariableGroup. Args: - key: a single key corresponding to a single Variable within a VariableGroup, or a list + key: a single key corresponding to a single Variable within a VariableGroup, or a sequence of such keys Returns: a single variable if the "key" argument is a single key. Otherwise, returns a list of variables corresponding to each key in the "key" argument. """ - if type(key) is list: + if isinstance(key, Sequence): vars_list: List[nodes.Variable] = [] for k in key: - var_group = self._key_to_vargroup.get(k[0]) - if var_group is None: - raise ValueError( - f"The key {key[0]} is not present in the CompositeVariableGroup {type(self)}; please ensure " - "it's been added to the VariableGroup before trying to query it." - ) - vars_list.append(var_group[k[1:]]) # type: ignore + variable_group = self.variable_group_container[k[0]] + vars_list.append(variable_group[k[1:]]) # type: ignore return vars_list else: - var_group = self._key_to_vargroup.get(key[0]) - if var_group is None: + variable_group = self.variable_group_container[key[0]] + if variable_group is None: raise ValueError( f"The key {key[0]} is not present in the CompositeVariableGroup {type(self)}; please ensure " "it's been added to the VariableGroup before trying to query it." ) - return var_group[key[1:]] + return variable_group[key[1:]] - def get_all_vars(self) -> Tuple[nodes.Variable, ...]: + @property + def variables(self) -> Tuple[nodes.Variable, ...]: """Function to return a tuple of all variables from all VariableGroups in this group. Returns: tuple of all variable that are part of this VariableGroup """ - return sum( - [var_group.get_all_vars() for var_group in self._key_to_vargroup.values()], - (), - ) + if isinstance(self.variable_group_container, Mapping): + variables = sum( + [ + variable_group.variables + for variable_group in self.variable_group_container.values() + ], + (), + ) + else: + variables = sum( + [ + variable_group.variables + for variable_group in self.variable_group_container + ], + (), + ) + + return variables -@dataclass +@dataclass(frozen=True, eq=False) +class NDVariableArray(VariableGroup): + """Subclass of VariableGroup for n-dimensional grids of variables. + + Args: + shape: a tuple specifying the size of each dimension of the grid (similar to + the notion of a NumPy ndarray shape) + """ + + variable_size: int + shape: Tuple[int, ...] + + def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: + """Function that generates a dictionary mapping keys to variables. + + Returns: + a dictionary mapping all possible keys to different variables. + """ + keys_to_vars: Dict[Tuple[int, ...], nodes.Variable] = {} + for key in itertools.product(*[list(range(k)) for k in self.shape]): + keys_to_vars[key] = nodes.Variable(self.variable_size) + return keys_to_vars + + +@dataclass(frozen=True, eq=False) +class GenericVariableGroup(VariableGroup): + """A generic variable group that contains a set of variables of the same size + + Returns: + a dictionary mapping all possible keys to different variables. + """ + + variable_size: int + key_tuple: Tuple[Any, ...] + + def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: + """Function that generates a dictionary mapping keys to variables. + + Returns: + a dictionary mapping all possible keys to different variables. + """ + keys_to_vars: Dict[Tuple[Any, ...], nodes.Variable] = {} + for key in self.key_tuple: + keys_to_vars[key] = nodes.Variable(self.variable_size) + return keys_to_vars + + +@dataclass(frozen=True, eq=False) class FactorGroup: """Base class to represent a group of factors. Args: - var_group: either a VariableGroup or - if the elements of more than one VariableGroup + variable_group: either a VariableGroup or - if the elements of more than one VariableGroup are connected to this FactorGroup - then a CompositeVariableGroup. This holds all the variables that are connected to this FactorGroup @@ -165,7 +210,7 @@ class FactorGroup: ValueError: if the connected_variables() method returns an empty list """ - var_group: Union[CompositeVariableGroup, VariableGroup] + variable_group: Union[CompositeVariableGroup, VariableGroup] def __post_init__(self) -> None: """Initializes a tuple of all the factors contained within this FactorGroup.""" @@ -189,7 +234,7 @@ def connected_variables( ) -@dataclass +@dataclass(frozen=True, eq=False) class EnumerationFactorGroup(FactorGroup): """Base class to represent a group of EnumerationFactors. @@ -232,14 +277,14 @@ def __post_init__(self) -> None: self.factors: Tuple[nodes.EnumerationFactor, ...] = tuple( [ nodes.EnumerationFactor( - tuple(self.var_group[keys_list]), self.factor_configs, factor_configs_log_potentials # type: ignore + tuple(self.variable_group[keys_list]), self.factor_configs, factor_configs_log_potentials # type: ignore ) for keys_list in connected_var_keys_for_factors ] ) -@dataclass +@dataclass(frozen=True, eq=False) class PairwiseFactorGroup(FactorGroup): """Base class to represent a group of EnumerationFactors where each factor connects to two different variables. @@ -284,13 +329,13 @@ def __post_init__(self) -> None: if not ( self.log_potential_matrix.shape == ( - self.var_group[fac_list[0]].num_states, # type: ignore - self.var_group[fac_list[1]].num_states, # type: ignore + self.variable_group[fac_list[0]].num_states, # type: ignore + self.variable_group[fac_list[1]].num_states, # type: ignore ) ): raise ValueError( "self.log_potential_matrix must have shape" - + f"{(self.var_group[fac_list[0]].num_states, self.var_group[fac_list[1]].num_states)} " # type: ignore + + f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " # type: ignore + f"based on the return value of self.connected_variables(). Instead, it has shape {self.log_potential_matrix.shape}" ) self.factor_configs = np.array( @@ -307,55 +352,8 @@ def __post_init__(self) -> None: self.factors: Tuple[nodes.EnumerationFactor, ...] = tuple( [ nodes.EnumerationFactor( - tuple(self.var_group[keys_list]), self.factor_configs, factor_configs_log_potentials # type: ignore + tuple(self.variable_group[keys_list]), self.factor_configs, factor_configs_log_potentials # type: ignore ) for keys_list in connected_var_keys_for_factors ] ) - - -@dataclass -class NDVariableArray(VariableGroup): - """Concrete subclass of VariableGroup for n-dimensional grids of variables. - - Args: - shape: a tuple specifying the size of each dimension of the grid (similar to - the notion of a NumPy ndarray shape) - """ - - shape: Tuple[int, ...] - - def _generate_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: - """Function that generates a dictionary mapping keys to variables. - - Returns: - a dictionary mapping all possible keys to different variables. - """ - key_to_var_mapping: Dict[Tuple[int, ...], nodes.Variable] = {} - for key in itertools.product(*[list(range(k)) for k in self.shape]): - key_to_var_mapping[key] = nodes.Variable(self.variable_size) - return key_to_var_mapping - - -@dataclass -class GenericVariableGroup(VariableGroup): - """A generic variable group that contains a set of variables of the same size - - This is an overriden function from the parent class. - - Returns: - a dictionary mapping all possible keys to different variables. - """ - - key_tuple: Tuple[Any, ...] - - def _generate_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: - """Function that generates a dictionary mapping keys to variables. - - Returns: - a dictionary mapping all possible keys to different variables. - """ - key_to_var_mapping: Dict[Tuple[Any, ...], nodes.Variable] = {} - for key in self.key_tuple: - key_to_var_mapping[key] = nodes.Variable(self.variable_size) - return key_to_var_mapping From 87cb3f79e37c973f4a2d889ac0d9c315c6cfb6f7 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:09:57 -0700 Subject: [PATCH 03/11] Update mypy setup --- mypy.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy.ini b/mypy.ini index 5d287cff..f878d7cc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -9,3 +9,4 @@ warn_redundant_casts = True warn_return_any = True warn_unused_configs = True warn_unused_ignores = True +allow_redefinition = True From 52dbb089a179fd0dbb0448e4e71d24a385171572 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:16:17 -0700 Subject: [PATCH 04/11] More fixes --- examples/sanity_check_example.py | 8 +- pgmax/interface/datatypes.py | 129 +++++++++++++++++++------------ 2 files changed, 86 insertions(+), 51 deletions(-) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index 778731b3..e6034cd7 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -424,13 +424,17 @@ def get_evidence( for row in range(M): for col in range(N): try: - bp_values[i, row, col] = map_message_dict[composite_grid_group["grid_vars", i, row, col]] # type: ignore + bp_values[i, row, col] = map_message_dict[ + composite_grid_group["grid_vars", i, row, col] + ] bu_evidence[i, row, col, :] = var_evidence_dict[ grid_vars_group[i, row, col] ] except ValueError: try: - bp_values[i, row, col] = map_message_dict[composite_grid_group["additional_vars", i, row, col]] # type: ignore + bp_values[i, row, col] = map_message_dict[ + composite_grid_group["additional_vars", i, row, col] + ] bu_evidence[i, row, col, :] = var_evidence_dict[ additional_keys_group[i, row, col] ] diff --git a/pgmax/interface/datatypes.py b/pgmax/interface/datatypes.py index 3a97c35e..9ebd0916 100644 --- a/pgmax/interface/datatypes.py +++ b/pgmax/interface/datatypes.py @@ -1,7 +1,8 @@ import itertools +import typing from dataclasses import dataclass, field from types import MappingProxyType -from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Sequence, Tuple, Union import numpy as np @@ -25,18 +26,26 @@ def __post_init__(self) -> None: self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars()) ) - def __getitem__(self, key: Any) -> Union[nodes.Variable, List[nodes.Variable]]: + @typing.overload + def __getitem__(self, key: Hashable) -> nodes.Variable: + pass + + @typing.overload + def __getitem__(self, key: List) -> List[nodes.Variable]: + pass + + def __getitem__(self, key): """Given a key, retrieve the associated Variable. Args: - key: a single key corresponding to a single variable, or a sequence of such keys + key: a single key corresponding to a single variable, or a list of such keys Returns: a single variable if the "key" argument is a single key. Otherwise, returns a list of variables corresponding to each key in the "key" argument. """ - if isinstance(key, Sequence): + if isinstance(key, List): vars_list: List[nodes.Variable] = [] for k in key: var = self._keys_to_vars.get(k) @@ -66,6 +75,10 @@ def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: "Please subclass the VariableGroup class and override this method" ) + @property + def keys(self) -> Tuple[Any, ...]: + return tuple(self._keys_to_vars.keys()) + @property def variables(self) -> Tuple[nodes.Variable, ...]: """Function to return a tuple of all variables in the group. @@ -97,57 +110,72 @@ class CompositeVariableGroup(VariableGroup): Mapping[Any, VariableGroup], Sequence[VariableGroup] ] - def __getitem__(self, key: Any) -> Union[nodes.Variable, List[nodes.Variable]]: + def __post_init__(self): + if (not isinstance(self.variable_group_container, Mapping)) and ( + not isinstance(self.variable_group_container, Sequence) + ): + raise ValueError( + f"variable_group_container needs to be a mapping or a sequence. Got {type(self.variable_group_container)}" + ) + + @typing.overload + def __getitem__(self, key: Hashable) -> nodes.Variable: + pass + + @typing.overload + def __getitem__(self, key: List) -> List[nodes.Variable]: + pass + + def __getitem__(self, key): """Given a key, retrieve the associated Variable from the associated VariableGroup. Args: - key: a single key corresponding to a single Variable within a VariableGroup, or a sequence + key: a single key corresponding to a single Variable within a VariableGroup, or a list of such keys Returns: a single variable if the "key" argument is a single key. Otherwise, returns a list of variables corresponding to each key in the "key" argument. """ - if isinstance(key, Sequence): - vars_list: List[nodes.Variable] = [] - for k in key: - variable_group = self.variable_group_container[k[0]] - vars_list.append(variable_group[k[1:]]) # type: ignore - return vars_list + if isinstance(key, List): + keys_list = key else: - variable_group = self.variable_group_container[key[0]] - if variable_group is None: + keys_list = [key] + + vars_list = [] + for key in keys_list: + if len(key) < 2: raise ValueError( - f"The key {key[0]} is not present in the CompositeVariableGroup {type(self)}; please ensure " - "it's been added to the VariableGroup before trying to query it." + "The key needs to have at least 2 elements to index from a composite variable group." ) - return variable_group[key[1:]] - @property - def variables(self) -> Tuple[nodes.Variable, ...]: - """Function to return a tuple of all variables from all VariableGroups in this group. + variable_group = self.variable_group_container[key[0]] + if len(key) == 2: + vars_list.append(variable_group[key[1]]) + else: + vars_list.append(variable_group[key[1:]]) - Returns: - tuple of all variable that are part of this VariableGroup - """ + return vars_list[0] + + def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: + keys_to_vars = {} if isinstance(self.variable_group_container, Mapping): - variables = sum( - [ - variable_group.variables - for variable_group in self.variable_group_container.values() - ], - (), - ) + container_keys = self.variable_group_container.keys() else: - variables = sum( - [ - variable_group.variables - for variable_group in self.variable_group_container - ], - (), - ) + container_keys = set(range(len(self.variable_group_container))) + + for container_key in container_keys: + for variable_group_key in self.variable_group_container[container_key].keys: + if isinstance(variable_group_key, tuple): + keys_to_vars[ + (container_key,) + variable_group_key + ] = self.variable_group_container[container_key][variable_group_key] + else: + keys_to_vars[ + (container_key, variable_group_key) + ] = self.variable_group_container[container_key][variable_group_key] - return variables + return keys_to_vars @dataclass(frozen=True, eq=False) @@ -264,20 +292,21 @@ class EnumerationFactorGroup(FactorGroup): def __post_init__(self) -> None: """Initializes a tuple of all the factors contained within this FactorGroup.""" connected_var_keys_for_factors = self.connected_variables() - if ( - not hasattr(self, "factor_configs_log_potentials") - or hasattr(self, "factor_configs_log_potentials") - and self.factor_configs_log_potentials is None # type: ignore - ): + if getattr(self, "factor_configs_log_potentials", None) is None: factor_configs_log_potentials = np.zeros( self.factor_configs.shape[0], dtype=float ) else: - factor_configs_log_potentials = self.factor_configs_log_potentials # type: ignore + factor_configs_log_potentials = getattr( + self, "factor_configs_log_potentials" + ) + self.factors: Tuple[nodes.EnumerationFactor, ...] = tuple( [ nodes.EnumerationFactor( - tuple(self.variable_group[keys_list]), self.factor_configs, factor_configs_log_potentials # type: ignore + tuple(self.variable_group[keys_list]), + self.factor_configs, + factor_configs_log_potentials, ) for keys_list in connected_var_keys_for_factors ] @@ -329,13 +358,13 @@ def __post_init__(self) -> None: if not ( self.log_potential_matrix.shape == ( - self.variable_group[fac_list[0]].num_states, # type: ignore - self.variable_group[fac_list[1]].num_states, # type: ignore + self.variable_group[fac_list[0]].num_states, + self.variable_group[fac_list[1]].num_states, ) ): raise ValueError( "self.log_potential_matrix must have shape" - + f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " # type: ignore + + f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " + f"based on the return value of self.connected_variables(). Instead, it has shape {self.log_potential_matrix.shape}" ) self.factor_configs = np.array( @@ -352,7 +381,9 @@ def __post_init__(self) -> None: self.factors: Tuple[nodes.EnumerationFactor, ...] = tuple( [ nodes.EnumerationFactor( - tuple(self.variable_group[keys_list]), self.factor_configs, factor_configs_log_potentials # type: ignore + tuple(self.variable_group[keys_list]), + self.factor_configs, + factor_configs_log_potentials, ) for keys_list in connected_var_keys_for_factors ] From d4476b55234d161b609c36784117497682c769e1 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:20:08 -0700 Subject: [PATCH 05/11] Revert frozen for factor groups --- pgmax/interface/datatypes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pgmax/interface/datatypes.py b/pgmax/interface/datatypes.py index 9ebd0916..01c50196 100644 --- a/pgmax/interface/datatypes.py +++ b/pgmax/interface/datatypes.py @@ -225,7 +225,7 @@ def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: return keys_to_vars -@dataclass(frozen=True, eq=False) +@dataclass class FactorGroup: """Base class to represent a group of factors. @@ -262,7 +262,7 @@ def connected_variables( ) -@dataclass(frozen=True, eq=False) +@dataclass class EnumerationFactorGroup(FactorGroup): """Base class to represent a group of EnumerationFactors. @@ -313,7 +313,7 @@ def __post_init__(self) -> None: ) -@dataclass(frozen=True, eq=False) +@dataclass class PairwiseFactorGroup(FactorGroup): """Base class to represent a group of EnumerationFactors where each factor connects to two different variables. From 5840c1e5d69503c145a2d114623f59d7d0d9276b Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:31:52 -0700 Subject: [PATCH 06/11] Make things run --- pgmax/fg/graph.py | 5 ++++- pgmax/interface/datatypes.py | 21 ++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 209615a5..14eac8d6 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -42,7 +42,10 @@ def __post_init__(self): [factor_group.factors for factor_group in self.factor_groups], () ) self.variables = sum( - [factor_group.var_group.variables for factor_group in self.factor_groups], + [ + factor_group.variable_group.variables + for factor_group in self.factor_groups + ], (), ) diff --git a/pgmax/interface/datatypes.py b/pgmax/interface/datatypes.py index 01c50196..af37bb51 100644 --- a/pgmax/interface/datatypes.py +++ b/pgmax/interface/datatypes.py @@ -118,6 +118,10 @@ def __post_init__(self): f"variable_group_container needs to be a mapping or a sequence. Got {type(self.variable_group_container)}" ) + object.__setattr__( + self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars()) + ) + @typing.overload def __getitem__(self, key: Hashable) -> nodes.Variable: pass @@ -143,19 +147,22 @@ def __getitem__(self, key): keys_list = [key] vars_list = [] - for key in keys_list: - if len(key) < 2: + for curr_key in keys_list: + if len(curr_key) < 2: raise ValueError( "The key needs to have at least 2 elements to index from a composite variable group." ) - variable_group = self.variable_group_container[key[0]] - if len(key) == 2: - vars_list.append(variable_group[key[1]]) + variable_group = self.variable_group_container[curr_key[0]] + if len(curr_key) == 2: + vars_list.append(variable_group[curr_key[1]]) else: - vars_list.append(variable_group[key[1:]]) + vars_list.append(variable_group[curr_key[1:]]) - return vars_list[0] + if isinstance(key, List): + return vars_list + else: + return vars_list[0] def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: keys_to_vars = {} From c5e926857659bccc14f6a2ddcd19fe9a6f44666b Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:45:24 -0700 Subject: [PATCH 07/11] Make everything frozen --- examples/sanity_check_example.py | 6 +++--- pgmax/interface/datatypes.py | 27 +++++++++++++++++---------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index e6034cd7..ae733589 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -183,7 +183,7 @@ def create_valid_suppression_config_arr(suppression_diameter): # Subclass FactorGroup into the 3 different groups that appear in this problem -@dataclass +@dataclass(frozen=True, eq=False) class FourFactorGroup(interface_datatypes.EnumerationFactorGroup): num_rows: int num_cols: int @@ -235,7 +235,7 @@ def connected_variables( return ret_list -@dataclass +@dataclass(frozen=True, eq=False) class VertSuppressionFactorGroup(interface_datatypes.EnumerationFactorGroup): num_rows: int num_cols: int @@ -270,7 +270,7 @@ def connected_variables( return ret_list -@dataclass +@dataclass(frozen=True, eq=False) class HorzSuppressionFactorGroup(interface_datatypes.EnumerationFactorGroup): num_rows: int num_cols: int diff --git a/pgmax/interface/datatypes.py b/pgmax/interface/datatypes.py index af37bb51..f91997a8 100644 --- a/pgmax/interface/datatypes.py +++ b/pgmax/interface/datatypes.py @@ -7,6 +7,7 @@ import numpy as np import pgmax.fg.nodes as nodes +from pgmax.utils import cached_property @dataclass(frozen=True, eq=False) @@ -232,7 +233,7 @@ def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: return keys_to_vars -@dataclass +@dataclass(frozen=True, eq=False) class FactorGroup: """Base class to represent a group of factors. @@ -269,7 +270,7 @@ def connected_variables( ) -@dataclass +@dataclass(frozen=True, eq=False) class EnumerationFactorGroup(FactorGroup): """Base class to represent a group of EnumerationFactors. @@ -296,7 +297,8 @@ class EnumerationFactorGroup(FactorGroup): factor_configs: np.ndarray - def __post_init__(self) -> None: + @cached_property + def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: """Initializes a tuple of all the factors contained within this FactorGroup.""" connected_var_keys_for_factors = self.connected_variables() if getattr(self, "factor_configs_log_potentials", None) is None: @@ -308,7 +310,7 @@ def __post_init__(self) -> None: self, "factor_configs_log_potentials" ) - self.factors: Tuple[nodes.EnumerationFactor, ...] = tuple( + return tuple( [ nodes.EnumerationFactor( tuple(self.variable_group[keys_list]), @@ -320,7 +322,7 @@ def __post_init__(self) -> None: ) -@dataclass +@dataclass(frozen=True, eq=False) class PairwiseFactorGroup(FactorGroup): """Base class to represent a group of EnumerationFactors where each factor connects to two different variables. @@ -352,8 +354,7 @@ class PairwiseFactorGroup(FactorGroup): log_potential_matrix: np.ndarray - def __post_init__(self) -> None: - """Initializes a tuple of all the factors contained within this FactorGroup.""" + def __post_init(self) -> None: connected_var_keys_for_factors = self.connected_variables() for fac_list in connected_var_keys_for_factors: @@ -374,18 +375,24 @@ def __post_init__(self) -> None: + f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " + f"based on the return value of self.connected_variables(). Instead, it has shape {self.log_potential_matrix.shape}" ) - self.factor_configs = np.array( + + @cached_property + def factor_configs(self) -> np.ndarray: + return np.array( np.meshgrid( np.arange(self.log_potential_matrix.shape[0]), np.arange(self.log_potential_matrix.shape[1]), ) ).T.reshape((-1, 2)) + @cached_property + def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: + """Initializes a tuple of all the factors contained within this FactorGroup.""" + connected_var_keys_for_factors = self.connected_variables() factor_configs_log_potentials = self.log_potential_matrix[ self.factor_configs[:, 0], self.factor_configs[:, 1] ] - - self.factors: Tuple[nodes.EnumerationFactor, ...] = tuple( + return tuple( [ nodes.EnumerationFactor( tuple(self.variable_group[keys_list]), From c022f6ac3962a95a6f081e949d25b4de42c5a0a9 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:47:28 -0700 Subject: [PATCH 08/11] Reorg --- examples/sanity_check_example.py | 14 +++++++------- pgmax/{interface/datatypes.py => fg/groups.py} | 0 pgmax/interface/__init__.py | 1 - 3 files changed, 7 insertions(+), 8 deletions(-) rename pgmax/{interface/datatypes.py => fg/groups.py} (100%) delete mode 100644 pgmax/interface/__init__.py diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index ae733589..88960d94 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -21,7 +21,7 @@ # Custom Imports import pgmax.fg.nodes as nodes # isort:skip -import pgmax.interface.datatypes as interface_datatypes # isort:skip +import pgmax.fg.groups as groups # isort:skip # Standard Package Imports import matplotlib.pyplot as plt # isort:skip @@ -165,16 +165,16 @@ def create_valid_suppression_config_arr(suppression_diameter): # We create a NDVariableArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one # attached horizontally to the factor) that's at that location in the image, and the [1,i,j] entry corresponds to # the horizontal cut variable (i.e, the one attached vertically to the factor) that's at that location -grid_vars_group = interface_datatypes.NDVariableArray(3, (2, M - 1, N - 1)) +grid_vars_group = groups.NDVariableArray(3, (2, M - 1, N - 1)) # Make a group of additional variables for the edges of the grid extra_row_keys: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] extra_col_keys: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] additional_keys = tuple(extra_row_keys + extra_col_keys) -additional_keys_group = interface_datatypes.GenericVariableGroup(3, additional_keys) +additional_keys_group = groups.GenericVariableGroup(3, additional_keys) # Combine these two VariableGroups into one CompositeVariableGroup -composite_grid_group = interface_datatypes.CompositeVariableGroup( +composite_grid_group = groups.CompositeVariableGroup( dict(grid_vars=grid_vars_group, additional_vars=additional_keys_group) ) @@ -184,7 +184,7 @@ def create_valid_suppression_config_arr(suppression_diameter): @dataclass(frozen=True, eq=False) -class FourFactorGroup(interface_datatypes.EnumerationFactorGroup): +class FourFactorGroup(groups.EnumerationFactorGroup): num_rows: int num_cols: int factor_configs_log_potentials: Optional[np.ndarray] = None @@ -236,7 +236,7 @@ def connected_variables( @dataclass(frozen=True, eq=False) -class VertSuppressionFactorGroup(interface_datatypes.EnumerationFactorGroup): +class VertSuppressionFactorGroup(groups.EnumerationFactorGroup): num_rows: int num_cols: int suppression_diameter: int @@ -271,7 +271,7 @@ def connected_variables( @dataclass(frozen=True, eq=False) -class HorzSuppressionFactorGroup(interface_datatypes.EnumerationFactorGroup): +class HorzSuppressionFactorGroup(groups.EnumerationFactorGroup): num_rows: int num_cols: int suppression_diameter: int diff --git a/pgmax/interface/datatypes.py b/pgmax/fg/groups.py similarity index 100% rename from pgmax/interface/datatypes.py rename to pgmax/fg/groups.py diff --git a/pgmax/interface/__init__.py b/pgmax/interface/__init__.py deleted file mode 100644 index 51753285..00000000 --- a/pgmax/interface/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""A container package for a new user-facing interface for easily specifying Factor Graphs""" From 7554d49b8254275b71612c7e24ea18e9b3e4ab99 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 22:49:08 -0700 Subject: [PATCH 09/11] Fix crash --- pgmax/fg/graph.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 14eac8d6..db12b2bb 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -8,11 +8,9 @@ import jax.numpy as jnp import numpy as np -import pgmax.bp.infer as infer -import pgmax.fg.fg_utils as fg_utils -import pgmax.fg.nodes as nodes -import pgmax.interface.datatypes as interface_datatypes from pgmax import utils +from pgmax.bp import infer +from pgmax.fg import fg_utils, groups, nodes @dataclass(frozen=True, eq=False) @@ -35,7 +33,7 @@ class FactorGraph: for that particular variable should be placed. """ - factor_groups: Tuple[interface_datatypes.FactorGroup, ...] + factor_groups: Tuple[groups.FactorGroup, ...] def __post_init__(self): self.factors = sum( From d6e27959d8da2cb7e8df46eddf27317df9cd1cdd Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 10 Aug 2021 23:09:17 -0700 Subject: [PATCH 10/11] Docstrings and fixes --- pgmax/fg/groups.py | 69 ++++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index f91997a8..f6ec9196 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -17,12 +17,15 @@ class VariableGroup: All variables in the group are assumed to have the same size. Additionally, the variables are indexed by a "key", and can be retrieved by direct indexing (even indexing a sequence of keys) of the VariableGroup. + + Attributes: + _keys_to_vars: A private, immutable mapping from keys to variables """ _keys_to_vars: Mapping[Any, nodes.Variable] = field(init=False) def __post_init__(self) -> None: - """Initialize a private, immuable mapping from keys to Variables.""" + """Initialize a private, immutable mapping from keys to variables.""" object.__setattr__( self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars()) ) @@ -47,24 +50,25 @@ def __getitem__(self, key): """ if isinstance(key, List): - vars_list: List[nodes.Variable] = [] - for k in key: - var = self._keys_to_vars.get(k) - if var is None: - raise ValueError( - f"The key {k} is not present in the VariableGroup {type(self)}; please ensure " - "it's been added to the VariableGroup before trying to query it." - ) - vars_list.append(var) - return vars_list + keys_list = key else: - var = self._keys_to_vars.get(key) + keys_list = [key] + + vars_list = [] + for curr_key in keys_list: + var = self._keys_to_vars.get(curr_key) if var is None: raise ValueError( - f"The key {key} is not present in in the VariableGroup {type(self)}; please ensure " + f"The key {curr_key} is not present in the VariableGroup {type(self)}; please ensure " "it's been added to the VariableGroup before trying to query it." ) - return var + + vars_list.append(var) + + if isinstance(key, List): + return vars_list + else: + return vars_list[0] def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: """Function that generates a dictionary mapping keys to variables. @@ -78,6 +82,11 @@ def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: @property def keys(self) -> Tuple[Any, ...]: + """Function to return a tuple of all keys in the group. + + Returns: + tuple of all keys that are part of this VariableGroup + """ return tuple(self._keys_to_vars.keys()) @property @@ -105,6 +114,8 @@ class CompositeVariableGroup(VariableGroup): For a mapping, the keys of the mapping are used to index the variable groups. For a sequence, the indices of the sequence are used to index the variable groups. + Attributes: + _keys_to_vars: A private, immutable mapping from keys to variables """ variable_group_container: Union[ @@ -166,6 +177,11 @@ def __getitem__(self, key): return vars_list[0] def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: + """Function that generates a dictionary mapping keys to variables. + + Returns: + a dictionary mapping all possible keys to different variables. + """ keys_to_vars = {} if isinstance(self.variable_group_container, Mapping): container_keys = self.variable_group_container.keys() @@ -191,6 +207,7 @@ class NDVariableArray(VariableGroup): """Subclass of VariableGroup for n-dimensional grids of variables. Args: + variable_size: The size of the variables in this variable group shape: a tuple specifying the size of each dimension of the grid (similar to the notion of a NumPy ndarray shape) """ @@ -214,8 +231,10 @@ def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: class GenericVariableGroup(VariableGroup): """A generic variable group that contains a set of variables of the same size - Returns: - a dictionary mapping all possible keys to different variables. + Args: + variable_size: The size of the variables in this variable group + key_tuple: A tuple of all keys in this variable group + """ variable_size: int @@ -299,7 +318,7 @@ class EnumerationFactorGroup(FactorGroup): @cached_property def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: - """Initializes a tuple of all the factors contained within this FactorGroup.""" + """Returns a tuple of all the factors contained within this FactorGroup.""" connected_var_keys_for_factors = self.connected_variables() if getattr(self, "factor_configs_log_potentials", None) is None: factor_configs_log_potentials = np.zeros( @@ -356,7 +375,6 @@ class PairwiseFactorGroup(FactorGroup): def __post_init(self) -> None: connected_var_keys_for_factors = self.connected_variables() - for fac_list in connected_var_keys_for_factors: if len(fac_list) != 2: raise ValueError( @@ -377,26 +395,23 @@ def __post_init(self) -> None: ) @cached_property - def factor_configs(self) -> np.ndarray: - return np.array( + def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: + """Returns a tuple of all the factors contained within this FactorGroup.""" + factor_configs = np.array( np.meshgrid( np.arange(self.log_potential_matrix.shape[0]), np.arange(self.log_potential_matrix.shape[1]), ) ).T.reshape((-1, 2)) - - @cached_property - def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: - """Initializes a tuple of all the factors contained within this FactorGroup.""" - connected_var_keys_for_factors = self.connected_variables() factor_configs_log_potentials = self.log_potential_matrix[ - self.factor_configs[:, 0], self.factor_configs[:, 1] + factor_configs[:, 0], factor_configs[:, 1] ] + connected_var_keys_for_factors = self.connected_variables() return tuple( [ nodes.EnumerationFactor( tuple(self.variable_group[keys_list]), - self.factor_configs, + factor_configs, factor_configs_log_potentials, ) for keys_list in connected_var_keys_for_factors From 25288da74a512bc559a660219bbffc7f850bfc55 Mon Sep 17 00:00:00 2001 From: Guangyao Zhou Date: Wed, 11 Aug 2021 10:53:46 -0700 Subject: [PATCH 11/11] Update pgmax/fg/groups.py Co-authored-by: Nishanth Kumar --- pgmax/fg/groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index f6ec9196..33be74ce 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -373,7 +373,7 @@ class PairwiseFactorGroup(FactorGroup): log_potential_matrix: np.ndarray - def __post_init(self) -> None: + def __post_init__(self) -> None: connected_var_keys_for_factors = self.connected_variables() for fac_list in connected_var_keys_for_factors: if len(fac_list) != 2: