Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds unit tests #41

Merged
merged 29 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b055f93
created few unit tests for bp_utils
Jul 30, 2021
faef39b
Merge branch 'vicariousinc:master' into unit_tests
NishanthJKumar Jul 30, 2021
7b25f33
Merge branch 'vicariousinc:master' into unit_tests
NishanthJKumar Aug 2, 2021
f12001e
includes first E2E test for pgmax
Aug 2, 2021
e17f697
adds new tests
Aug 2, 2021
0835ef8
Merge branch 'master' of github.com:vicariousinc/PGMax into vicarious…
Aug 13, 2021
407ccee
Merge branch 'vicariousinc-master' into unit_tests
Aug 13, 2021
2c7ab3c
updates sanity check and E2E test, but test currently fails
Aug 13, 2021
99501b4
adds coverage dependency to poetry
NishanthJKumar Aug 18, 2021
e0460e7
Merge branch 'unit_tests' into master
NishanthJKumar Aug 18, 2021
acaf801
Merge pull request #7 from NishanthJKumar/master
NishanthJKumar Aug 18, 2021
a6a32ff
small update to test_pgmax
NishanthJKumar Aug 18, 2021
f1811aa
adds codecov to CI
NishanthJKumar Aug 19, 2021
bf78603
test codecov upload with github secret
NishanthJKumar Aug 19, 2021
944297d
updates ci and poetry deps
NishanthJKumar Aug 19, 2021
e2afc5f
fixes command error in ci
NishanthJKumar Aug 19, 2021
46188df
fixes codecov yaml to be parseable
NishanthJKumar Aug 19, 2021
c5aa087
Merge branch 'master' of github.com:vicariousinc/PGMax into vicarious…
NishanthJKumar Aug 19, 2021
889479c
Merge branch 'vicariousinc-master' into unit_tests
NishanthJKumar Aug 19, 2021
cbac54a
updates ci to only run build on PR's again!
NishanthJKumar Aug 19, 2021
db8f120
testing codecov from personal fork
NishanthJKumar Aug 20, 2021
8472fd4
updates tests
NishanthJKumar Aug 21, 2021
43ccf99
adds simple e2e test inspired by the heretic model
NishanthJKumar Aug 22, 2021
2ab9fbe
includes tests reaching 95% coverage of pgmax
NishanthJKumar Aug 22, 2021
d75050e
adds more tests to get to 97% overall coverage
NishanthJKumar Aug 24, 2021
ed293af
updates tests to 100% coverage
NishanthJKumar Aug 24, 2021
401f629
addresses first major round of comments on #41
NishanthJKumar Aug 25, 2021
e894d9d
Update tests/fg/test_groups.py
NishanthJKumar Aug 25, 2021
52fdfa0
addresses second round of comments on #41
NishanthJKumar Aug 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[report]
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain if tests don't hit defensive assertion code:
raise NotImplementedError
14 changes: 10 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,22 @@ jobs:
- name: Install library
run: poetry install --no-interaction
#----------------------------------------------
# run test suite
# run test suite with coverage
#----------------------------------------------
- name: Test with pytest
- name: Test with coverage
run: |
poetry run pytest
poetry run pytest --cov=pgmax --cov-report=xml
#----------------------------------------------
# upload coverage report to codecov
#----------------------------------------------
- name: Upload Coverage to Codecov
NishanthJKumar marked this conversation as resolved.
Show resolved Hide resolved
uses: codecov/codecov-action@v2
with:
verbose: true # optional (default = false)
#----------------------------------------------
# test docs build only on PR
#----------------------------------------------
- name: Test docs build
if: ${{ github.event_name == 'pull_request' }}
run: |
cd docs
poetry run make html
Expand Down
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -131,3 +129,7 @@ dmypy.json

# Pyre type checker
.pyre/
.ruby-version

# VSCode settings
.vscode/
2 changes: 2 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[settings]
profile=black
4 changes: 4 additions & 0 deletions codecov.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ignore:
- "docs/**/*"
- "tests/**/*"
- "examples/**/*"
33 changes: 16 additions & 17 deletions examples/heretic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,31 @@
# format_version: '1.3'
# jupytext_version: 1.11.4
# kernelspec:
# display_name: 'Python 3.8.5 64-bit (''pgmax-JcKb81GE-py3.8'': poetry)'
# name: python3
# display_name: 'Python 3.7.11 64-bit (''pgmax-zIh0MZVc-py3.7'': venv)'
# name: python371164bitpgmaxzih0mzvcpy37venve540bb1b5cdf4292a3f5a12c4904cc40
# ---

from timeit import default_timer as timer
from typing import Any, List, Tuple

import jax
import jax.numpy as jnp

# %%
# %matplotlib inline
# fmt: off

