Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typos #66

Merged
merged 3 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions escnn/nn/modules/conv/r3convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and
:math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output
field types.
Then :class:`~escnn.nn.R2Conv` guarantees an equivariant mapping
Then :class:`~escnn.nn.R3Conv` guarantees an equivariant mapping

.. math::
\kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^3
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self,
convolution module.

During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights
before calling :func:`torch.nn.functional.conv2d`.
before calling :func:`torch.nn.functional.conv3d`.
When :meth:`~torch.nn.Module.eval()` is called, the filter is built with the current trained weights and stored
for future reuse such that no overhead of expanding the kernel remains.

Expand Down Expand Up @@ -145,10 +145,10 @@ def __init__(self,

~.weights (torch.Tensor): the learnable parameters which are used to expand the kernel
~.filter (torch.Tensor): the convolutional kernel obtained by expanding the parameters
in :attr:`~escnn.nn.R2Conv.weights`
in :attr:`~escnn.nn.R3Conv.weights`
~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True``
~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding
the parameters in :attr:`~escnn.nn.R2Conv.bias`
the parameters in :attr:`~escnn.nn.R3Conv.bias`

"""

Expand Down Expand Up @@ -340,7 +340,7 @@ def shrink(t: GeometricTensor, s) -> GeometricTensor:

def export(self):
r"""
Export this module to a normal PyTorch :class:`torch.nn.Conv2d` module and set to "eval" mode.
Export this module to a normal PyTorch :class:`torch.nn.Conv3d` module and set to "eval" mode.

"""

Expand Down
2 changes: 1 addition & 1 deletion escnn/nn/modules/pooling/pointwise_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def export(self):
if self.d == 2:
return torch.nn.MaxPool2d(self.kernel_size, self.stride, self.padding, self.dilation).eval()
elif self.d == 3:
return torch.nn.MaxPool2d(self.kernel_size, self.stride, self.padding, self.dilation).eval()
return torch.nn.MaxPool3d(self.kernel_size, self.stride, self.padding, self.dilation).eval()
else:
raise NotImplementedError

Expand Down
16 changes: 8 additions & 8 deletions test/group/test_disentangle_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,49 +294,49 @@ def test_restrict_rr_cyclic_odd_cyclic_odd(self):

def test_restrict_o2_flips(self):
dg = O2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = (0., 1)
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_o2_dihedral_even(self):
dg = O2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = (0., 6)
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_o2_dihedral_odd(self):
dg = O2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = (0., 3)
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_o2_so2(self):
dg = O2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = (None, -1)
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_o2_cyclic_even(self):
dg = O2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = (None, 4)
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_o2_cyclic_odd(self):
dg = O2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = (None, 3)
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_so2_cyclic_even(self):
dg = SO2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = 8
self.check_disentangle(dg.restrict_representation(sg_id, repr))

def test_restrict_so2_cyclic_odd(self):
dg = SO2(10)
repr = directsum(dg.irreps)
repr = directsum(dg.irreps())
sg_id = 7
self.check_disentangle(dg.restrict_representation(sg_id, repr))

Expand Down