From 64361adef294de7dfced4229604c585f59b89aff Mon Sep 17 00:00:00 2001 From: drawfish Date: Tue, 22 Feb 2022 14:42:13 +0800 Subject: [PATCH 1/6] Fix 'TypeError' of rnnt_loss_simple function. Fix 'TypeError' exception when calling rnnt_loss_simple(..., return_grad=False) at validation steps. --- k2/python/k2/mutual_information.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 6a61d8f1f..30062cd9e 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -53,7 +53,7 @@ def forward( ans = _k2.mutual_information_forward(px, py, boundary, p) - px_grad, py_grad = None, None + px_grad, py_grad = torch.Tensor(), torch.Tensor() if return_grad or px.requires_grad or py.requires_grad: ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _k2.mutual_information_backward( From a7cc32c891a9abe64406076c7012c99d8c34ee68 Mon Sep 17 00:00:00 2001 From: gzchenduisheng Date: Wed, 23 Feb 2022 17:22:05 +0800 Subject: [PATCH 2/6] Fix 'MutualInformationRecursionFunction.forward()' return type check error for pytorch < 1.10.x --- k2/python/k2/mutual_information.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 30062cd9e..9f41e0102 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -27,6 +27,7 @@ def forward( ctx, px: torch.Tensor, py: torch.Tensor, + pxy_grads: list, boundary: Optional[torch.Tensor] = None, return_grad: bool = False, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -53,25 +54,29 @@ def forward( ans = _k2.mutual_information_forward(px, py, boundary, p) - px_grad, py_grad = torch.Tensor(), torch.Tensor() + px_grad, py_grad = None, None if return_grad or px.requires_grad or py.requires_grad: ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _k2.mutual_information_backward( px, py, boundary, p, ans_grad ) ctx.save_for_backward(px_grad, py_grad) - return ans, px_grad, py_grad + assert len(pxy_grads) == 2 + pxy_grads[0] = px_grad + pxy_grads[1] = py_grad + + return ans @staticmethod def backward( - ctx, ans_grad: Tensor, dummy_px_grad: Tensor, dummy_py_grad: Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: + ctx, ans_grad: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: (px_grad, py_grad) = ctx.saved_tensors (B,) = ans_grad.shape ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1) px_grad *= ans_grad py_grad *= ans_grad - return (px_grad, py_grad, None, None) + return (px_grad, py_grad, None, None, None) def mutual_information_recursion( @@ -179,10 +184,11 @@ def mutual_information_recursion( # The following assertions are for efficiency assert px.is_contiguous() assert py.is_contiguous() - - scores, px_grad, py_grad = MutualInformationRecursionFunction.apply( - px, py, boundary, return_grad + pxy_grads = [None, None] + scores = MutualInformationRecursionFunction.apply( + px, py, pxy_grads, boundary, return_grad ) + px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores From 2eeabe6c366e68cbf95a6f8f03d2aad1fdde25d3 Mon Sep 17 00:00:00 2001 From: gzchenduisheng Date: Wed, 23 Feb 2022 17:41:32 +0800 Subject: [PATCH 3/6] Modify return type. --- k2/python/k2/mutual_information.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 9f41e0102..76247c4fc 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -18,7 +18,7 @@ import torch import _k2 from torch import Tensor -from typing import Tuple, Optional, Sequence, Union +from typing import Tuple, Optional, Sequence, Union, List class MutualInformationRecursionFunction(torch.autograd.Function): @@ -27,10 +27,10 @@ def forward( ctx, px: torch.Tensor, py: torch.Tensor, - pxy_grads: list, + pxy_grads: List[Optional[torch.Tensor]], boundary: Optional[torch.Tensor] = None, return_grad: bool = False, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor: (B, S, T1) = px.shape T = T1 - 1 assert py.shape == (B, S + 1, T) From b208a4ff8e1324d65870535e3aaced8b0818c09b Mon Sep 17 00:00:00 2001 From: gzchenduisheng Date: Wed, 23 Feb 2022 18:38:17 +0800 Subject: [PATCH 4/6] Add documents about class MutualInformationRecursionFunction. --- k2/python/k2/mutual_information.py | 95 ++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 76247c4fc..0498ecfba 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -22,6 +22,12 @@ class MutualInformationRecursionFunction(torch.autograd.Function): + """A recursion that is useful in computing mutual information between two + sequences of real vectors, but may be useful more generally in + sequence-to-sequence tasks where monotonic alignment between pairs of + sequences is desired. + """ + @staticmethod def forward( ctx, @@ -31,6 +37,95 @@ def forward( boundary: Optional[torch.Tensor] = None, return_grad: bool = False, ) -> torch.Tensor: + """ + Computing mutual information between two sequences of real vectors. + Args: + px: + A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``, + where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence + (including representations of ``EOS`` symbols but not ``BOS`` symbols), + and ``S`` is the length of the ``y`` sequence (including representations + of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information + application, ``px[b][s][t]`` would represent the following log odds + ratio; ignoring the b index on the right to make the notation more + compact:: + + px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] + + This expression also implicitly includes the log-probability of + choosing to generate an ``x`` value as opposed to a ``y`` value. In + practice it might be computed as ``a + b``, where ``a`` is the log + probability of choosing to extend the sequence of length ``(s,t)`` + with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice + be of the form:: + + log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) + + where ``N`` is the number of terms that the sum over ``t'`` included, + which might include some or all of the other sequences as well as this + one. + + Note: + we don't require ``px`` and py to be contiguous, but the + code assumes for optimization purposes that the ``T`` axis has + stride 1. + + py: + A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``, + representing:: + + py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] + + This function does not treat ``x`` and ``y`` differently; the only + difference is that for optimization purposes we assume the last axis + (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are + contiguous. + + pxy_grads: + A List to store the return grads of ``px`` and ``py`` if return_grad == True. + Remain unchanged if return_grad == False. + + See `this PR ` for more information about + why we add this parameter. + + Note: + the length of the list must be 2, where the first element represents the grads + of ``px`` and the second one represents the grads of ``py``. + + boundary: + If supplied, a torch.LongTensor of shape ``[B][4]``, where each + row contains ``[s_begin, t_begin, s_end, t_end]``, + with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T`` + (this implies that empty sequences are allowed). + If not supplied, the values ``[0, 0, S, T]`` will be assumed. + These are the beginning and one-past-the-last positions in the ``x`` and + ``y`` sequences respectively, and can be used if not all sequences are + of the same length. + + return_grad: + Whether to return grads of ``px`` and ``py``, this grad standing for the + occupation probability is the output of the backward with a + ``fake gradient`` the ``fake gradient`` is the same as the gradient + you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``. + This is useful to implement the pruned version of rnnt loss. + + Returns: + Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual + information between the b'th pair of sequences. This is defined by + the following recursion on ``p[b,s,t]`` (where ``p`` is of shape + ``[B,S+1,T+1]``), representing a mutual information between sub-sequences + of lengths ``s`` and ``t``:: + + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + (if s > 0 or t > 0) + + where we handle edge cases by treating quantities with negative indexes + as **-infinity**. The extension to cases where the boundaries are + specified should be obvious; it just works on shorter sequences with + offsets into ``px`` and ``py``. + """ (B, S, T1) = px.shape T = T1 - 1 assert py.shape == (B, S + 1, T) From 0697c64d2cc69cae6be2da6172038c79f5468855 Mon Sep 17 00:00:00 2001 From: gzchenduisheng Date: Wed, 23 Feb 2022 19:30:27 +0800 Subject: [PATCH 5/6] Formated code style. --- k2/python/k2/mutual_information.py | 102 +++++++++++++++-------------- 1 file changed, 53 insertions(+), 49 deletions(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 0498ecfba..402e5bd4a 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -27,7 +27,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function): sequence-to-sequence tasks where monotonic alignment between pairs of sequences is desired. """ - + @staticmethod def forward( ctx, @@ -41,13 +41,15 @@ def forward( Computing mutual information between two sequences of real vectors. Args: px: - A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``, - where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence - (including representations of ``EOS`` symbols but not ``BOS`` symbols), - and ``S`` is the length of the ``y`` sequence (including representations - of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information - application, ``px[b][s][t]`` would represent the following log odds - ratio; ignoring the b index on the right to make the notation more + A torch.Tensor of some floating point type, with shape + ``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the + length of the ``x`` sequence (including representations of + ``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the + length of the ``y`` sequence (including representations of + ``EOS`` symbols but not ``BOS`` symbols). In the mutual + information application, ``px[b][s][t]`` would represent the + following log odds ratio; ignoring the b index on the right + to make the notation more compact:: px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] @@ -56,14 +58,14 @@ def forward( choosing to generate an ``x`` value as opposed to a ``y`` value. In practice it might be computed as ``a + b``, where ``a`` is the log probability of choosing to extend the sequence of length ``(s,t)`` - with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice - be of the form:: + with an ``x`` as opposed to a ``y`` value; and ``b`` might in + practice be of the form:: log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) - where ``N`` is the number of terms that the sum over ``t'`` included, - which might include some or all of the other sequences as well as this - one. + where ``N`` is the number of terms that the sum over ``t'`` + included, which might include some or all of the other sequences as + well as this one. Note: we don't require ``px`` and py to be contiguous, but the @@ -71,60 +73,64 @@ def forward( stride 1. py: - A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``, - representing:: + A torch.Tensor of the same dtype as ``px``, with shape + ``[B][S+1][T]``, representing:: py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] This function does not treat ``x`` and ``y`` differently; the only difference is that for optimization purposes we assume the last axis - (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are - contiguous. - + (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` + are contiguous. + pxy_grads: - A List to store the return grads of ``px`` and ``py`` if return_grad == True. + A List to store the return grads of ``px`` and ``py`` + if return_grad == True. Remain unchanged if return_grad == False. - - See `this PR ` for more information about - why we add this parameter. - + + See `this PR ` for more + information about why we add this parameter. + Note: - the length of the list must be 2, where the first element represents the grads - of ``px`` and the second one represents the grads of ``py``. - + the length of the list must be 2, where the first element + represents the grads of ``px`` and the second one represents + the grads of ``py``. + boundary: If supplied, a torch.LongTensor of shape ``[B][4]``, where each row contains ``[s_begin, t_begin, s_end, t_end]``, with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T`` (this implies that empty sequences are allowed). If not supplied, the values ``[0, 0, S, T]`` will be assumed. - These are the beginning and one-past-the-last positions in the ``x`` and - ``y`` sequences respectively, and can be used if not all sequences are + These are the beginning and one-past-the-last positions in the + ``x`` and ``y`` sequences respectively, and can be used if not + all sequences are of the same length. - + return_grad: - Whether to return grads of ``px`` and ``py``, this grad standing for the - occupation probability is the output of the backward with a + Whether to return grads of ``px`` and ``py``, this grad standing + for the occupation probability is the output of the backward with a ``fake gradient`` the ``fake gradient`` is the same as the gradient - you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``. + you'd get if you did + ``torch.autograd.grad((scores.sum()), [px, py])``. This is useful to implement the pruned version of rnnt loss. Returns: - Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual - information between the b'th pair of sequences. This is defined by - the following recursion on ``p[b,s,t]`` (where ``p`` is of shape - ``[B,S+1,T+1]``), representing a mutual information between sub-sequences - of lengths ``s`` and ``t``:: + Returns a torch.Tensor of shape ``[B]``, containing the log of + the mutual information between the b'th pair of sequences. This is + defined by the following recursion on ``p[b,s,t]`` (where ``p`` + is of shape ``[B,S+1,T+1]``), representing a mutual information + between sub-sequences of lengths ``s`` and ``t``:: p[b,0,0] = 0.0 p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) (if s > 0 or t > 0) - where we handle edge cases by treating quantities with negative indexes - as **-infinity**. The extension to cases where the boundaries are - specified should be obvious; it just works on shorter sequences with - offsets into ``px`` and ``py``. + where we handle edge cases by treating quantities with negative + indexes as **-infinity**. The extension to cases where the + boundaries are specified should be obvious; it just works on + shorter sequences with offsets into ``px`` and ``py``. """ (B, S, T1) = px.shape T = T1 - 1 @@ -153,8 +159,7 @@ def forward( if return_grad or px.requires_grad or py.requires_grad: ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _k2.mutual_information_backward( - px, py, boundary, p, ans_grad - ) + px, py, boundary, p, ans_grad) ctx.save_for_backward(px_grad, py_grad) assert len(pxy_grads) == 2 pxy_grads[0] = px_grad @@ -280,9 +285,8 @@ def mutual_information_recursion( assert px.is_contiguous() assert py.is_contiguous() pxy_grads = [None, None] - scores = MutualInformationRecursionFunction.apply( - px, py, pxy_grads, boundary, return_grad - ) + scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, + boundary, return_grad) px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores @@ -382,9 +386,9 @@ def joint_mutual_information_recursion( # actual derivative w.r.t. the total probs. ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype) - (px_grad, py_grad) = _k2.mutual_information_backward( - px_tot, py_tot, boundary, p, ans_grad - ) + (px_grad, + py_grad) = _k2.mutual_information_backward(px_tot, py_tot, boundary, p, + ans_grad) px_grad = px_grad.reshape(1, B, -1) py_grad = py_grad.reshape(1, B, -1) From acdbb2aeb2bb5edb83fec7e2eb4d4b2c35fd9021 Mon Sep 17 00:00:00 2001 From: drawfish Date: Thu, 24 Feb 2022 12:18:25 +0800 Subject: [PATCH 6/6] Fix rnnt_loss_smoothed return type. --- k2/python/k2/rnnt_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 6805ae2a4..6aad71cb5 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -1019,7 +1019,7 @@ def rnnt_loss_smoothed( boundary: Optional[Tensor] = None, reduction: Optional[str] = "mean", return_grad: bool = False, -) -> Tensor: +) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: """A simple case of the RNN-T loss, where the 'joiner' network is just addition.