Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'TypeError' of rnnt_loss_pruned function. #924

Merged
merged 7 commits into from
Feb 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 120 additions & 15 deletions k2/python/k2/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,120 @@
import torch
import _k2
from torch import Tensor
from typing import Tuple, Optional, Sequence, Union
from typing import Tuple, Optional, Sequence, Union, List


class MutualInformationRecursionFunction(torch.autograd.Function):
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired.
"""

@staticmethod
def forward(
ctx,
px: torch.Tensor,
py: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this function lacks of documentation about its arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, feel free to modify it if necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks great to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you fix the CI errors about the style issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

pxy_grads: List[Optional[torch.Tensor]],
boundary: Optional[torch.Tensor] = None,
return_grad: bool = False,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor:
"""
Computing mutual information between two sequences of real vectors.
Args:
px:
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
following log odds ratio; ignoring the b index on the right
to make the notation more
compact::

px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]

This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in
practice be of the form::

log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))

where ``N`` is the number of terms that the sum over ``t'``
included, which might include some or all of the other sequences as
well as this one.

Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.

py:
A torch.Tensor of the same dtype as ``px``, with shape
``[B][S+1][T]``, representing::

py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]

This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py``
are contiguous.

pxy_grads:
A List to store the return grads of ``px`` and ``py``
if return_grad == True.
Remain unchanged if return_grad == False.

See `this PR <https://github.com/k2-fsa/k2/pull/924>` for more
information about why we add this parameter.

Note:
the length of the list must be 2, where the first element
represents the grads of ``px`` and the second one represents
the grads of ``py``.

boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the
``x`` and ``y`` sequences respectively, and can be used if not
all sequences are
of the same length.

return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing
for the occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did
``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.

Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of
the mutual information between the b'th pair of sequences. This is
defined by the following recursion on ``p[b,s,t]`` (where ``p``
is of shape ``[B,S+1,T+1]``), representing a mutual information
between sub-sequences of lengths ``s`` and ``t``::

p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)

where we handle edge cases by treating quantities with negative
indexes as **-infinity**. The extension to cases where the
boundaries are specified should be obvious; it just works on
shorter sequences with offsets into ``px`` and ``py``.
"""
(B, S, T1) = px.shape
T = T1 - 1
assert py.shape == (B, S + 1, T)
Expand Down Expand Up @@ -57,21 +159,24 @@ def forward(
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _k2.mutual_information_backward(
px, py, boundary, p, ans_grad
)
px, py, boundary, p, ans_grad)
ctx.save_for_backward(px_grad, py_grad)
return ans, px_grad, py_grad
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
pxy_grads[1] = py_grad

return ans

@staticmethod
def backward(
ctx, ans_grad: Tensor, dummy_px_grad: Tensor, dummy_py_grad: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None]:
ctx, ans_grad: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
(px_grad, py_grad) = ctx.saved_tensors
(B,) = ans_grad.shape
ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1)
px_grad *= ans_grad
py_grad *= ans_grad
return (px_grad, py_grad, None, None)
return (px_grad, py_grad, None, None, None)


def mutual_information_recursion(
Expand Down Expand Up @@ -179,10 +284,10 @@ def mutual_information_recursion(
# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()

scores, px_grad, py_grad = MutualInformationRecursionFunction.apply(
px, py, boundary, return_grad
)
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores


Expand Down Expand Up @@ -281,9 +386,9 @@ def joint_mutual_information_recursion(
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)

(px_grad, py_grad) = _k2.mutual_information_backward(
px_tot, py_tot, boundary, p, ans_grad
)
(px_grad,
py_grad) = _k2.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)

px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
Expand Down
2 changes: 1 addition & 1 deletion k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def rnnt_loss_smoothed(
boundary: Optional[Tensor] = None,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Tensor:
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
Copy link
Collaborator

@pkufool pkufool Feb 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think current return type is right, rnnt_loss_smoothed may return grads if return_grad equals True.
Edit: Sorry, I got it wrong.

"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.

Expand Down