# Standard Package Imports
import matplotlib.pyplot as plt # isort:skip
import numpy as np # isort:skip
import jax # isort:skip
import jax.numpy as jnp # isort:skip
from typing import Any, Tuple, List # isort:skip
from timeit import default_timer as timer # isort:skip
import matplotlib.pyplot as plt
import numpy as np

# Custom Imports
import pgmax.fg.groups as groups # isort:skip
import pgmax.fg.graph as graph # isort:skip
import pgmax.fg.graph as graph

# fmt: on
# Custom Imports
import pgmax.fg.groups as groups

# %% [markdown]
# # Setup Variables

# %%
# %% tags=[]
# Define some global constants
im_size = (30, 30)
prng_key = jax.random.PRNGKey(42)
Expand Down Expand Up @@ -100,7 +99,7 @@
# %% [markdown]
# # Add all Factors to graph via constructing FactorGroups

# %%
# %% tags=[]
def binary_connected_variables(
num_hidden_rows, num_hidden_cols, kernel_row, kernel_col
):
Expand Down Expand Up @@ -182,7 +181,7 @@ def custom_flatten_ordering(Mdown, Mup):
# %% [markdown]
# # Run Belief Propagation and Retrieve MAP Estimate

# %%
# %% tags=[]
# Run BP
bp_start_time = timer()
final_msgs = fg.run_bp(
Expand Down
23 changes: 10 additions & 13 deletions examples/sanity_check_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,21 @@

# %%
# %matplotlib inline
# fmt: off
import os
from timeit import default_timer as timer
from typing import Any, Dict, List, Tuple

# Standard Package Imports
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import default_rng
from scipy import sparse
from scipy.ndimage import gaussian_filter

import pgmax.fg.graph as graph

# Custom Imports
import pgmax.fg.groups as groups # isort:skip

# Standard Package Imports
import matplotlib.pyplot as plt # isort:skip
import numpy as np # isort:skip
from numpy.random import default_rng # isort:skip
from scipy import sparse # isort:skip
from scipy.ndimage import gaussian_filter # isort:skip
from typing import Any, Dict, Tuple, List # isort:skip
from timeit import default_timer as timer # isort:skip

# fmt: on
import pgmax.fg.groups as groups

# %% [markdown]
# ## Setting up Image
Expand Down
4 changes: 2 additions & 2 deletions pgmax/bp/bp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def segment_max_opt(
Args:
data: Array of shape (a_len,) where a_len is an arbitrary integer.
segments_lengths: Array of shape (num_segments,), where num_segments <= a_len.
segments_lengths.sum() should yield a_len.
segments_lengths: Array of shape (num_segments,), where 0 < num_segments <= a_len.
segments_lengths.sum() should yield a_len, and all elements must be > 0.
max_segment_length: The maximum value in segments_lengths.
Returns:
Expand Down
12 changes: 7 additions & 5 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ def add_factors(
**kwargs: optional mapping of keyword arguments. If specified, and if there
is no "factor_factory" key specified as part of this mapping, then these
args are taken to specify the arguments to be used to instantiate an
EnumerationFactor. If there is a "factor_factory" key, then these args
are taken to specify the arguments to be used to construct the class
specified by the "factor_factory" argument. Note that either *args or
**kwargs must be specified.
EnumerationFactor (specify a kwarg with the key 'keys' to indicate the
indices of variables ot be indexed to create the EnumerationFactor).
If there is a "factor_factory" key, then these args are taken to specify
the arguments to be used to construct the class specified by the
"factor_factory" argument. Note that either *args or **kwargs must be
specified.
"""
factor_factory = kwargs.pop("factor_factory", None)
if factor_factory is not None:
Expand Down Expand Up @@ -168,7 +170,7 @@ def evidence(self) -> np.ndarray:
if self.evidence_default_mode == "zeros":
evidence = np.zeros(self.num_var_states)
elif self.evidence_default_mode == "random":
evidence = np.random.gumbel(self.num_var_states)
evidence = np.random.gumbel(size=self.num_var_states)
else:
raise NotImplementedError(
f"evidence_default_mode {self.evidence_default_mode} is not yet implemented"
Expand Down
50 changes: 20 additions & 30 deletions pgmax/fg/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Dict, Hashable, List, Mapping, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -12,7 +12,7 @@

@dataclass(frozen=True, eq=False)
class VariableGroup:
"""Base class to represent a group of variables.
"""Class to represent a group of variables.

All variables in the group are assumed to have the same size. Additionally, the
variables are indexed by a "key", and can be retrieved by direct indexing (even indexing
Expand All @@ -32,11 +32,11 @@ def __post_init__(self) -> None:

@typing.overload
def __getitem__(self, key: Hashable) -> nodes.Variable:
pass
"""This function is a typing overload and is overwritten by the implemented __getitem__"""

@typing.overload
def __getitem__(self, key: List) -> List[nodes.Variable]:
pass
"""This function is a typing overload and is overwritten by the implemented __getitem__"""

def __getitem__(self, key):
"""Given a key, retrieve the associated Variable.
Expand Down Expand Up @@ -133,24 +133,17 @@ class CompositeVariableGroup(VariableGroup):
]

def __post_init__(self):
if (not isinstance(self.variable_group_container, Mapping)) and (
not isinstance(self.variable_group_container, Sequence)
):
raise ValueError(
f"variable_group_container needs to be a mapping or a sequence. Got {type(self.variable_group_container)}"
)

object.__setattr__(
self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars())
)

@typing.overload
def __getitem__(self, key: Hashable) -> nodes.Variable:
pass
"""This function is a typing overload and is overwritten by the implemented __getitem__"""

@typing.overload
def __getitem__(self, key: List) -> List[nodes.Variable]:
pass
"""This function is a typing overload and is overwritten by the implemented __getitem__"""

def __getitem__(self, key):
"""Given a key, retrieve the associated Variable from the associated VariableGroup.
Expand Down Expand Up @@ -213,7 +206,7 @@ def get_vars_to_evidence(

Args:
evidence: A mapping or a sequence of evidences.
The type of evidence should match that of self.variable_group_container
The type of evidence should match that of self.variable_group_container.

Returns:
a dictionary mapping all possible variables to the corresponding evidence
Expand Down Expand Up @@ -344,7 +337,7 @@ def get_vars_to_evidence(

if evidence[key].shape != (self.variable_size,):
raise ValueError(
f"Variable {key} expect an evidence array of shape "
f"Variable {key} expects an evidence array of shape "
f"({(self.variable_size,)})."
f"Got {evidence[key].shape}."
)
Expand All @@ -356,14 +349,14 @@ def get_vars_to_evidence(

@dataclass(frozen=True, eq=False)
class FactorGroup:
"""Base class to represent a group of factors.
"""Class to represent a group of factors.

Args:
variable_group: either a VariableGroup or - if the elements of more than one VariableGroup
are connected to this FactorGroup - then a CompositeVariableGroup. This holds
all the variables that are connected to this FactorGroup
connected_var_keys: A list of tuples of tuples, where each innermost tuple contains a
key variable_group. Each list within the outer list is taken to contain the keys of variables
connected_var_keys: A list of list of tuples, where each innermost tuple contains a
key into variable_group. Each list within the outer list is taken to contain the keys of variables
neighboring a particular factor to be added.

Raises:
Expand All @@ -385,7 +378,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:

@dataclass(frozen=True, eq=False)
class EnumerationFactorGroup(FactorGroup):
"""Base class to represent a group of EnumerationFactors.
"""Class to represent a group of EnumerationFactors.

All factors in the group are assumed to have the same set of valid configurations and
the same potential function. Note that the log potential function is assumed to be
Expand All @@ -398,27 +391,24 @@ class EnumerationFactorGroup(FactorGroup):
Attributes:
factors: a tuple of all the factors belonging to this group. These are constructed
internally by invoking the _get_connected_var_keys_for_factors method.
factor_configs_log_potentials: Can be specified by an inheriting class, or just left
unspecified (equivalent to specifying None). If specified, must have (num_val_configs,).
and contain the log of the potential value for every possible configuration.
If none, it is assumed the log potential is uniform 0 and such an array is automatically
initialized.

factor_configs_log_potentials: Optional ndarray of shape (num_val_configs,).
if specified. Must contain the log of the potential value for every possible
configuration. If left unspecified, it is assumed the log potential is uniform
0 and such an array is automatically initialized.
"""

factor_configs: np.ndarray
factor_configs_log_potentials: Optional[np.ndarray] = None

@cached_property
def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
"""Returns a tuple of all the factors contained within this FactorGroup."""
if getattr(self, "factor_configs_log_potentials", None) is None:
if self.factor_configs_log_potentials is None:
NishanthJKumar marked this conversation as resolved.
Show resolved Hide resolved
factor_configs_log_potentials = np.zeros(
self.factor_configs.shape[0], dtype=float
)
else:
factor_configs_log_potentials = getattr(
self, "factor_configs_log_potentials"
)
factor_configs_log_potentials = self.factor_configs_log_potentials

return tuple(
[
Expand All @@ -434,7 +424,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:

@dataclass(frozen=True, eq=False)
class PairwiseFactorGroup(FactorGroup):
"""Base class to represent a group of EnumerationFactors where each factor connects to
"""Class to represent a group of EnumerationFactors where each factor connects to
two different variables.

All factors in the group are assumed to be such that all possible configuration of the two
Expand Down
Loading