From 9ac1e78e6ae06f7c6e4e9efe0a119d5238f06bfe Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Sep 2021 17:32:34 +0800 Subject: [PATCH 01/64] Update doc URL. (#821) --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b4c12c59f..c7768d163 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@
- + + + -[![Documentation Status](https://readthedocs.org/projects/k2/badge/?version=latest)](https://k2.readthedocs.io/en/latest/?badge=latest) +
+ +[![Documentation Status](https://github.com/k2-fsa/k2/actions/workflows/build-doc.yml/badge.svg)](https://k2-fsa.github.io/k2/)
From bbe0dedc67cf82021ec8277c5e863d2f07ecce49 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 14 Sep 2021 12:18:13 +0800 Subject: [PATCH 02/64] Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825) * Support index 2-axes RaggedTensor, Support slicing for RaggedTensor * Fix compiling errors * Fix unit test * Change RaggedTensor.data to RaggedTensor.values * Fix style * Add docs * Run nightly-cpu when pushing code to nightly-cpu branch --- .github/workflows/nightly-cpu.yml | 3 ++ k2/python/csrc/torch/v2/any.cu | 45 ++++++++++++++--- k2/python/csrc/torch/v2/doc/any.h | 70 ++++++++++++++++---------- k2/python/csrc/torch/v2/ragged_any.cu | 4 +- k2/python/csrc/torch/v2/ragged_any.h | 2 +- k2/python/k2/autograd_utils.py | 2 +- k2/python/k2/fsa.py | 2 +- k2/python/k2/fsa_algo.py | 4 +- k2/python/k2/utils.py | 3 +- k2/python/tests/index_test.py | 7 +-- k2/python/tests/ragged_ops_test.py | 38 +++++++------- k2/python/tests/ragged_shape_test.py | 2 +- k2/python/tests/ragged_tensor_test.py | 22 ++++---- k2/python/tests/ragged_test.py | 4 +- k2/python/tests/remove_epsilon_test.py | 2 +- 15 files changed, 133 insertions(+), 77 deletions(-) diff --git a/.github/workflows/nightly-cpu.yml b/.github/workflows/nightly-cpu.yml index 0b09beb50..6d9e1c025 100644 --- a/.github/workflows/nightly-cpu.yml +++ b/.github/workflows/nightly-cpu.yml @@ -17,6 +17,9 @@ name: nightly-cpu on: + push: + branches: + - nightly-cpu schedule: # minute (0-59) # hour (0-23) diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index b21c1989d..472ea7cea 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -70,15 +70,48 @@ void PybindRaggedAny(py::module &m) { any.def( "__getitem__", - [](RaggedAny &self, int32_t i) -> RaggedAny { - return self.Index(/*axis*/ 0, i); + [](RaggedAny &self, int32_t i) -> py::object { + if (self.any.NumAxes() > 2) { + RaggedAny ragged = self.Index(/*axis*/ 0, i); + return py::cast(ragged); + } else { + K2_CHECK_EQ(self.any.NumAxes(), 2); + Array1 row_split = self.any.RowSplits(1).To(GetCpuContext()); + const int32_t *row_split_data = row_split.Data(); + int32_t begin = row_split_data[i], + end = row_split_data[i + 1]; + Dtype t = self.any.GetDtype(); + FOR_REAL_AND_INT32_TYPES(t, T, { + Array1 array = + self.any.Specialize().values.Arange(begin, end); + torch::Tensor tensor = ToTorch(array); + return py::cast(tensor); + }); + } + // Unreachable code + return py::none(); }, py::arg("i"), kRaggedAnyGetItemDoc); + any.def( + "__getitem__", + [](RaggedAny &self, const py::slice &slice) -> RaggedAny { + py::ssize_t start = 0, stop = 0, step = 0, slicelength = 0; + if (!slice.compute(self.any.Dim0(), &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + int32_t istart = static_cast(start); + int32_t istop = static_cast(stop); + int32_t istep = static_cast(step); + K2_CHECK_EQ(istep, 1) << "Only support slicing with step 1, given : " + << istep; + + return self.Arange(/*axis*/ 0, istart, istop); + }, py::arg("key"), kRaggedAnyGetItemSliceDoc); + any.def("index", - static_cast( + static_cast( &RaggedAny::Index), - py::arg("indexes"), py::arg("remove_axis") = true, + py::arg("indexes"), kRaggedAnyRaggedIndexDoc); any.def("index", @@ -325,8 +358,8 @@ void PybindRaggedAny(py::module &m) { // Return the underlying memory of this tensor. // No data is copied. Memory is shared. any.def_property_readonly( - "data", [](RaggedAny &self) -> torch::Tensor { return self.Data(); }, - kRaggedAnyDataDoc); + "values", [](RaggedAny &self) -> torch::Tensor { return self.Data(); }, + kRaggedAnyValuesDoc); any.def_property_readonly( "shape", [](RaggedAny &self) -> RaggedShape { return self.any.shape; }, diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index 46d2de7a9..b69ca3093 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -350,9 +350,6 @@ Select the i-th sublist along axis 0. Caution: Support for autograd is to be implemented. -Note: - It requires that this tensor has at least 3 axes. - >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] ]') @@ -363,11 +360,45 @@ Select the i-th sublist along axis 0. >>> a[1] [ [ 8 ] ] +>>> a = k2r.RaggedTensor('[ [1 3] [9] [8] ]') +>>> a +[ [ 1 3 ] [ 9 ] [ 8 ] ] +>>> a[0] +tensor([1, 3], dtype=torch.int32) +>>> a[1] +tensor([9], dtype=torch.int32) + Args: i: The i-th sublist along axis 0. Returns: - Return a new ragged tensor with one fewer axis. + Return a new ragged tensor with one fewer axis. If `num_axes == 2`, the + return value will be a 1D tensor. +)doc"; + +static constexpr const char *kRaggedAnyGetItemSliceDoc = R"doc( +Slices sublists along axis 0 with the given range. Only support slicing step +equals to 1. + +Caution: + Support for autograd is to be implemented. + +>>> import torch +>>> import k2.ragged as k2r +>>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] [[10 11]] ]') +>>> a +[ [ [ 1 3 ] [ ] [ 9 ] ] [ [ 8 ] ] [ [ 10 11 ] ] ] +>>> a[0:2] +[ [ [ 1 3 ] [ ] [ 9 ] [ [ 8 ] ] ] ] +>>> a[1:2] +[ [ [ 8 ] ] [ [ 10 11 ] ] ] + +Args: + key: + Slice containing integer constants. +Returns: + Return a new ragged tensor with the same axes as original ragged tensor, but + only contains the sublists within the range. )doc"; static constexpr const char *kRaggedAnyCloneDoc = R"doc( @@ -644,23 +675,23 @@ device(type='cuda', index=0) >>> b.device == torch.device('cuda:0') )doc"; -static constexpr const char *kRaggedAnyDataDoc = R"doc( +static constexpr const char *kRaggedAnyValuesDoc = R"doc( Return the underlying memory as a 1-D tensor. >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2], [], [5], [], [8, 9, 10]]) ->>> a.data +>>> a.values tensor([ 1, 2, 5, 8, 9, 10], dtype=torch.int32) ->>> isinstance(a.data, torch.Tensor) +>>> isinstance(a.values, torch.Tensor) True ->>> a.data[0] = -1 +>>> a.values[-2] = -1 >>> a [ [ -1 2 ] [ ] [ 5 ] [ ] [ 8 9 10 ] ] ->>> a.data[3] = -3 +>>> a.values[3] = -3 >>> a [ [ -1 2 ] [ ] [ 5 ] [ ] [ -3 9 10 ] ] ->>> a.data[2] = -2 +>>> a.values[2] = -2 >>> a [ [ -1 2 ] [ ] [ -2 ] [ ] [ -3 9 10 ] ] )doc"; @@ -1301,14 +1332,10 @@ Index a ragged tensor with a ragged tensor. >>> import k2.ragged as k2r >>> src = k2r.RaggedTensor([[10, 11], [12, 13.5]]) >>> indexes = k2r.RaggedTensor([[0, 1]]) - >>> src.index(indexes, remove_axis=True) - [ [ 10 11 12 13.5 ] ] - >>> src.index(indexes, remove_axis=False) + >>> src.index(indexes) [ [ [ 10 11 ] [ 12 13.5 ] ] ] >>> i = k2r.RaggedTensor([[0], [1], [0, 0]]) - >>> src.index(i, remove_axis=True) - [ [ 10 11 ] [ 12 13.5 ] [ 10 11 10 11 ] ] - >>> src.index(i, remove_axis=False) + >>> src.index(i) [ [ [ 10 11 ] ] [ [ 12 13.5 ] ] [ [ 10 11 ] [ 10 11 ] ] ] **Example 2**: @@ -1316,9 +1343,7 @@ Index a ragged tensor with a ragged tensor. >>> import k2.ragged as k2r >>> src = k2r.RaggedTensor([ [[1, 0], [], [2]], [[], [3], [0, 0, 1]], [[1, 2], [-1]]]) >>> i = k2r.RaggedTensor([[[0, 2], [1]], [[0]]]) - >>> src.index(i, remove_axis=True) - [ [ [ [ 1 0 2 ] [ 1 2 -1 ] ] [ [ 3 0 0 1 ] ] ] [ [ [ 1 0 2 ] ] ] ] - >>> src.index(i, remove_axis=False) + >>> src.index(i) [ [ [ [ [ 1 0 ] [ ] [ 2 ] ] [ [ 1 2 ] [ -1 ] ] ] [ [ [ ] [ 3 ] [ 0 0 1 ] ] ] ] [ [ [ [ 1 0 ] [ ] [ 2 ] ] ] ] ] Args: @@ -1328,13 +1353,6 @@ Index a ragged tensor with a ragged tensor. Caution: Its dtype has to be ``torch.int32``. - remove_axis: - If ``True``, then we remove the last-but-one axis, - which has the effect of appending lists, e.g. - if ``self`` is ``[[ 10 11 ] [ 12 13 ]]`` and ``indexes`` - is ``[[0 1]]`, this function will give us ``[[ 10 11 12 13 ]]``. - If ``False`` the answer will have at least 3 axes, e.g., ``[[[10 11]] [12 13]]]`` , - in this case. Returns: Return indexed tensor. )doc"; diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index 87dc6846f..8366a5f45 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -560,13 +560,13 @@ torch::optional RaggedAny::Sort( return ans; } -RaggedAny RaggedAny::Index(RaggedAny &indexes, - bool remove_axis /* = true*/) /*const*/ { +RaggedAny RaggedAny::Index(RaggedAny &indexes) /*const*/ { K2_CHECK_EQ(indexes.any.GetDtype(), kInt32Dtype) << "Unsupported dtype: " << TraitsOf(indexes.any.GetDtype()).Name(); DeviceGuard guard(any.Context()); + bool remove_axis = false; Dtype t = any.GetDtype(); FOR_REAL_AND_INT32_TYPES(t, T, { return RaggedAny(k2::Index(any.Specialize(), diff --git a/k2/python/csrc/torch/v2/ragged_any.h b/k2/python/csrc/torch/v2/ragged_any.h index 1fbb00ca0..c3e171f59 100644 --- a/k2/python/csrc/torch/v2/ragged_any.h +++ b/k2/python/csrc/torch/v2/ragged_any.h @@ -226,7 +226,7 @@ struct RaggedAny { bool need_new2old_indexes = false); /// Wrapper for k2::Index - RaggedAny Index(RaggedAny &indexes, bool remove_axis = true) /*const*/; + RaggedAny Index(RaggedAny &indexes) /*const*/; /// Wrapper for k2::Index std::pair> Index( diff --git a/k2/python/k2/autograd_utils.py b/k2/python/k2/autograd_utils.py index 1ceab1545..2a328e6d6 100644 --- a/k2/python/k2/autograd_utils.py +++ b/k2/python/k2/autograd_utils.py @@ -105,7 +105,7 @@ def backward(ctx, out_fsa_scores_grad: torch.Tensor dtype=torch.float32, device=unused_in_fsa_scores.device, requires_grad=False) - _k2.index_add(arc_map.data, expanded, ans) + _k2.index_add(arc_map.values, expanded, ans) return ( None, # out_fsa diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index 1652c0d4e..619f3a0e0 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -1394,7 +1394,7 @@ def set_scores_stochastic_(self, scores) -> None: # Note we use `to` here since `scores` and `self.scores` may not # be on the same device. - self.scores = ragged_scores.data.to(self.scores.device) + self.scores = ragged_scores.values.to(self.scores.device) def convert_attr_to_ragged_(self, name: str, remove_eps: bool = True) -> 'Fsa': diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 588dbc333..9efa6251a 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -466,7 +466,7 @@ def shortest_path(fsa: Fsa, use_double_scores: bool) -> Fsa: ''' entering_arcs = fsa._get_entering_arcs(use_double_scores) ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs) - arc_map = ragged_int.data + arc_map = ragged_int.values out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map) return out_fsa @@ -1016,7 +1016,7 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedTensor], if isinstance(symbols, k2.RaggedTensor): assert device is None assert symbols.num_axes == 2 - symbol_values = symbols.data + symbol_values = symbols.values else: symbol_values = torch.tensor( [it for symbol in symbols for it in symbol], diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index ae2e55b3e..d24780240 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -169,7 +169,7 @@ def convert_aux_label_to_symbol( if end == begin: return ':' - labels = aux_labels.data[begin:end] + labels = aux_labels.values[begin:end] ans = [] for label in labels.tolist(): if label == -1: @@ -538,6 +538,7 @@ def fsa_from_unary_function_ragged(src: Fsa, # We currently don't support float ragged attributes assert value.dtype == torch.int32 new_value = value.index(arc_map) + new_value = new_value.remove_axis(new_value.num_axes - 2) setattr(dest, name, new_value) for name, value in src.named_non_tensor_attr(): diff --git a/k2/python/tests/index_test.py b/k2/python/tests/index_test.py index 6f9a23a13..9ddcb0e51 100644 --- a/k2/python/tests/index_test.py +++ b/k2/python/tests/index_test.py @@ -145,6 +145,7 @@ def test(self): device=device) ragged_index = k2.RaggedTensor(index_shape, index_values) ans = src.index(ragged_index) + ans = ans.remove_axis(1) expected_row_splits = torch.tensor([0, 5, 5, 5, 9], dtype=torch.int32, device=device) @@ -153,7 +154,7 @@ def test(self): expected_values = torch.tensor([1, 2, 4, 5, 6, 3, 3, 1, 2], dtype=torch.int32, device=device) - self.assertTrue(torch.allclose(ans.data, expected_values)) + self.assertTrue(torch.allclose(ans.values, expected_values)) # index with tensor tensor_index = torch.tensor([0, 3, 2, 1, 2, 1], @@ -168,7 +169,7 @@ def test(self): expected_values = torch.tensor([1, 2, 4, 5, 6, 3, 3], dtype=torch.int32, device=device) - self.assertTrue(torch.allclose(ans.data, expected_values)) + self.assertTrue(torch.allclose(ans.values, expected_values)) class TestIndexTensorWithRaggedInt(unittest.TestCase): @@ -203,7 +204,7 @@ def test(self): expected_values = torch.tensor([1, 4, 3, 4, 6, 2, 4], dtype=torch.int32, device=device) - self.assertTrue(torch.allclose(ans.data, expected_values)) + self.assertTrue(torch.allclose(ans.values, expected_values)) if __name__ == '__main__': diff --git a/k2/python/tests/ragged_ops_test.py b/k2/python/tests/ragged_ops_test.py index 442eb0ad9..14cf9a339 100644 --- a/k2/python/tests/ragged_ops_test.py +++ b/k2/python/tests/ragged_ops_test.py @@ -195,7 +195,7 @@ def test_normalize_scores_use_log_non_zero_stride(self): for device in self.devices: for dtype in [torch.float32, torch.float64]: src = k2.RaggedTensor(s, dtype).to(device) - saved = src.data.clone().detach() + saved = src.values.clone().detach() saved.requires_grad_(True) src.requires_grad_(True) @@ -204,9 +204,9 @@ def test_normalize_scores_use_log_non_zero_stride(self): scale = torch.arange(ans.numel(), device=device) # the stride of grad is not 0 - (ans.data * scale).sum().backward() + (ans.values * scale).sum().backward() - expected = saved.new_zeros(*ans.data.shape) + expected = saved.new_zeros(*ans.values.shape) normalizer = saved[:3].exp().sum().log() expected[:3] = saved[:3] - normalizer @@ -219,7 +219,7 @@ def test_normalize_scores_use_log_non_zero_stride(self): normalizer = saved[6:8].exp().sum().log() expected[6:8] = saved[6:8] - normalizer - self.assertTrue(torch.allclose(expected, ans.data)) + self.assertTrue(torch.allclose(expected, ans.values)) (expected * scale).sum().backward() self.assertTrue(torch.allclose(saved.grad, src.grad)) @@ -231,16 +231,16 @@ def test_normalize_scores_use_log_zero_stride(self): for device in self.devices: for dtype in [torch.float32, torch.float64]: src = k2.RaggedTensor(s, dtype).to(device) - saved = src.data.clone().detach() + saved = src.values.clone().detach() saved.requires_grad_(True) src.requires_grad_(True) ans = src.normalize(use_log=True) # the stride of grad is 0 - ans.data.sum().backward() + ans.values.sum().backward() - expected = saved.new_zeros(*ans.data.shape) + expected = saved.new_zeros(*ans.values.shape) normalizer = saved[:3].exp().sum().log() expected[:3] = saved[:3] - normalizer @@ -253,7 +253,7 @@ def test_normalize_scores_use_log_zero_stride(self): normalizer = saved[6:8].exp().sum().log() expected[6:8] = saved[6:8] - normalizer - self.assertTrue(torch.allclose(expected, ans.data)) + self.assertTrue(torch.allclose(expected, ans.values)) expected.sum().backward() self.assertTrue(torch.allclose(saved.grad, src.grad)) @@ -281,7 +281,7 @@ def test_normalize_scores_use_log_from_shape(self): normalized_scores = ragged_scores.normalize(use_log=True) assert normalized_scores.requires_grad is True - fsa.scores = normalized_scores.data + fsa.scores = normalized_scores.values assert fsa.scores.requires_grad is True # arcs leaving state 0 @@ -306,16 +306,16 @@ def test_normalize_scores(self): [ [1 3 5] [2 -1] [] [3] [5 2] ] ''' src = k2.RaggedTensor(s, dtype=dtype).to(device) - saved = src.data + saved = src.values ans = src.normalize(use_log=False) - expected = saved.new_zeros(*ans.data.shape) + expected = saved.new_zeros(*ans.values.shape) expected[:3] = saved[:3] / saved[:3].sum() expected[3:5] = saved[3:5] / saved[3:5].sum() expected[5] = 1 expected[6:8] = saved[6:8] / saved[6:8].sum() - assert torch.allclose(ans.data, expected) + assert torch.allclose(ans.values, expected) def test_sum_per_sublist(self): s = ''' @@ -591,11 +591,11 @@ def test_argmax_per_sublist_two_axes_random(self): res.append([random.random() * -100]) # sublist with huge elements res.append([random.random() * -100 for x in range(5000)]) - ragged_cpu = k2.ragged.create_ragged2(res) - indexes_cpu = k2.ragged.argmax_per_sublist(ragged_cpu) + ragged_cpu = k2.RaggedTensor(res) + indexes_cpu = ragged_cpu.argmax() for device in self.devices: ragged = ragged_cpu.to(device) - indexes = k2.ragged.argmax_per_sublist(ragged).to("cpu") + indexes = ragged.argmax().to("cpu") assert torch.all(torch.eq(indexes, indexes_cpu)) def test_max_per_sublist_two_axes(self): @@ -636,8 +636,8 @@ def test_sort_sublist_ascending(self): assert src == expected_src assert torch.all(torch.eq(new2old, expected_new2old)) - expected_sorted = k2.index_select(src_clone.data, new2old) - sorted = src.data + expected_sorted = k2.index_select(src_clone.values, new2old) + sorted = src.values assert torch.all(torch.eq(expected_sorted, sorted)) def test_sort_sublist_descending(self): @@ -655,8 +655,8 @@ def test_sort_sublist_descending(self): assert src == sorted_src assert torch.all(torch.eq(new2old, expected_new2old)) - expected_sorted = k2.index_select(src_clone.data, new2old) - sorted = src.data + expected_sorted = k2.index_select(src_clone.values, new2old) + sorted = src.values assert torch.all(torch.eq(expected_sorted, sorted)) diff --git a/k2/python/tests/ragged_shape_test.py b/k2/python/tests/ragged_shape_test.py index 04eabcc44..1e6d41e4b 100644 --- a/k2/python/tests/ragged_shape_test.py +++ b/k2/python/tests/ragged_shape_test.py @@ -118,7 +118,7 @@ def test_compose_ragged_shape(self): abshape2 = ashape.compose(bshape) assert abshape == prod.shape assert abshape2 == prod.shape - prod2 = k2.RaggedTensor(abshape2, b.data) + prod2 = k2.RaggedTensor(abshape2, b.values) assert prod == prod2 diff --git a/k2/python/tests/ragged_tensor_test.py b/k2/python/tests/ragged_tensor_test.py index be71d476a..ec59d4831 100644 --- a/k2/python/tests/ragged_tensor_test.py +++ b/k2/python/tests/ragged_tensor_test.py @@ -65,22 +65,22 @@ def test_create_ragged_tensor_from_string(self): assert b.num_axes == 3 assert b.dim0 == 2 - def test_property_data(self): + def test_property_values(self): a = k2r.RaggedTensor([[1], [2], [], [3, 4]]) - assert torch.all(torch.eq(a.data, torch.tensor([1, 2, 3, 4]))) + assert torch.all(torch.eq(a.values, torch.tensor([1, 2, 3, 4]))) with self.assertRaises(AttributeError): - # the `data` attribute is const. You cannot rebind it - a.data = 10 + # the `values` attribute is const. You cannot rebind it + a.values = 10 - # However, we can change the elements of a.data - a.data[0] = 10 - a.data[-1] *= 2 + # However, we can change the elements of a.values + a.values[0] = 10 + a.values[-1] *= 2 expected = k2r.RaggedTensor([[10], [2], [], [3, 8]]) assert a == expected - a.data[0] = 1 + a.values[0] = 1 assert a != expected def test_clone(self): @@ -88,7 +88,7 @@ def test_clone(self): b = a.clone() assert a == b - a.data[0] = 10 + a.values[0] = 10 assert a != b @@ -181,7 +181,7 @@ def test_getstate_2axes(self): dtype=torch.int32, device=device) b_1 = "row_ids1" - b_2 = a.data + b_2 = a.values assert torch.all(torch.eq(b[0], b_0)) assert b[1] == b_1 @@ -203,7 +203,7 @@ def test_getstate_3axes(self): dtype=torch.int32, device=device) # noqa b_3 = "row_ids2" - b_4 = a.data + b_4 = a.values assert torch.all(torch.eq(b[0], b_0)) assert b[1] == b_1 diff --git a/k2/python/tests/ragged_test.py b/k2/python/tests/ragged_test.py index 704ffb936..2ea2d17da 100644 --- a/k2/python/tests/ragged_test.py +++ b/k2/python/tests/ragged_test.py @@ -48,8 +48,8 @@ def test_ragged_int_from_str(self): ragged_int = k2.RaggedTensor(s).to(device) print(ragged_int) assert torch.all( - torch.eq(ragged_int.data, torch.tensor([1, 2, 3], - device=device))) + torch.eq(ragged_int.values, torch.tensor([1, 2, 3], + device=device))) assert ragged_int.dim0 == 2 assert torch.all( torch.eq(ragged_int.shape.row_splits(1), diff --git a/k2/python/tests/remove_epsilon_test.py b/k2/python/tests/remove_epsilon_test.py index 9b2ca07eb..14d4690d5 100644 --- a/k2/python/tests/remove_epsilon_test.py +++ b/k2/python/tests/remove_epsilon_test.py @@ -352,7 +352,7 @@ def test1(self): self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) print("After removing epsilons: ", dest) - assert torch.where(dest.foo.data == filler)[0].numel() == 0 + assert torch.where(dest.foo.values == filler)[0].numel() == 0 if __name__ == '__main__': From 2c280701b82e96d99ecefa951485e5c00fddf43e Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 14 Sep 2021 16:35:39 +0800 Subject: [PATCH 03/64] Prune with max_arcs in IntersectDense (#820) * Add checking for array constructor * Prune with max arcs * Minor fix * Fix typo * Fix review comments * Fix typo --- k2/csrc/array.h | 4 ++ k2/csrc/fsa_algo.h | 14 ++++- k2/csrc/intersect_dense.cu | 93 +++++++++++++++++++++----------- k2/csrc/intersect_test.cu | 21 ++++---- k2/python/csrc/torch/fsa_algo.cu | 10 ++-- k2/python/k2/autograd.py | 21 ++++++-- 6 files changed, 115 insertions(+), 48 deletions(-) diff --git a/k2/csrc/array.h b/k2/csrc/array.h index dcef42c18..d67f10d5a 100644 --- a/k2/csrc/array.h +++ b/k2/csrc/array.h @@ -139,6 +139,8 @@ class Array1 { Dtype dtype = DtypeOf::dtype) : dim_(dim), dtype_(dtype), byte_offset_(byte_offset), region_(region) { K2_CHECK(K2_TYPE_IS_ANY(T) || dtype == DtypeOf::dtype); + K2_CHECK_GE(dim_, 0) << "Array dim MUST be greater than or equal to 0, " + << "given :" << dim; } Array1(ContextPtr ctx, int32_t size, T elem, @@ -496,6 +498,8 @@ ToType(int64_t, Long) void Init(ContextPtr context, int32_t size, Dtype dtype) { K2_CHECK(K2_TYPE_IS_ANY(T) || dtype == DtypeOf::dtype); + K2_CHECK_GE(size, 0) << "Array size MUST be greater than or equal to 0, " + << "given :" << size; dtype_ = dtype; region_ = NewRegion(context, static_cast(size) * ElementSize()); dim_ = size; diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 974381bc8..d9ebf3e58 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -221,10 +221,20 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, `IsMonotonic(*a_to_b_map)` (this requirement is related to the length-sorting requirement of b_fsas). - @param[in] output_beam Beam with which we prune the output (analogous + @param[in] output_beam Beam with which we prune the output (analogous to lattice-beam in Kaldi), e.g. 8. We discard arcs in the output that are not on a path that's within `output_beam` of the best path of the composed output. + @param[in] max_states The max number of states with which we prune the + output, mainly to avoid out-of-memory and numerical overflow. + If number of states exceeds max_states, we'll decrease + output_beam to prune out more states, util the number of + states is less than max_states. + @param[in] max_arcs The max number of arcs with which we prune the + output, mainly to avoid out-of-memory and numerical overflow. + If number of arcs exceeds max_arcs, we'll decrease + output_beam to prune out more states, util the number of + arcs is less than max_arcs. @param[out] out Output vector of composed, pruned FSAs, with same Dim0() as a_fsas. Elements of it may be empty if the composed results was empty. All states in the output will be @@ -239,7 +249,7 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, */ void IntersectDense(FsaVec &a_fsas, DenseFsaVec &b_fsas, const Array1 *a_to_b_map, - float output_beam, + float output_beam, int32_t max_states, int32_t max_arcs, FsaVec *out, Array1 *arc_map_a, Array1 *arc_map_b); diff --git a/k2/csrc/intersect_dense.cu b/k2/csrc/intersect_dense.cu index 3b9913712..da78dbac6 100644 --- a/k2/csrc/intersect_dense.cu +++ b/k2/csrc/intersect_dense.cu @@ -106,9 +106,11 @@ class MultiGraphDenseIntersect { */ MultiGraphDenseIntersect(FsaVec &a_fsas, DenseFsaVec &b_fsas, const Array1 &a_to_b_map, - float output_beam) + float output_beam, int32_t max_states, + int32_t max_arcs) : a_fsas_(a_fsas), b_fsas_(b_fsas), a_to_b_map_(a_to_b_map), - output_beam_(output_beam) { + output_beam_(output_beam), max_states_(max_states), + max_arcs_(max_arcs) { NVTX_RANGE(K2_FUNC); c_ = GetContext(a_fsas.shape, b_fsas.shape, a_to_b_map); @@ -214,13 +216,10 @@ class MultiGraphDenseIntersect { int32_t product = ((size_t)(T_ + 1) * (size_t)num_states); Renumbering renumber_states; int32_t T = T_; - const int32_t *a_fsas_row_ids1_data = a_fsas_.RowIds(1).Data(); + const int32_t *a_fsas_row_ids1_data = a_fsas_.RowIds(1).Data(), + *a_fsas_row_splits2_data = a_fsas_.RowSplits(2).Data(); FsaInfo *fsa_info_data = fsa_info_.Data(); - // 15 million is max_states... this is to avoid out-of-memory conditions - // Eventually we can make this an option. - int32_t max_states = 15000000; - while (1) { // This code is in a loop is in case we get too many states and have to // retry. The limit `max_states` is to reduce the likelihood of @@ -230,6 +229,8 @@ class MultiGraphDenseIntersect { score_cutoffs = GetScoreCutoffs(); score_cutoffs_data = score_cutoffs.Data(); float **state_scores_data = state_scores_.Data(); + Array1 state_arcs(c_, product); + int64_t *state_arcs_data = state_arcs.Data(); // We'll do exclusive-sum on the following array, after setting its // elements to 1 if the corresponding state was not pruned away. The @@ -255,7 +256,10 @@ class MultiGraphDenseIntersect { int32_t idx_within_fsa = i - (T + 1) * fsa_info.state_offset, t = idx_within_fsa / fsa_info.num_states, - state_idx1 = idx_within_fsa % fsa_info.num_states; + state_idx1 = idx_within_fsa % fsa_info.num_states, + state_idx01 = fsa_info.state_offset + state_idx1, + num_arcs = a_fsas_row_splits2_data[state_idx01 + 1] - + a_fsas_row_splits2_data[state_idx01]; // In the state_scores arrays, there are 2 copies of each FSA's // states, for backward and forward. int32_t backward_state_idx = @@ -273,19 +277,25 @@ class MultiGraphDenseIntersect { if (forward_score + backward_score > cutoff) keep = 1; } keep_state_data[i] = keep; + state_arcs_data[i] = keep * num_arcs; }); int32_t tot_states = renumber_states.New2Old().Dim(); - if (tot_states > max_states) { - float cur_beam = output_beam_, - next_beam = cur_beam * sqrt(max_states * 1.0 / tot_states); - if (next_beam < cur_beam * 0.25) - next_beam = cur_beam * 0.25; - if (next_beam > cur_beam * 0.75) - next_beam = cur_beam * 0.75; + if (tot_states > max_states_) { + float cur_beam = output_beam_; + DecreaseBeam(max_states_, tot_states); K2_LOG(INFO) << "Num-states " << tot_states << " exceeds limit " - << max_states << ", decreasing beam from " << cur_beam - << " to " << next_beam; - output_beam_ = next_beam; + << max_states_ << ", decreasing beam from " << cur_beam + << " to " << output_beam_; + continue; + } + + int64_t tot_arcs = Sum(state_arcs); + if (tot_arcs > max_arcs_) { + float cur_beam = output_beam_; + DecreaseBeam(max_arcs_, tot_arcs); + K2_LOG(INFO) << "Num-arcs " << tot_arcs << " exceeds limit " + << max_arcs_ << ", decreasing beam from " << cur_beam + << " to " << output_beam_; } else { break; } @@ -328,7 +338,6 @@ class MultiGraphDenseIntersect { // the answer. Array1 ans_state_idx01(c_, ans_tot_num_states); int32_t *ans_state_idx01_data = ans_state_idx01.Data(); - const int32_t *a_fsas_row_splits2_data = a_fsas_.RowSplits(2).Data(); // set ans_row_ids2_data, which contains an ans_idx01 that combines // FSA-index and time-index. @@ -562,9 +571,10 @@ class MultiGraphDenseIntersect { // subsample the output shape, removing arcs that weren't kept // TODO: make this more efficient, avoid constructing and_row_ids3. RaggedShape ans_shape = RaggedShape4( - &ans_row_splits1, &ans_row_ids1, -1, - &ans_row_splits2, &ans_row_ids2, -1, - &ans_row_splits3_subsampled, &ans_row_ids3_subsampled, -1); + &ans_row_splits1, &ans_row_ids1, ans_row_ids1.Dim(), + &ans_row_splits2, &ans_row_ids2, ans_row_ids2.Dim(), + &ans_row_splits3_subsampled, &ans_row_ids3_subsampled, + ans_row_ids3_subsampled.Dim()); // .. remove the 't' axis return Ragged(RemoveAxis(ans_shape, 1), arcs); @@ -805,6 +815,21 @@ class MultiGraphDenseIntersect { &step.state_scores); } + /* + Decrease output beam according to num_states or num_arcs, `limit` would be + the max_states or max_arcs (mainly to avoid out-of-memory conditions), + `total` would be current total states or total arcs. + */ + void DecreaseBeam(int64_t limit, int64_t total) { + float cur_beam = output_beam_, + next_beam = cur_beam * sqrt(limit * 1.0 / total); + if (next_beam < cur_beam * 0.25) + next_beam = cur_beam * 0.25; + if (next_beam > cur_beam * 0.75) + next_beam = cur_beam * 0.75; + output_beam_ = next_beam; + } + /* Called after DoStep() is done for all time steps, returns the total scores minus output_beam_. (This is what it does in the absence of roundoff error @@ -825,8 +850,10 @@ class MultiGraphDenseIntersect { float **state_scores_data = state_scores_.Data(); FsaInfo *fsa_info_data = fsa_info_.Data(); - Array1 score_cutoffs(c_, num_fsas_); - float *score_cutoffs_data = score_cutoffs.Data(); + Array1 score_cutoffs(c_, num_fsas_), + score_diff(c_, num_fsas_); + float *score_cutoffs_data = score_cutoffs.Data(), + *score_diff_data = score_diff.Data(); float output_beam = output_beam_; const float minus_inf = -std::numeric_limits::infinity(); K2_EVAL( @@ -855,12 +882,13 @@ class MultiGraphDenseIntersect { tot_score_min = (tot_score_start < tot_score_end ? tot_score_start : tot_score_end); - K2_CHECK(tot_score_end == tot_score_start || - fabs(tot_score_end - tot_score_start) < 1.0) - << tot_score_end << " vs " - << tot_score_start; // TODO: remove this score_cutoffs_data[fsa_idx0] = tot_score_min - output_beam; + score_diff_data[fsa_idx0] = fabs(tot_score_end - tot_score_start); }); + float max_diff = MaxValue(score_diff); + if (max_diff >= 1.0) + K2_LOG(WARNING) << "The difference between forward score and backward" + << " score exceeds 1.0, the value is : " << max_diff; return score_cutoffs; } @@ -978,11 +1006,14 @@ class MultiGraphDenseIntersect { float output_beam_; int32_t T_; // == b_fsas_.MaxSize(1) + + int32_t max_states_; // number of max states to avoid out-of-memory + int32_t max_arcs_; // number of max arcs to avoid out-of-memory }; void IntersectDense(FsaVec &a_fsas, DenseFsaVec &b_fsas, const Array1 *a_to_b_map, - float output_beam, + float output_beam, int32_t max_states, int32_t max_arcs, FsaVec *out, Array1 *arc_map_a, Array1 *arc_map_b) { NVTX_RANGE("IntersectDense"); @@ -1001,7 +1032,9 @@ void IntersectDense(FsaVec &a_fsas, DenseFsaVec &b_fsas, MultiGraphDenseIntersect intersector(a_fsas, b_fsas, *a_to_b_map, - output_beam); + output_beam, + max_states, + max_arcs); intersector.Intersect(); FsaVec ret = intersector.FormatOutput(arc_map_a, arc_map_b); diff --git a/k2/csrc/intersect_test.cu b/k2/csrc/intersect_test.cu index 6236b824d..25ac69a03 100644 --- a/k2/csrc/intersect_test.cu +++ b/k2/csrc/intersect_test.cu @@ -84,12 +84,13 @@ TEST(Intersect, Simple) { fsa = FsaToFsaVec(fsa); float output_beam = 1000; + int32_t max_states = 15000000, + max_arcs = 1 << 30; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; - IntersectDense(fsa, dfsavec, nullptr, - output_beam, &out_fsas, &arc_map_a, - &arc_map_b); + IntersectDense(fsa, dfsavec, nullptr, output_beam, max_states, max_arcs, + &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -232,9 +233,10 @@ TEST(Intersect, RandomSingle) { FsaVec out_fsas; float output_beam = 1000.0; - IntersectDense(fsa, dfsavec, nullptr, - output_beam, &out_fsas, &arc_map_a, - &arc_map_b); + int32_t max_states = 15000000, + max_arcs = 1 << 30; + IntersectDense(fsa, dfsavec, nullptr, output_beam, max_states, max_arcs, + &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b; FsaVec fsas_b = ConvertDenseToFsaVec(dfsavec); @@ -305,9 +307,10 @@ TEST(Intersect, RandomFsaVec) { FsaVec out_fsas; float output_beam = 100000.0; // TODO(Dan) ... - IntersectDense(fsavec, dfsavec, nullptr, - output_beam, &out_fsas, &arc_map_a, - &arc_map_b); + int32_t max_states = 15000000, + max_arcs = 1 << 30; + IntersectDense(fsavec, dfsavec, nullptr, output_beam, max_states, max_arcs, + &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 5907f394c..7e8d55847 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -243,7 +243,8 @@ static void PybindIntersectDense(py::module &m) { m.def( "intersect_dense", [](FsaVec &a_fsas, DenseFsaVec &b_fsas, - torch::optional a_to_b_map, float output_beam) + torch::optional a_to_b_map, float output_beam, + int32_t max_states, int32_t max_arcs) -> std::tuple { DeviceGuard guard(a_fsas.Context()); Array1 arc_map_a; @@ -260,12 +261,13 @@ static void PybindIntersectDense(py::module &m) { } else { a_to_b_map_array = Arange(a_fsa_vec.Context(), 0, a_fsa_vec.Dim0()); } - IntersectDense(a_fsa_vec, b_fsas, &a_to_b_map_array, output_beam, &out, - &arc_map_a, &arc_map_b); + IntersectDense(a_fsa_vec, b_fsas, &a_to_b_map_array, output_beam, + max_states, max_arcs, &out, &arc_map_a, &arc_map_b); return std::make_tuple(out, ToTorch(arc_map_a), ToTorch(arc_map_b)); }, py::arg("a_fsas"), py::arg("b_fsas"), py::arg("a_to_b_map"), - py::arg("output_beam")); + py::arg("output_beam"), py::arg("max_states") = 15000000, + py::arg("max_arcs") = 1073741824 /* 2^30 */); } static void PybindConnect(py::module &m) { diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index 7509d5aa4..96e6a9f72 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -507,6 +507,8 @@ def forward(ctx, b_fsas: DenseFsaVec, out_fsa: List[Fsa], output_beam: float, + max_states: int, + max_arcs: int, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, a_to_b_map: Optional[torch.Tensor] = None, @@ -560,7 +562,9 @@ def forward(ctx, a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, a_to_b_map=a_to_b_map, - output_beam=output_beam) + output_beam=output_beam, + max_states=max_states, + max_arcs=max_arcs) out_fsa[0] = Fsa(ragged_arc) @@ -631,6 +635,8 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \ None, # b_fsas None, # out_fsa None, # output_beam + None, # max_states + None, # max_arcs grad_a, # unused_scores_a grad_b, # unused_scores_b None, # a_to_b_map @@ -766,6 +772,8 @@ def intersect_dense_pruned(a_fsas: Fsa, def intersect_dense(a_fsas: Fsa, b_fsas: DenseFsaVec, output_beam: float, + max_states: int = 15000000, + max_arcs: int = 1073741824, a_to_b_map: Optional[torch.Tensor] = None, seqframe_idx_name: Optional[str] = None, frame_idx_name: Optional[str] = None) -> Fsa: @@ -783,8 +791,14 @@ def intersect_dense(a_fsas: Fsa, b_fsas: Input FSAs that correspond to neural network output. output_beam: - Beam to prune output, similar to lattice-beam in Kaldi. Relative - to best path of output. + Beam to prune output, similar to lattice-beam in Kaldi. Relative + to best path of output. + max_states: + The max number of states to prune the output, mainly to avoid + out-of-memory and numerical overflow, default 15,000,000. + max_arcs: + The max number of arcs to prune the output, mainly to avoid + out-of-memory and numerical overflow, default 1073741824(2^30). a_to_b_map: Maps from FSA-index in a to FSA-index in b to use for it. If None, then we expect the number of FSAs in a_fsas to equal @@ -825,6 +839,7 @@ def intersect_dense(a_fsas: Fsa, # the following return value is discarded since it is already contained # in `out_fsa[0].scores` _IntersectDenseFunction.apply(a_fsas, b_fsas, out_fsa, output_beam, + max_states, max_arcs, a_fsas.scores, b_fsas.scores, a_to_b_map, seqframe_idx_name, frame_idx_name) return out_fsa[0] From 210175c08ba8ca4b0e172a59a4f6fb4c677b176c Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 14 Sep 2021 16:51:29 +0800 Subject: [PATCH 04/64] Release v1.8 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 486f2f26d..75af4772f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,7 +42,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.7") +set(K2_VERSION "1.8") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) From 33a212c55575c553b4e597d7b2db9e51b8cc8086 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 15 Sep 2021 20:59:01 +0800 Subject: [PATCH 05/64] Create a ragged tensor from a regular tensor. (#827) * Create a ragged tensor from a regular tensor. * Add tests for creating ragged tensors from regular tensors. * Add more tests. * Print ragged tensors in a way like what PyTorch is doing. * Fix test cases. --- .github/workflows/build-doc.yml | 1 + k2/csrc/ragged_ops_inl.h | 2 + k2/python/csrc/torch/torch_util.cu | 4 +- k2/python/csrc/torch/torch_util.h | 8 +- k2/python/csrc/torch/v2/any.cu | 92 ++- k2/python/csrc/torch/v2/doc/any.h | 808 +++++++++++++++++++++----- k2/python/csrc/torch/v2/ragged_any.cu | 132 ++++- k2/python/csrc/torch/v2/ragged_any.h | 29 +- k2/python/tests/fsa_test.py | 7 +- k2/python/tests/ragged_ops_test.py | 65 ++- k2/python/tests/ragged_tensor_test.py | 96 ++- k2/python/tests/random_paths_test.py | 33 +- 12 files changed, 1026 insertions(+), 251 deletions(-) diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index 5f61d0683..0a8d65607 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -22,6 +22,7 @@ on: push: branches: - master + - doc env: # debug is faster in terms of compilation time diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index f2c4b7377..4d44761cd 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -412,6 +412,8 @@ std::istream &operator>>(std::istream &is, Ragged &r) { : (row_splits[cur_level + 1].size() - 1)); is.get(); // consume character 'c' if (cur_level == 0) break; + } else if (c == ',') { + is.get(); // consume character 'c' } else { InputFixer t; is >> t; diff --git a/k2/python/csrc/torch/torch_util.cu b/k2/python/csrc/torch/torch_util.cu index afe88cb79..1860180be 100644 --- a/k2/python/csrc/torch/torch_util.cu +++ b/k2/python/csrc/torch/torch_util.cu @@ -105,7 +105,7 @@ torch::Tensor ToTorch(Array1 &array) { } template <> -Array1 FromTorch(torch::Tensor &tensor) { +Array1 FromTorch(torch::Tensor tensor) { K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim(); K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value @@ -124,7 +124,7 @@ Array1 FromTorch(torch::Tensor &tensor) { return ans; } -Tensor FromTorch(torch::Tensor &tensor, TensorTag) { +Tensor FromTorch(torch::Tensor tensor, TensorTag) { Dtype dtype = ScalarTypeToDtype(tensor.scalar_type()); torch::IntArrayRef sizes = tensor.sizes(); torch::IntArrayRef strides = tensor.strides(); diff --git a/k2/python/csrc/torch/torch_util.h b/k2/python/csrc/torch/torch_util.h index 110e46a9a..7808bdc85 100644 --- a/k2/python/csrc/torch/torch_util.h +++ b/k2/python/csrc/torch/torch_util.h @@ -113,7 +113,7 @@ torch::Tensor ToTorch(Array1 &array) { input tensor. */ template -Array1 FromTorch(torch::Tensor &tensor) { +Array1 FromTorch(torch::Tensor tensor) { K2_CHECK_EQ(tensor.dim(), 1) << "Expected dim: 1. Given: " << tensor.dim(); K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value @@ -158,12 +158,12 @@ torch::Tensor ToTorch(Array1 &array); the input tensor. */ template <> -Array1 FromTorch(torch::Tensor &tensor); +Array1 FromTorch(torch::Tensor tensor); struct Array2Tag {}; template -Array2 FromTorch(torch::Tensor &tensor, Array2Tag) { +Array2 FromTorch(torch::Tensor tensor, Array2Tag) { K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim(); K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value @@ -199,7 +199,7 @@ torch::Tensor ToTorch(Array2 &array) { struct TensorTag {}; -Tensor FromTorch(torch::Tensor &tensor, TensorTag); +Tensor FromTorch(torch::Tensor tensor, TensorTag); torch::Tensor ToTorch(Tensor &tensor); /* Transfer an object to a specific device. diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 472ea7cea..8a7a11194 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -40,24 +40,31 @@ void PybindRaggedAny(py::module &m) { // k2.ragged.Tensor methods //-------------------------------------------------- - any.def( - py::init([](py::list data, - py::object dtype = py::none()) -> std::unique_ptr { - return std::make_unique(data, dtype); - }), - py::arg("data"), py::arg("dtype") = py::none(), kRaggedAnyInitDataDoc); + any.def(py::init(), py::arg("data"), + py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kRaggedAnyInitDataDeviceDoc); - any.def( - py::init([](const std::string &s, - py::object dtype = py::none()) -> std::unique_ptr { - return std::make_unique(s, dtype); - }), - py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc); + any.def(py::init(), + py::arg("data"), py::arg("dtype") = py::none(), + py::arg("device") = "cpu", kRaggedAnyInitDataDeviceDoc); + + any.def(py::init(), + py::arg("s"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kRaggedAnyInitStrDeviceDoc); + + any.def(py::init(), + py::arg("s"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kRaggedAnyInitStrDeviceDoc); - // TODO(fangjun): add documentation for it any.def(py::init(), py::arg("shape"), py::arg("value"), kRaggedInitFromShapeAndTensorDoc); + any.def(py::init(), py::arg("tensor"), + kRaggedAnyInitTensorDoc); + any.def( "__str__", [](const RaggedAny &self) -> std::string { return self.ToString(); }, @@ -78,8 +85,7 @@ void PybindRaggedAny(py::module &m) { K2_CHECK_EQ(self.any.NumAxes(), 2); Array1 row_split = self.any.RowSplits(1).To(GetCpuContext()); const int32_t *row_split_data = row_split.Data(); - int32_t begin = row_split_data[i], - end = row_split_data[i + 1]; + int32_t begin = row_split_data[i], end = row_split_data[i + 1]; Dtype t = self.any.GetDtype(); FOR_REAL_AND_INT32_TYPES(t, T, { Array1 array = @@ -100,19 +106,18 @@ void PybindRaggedAny(py::module &m) { if (!slice.compute(self.any.Dim0(), &start, &stop, &step, &slicelength)) throw py::error_already_set(); int32_t istart = static_cast(start); - int32_t istop = static_cast(stop); - int32_t istep = static_cast(step); - K2_CHECK_EQ(istep, 1) << "Only support slicing with step 1, given : " - << istep; + int32_t istop = static_cast(stop); + int32_t istep = static_cast(step); + K2_CHECK_EQ(istep, 1) + << "Only support slicing with step 1, given : " << istep; return self.Arange(/*axis*/ 0, istart, istop); - }, py::arg("key"), kRaggedAnyGetItemSliceDoc); + }, + py::arg("key"), kRaggedAnyGetItemSliceDoc); any.def("index", - static_cast( - &RaggedAny::Index), - py::arg("indexes"), - kRaggedAnyRaggedIndexDoc); + static_cast(&RaggedAny::Index), + py::arg("indexes"), kRaggedAnyRaggedIndexDoc); any.def("index", static_cast> ( @@ -408,21 +413,48 @@ void PybindRaggedAny(py::module &m) { // _k2.ragged.functions //-------------------------------------------------- - // TODO: change the function name from "create_tensor" to "tensor" m.def( "create_ragged_tensor", - [](py::list data, py::object dtype = py::none()) -> RaggedAny { - return RaggedAny(data, dtype); + [](py::list data, py::object dtype = py::none(), + torch::Device device = torch::kCPU) -> RaggedAny { + return RaggedAny(data, dtype, device); }, py::arg("data"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kCreateRaggedTensorDataDoc); + + m.def( + "create_ragged_tensor", + [](py::list data, py::object dtype = py::none(), + const std::string &device = "cpu") -> RaggedAny { + return RaggedAny(data, dtype, device); + }, + py::arg("data"), py::arg("dtype") = py::none(), py::arg("device") = "cpu", kCreateRaggedTensorDataDoc); m.def( "create_ragged_tensor", - [](const std::string &s, py::object dtype = py::none()) -> RaggedAny { - return RaggedAny(s, dtype); + [](const std::string &s, py::object dtype = py::none(), + torch::Device device = torch::kCPU) -> RaggedAny { + return RaggedAny(s, dtype, device); }, - py::arg("s"), py::arg("dtype") = py::none(), kCreateRaggedTensorStrDoc); + py::arg("s"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kCreateRaggedTensorStrDoc); + + m.def( + "create_ragged_tensor", + [](const std::string &s, py::object dtype = py::none(), + const std::string &device = "cpu") -> RaggedAny { + return RaggedAny(s, dtype, device); + }, + py::arg("s"), py::arg("dtype") = py::none(), py::arg("device") = "cpu", + kCreateRaggedTensorStrDoc); + + m.def( + "create_ragged_tensor", + [](torch::Tensor tensor) -> RaggedAny { return RaggedAny(tensor); }, + py::arg("tensor"), kCreateRaggedTensorTensorDoc); } } // namespace k2 diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index b69ca3093..5c995fc9d 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -32,15 +32,20 @@ Create a ragged tensor with arbitrary number of axes. Hint: The returned tensor is on CPU. +>>> import torch >>> import k2.ragged as k2r >>> a = k2r.create_ragged_tensor([ [1, 2], [5], [], [9] ]) >>> a -[ [ 1 2 ] [ 5 ] [ ] [ 9 ] ] +RaggedTensor([[1, 2], + [5], + [], + [9]], dtype=torch.int32) >>> a.dtype torch.int32 >>> b = k2r.create_ragged_tensor([ [1, 3.0], [] ]) >>> b -[ [ 1 3 ] [ ] ] +RaggedTensor([[1, 3], + []], dtype=torch.float32) >>> b.dtype torch.float32 >>> c = k2r.create_ragged_tensor([ [1] ], dtype=torch.float64) @@ -48,18 +53,30 @@ torch.float32 torch.float64 >>> d = k2r.create_ragged_tensor([ [[1], [2, 3]], [[4], []] ]) >>> d -[ [ [ 1 ] [ 2 3 ] ] [ [ 4 ] [ ] ] ] +RaggedTensor([[[1], + [2, 3]], + [[4], + []]], dtype=torch.int32) >>> d.num_axes 3 >>> e = k2r.create_ragged_tensor([]) >>> e -[ ] +RaggedTensor([], dtype=torch.int32) >>> e.num_axes 2 >>> e.shape.row_splits(1) tensor([0], dtype=torch.int32) >>> e.shape.row_ids(1) tensor([], dtype=torch.int32) +>>> f = k2r.create_ragged_tensor([ [1, 2], [], [3] ], device=torch.device('cuda', 0)) +>>> f +RaggedTensor([[1, 2], + [], + [3]], device='cuda:0', dtype=torch.int32) +>>> e = k2r.create_ragged_tensor([[1], []], device='cuda:1') +>>> e +RaggedTensor([[1], + []], device='cuda:1', dtype=torch.int32) Args: data: @@ -70,6 +87,12 @@ tensor([], dtype=torch.int32) automatically, which is either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. + Returns: Return a ragged tensor. )doc"; @@ -77,28 +100,28 @@ tensor([], dtype=torch.int32) static constexpr const char *kCreateRaggedTensorStrDoc = R"doc( Create a ragged tensor from its string representation. +Fields are separated by space(s) **or** comma(s). + An example string for a 2-axis ragged tensor is given below:: - [ [1] [2] ] + [ [1] [2] [3, 4], [5 6 7, 8] ] An example string for a 3-axis ragged tensor is given below:: [ [[1]] [[]] ] -Hint: - The returned tensor is on CPU. - >>> import torch >>> import k2.ragged as k2r >>> a = k2r.create_ragged_tensor('[ [1] [] [3 4] ]') >>> a -[ [ 1 ] [ ] [ 3 4 ] ] +RaggedTensor([[1], + [], + [3, 4]], dtype=torch.int32) >>> a.num_axes 2 >>> a.dtype torch.int32 >>> b = k2r.create_ragged_tensor('[ [[] [3]] [[10]] ]', dtype=torch.float32) ->>> b = k2r.create_ragged_tensor('[ [[] [3]] [[10]] ]', dtype=torch.float32) >>> b [ [ [ ] [ 3 ] ] [ [ 10 ] ] ] >>> b.dtype @@ -109,6 +132,10 @@ torch.float32 >>> c.dtype torch.float32 +Note: + Number of spaces or commas in ``s`` does not affect the result. + Of course, numbers have to be separated by at least one space or comma. + Args: s: A string representation of a ragged tensor. @@ -117,9 +144,108 @@ torch.float32 to infer the correct dtype from ``s``, which is assumed to be either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. +Returns: + Return a ragged tensor. +)doc"; + +static constexpr const char *kCreateRaggedTensorTensorDoc = R"doc( +Create a ragged tensor from a torch tensor. + +Note: + It turns a regular tensor into a ragged tensor. + +Caution: + The input tensor has to have more than 1 dimension. + That is ``tensor.ndim > 1``. + + Also, if the input tensor is contiguous, ``self`` + will share the underlying memory with it. Otherwise, + memory of the input tensor is copied to create ``self``. + + Supported dtypes of the input tensor are: ``torch.int32``, + ``torch.float32``, and ``torch.float64``. + +**Example 1**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(6, dtype=torch.int32).reshape(2, 3) + >>> b = k2r.create_ragged_tensor(a) + >>> a + tensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b + RaggedTensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b.dtype + torch.int32 + >>> a.is_contiguous() + True + >>> a[0, 0] = 10 + >>> b + RaggedTensor([[10, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b.values[1] = -2 + >>> a + tensor([[10, -2, 2], + [ 3, 4, 5]], dtype=torch.int32) + +**Example 2**: + + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.int32).reshape(2, 12)[:, ::4] + >>> a + tensor([[ 0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a.is_contiguous() + False + >>> b = k2r.create_ragged_tensor(a) + >>> b + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> b.dtype + torch.int32 + >>> a[0, 0] = 10 + >>> b + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a + tensor([[10, 4, 8], + [12, 16, 20]], dtype=torch.int32) + +**Example 3**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + >>> a + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]) + >>> b = k2r.create_ragged_tensor(a) + >>> b + RaggedTensor([[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]], dtype=torch.float32) + +Args: + tensor: + An N-D (N > 1) tensor. Returns: Return a ragged tensor. )doc"; + static constexpr const char *kRaggedInitFromShapeAndTensorDoc = R"doc( Create a ragged tensor from a shape and a value. @@ -129,7 +255,9 @@ Create a ragged tensor from a shape and a value. >>> value = torch.tensor([10, 0, 20, 30, 40], dtype=torch.float32) >>> ragged = k2r.RaggedTensor(shape, value) >>> ragged -[ [ 10 0 ] [ ] [ 20 30 40 ] ] +RaggedTensor([[10, 0], + [], + [20, 30, 40]], dtype=torch.float32) Args: shape: @@ -138,42 +266,46 @@ Create a ragged tensor from a shape and a value. The value of the tensor. )doc"; -static constexpr const char *kRaggedAnyInitDataDoc = R"doc( +static constexpr const char *kRaggedAnyInitDataDeviceDoc = R"doc( Create a ragged tensor with arbitrary number of axes. Note: A ragged tensor has at least two axes. -Hint: - The returned tensor is on CPU. - **Example 1**: >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([ [1, 2], [5], [], [9] ]) >>> a - [ [ 1 2 ] [ 5 ] [ ] [ 9 ] ] + RaggedTensor([[1, 2], + [5], + [], + [9]], dtype=torch.int32) >>> a.dtype torch.int32 >>> b = k2r.RaggedTensor([ [1, 3.0], [] ]) >>> b - [ [ 1 3 ] [ ] ] + RaggedTensor([[1, 3], + []], dtype=torch.float32) >>> b.dtype torch.float32 >>> c = k2r.RaggedTensor([ [1] ], dtype=torch.float64) >>> c - [ [ 1 ] ] + RaggedTensor([[1]], dtype=torch.float64) >>> c.dtype torch.float64 >>> d = k2r.RaggedTensor([ [[1], [2, 3]], [[4], []] ]) >>> d - [ [ [ 1 ] [ 2 3 ] ] [ [ 4 ] [ ] ] ] + RaggedTensor([[[1], + [2, 3]], + [[4], + []]], dtype=torch.int32) >>> d.num_axes 3 >>> e = k2r.RaggedTensor([]) >>> e - [ ] + RaggedTensor([], dtype=torch.int32) >>> e.num_axes 2 >>> e.shape.row_splits(1) @@ -184,7 +316,13 @@ Create a ragged tensor with arbitrary number of axes. **Example 2**: >>> k2r.RaggedTensor([ [[1, 2]], [], [[]] ]) - [ [ [ 1 2 ] ] [ ] [ [ ] ] ] + RaggedTensor([[[1, 2]], + [], + [[]]], dtype=torch.int32) + >>> k2r.RaggedTensor([ [[1, 2]], [], [[]] ], device='cuda:0') + RaggedTensor([[[1, 2]], + [], + [[]]], device='cuda:0', dtype=torch.int32) Args: data: @@ -195,34 +333,42 @@ Create a ragged tensor with arbitrary number of axes. automatically, which is either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. )doc"; -static constexpr const char *kRaggedAnyInitStrDoc = R"doc( +static constexpr const char *kRaggedAnyInitStrDeviceDoc = R"doc( Create a ragged tensor from its string representation. +Fields are separated by space(s) **or** comma(s). + An example string for a 2-axis ragged tensor is given below:: - [ [1] [2] ] + [ [1] [2] [3, 4], [5 6 7, 8] ] An example string for a 3-axis ragged tensor is given below:: [ [[1]] [[]] ] -Hint: - The returned tensor is on CPU. - >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [1] [] [3 4] ]') >>> a -[ [ 1 ] [ ] [ 3 4 ] ] +RaggedTensor([[1], + [], + [3, 4]], dtype=torch.int32) >>> a.num_axes 2 >>> a.dtype torch.int32 >>> b = k2r.RaggedTensor('[ [[] [3]] [[10]] ]', dtype=torch.float32) >>> b -[ [ [ ] [ 3 ] ] [ [ 10 ] ] ] +RaggedTensor([[[], + [3]], + [[10]]], dtype=torch.float32) >>> b.dtype torch.float32 >>> b.num_axes @@ -230,10 +376,13 @@ torch.float32 >>> c = k2r.RaggedTensor('[[1.]]') >>> c.dtype torch.float32 +>>> d = k2r.RaggedTensor('[[1.]]', device='cuda:0') +>>> d +RaggedTensor([[1]], device='cuda:0', dtype=torch.float32) Note: - Number of spaces in ``s`` does not affect the result. - Of course, numbers have to be separated by at least one space. + Number of spaces or commas in ``s`` does not affect the result. + Of course, numbers have to be separated by at least one space or comma. Args: s: @@ -243,6 +392,103 @@ torch.float32 to infer the correct dtype from ``s``, which is assumed to be either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. +)doc"; + +static constexpr const char *kRaggedAnyInitTensorDoc = R"doc( +Create a ragged tensor from a torch tensor. + +Note: + It turns a regular tensor into a ragged tensor. + +Caution: + The input tensor has to have more than 1 dimension. + That is ``tensor.ndim > 1``. + + Also, if the input tensor is contiguous, ``self`` + will share the underlying memory with it. Otherwise, + memory of the input tensor is copied to create ``self``. + + Supported dtypes of the input tensor are: ``torch.int32``, + ``torch.float32``, and ``torch.float64``. + +**Example 1**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(6, dtype=torch.int32).reshape(2, 3) + >>> b = k2r.RaggedTensor(a) + >>> a + tensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b + RaggedTensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> a.is_contiguous() + True + >>> a[0, 0] = 10 + >>> b + RaggedTensor([[10, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b.values[1] = -2 + >>> a + tensor([[10, -2, 2], + [ 3, 4, 5]], dtype=torch.int32) + +**Example 2**: + + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.int32).reshape(2, 12)[:, ::4] + >>> a + tensor([[ 0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a.is_contiguous() + False + >>> b = k2r.RaggedTensor(a) + >>> b + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a[0, 0] = 10 + >>> b + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a + tensor([[10, 4, 8], + [12, 16, 20]], dtype=torch.int32) + +**Example 3**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + >>> a + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]) + >>> b = k2r.RaggedTensor(a) + >>> b + RaggedTensor([[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]], dtype=torch.float32) + >>> b.dtype + torch.float32 + >>> c = torch.tensor([[1, 2]], device='cuda:0', dtype=torch.float32) + >>> k2r.RaggedTensor(c) + RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.float32) + +Args: + tensor: + An N-D (N > 1) tensor. )doc"; static constexpr const char *kRaggedAnyToDeviceDoc = R"doc( @@ -339,9 +585,14 @@ Return a string representation of this tensor. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [2, 3], []]) >>> a -[ [ 1 ] [ 2 3 ] [ ] ] +RaggedTensor([[1], + [2, 3], + []], dtype=torch.int32) >>> str(a) -'[ [ 1 ] [ 2 3 ] [ ] ]' +'RaggedTensor([[1],\n [2, 3],\n []], dtype=torch.int32)' +>>> b = k2r.RaggedTensor([[1, 2]], device='cuda:0') +>>> b +RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.int32) )doc"; static constexpr const char *kRaggedAnyGetItemDoc = R"doc( @@ -350,23 +601,34 @@ Select the i-th sublist along axis 0. Caution: Support for autograd is to be implemented. ->>> import torch ->>> import k2.ragged as k2r ->>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] ]') ->>> a -[ [ [ 1 3 ] [ ] [ 9 ] ] [ [ 8 ] ] ] ->>> a[0] -[ [ 1 3 ] [ ] [ 9 ] ] ->>> a[1] -[ [ 8 ] ] +**Example 1**: ->>> a = k2r.RaggedTensor('[ [1 3] [9] [8] ]') ->>> a -[ [ 1 3 ] [ 9 ] [ 8 ] ] ->>> a[0] -tensor([1, 3], dtype=torch.int32) ->>> a[1] -tensor([9], dtype=torch.int32) + >>> import torch + >>> import k2.ragged as k2r + >>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] ]') + >>> a + RaggedTensor([[[1, 3], + [], + [9]], + [[8]]], dtype=torch.int32) + >>> a[0] + RaggedTensor([[1, 3], + [], + [9]], dtype=torch.int32) + >>> a[1] + RaggedTensor([[8]], dtype=torch.int32) + +**Example 2**: + + >>> a = k2r.RaggedTensor('[ [1 3] [9] [8] ]') + >>> a + RaggedTensor([[1, 3], + [9], + [8]], dtype=torch.int32) + >>> a[0] + tensor([1, 3], dtype=torch.int32) + >>> a[1] + tensor([9], dtype=torch.int32) Args: i: @@ -387,11 +649,18 @@ equals to 1. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] [[10 11]] ]') >>> a -[ [ [ 1 3 ] [ ] [ 9 ] ] [ [ 8 ] ] [ [ 10 11 ] ] ] +RaggedTensor([[[1, 3], + [], + [9]], + [[8]], + [[10, 11]]], dtype=torch.int32) >>> a[0:2] -[ [ [ 1 3 ] [ ] [ 9 ] [ [ 8 ] ] ] ] +RaggedTensor([[[1, 3], + [], + [9]], + [[8]]], dtype=torch.int32) >>> a[1:2] -[ [ [ 8 ] ] [ [ 10 11 ] ] ] +RaggedTensor([[[8]]], dtype=torch.int32) Args: key: @@ -410,19 +679,25 @@ Return a copy of this tensor. >>> b = a >>> c = a.clone() >>> a -[ [ 1 2 ] [ 3 ] ] ->>> b.data[0] = 10 +RaggedTensor([[1, 2], + [3]], dtype=torch.int32) +>>> b.values[0] = 10 >>> a -[ [ 10 2 ] [ 3 ] ] +RaggedTensor([[10, 2], + [3]], dtype=torch.int32) >>> c -[ [ 1 2 ] [ 3 ] ] ->>> c.data[0] = -1 +RaggedTensor([[1, 2], + [3]], dtype=torch.int32) +>>> c.values[0] = -1 >>> c -[ [ -1 2 ] [ 3 ] ] +RaggedTensor([[-1, 2], + [3]], dtype=torch.int32) >>> a -[ [ 10 2 ] [ 3 ] ] +RaggedTensor([[10, 2], + [3]], dtype=torch.int32) >>> b -[ [ 10 2 ] [ 3 ] ] +RaggedTensor([[10, 2], + [3]], dtype=torch.int32) )doc"; static constexpr const char *kRaggedAnyEqDoc = R"doc( @@ -501,7 +776,10 @@ calls to ``backward()`` will accumulate (add) gradients into it. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2], [3], [5, 6], []], dtype=torch.float32) >>> a.requires_grad_(True) -[ [ 1 2 ] [ 3 ] [ 5 6 ] [ ] ] +RaggedTensor([[1, 2], + [3], + [5, 6], + []], dtype=torch.float32) >>> b = a.sum() >>> b tensor([ 3., 3., 11., 0.], grad_fn=>) @@ -529,7 +807,7 @@ this tensor's :attr:`requires_grad` attribute **in-place**. >>> a.requires_grad False >>> a.requires_grad_(True) -[ [ 1 ] ] +RaggedTensor([[1]], dtype=torch.float64) >>> a.requires_grad True @@ -554,7 +832,10 @@ Compute the sum of sublists over the last axis of this tensor. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [[1 2] [] [5]] [[10]] ]', dtype=torch.float32) >>> a.requires_grad_(True) -[ [ [ 1 2 ] [ ] [ 5 ] ] [ [ 10 ] ] ] +RaggedTensor([[[1, 2], + [], + [5]], + [[10]]], dtype=torch.float32) >>> b = a.sum() >>> c = (b * torch.arange(4)).sum() >>> c.backward() @@ -577,7 +858,7 @@ tensor(40., grad_fn=) static constexpr const char *kRaggedAnyNumelDoc = R"doc( Returns: Return number of elements in this tensor. It equals to - ``self.data.numel()``. + ``self.values.numel()``. >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [], [3, 4, 5, 6]]) @@ -590,6 +871,7 @@ static constexpr const char *kRaggedAnyNumelDoc = R"doc( >>> c.numel() 5 )doc"; + static constexpr const char *kRaggedAnyTotSizeDoc = R"doc( Return the number of elements of an given axis. If axis is 0, it's equivalent to the property ``dim0``. @@ -622,10 +904,10 @@ You are not expected to call it by yourself. Returns: If this tensor has 2 axes, return a tuple containing - (self.row_splits(1), "row_ids1", self.data). + (self.row_splits(1), "row_ids1", self.values). If this tensor has 3 axes, return a tuple containing (self.row_splits(1), "row_ids1", self.row_splits(1), - "row_ids2", self.data) + "row_ids2", self.values) Note: "row_ids1" and "row_ids2" in the returned value is for @@ -687,13 +969,25 @@ tensor([ 1, 2, 5, 8, 9, 10], dtype=torch.int32) True >>> a.values[-2] = -1 >>> a -[ [ -1 2 ] [ ] [ 5 ] [ ] [ 8 9 10 ] ] +RaggedTensor([[1, 2], + [], + [5], + [], + [8, -1, 10]], dtype=torch.int32) >>> a.values[3] = -3 >>> a -[ [ -1 2 ] [ ] [ 5 ] [ ] [ -3 9 10 ] ] +RaggedTensor([[1, 2], + [], + [5], + [], + [-3, -1, 10]], dtype=torch.int32) >>> a.values[2] = -2 >>> a -[ [ -1 2 ] [ ] [ -2 ] [ ] [ -3 9 10 ] ] +RaggedTensor([[1, 2], + [], + [-2], + [], + [-3, -1, 10]], dtype=torch.int32) )doc"; static constexpr const char *kRaggedAnyShapeDoc = R"doc( @@ -766,15 +1060,32 @@ last axis it is just removed and the number of elements may be changed. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([ [[1], [], [0, -1]], [[], [2, 3], []], [[0]], [[]] ]) >>> a - [ [ [ 1 ] [ ] [ 0 -1 ] ] [ [ ] [ 2 3 ] [ ] ] [ [ 0 ] ] [ [ ] ] ] + RaggedTensor([[[1], + [], + [0, -1]], + [[], + [2, 3], + []], + [[0]], + [[]]], dtype=torch.int32) >>> a.num_axes 3 >>> b = a.remove_axis(0) >>> b - [ [ 1 ] [ ] [ 0 -1 ] [ ] [ 2 3 ] [ ] [ 0 ] [ ] ] + RaggedTensor([[1], + [], + [0, -1], + [], + [2, 3], + [], + [0], + []], dtype=torch.int32) >>> c = a.remove_axis(1) >>> c - [ [ 1 0 -1 ] [ 2 3 ] [ 0 ] [ ] ] + RaggedTensor([[1, 0, -1], + [2, 3], + [0], + []], dtype=torch.int32) **Example 2**: @@ -782,16 +1093,42 @@ last axis it is just removed and the number of elements may be changed. >>> a.num_axes 4 >>> a - [ [ [ [ 1 ] [ ] [ 2 ] ] ] [ [ [ 3 4 ] [ ] [ 5 6 ] [ ] ] ] [ [ [ ] [ 0 ] ] ] ] + RaggedTensor([[[[1], + [], + [2]]], + [[[3, 4], + [], + [5, 6], + []]], + [[[], + [0]]]], dtype=torch.int32) >>> b = a.remove_axis(0) >>> b - [ [ [ 1 ] [ ] [ 2 ] ] [ [ 3 4 ] [ ] [ 5 6 ] [ ] ] [ [ ] [ 0 ] ] ] + RaggedTensor([[[1], + [], + [2]], + [[3, 4], + [], + [5, 6], + []], + [[], + [0]]], dtype=torch.int32) >>> c = a.remove_axis(1) >>> c - [ [ [ 1 ] [ ] [ 2 ] ] [ [ 3 4 ] [ ] [ 5 6 ] [ ] ] [ [ ] [ 0 ] ] ] + RaggedTensor([[[1], + [], + [2]], + [[3, 4], + [], + [5, 6], + []], + [[], + [0]]], dtype=torch.int32) >>> d = a.remove_axis(2) >>> d - [ [ [ 1 2 ] ] [ [ 3 4 5 6 ] ] [ [ 0 ] ] ] + RaggedTensor([[[1, 2]], + [[3, 4, 5, 6]], + [[0]]], dtype=torch.int32) Args: axis: @@ -821,27 +1158,46 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([ [[1], [], [2]], [[], [4, 5], []], [[], [1]], [[]] ]) >>> a - [ [ [ 1 ] [ ] [ 2 ] ] [ [ ] [ 4 5 ] [ ] ] [ [ ] [ 1 ] ] [ [ ] ] ] + RaggedTensor([[[1], + [], + [2]], + [[], + [4, 5], + []], + [[], + [1]], + [[]]], dtype=torch.int32) >>> a.num_axes 3 >>> b = a.arange(axis=0, begin=1, end=3) >>> b - [ [ [ ] [ 4 5 ] [ ] ] [ [ ] [ 1 ] ] ] + RaggedTensor([[[], + [4, 5], + []], + [[], + [1]]], dtype=torch.int32) >>> b.num_axes 3 >>> c = a.arange(axis=0, begin=1, end=2) >>> c - [ [ [ ] [ 4 5 ] [ ] ] ] + RaggedTensor([[[], + [4, 5], + []]], dtype=torch.int32) >>> c.num_axes 3 >>> d = a.arange(axis=1, begin=0, end=4) >>> d - [ [ 1 ] [ ] [ 2 ] [ ] ] + RaggedTensor([[1], + [], + [2], + []], dtype=torch.int32) >>> d.num_axes 2 >>> e = a.arange(axis=1, begin=2, end=5) >>> e - [ [ 2 ] [ ] [ 4 5 ] ] + RaggedTensor([[2], + [], + [4, 5]], dtype=torch.int32) >>> e.num_axes 2 @@ -852,17 +1208,34 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: 4 >>> b = a.arange(axis=0, begin=0, end=2) >>> b - [ [ [ [ ] [ 1 ] [ 2 3 ] ] [ [ 5 8 ] [ ] [ 9 ] ] ] [ [ [ 10 ] [ 0 ] [ ] ] ] ] + RaggedTensor([[[[], + [1], + [2, 3]], + [[5, 8], + [], + [9]]], + [[[10], + [0], + []]]], dtype=torch.int32) >>> b.num_axes 4 >>> c = a.arange(axis=1, begin=1, end=3) >>> c - [ [ [ 5 8 ] [ ] [ 9 ] ] [ [ 10 ] [ 0 ] [ ] ] ] + RaggedTensor([[[5, 8], + [], + [9]], + [[10], + [0], + []]], dtype=torch.int32) >>> c.num_axes 3 >>> d = a.arange(axis=2, begin=0, end=5) >>> d - [ [ ] [ 1 ] [ 2 3 ] [ 5 8 ] [ ] ] + RaggedTensor([[], + [1], + [2, 3], + [5, 8], + []], dtype=torch.int32) >>> d.num_axes 2 @@ -870,15 +1243,25 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: >>> a = k2r.RaggedTensor([[0], [1], [2], [], [3]]) >>> a - [ [ 0 ] [ 1 ] [ 2 ] [ ] [ 3 ] ] + RaggedTensor([[0], + [1], + [2], + [], + [3]], dtype=torch.int32) >>> a.num_axes 2 >>> b = a.arange(axis=0, begin=1, end=4) >>> b - [ [ 1 ] [ 2 ] [ ] ] - >>> b.data[0] = -1 + RaggedTensor([[1], + [2], + []], dtype=torch.int32) + >>> b.values[0] = -1 >>> a - [ [ 0 ] [ -1 ] [ 2 ] [ ] [ 3 ] ] + RaggedTensor([[0], + [-1], + [2], + [], + [3]], dtype=torch.int32) Args: axis: @@ -896,12 +1279,23 @@ target. Leaves all layers of the shape except for the last one unaffected. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2, 3, 0, 3, 2], [], [3, 2, 3], [3]]) +>>> a +RaggedTensor([[1, 2, 3, 0, 3, 2], + [], + [3, 2, 3], + [3]], dtype=torch.int32) >>> b = a.remove_values_eq(3) >>> b -[ [ 1 2 0 2 ] [ ] [ 2 ] [ ] ] +RaggedTensor([[1, 2, 0, 2], + [], + [2], + []], dtype=torch.int32) >>> c = a.remove_values_eq(2) >>> c -[ [ 1 3 0 3 ] [ ] [ 3 3 ] [ 3 ] ] +RaggedTensor([[1, 3, 0, 3], + [], + [3, 3], + [3]], dtype=torch.int32) Args: target: @@ -917,15 +1311,29 @@ Leaves all layers of the shape except for the last one unaffected. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2, 3, 0, 3, 2], [], [3, 2, 3], [3]]) +>>> a +RaggedTensor([[1, 2, 3, 0, 3, 2], + [], + [3, 2, 3], + [3]], dtype=torch.int32) >>> b = a.remove_values_leq(3) >>> b -[ [ ] [ ] [ ] [ ] ] +RaggedTensor([[], + [], + [], + []], dtype=torch.int32) >>> c = a.remove_values_leq(2) >>> c -[ [ 3 3 ] [ ] [ 3 3 ] [ 3 ] ] +RaggedTensor([[3, 3], + [], + [3, 3], + [3]], dtype=torch.int32) >>> d = a.remove_values_leq(1) >>> d -[ [ 2 3 3 2 ] [ ] [ 3 2 3 ] [ 3 ] ] +RaggedTensor([[2, 3, 3, 2], + [], + [3, 2, 3], + [3]], dtype=torch.int32) Args: cutoff: @@ -953,7 +1361,7 @@ tensor([ 3, -1, 7], dtype=torch.int32) >>> d = c.argmax(initial_value=0) >>> d tensor([ 3, -1, 7], dtype=torch.int32) ->>> c.data[3], c.data[7] +>>> c.values[3], c.values[7] (tensor(5, dtype=torch.int32), tensor(8, dtype=torch.int32)) >>> c.argmax(initial_value=6) tensor([-1, -1, 7], dtype=torch.int32) @@ -1045,9 +1453,16 @@ Concatenate a list of ragged tensor over a specified axis. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [], [2, 3]]) >>> k2r.cat([a, a], axis=0) - [ [ 1 ] [ ] [ 2 3 ] [ 1 ] [ ] [ 2 3 ] ] + RaggedTensor([[1], + [], + [2, 3], + [1], + [], + [2, 3]], dtype=torch.int32) >>> k2r.cat((a, a), axis=1) - [ [ 1 1 ] [ ] [ 2 3 2 3 ] ] + RaggedTensor([[1, 1], + [], + [2, 3, 2, 3]], dtype=torch.int32) **Example 2** @@ -1056,18 +1471,44 @@ Concatenate a list of ragged tensor over a specified axis. >>> b = k2r.RaggedTensor([[0], [1, 8], [], [-1], [10]]) >>> c = k2r.cat([a, b], axis=0) >>> c - [ [ 1 3 ] [ ] [ 5 8 ] [ ] [ 9 ] [ 0 ] [ 1 8 ] [ ] [ -1 ] [ 10 ] ] + RaggedTensor([[1, 3], + [], + [5, 8], + [], + [9], + [0], + [1, 8], + [], + [-1], + [10]], dtype=torch.int32) >>> c.num_axes 2 >>> d = k2r.cat([a, b], axis=1) >>> d - [ [ 1 3 0 ] [ 1 8 ] [ 5 8 ] [ -1 ] [ 9 10 ] ] + RaggedTensor([[1, 3, 0], + [1, 8], + [5, 8], + [-1], + [9, 10]], dtype=torch.int32) >>> d.num_axes 2 >>> k2r.RaggedTensor.cat([a, b], axis=1) - [ [ 1 3 0 ] [ 1 8 ] [ 5 8 ] [ -1 ] [ 9 10 ] ] + RaggedTensor([[1, 3, 0], + [1, 8], + [5, 8], + [-1], + [9, 10]], dtype=torch.int32) >>> k2r.cat((b, a), axis=0) - [ [ 0 ] [ 1 8 ] [ ] [ -1 ] [ 10 ] [ 1 3 ] [ ] [ 5 8 ] [ ] [ 9 ] ] + RaggedTensor([[0], + [1, 8], + [], + [-1], + [10], + [1, 3], + [], + [5, 8], + [], + [9]], dtype=torch.int32) Args: srcs: @@ -1103,34 +1544,79 @@ index on axis 0; if more than 3 axes, the earliest axes will be ignored. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[3, 1], [3], [1], [1], [3, 1], [2]]) >>> a.unique() - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], None, None) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), None, None) >>> a.unique(need_num_repeats=True, need_new2old_indexes=True) - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], [ [ 2 1 1 2 ] ], tensor([2, 5, 1, 0], dtype=torch.int32)) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), RaggedTensor([[2, 1, 1, 2]], dtype=torch.int32), tensor([2, 5, 1, 0], dtype=torch.int32)) >>> a.unique(need_num_repeats=True) - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], [ [ 2 1 1 2 ] ], None) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), RaggedTensor([[2, 1, 1, 2]], dtype=torch.int32), None) >>> a.unique(need_new2old_indexes=True) - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], None, tensor([2, 5, 1, 0], dtype=torch.int32)) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), None, tensor([2, 5, 1, 0], dtype=torch.int32)) **Example 2** >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[[1, 2], [2, 1], [1, 2], [1, 2]], [[3], [2], [0, 1], [2]], [[], [2, 3], [], [3]] ]) >>> a.unique() - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], None, None) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), None, None) >>> a.unique(need_num_repeats=True, need_new2old_indexes=True) - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], [ [ 3 1 ] [ 2 1 1 ] [ 2 1 1 ] ], tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), RaggedTensor([[3, 1], + [2, 1, 1], + [2, 1, 1]], dtype=torch.int32), tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) >>> a.unique(need_num_repeats=True) - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], [ [ 3 1 ] [ 2 1 1 ] [ 2 1 1 ] ], None) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), RaggedTensor([[3, 1], + [2, 1, 1], + [2, 1, 1]], dtype=torch.int32), None) >>> a.unique(need_new2old_indexes=True) - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], None, tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), None, tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) **Example 3** >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [3], [2]]) >>> a.unique(True, True) - ([ [ 1 ] [ 2 ] [ 3 ] ], [ [ 1 1 1 ] ], tensor([0, 2, 1], dtype=torch.int32)) - + (RaggedTensor([[1], + [2], + [3]], dtype=torch.int32), RaggedTensor([[1, 1, 1]], dtype=torch.int32), tensor([0, 2, 1], dtype=torch.int32)) Args: need_num_repeats: @@ -1198,14 +1684,26 @@ If ``use_log`` is ``False``, the normalization per sublist is done as follows: >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[0.1, 0.3], [], [1], [0.2, 0.8]]) >>> a.normalize(use_log=False) -[ [ 0.25 0.75 ] [ ] [ 1 ] [ 0.2 0.8 ] ] +RaggedTensor([[0.25, 0.75], + [], + [1], + [0.2, 0.8]], dtype=torch.float32) >>> a.normalize(use_log=True) -[ [ -0.798139 -0.598139 ] [ ] [ 0 ] [ -1.03749 -0.437488 ] ] +RaggedTensor([[-0.798139, -0.598139], + [], + [0], + [-1.03749, -0.437488]], dtype=torch.float32) >>> b = k2r.RaggedTensor([ [[0.1, 0.3], []], [[1], [0.2, 0.8]] ]) >>> b.normalize(use_log=False) -[ [ [ 0.25 0.75 ] [ ] ] [ [ 1 ] [ 0.2 0.8 ] ] ] +RaggedTensor([[[0.25, 0.75], + []], + [[1], + [0.2, 0.8]]], dtype=torch.float32) >>> b.normalize(use_log=True) -[ [ [ -0.798139 -0.598139 ] [ ] ] [ [ 0 ] [ -1.03749 -0.437488 ] ] ] +RaggedTensor([[[-0.798139, -0.598139], + []], + [[0], + [-1.03749, -0.437488]]], dtype=torch.float32) >>> a.num_axes 2 >>> b.num_axes @@ -1299,16 +1797,22 @@ Sort a ragged tensor over the last axis **in-place**. >>> b tensor([1, 0, 2, 4, 5, 3, 7, 6, 8], dtype=torch.int32) >>> a -[ [ 3 1 0 ] [ 5 3 2 ] [ ] [ 3 1 0 ] ] ->>> a_clone.data[b.long()] +RaggedTensor([[3, 1, 0], + [5, 3, 2], + [], + [3, 1, 0]], dtype=torch.float32) +>>> a_clone.values[b.long()] tensor([3., 1., 0., 5., 3., 2., 3., 1., 0.]) >>> a_clone = a.clone() >>> c = a.sort_(descending=False, need_new2old_indexes=True) >>> c tensor([2, 1, 0, 5, 4, 3, 8, 7, 6], dtype=torch.int32) >>> a -[ [ 0 1 3 ] [ 2 3 5 ] [ ] [ 0 1 3 ] ] ->>> a_clone.data[c.long()] +RaggedTensor([[0, 1, 3], + [2, 3, 5], + [], + [0, 1, 3]], dtype=torch.float32) +>>> a_clone.values[c.long()] tensor([0., 1., 3., 2., 3., 5., 0., 1., 3.]) Args: @@ -1318,7 +1822,7 @@ tensor([0., 1., 3., 2., 3., 5., 0., 1., 3.]) need_new2old_indexes: If ``True``, also returns a 1-D tensor, containing the indexes mapping from the sorted elements to the unsorted elements. We can use - ``self.clone().data[returned_tensor]`` to get a sorted tensor. + ``self.clone().values[returned_tensor]`` to get a sorted tensor. Returns: If ``need_new2old_indexes`` is False, returns None. Otherwise, returns a 1-D tensor of dtype ``torch.int32``. @@ -1333,10 +1837,14 @@ Index a ragged tensor with a ragged tensor. >>> src = k2r.RaggedTensor([[10, 11], [12, 13.5]]) >>> indexes = k2r.RaggedTensor([[0, 1]]) >>> src.index(indexes) - [ [ [ 10 11 ] [ 12 13.5 ] ] ] + RaggedTensor([[[10, 11], + [12, 13.5]]], dtype=torch.float32) >>> i = k2r.RaggedTensor([[0], [1], [0, 0]]) >>> src.index(i) - [ [ [ 10 11 ] ] [ [ 12 13.5 ] ] [ [ 10 11 ] [ 10 11 ] ] ] + RaggedTensor([[[10, 11]], + [[12, 13.5]], + [[10, 11], + [10, 11]]], dtype=torch.float32) **Example 2**: @@ -1344,11 +1852,21 @@ Index a ragged tensor with a ragged tensor. >>> src = k2r.RaggedTensor([ [[1, 0], [], [2]], [[], [3], [0, 0, 1]], [[1, 2], [-1]]]) >>> i = k2r.RaggedTensor([[[0, 2], [1]], [[0]]]) >>> src.index(i) - [ [ [ [ [ 1 0 ] [ ] [ 2 ] ] [ [ 1 2 ] [ -1 ] ] ] [ [ [ ] [ 3 ] [ 0 0 1 ] ] ] ] [ [ [ [ 1 0 ] [ ] [ 2 ] ] ] ] ] + RaggedTensor([[[[[1, 0], + [], + [2]], + [[1, 2], + [-1]]], + [[[], + [3], + [0, 0, 1]]]], + [[[[1, 0], + [], + [2]]]]], dtype=torch.int32) Args: indexes: - Its values must satisfy ``0 <= data[i] < self.dim0``. + Its values must satisfy ``0 <= values[i] < self.dim0``. Caution: Its dtype has to be ``torch.int32``. @@ -1373,14 +1891,19 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of >>> i = torch.tensor([2, 0, 3, 5], dtype=torch.int32) >>> b, value_indexes = a.index(i, axis=0, need_value_indexes=True) >>> b - [ [ 0 1 2 ] [ 0 2 3 ] [ ] [ 3 -1.25 ] ] + RaggedTensor([[0, 1, 2], + [0, 2, 3], + [], + [3, -1.25]], dtype=torch.float32) >>> value_indexes tensor([3, 4, 5, 0, 1, 2, 6, 7], dtype=torch.int32) - >>> a.data[value_indexes.long()] + >>> a.values[value_indexes.long()] tensor([ 0.0000, 1.0000, 2.0000, 0.0000, 2.0000, 3.0000, 3.0000, -1.2500]) >>> k = torch.tensor([2, -1, 0], dtype=torch.int32) >>> a.index(k, axis=0, need_value_indexes=True) - ([ [ 0 1 2 ] [ ] [ 0 2 3 ] ], tensor([3, 4, 5, 0, 1, 2], dtype=torch.int32)) + (RaggedTensor([[0, 1, 2], + [], + [0, 2, 3]], dtype=torch.float32), tensor([3, 4, 5, 0, 1, 2], dtype=torch.int32)) **Example 2**: @@ -1391,13 +1914,18 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of tensor([0, 0, 0, 1, 1, 1, 1], dtype=torch.int32) >>> b, value_indexes = a.index(i, axis=1, need_value_indexes=True) >>> b - [ [ [ 1 3 ] [ 2 ] [ ] ] [ [ 2 ] [ 5 8 ] [ -1 ] [ ] ] ] + RaggedTensor([[[1, 3], + [2], + []], + [[2], + [5, 8], + [-1], + []]], dtype=torch.int32) >>> value_indexes tensor([0, 1, 2, 6, 3, 4, 5], dtype=torch.int32) - >>> a.data[value_indexes.long()] + >>> a.values[value_indexes.long()] tensor([ 1, 3, 2, 2, 5, 8, -1], dtype=torch.int32) - Args: indexes: Array of indexes, which will be interpreted as indexes into axis ``axis`` @@ -1414,15 +1942,15 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of The axis to be indexed. Must satisfy ``0 <= axis < self.num_axes``. need_value_indexes: If ``True``, it will return a torch.Tensor containing the indexes into - ``self.data`` that ``ans.data`` has, as in - ``ans.data = self.data[value_indexes]``. + ``self.values`` that ``ans.values`` has, as in + ``ans.values = self.values[value_indexes]``. Returns: Return a tuple containing: - A ragged tensor, sharing the same dtype and device with ``self`` - ``None`` if ``need_value_indexes`` is False; a 1-D torch.tensor of - dtype ``torch.int32`` containing the indexes into ``self.data`` that - ``ans.data`` has. + dtype ``torch.int32`` containing the indexes into ``self.values`` that + ``ans.values`` has. )doc"; static constexpr const char *kRaggedAnyIndexTensorWithRaggedDoc = R"doc( @@ -1435,15 +1963,23 @@ Use a ragged tensor to index a 1-d torch tensor. >>> src tensor([ 0, 10, 20, 30, 40, 50], dtype=torch.int32) >>> k2r.index(src, i) -[ [ 10 50 30 ] [ 0 20 ] ] +RaggedTensor([[10, 50, 30], + [0, 20]], dtype=torch.int32) >>> k = k2r.RaggedTensor([ [[1, 5, 3], [0]], [[0, 2], [1, 3]] ]) >>> k2r.index(src, k) -[ [ [ 10 50 30 ] [ 0 ] ] [ [ 0 20 ] [ 10 30 ] ] ] +RaggedTensor([[[10, 50, 30], + [0]], + [[0, 20], + [10, 30]]], dtype=torch.int32) >>> n = k2r.RaggedTensor([ [1, -1], [-1, 0], [-1] ]) >>> k2r.index(src, n) -[ [ 10 0 ] [ 0 0 ] [ 0 ] ] +RaggedTensor([[10, 0], + [0, 0], + [0]], dtype=torch.int32) >>> k2r.index(src, n, default_value=-2) -[ [ 10 -2 ] [ -2 0 ] [ -2 ] ] +RaggedTensor([[10, -2], + [-2, 0], + [-2]], dtype=torch.int32) Args: src: diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index 8366a5f45..987d3e598 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -35,6 +35,44 @@ namespace k2 { +static void PrintSpaces(std::ostream &os, int32_t num_spaces) { + K2_CHECK_GE(num_spaces, 0); + for (int32_t i = 0; i != num_spaces; ++i) os << " "; +} + +template +void RaggedAnyToStringIter(std::ostream &os, const Ragged ragged, + int32_t axis, int32_t begin_pos, int32_t end_pos, + int32_t num_indent) { + const auto &shape = ragged.shape; + K2_CHECK(axis >= 0 && axis < shape.NumAxes() && begin_pos >= 0 && + begin_pos <= end_pos && end_pos <= shape.TotSize(axis)); + std::string sep = ""; + bool is_first_row = true; + for (int32_t d = begin_pos; d < end_pos; d++) { + if (axis == shape.NumAxes() - 1) { + os << sep << ragged.values[d]; + sep = ", "; + } else { + const int32_t *row_splits = shape.RowSplits(axis + 1).Data(); + K2_DCHECK_LE(d, shape.RowSplits(axis + 1).Dim()); + int32_t row_start = row_splits[d], row_end = row_splits[d + 1]; + + if (!is_first_row) { + PrintSpaces(os, num_indent + 1); + } + is_first_row = false; + + os << "["; + + RaggedAnyToStringIter(os, ragged, axis + 1, row_start, row_end, + num_indent + 1); + os << "]"; + if (d != end_pos - 1) os << ",\n"; + } + } +} + /** One iteration of RaggedAnyFromList. @param data It is a list or a list-of sublist(s). @@ -154,21 +192,24 @@ RaggedAny::RaggedAny(const RaggedShape &shape, torch::Tensor value) K2_LOG(FATAL) << "Unsupported dtype: " << TraitsOf(t).Name(); } -RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/) { +RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/, + torch::Device device /*=torch::kCPU*/) { if (!dtype.is_none() && !THPDtype_Check(dtype.ptr())) { K2_LOG(FATAL) << "Expect an instance of torch.dtype. " << "Given: " << py::str(dtype); } + ContextPtr context = GetContext(device); + if (dtype.is_none()) { try { // We try int first, if it fails, use float - any = Ragged(s, /*throw_on_failure*/ true).Generic(); + any = Ragged(s, /*throw_on_failure*/ true).To(context).Generic(); return; } catch (const std::runtime_error &) { // Use float. If it fails again, another exception // is thrown and it is propagated to the user - any = Ragged(s).Generic(); + any = Ragged(s).To(context).Generic(); return; } } @@ -178,7 +219,7 @@ RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/) { Dtype t = ScalarTypeToDtype(scalar_type); FOR_REAL_AND_INT32_TYPES(t, T, { - any = Ragged(s).Generic(); + any = Ragged(s).To(context).Generic(); return; }); @@ -187,21 +228,24 @@ RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/) { << "and torch.float64"; } -RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/) { +RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/, + torch::Device device /*=torch::kCPU*/) { if (!dtype.is_none() && !THPDtype_Check(dtype.ptr())) { K2_LOG(FATAL) << "Expect an instance of torch.dtype. " << "Given: " << py::str(dtype); } + ContextPtr context = GetContext(device); + if (dtype.is_none()) { try { // We try int first; if it fails, use float - any = RaggedAnyFromList(data).Generic(); + any = RaggedAnyFromList(data).To(context).Generic(); return; } catch (const std::exception &) { // Use float. If it fails again, another exception // is thrown and it is propagated to the user - any = RaggedAnyFromList(data).Generic(); + any = RaggedAnyFromList(data).To(context).Generic(); return; } } @@ -211,7 +255,7 @@ RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/) { Dtype t = ScalarTypeToDtype(scalar_type); FOR_REAL_AND_INT32_TYPES(t, T, { - any = RaggedAnyFromList(data).Generic(); + any = RaggedAnyFromList(data).To(context).Generic(); return; }); @@ -220,6 +264,51 @@ RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/) { << "and torch.float64"; } +RaggedAny::RaggedAny(torch::Tensor tensor) { + int32_t ndim = tensor.dim(); + K2_CHECK_GE(ndim, 2) << "Expect a tensor with more than 1-D"; + ContextPtr context = GetContext(tensor); + DeviceGuard guard(context); + std::vector shapes; + shapes.reserve(ndim - 1); + int32_t dim0 = tensor.size(0); + for (int32_t i = 1; i != ndim; ++i) { + int32_t dim1 = tensor.size(i); + shapes.push_back(RegularRaggedShape(context, dim0, dim1)); + dim0 *= dim1; + } + while (shapes.size() > 2u) { + RaggedShape c = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape b = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape a = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape abc = ComposeRaggedShapes3(a, b, c); + shapes.push_back(std::move(abc)); + } + + if (shapes.size() > 1u) { + RaggedShape b = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape a = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape ab = ComposeRaggedShapes(a, b); + shapes.push_back(std::move(ab)); + } + + Dtype t = ScalarTypeToDtype(tensor.scalar_type()); + FOR_REAL_AND_INT32_TYPES(t, T, { + Array1 values = FromTorch(tensor.contiguous().view({-1})); + any = Ragged(shapes[0], values).Generic(); + }); +} + const torch::Tensor &RaggedAny::Data() const { DeviceGuard guard(any.Context()); if (!data.defined()) { @@ -232,10 +321,33 @@ const torch::Tensor &RaggedAny::Data() const { return data; } -std::string RaggedAny::ToString() const { +std::string RaggedAny::ToString(int32_t device_id /*=-1*/) const { + ContextPtr context = any.Context(); + if (context->GetDeviceType() != kCpu) { + return To("cpu").ToString(context->GetDeviceId()); + } + std::ostringstream os; Dtype t = any.GetDtype(); - FOR_REAL_AND_INT32_TYPES(t, T, { os << any.Specialize(); }); + std::string dtype; + if (t == kInt32Dtype) + dtype = "torch.int32"; + else if (t == kFloatDtype) + dtype = "torch.float32"; + else if (t == kDoubleDtype) + dtype = "torch.float64"; + else + K2_LOG(FATAL) << "Unsupported dtype: " << TraitsOf(t).Name(); + + FOR_REAL_AND_INT32_TYPES(t, T, { + os << "RaggedTensor(["; + // 13 is strlen("RaggedTensor(") + RaggedAnyToStringIter(os, any.Specialize(), 0, 0, any.shape.Dim0(), 13); + os << "]"; + if (device_id != -1) os << ", device='cuda:" << device_id << "'"; + os << ", dtype=" << dtype; + os << ")"; + }); return os.str(); } diff --git a/k2/python/csrc/torch/v2/ragged_any.h b/k2/python/csrc/torch/v2/ragged_any.h index c3e171f59..2ddec79cd 100644 --- a/k2/python/csrc/torch/v2/ragged_any.h +++ b/k2/python/csrc/torch/v2/ragged_any.h @@ -52,6 +52,18 @@ struct RaggedAny { */ RaggedAny(const RaggedShape &shape, torch::Tensor value); + /* Create a ragged tensor from a torch tensor. + + @note The resulting ragged tensor has a regular structure. + + @params tensor An N-D PyTorch tensor, where N > 1. Supported dtypes are + torch.int32, torch.float32, torch.float64. + + @caution If the input tensor is contiguous, the ragged tensor shares the + underlying memory with the input tensor. Otherwise, memory is copied. + */ + explicit RaggedAny(torch::Tensor tensor); + RaggedAny(const RaggedAny &) = default; RaggedAny &operator=(const RaggedAny &) = default; RaggedAny(RaggedAny &&) = default; @@ -75,7 +87,12 @@ struct RaggedAny { @note We can support other dtypes if needed. */ - explicit RaggedAny(const std::string &s, py::object dtype = py::none()); + explicit RaggedAny(const std::string &s, py::object dtype = py::none(), + torch::Device device = torch::kCPU); + + explicit RaggedAny(const std::string &s, py::object dtype = py::none(), + const std::string &device = "cpu") + : RaggedAny(s, dtype, torch::Device(device)) {} /** Create a ragged tensor from a list of sublist(s). @@ -88,16 +105,22 @@ struct RaggedAny { @note It supports `data` with number of axes >= 2. */ - explicit RaggedAny(py::list data, py::object dtype = py::none()); + explicit RaggedAny(py::list data, py::object dtype = py::none(), + torch::Device device = torch::kCPU); + + explicit RaggedAny(py::list data, py::object dtype = py::none(), + const std::string device = "cpu") + : RaggedAny(data, dtype, torch::Device(device)) {} /// Populate `this->data` and return it const torch::Tensor &Data() const; /** Convert a ragged tensor to a string. + @param device_id -1 for CPU. 0 and above is for CUDA. @return Return a string representation of this tensor. */ - std::string ToString() const; + std::string ToString(int device_id = -1) const; /* Move a ragged tensor to a given device. diff --git a/k2/python/tests/fsa_test.py b/k2/python/tests/fsa_test.py index aca8676a5..4349fbffc 100644 --- a/k2/python/tests/fsa_test.py +++ b/k2/python/tests/fsa_test.py @@ -986,12 +986,11 @@ def test_create_fsa_vec(self): fsa2 = k2.Fsa.from_str(s2) fsa2.aux_labels = k2.RaggedTensor('[ [5 8 9] ]') fsa = k2.create_fsa_vec([fsa1, fsa2]) - self.assertEqual(str(fsa.aux_labels), - '[ [ 1 0 2 ] [ 3 5 ] [ 5 8 9 ] ]') + assert fsa.aux_labels == k2.RaggedTensor('[ [ 1 0 2 ] [ 3 5 ] [ 5 8 9 ] ]') # noqa fsa = k2.Fsa.from_fsas([fsa1, fsa2]) - self.assertEqual(str(fsa.aux_labels), - '[ [ 1 0 2 ] [ 3 5 ] [ 5 8 9 ] ]') + assert fsa.aux_labels == k2.RaggedTensor( + '[ [ 1 0 2 ] [ 3 5 ] [ 5 8 9 ] ]') def test_index_fsa(self): for device in self.devices: diff --git a/k2/python/tests/ragged_ops_test.py b/k2/python/tests/ragged_ops_test.py index 14cf9a339..9a182a870 100644 --- a/k2/python/tests/ragged_ops_test.py +++ b/k2/python/tests/ragged_ops_test.py @@ -48,13 +48,15 @@ def test_remove_axis_ragged_array(self): [ [ [ 1 2 ] [ 0 ] ] [ [3 0 ] [ 2 ] ] ] ''' for device in self.devices: - src = k2.RaggedTensor(s).to(device) + src = k2.RaggedTensor(s, device=device) ans = src.remove_axis(0) - self.assertEqual(str(ans), '[ [ 1 2 ] [ 0 ] [ 3 0 ] [ 2 ] ]') + assert ans == k2.RaggedTensor('[ [ 1 2 ] [ 0 ] [ 3 0 ] [ 2 ] ]', + device=device) ans = src.remove_axis(1) - self.assertEqual(str(ans), '[ [ 1 2 0 ] [ 3 0 2 ] ]') + assert ans == k2.RaggedTensor('[ [ 1 2 0 ] [ 3 0 2 ] ]', + device=device) def test_remove_axis_ragged_shape(self): for device in self.devices: @@ -151,19 +153,28 @@ def test_remove_values_leq(self): ''' for device in self.devices: for dtype in self.dtypes: - src = k2.RaggedTensor(s, dtype).to(device) + src = k2.RaggedTensor(s, dtype=dtype, device=device) ans = src.remove_values_leq(0) - self.assertEqual(str(ans), '[ [ 1 2 ] [ 3 2 ] [ 8 6 ] [ ] ]') + assert ans == k2.RaggedTensor( + '[ [ 1 2 ] [ 3 2 ] [ 8 6 ] [ ] ]', + device=device, + dtype=dtype) ans = src.remove_values_leq(1) - self.assertEqual(str(ans), '[ [ 2 ] [ 3 2 ] [ 8 6 ] [ ] ]') + assert ans == k2.RaggedTensor('[ [ 2 ] [ 3 2 ] [ 8 6 ] [ ] ]', + dtype=dtype, + device=device) ans = src.remove_values_leq(6) - self.assertEqual(str(ans), '[ [ ] [ ] [ 8 ] [ ] ]') + assert ans == k2.RaggedTensor('[ [ ] [ ] [ 8 ] [ ] ]', + device=device, + dtype=dtype) ans = src.remove_values_leq(8) - self.assertEqual(str(ans), '[ [ ] [ ] [ ] [ ] ]') + assert ans == k2.RaggedTensor('[ [ ] [ ] [ ] [ ] ]', + dtype=dtype, + device=device) def test_remove_values_eq(self): s = ''' @@ -171,22 +182,31 @@ def test_remove_values_eq(self): ''' for device in self.devices: for dtype in self.dtypes: - src = k2.RaggedTensor(s).to(device) + src = k2.RaggedTensor(s, device=device, dtype=dtype) ans = src.remove_values_eq(0) - self.assertEqual(str(ans), '[ [ 1 2 ] [ 3 2 ] [ 8 6 ] [ ] ]') + assert ans == k2.RaggedTensor( + '[ [ 1 2 ] [ 3 2 ] [ 8 6 ] [ ] ]', + device=device, + dtype=dtype) ans = src.remove_values_eq(1) - self.assertEqual(str(ans), - '[ [ 2 0 ] [ 3 0 2 ] [ 0 8 0 6 0 ] [ 0 ] ]') + assert ans == k2.RaggedTensor( + '[ [ 2 0 ] [ 3 0 2 ] [ 0 8 0 6 0 ] [ 0 ] ]', + device=device, + dtype=dtype) ans = src.remove_values_eq(6) - self.assertEqual(str(ans), - '[ [ 1 2 0 ] [ 3 0 2 ] [ 0 8 0 0 ] [ 0 ] ]') + assert ans == k2.RaggedTensor( + '[ [ 1 2 0 ] [ 3 0 2 ] [ 0 8 0 0 ] [ 0 ] ]', + device=device, + dtype=dtype) ans = src.remove_values_eq(8) - self.assertEqual(str(ans), - '[ [ 1 2 0 ] [ 3 0 2 ] [ 0 0 6 0 ] [ 0 ] ]') + assert ans == k2.RaggedTensor( + '[ [ 1 2 0 ] [ 3 0 2 ] [ 0 0 6 0 ] [ 0 ] ]', + device=device, + dtype=dtype) def test_normalize_scores_use_log_non_zero_stride(self): s = ''' @@ -350,9 +370,10 @@ def test_cat(self): ragged2 = k2.RaggedTensor('[ [] [10 20] [30] [40 50] ]', dtype).to(device) ragged = k2.ragged.cat([ragged1, ragged2], axis=0) - self.assertEqual( - str(ragged), - '[ [ 1 2 3 ] [ ] [ 4 5 ] [ ] [ 10 20 ] [ 30 ] [ 40 50 ] ]') + assert ragged == k2.RaggedTensor( + '[ [ 1 2 3 ] [ ] [ 4 5 ] [ ] [ 10 20 ] [ 30 ] [ 40 50 ] ]', + dtype=dtype, + device=device) def test_cat_axis1(self): for device in self.devices: @@ -362,8 +383,10 @@ def test_cat_axis1(self): ragged2 = k2.RaggedTensor('[ [10 20] [8] [9 10] ]', dtype).to(device) ragged = k2.ragged.cat([ragged1, ragged2], axis=1) - self.assertEqual(str(ragged), - '[ [ 1 2 3 10 20 ] [ 8 ] [ 4 5 9 10 ] ]') + assert ragged == k2.RaggedTensor( + '[ [ 1 2 3 10 20 ] [ 8 ] [ 4 5 9 10 ] ]', + device=device, + dtype=dtype) def test_get_layer_two_axes(self): for device in self.devices: diff --git a/k2/python/tests/ragged_tensor_test.py b/k2/python/tests/ragged_tensor_test.py index ec59d4831..2ae9063c0 100644 --- a/k2/python/tests/ragged_tensor_test.py +++ b/k2/python/tests/ragged_tensor_test.py @@ -31,7 +31,6 @@ class TestRaggedTensor(unittest.TestCase): - @classmethod def setUpClass(cls): cls.devices = [torch.device("cpu")] @@ -65,6 +64,50 @@ def test_create_ragged_tensor_from_string(self): assert b.num_axes == 3 assert b.dim0 == 2 + def test_create_ragged_tensor_from_torch_tensor(self): + for device in self.devices: + for func in [k2r.create_ragged_tensor, k2r.RaggedTensor]: + for dtype in self.dtypes: + a = torch.arange(24, dtype=dtype, device=device).reshape( + 2, 3, 4 + ) + b = func(a) + assert b.shape.tot_sizes() == (2, 2 * 3, 2 * 3 * 4) + + # a is contiguous, so memory is shared + c = a.reshape(-1) + c[0] = 10 + assert b.values[0] == 10 + b.values[1] = 100 + assert c[1] == 100 + + assert b.dtype == dtype + assert b.device == device + + assert torch.all(torch.eq(c, b.values)) + + for device in self.devices: + for func in [k2r.create_ragged_tensor, k2r.RaggedTensor]: + for dtype in self.dtypes: + a = torch.arange(100, dtype=dtype, device=device).reshape( + 10, 10 + )[:, ::2] + assert a.shape == (10, 5) + b = func(a) + assert b.dtype == dtype + assert b.device == device + + assert b.shape.tot_sizes() == (10, 10 * 5) + + c = a.reshape(-1) + assert torch.all(torch.eq(c, b.values)) + + # a is not contiguous, so memory is copied + c[0] = -10 + assert b.values[0] != -10 + b.values[1] = -100 + assert c[1] != -100 + def test_property_values(self): a = k2r.RaggedTensor([[1], [2], [], [3, 4]]) assert torch.all(torch.eq(a.values, torch.tensor([1, 2, 3, 4]))) @@ -128,17 +171,17 @@ def test_sum_with_grad(self): a = a.to(device) a.requires_grad_(True) b = a.sum() - expected_sum = torch.tensor([3, 0, 5], - dtype=dtype, - device=device) + expected_sum = torch.tensor( + [3, 0, 5], dtype=dtype, device=device + ) assert torch.all(torch.eq(b, expected_sum)) c = b[0] * 10 + b[1] * 20 + b[2] * 30 c.backward() - expected_grad = torch.tensor([10, 10, 30], - device=device, - dtype=dtype) + expected_grad = torch.tensor( + [10, 10, 30], device=device, dtype=dtype + ) assert torch.all(torch.eq(a.grad, expected_grad)) def test_sum_no_grad(self): @@ -147,26 +190,27 @@ def test_sum_no_grad(self): a = k2r.RaggedTensor([[1, 2], [], [5]], dtype=dtype) a = a.to(device) b = a.sum() - expected_sum = torch.tensor([3, 0, 5], - dtype=dtype, - device=device) + expected_sum = torch.tensor( + [3, 0, 5], dtype=dtype, device=device + ) assert torch.all(torch.eq(b, expected_sum)) def test_getitem(self): for device in self.devices: for dtype in self.dtypes: - a = k2r.RaggedTensor("[ [[1 2] [] [10]] [[3] [5]] ]", - dtype=dtype) + a = k2r.RaggedTensor( + "[ [[1 2] [] [10]] [[3] [5]] ]", dtype=dtype + ) a = a.to(device) b = a[0] - expected = k2r.RaggedTensor("[[1 2] [] [10]]", - dtype=dtype).to(device) + expected = k2r.RaggedTensor("[[1 2] [] [10]]", dtype=dtype).to( + device + ) assert b == expected b = a[1] - expected = k2r.RaggedTensor("[[3] [5]]", - dtype=dtype).to(device) + expected = k2r.RaggedTensor("[[3] [5]]", dtype=dtype).to(device) assert b == expected def test_getstate_2axes(self): @@ -177,9 +221,9 @@ def test_getstate_2axes(self): assert isinstance(b, tuple) assert len(b) == 3 # b contains (row_splits, "row_ids1", values) - b_0 = torch.tensor([0, 2, 3, 3], - dtype=torch.int32, - device=device) + b_0 = torch.tensor( + [0, 2, 3, 3], dtype=torch.int32, device=device + ) b_1 = "row_ids1" b_2 = a.values @@ -190,8 +234,9 @@ def test_getstate_2axes(self): def test_getstate_3axes(self): for device in self.devices: for dtype in self.dtypes: - a = k2r.RaggedTensor("[[[1 2] [3] []] [[4] [5 6]]]", - dtype=dtype).to(device) + a = k2r.RaggedTensor( + "[[[1 2] [3] []] [[4] [5 6]]]", dtype=dtype + ).to(device) b = a.__getstate__() assert isinstance(b, tuple) assert len(b) == 5 @@ -199,9 +244,9 @@ def test_getstate_3axes(self): # "row_ids2", values) b_0 = torch.tensor([0, 3, 5], dtype=torch.int32, device=device) b_1 = "row_ids1" - b_2 = torch.tensor([0, 2, 3, 3, 4, 6], - dtype=torch.int32, - device=device) # noqa + b_2 = torch.tensor( + [0, 2, 3, 3, 4, 6], dtype=torch.int32, device=device + ) # noqa b_3 = "row_ids2" b_4 = a.values @@ -255,7 +300,8 @@ def test_tot_size_3axes(self): for dtype in self.dtypes: a = k2r.RaggedTensor( "[ [[1 2 3] [] [5 8]] [[] [1 5 9 10 -1] [] [] []] ]", - dtype=dtype) + dtype=dtype, + ) a = a.to(device) assert a.tot_size(0) == 2 diff --git a/k2/python/tests/random_paths_test.py b/k2/python/tests/random_paths_test.py index 43e2a1341..f753c79ad 100644 --- a/k2/python/tests/random_paths_test.py +++ b/k2/python/tests/random_paths_test.py @@ -51,7 +51,8 @@ def test_single_fsa_case1(self): use_double_scores=use_double_scores, num_paths=2) assert path.num_axes == 3 - self.assertEqual(str(path), '[ [ [ 0 1 ] [ 0 1 ] ] ]') + assert path == k2.RaggedTensor('[ [ [ 0 1 ] [ 0 1 ] ] ]', + device=device) def test_single_fsa_case2(self): for device in self.devices: @@ -75,7 +76,8 @@ def test_single_fsa_case2(self): # iter 0, p is 0.5, select the second leaving arc of state 0 # iter 1, p is 0, select the first leaving arc of state 1 # iter 2, p is 0, select the first leaving arc of state 2 - self.assertEqual(str(path), '[ [ [ 1 2 4 ] ] ]') + assert path == k2.RaggedTensor('[ [ [ 1 2 4 ] ] ]', + device=device) path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, @@ -93,7 +95,8 @@ def test_single_fsa_case2(self): # iter 2, p is (0.5 - 0.5) / (1 - 0.5) = 0, select the # first leaving arc of state 2 assert path.num_axes == 3 - self.assertEqual(str(path), '[ [ [ 0 3 4 ] [ 1 3 4 ] ] ]') + assert path == k2.RaggedTensor('[ [ [ 0 3 4 ] [ 1 3 4 ] ] ]', + device=device) path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, @@ -123,9 +126,9 @@ def test_single_fsa_case2(self): # iter 2, p is 0.25/0.5=0.5, select the second leaving arc of # state 2 assert path.num_axes == 3 - self.assertEqual( - str(path), - '[ [ [ 0 2 5 ] [ 0 3 5 ] [ 1 2 5 ] [ 1 3 5 ] ] ]') + assert path == k2.RaggedTensor( + '[ [ [ 0 2 5 ] [ 0 3 5 ] [ 1 2 5 ] [ 1 3 5 ] ] ]', + device=device) def test_fsa_vec(self): for device in self.devices: @@ -171,8 +174,8 @@ def test_fsa_vec(self): # iter 2, p is 0, select arc 7 # iter 3, p is 0, select arc 9 # path 1 is [2, 4, 7, 9] + 6 = [8, 10, 13, 15] - self.assertEqual(str(path), - '[ [ [ 1 2 4 ] ] [ [ 8 10 13 15 ] ] ]') + assert path == k2.RaggedTensor( + '[ [ [ 1 2 4 ] ] [ [ 8 10 13 15 ] ] ]', device=device) path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, @@ -190,10 +193,9 @@ def test_fsa_vec(self): # iter 2, p is 0, select arc 7 # iter 3, p is 0, select arc 9 # path 1 is [3, 4, 7, 9] + 6 = [9, 10, 13, 15] - self.assertEqual( - str(path), - '[ [ [ 0 3 4 ] [ 1 3 4 ] ] [ [ 7 10 13 15 ] [ 9 10 13 15 ] ] ]' # noqa - ) + assert path == k2.RaggedTensor( + '[ [ [ 0 3 4 ] [ 1 3 4 ] ] [ [ 7 10 13 15 ] [ 9 10 13 15 ] ] ]', # noqa + device=device) # noqa path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, num_paths=4) @@ -234,10 +236,9 @@ def test_fsa_vec(self): # errors) # iter 3, p is 0.99911, select arc 9 # path 3 is [3, 5, 7, 9] + 6 = [9, 11, 13, 15] - self.assertEqual( - str(path), - '[ [ [ 0 2 5 ] [ 0 3 5 ] [ 1 2 5 ] [ 1 3 5 ] ] [ [ 6 11 13 15 ] [ 7 11 13 15 ] [ 8 11 13 15 ] [ 9 11 13 15 ] ] ]' # noqa - ) + assert path == k2.RaggedTensor( + '[ [ [ 0 2 5 ] [ 0 3 5 ] [ 1 2 5 ] [ 1 3 5 ] ] [ [ 6 11 13 15 ] [ 7 11 13 15 ] [ 8 11 13 15 ] [ 9 11 13 15 ] ] ]', # noqa + device=device) if __name__ == '__main__': From 971af7dbbebb5f0725bb8999a54b483e7e040649 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 16 Sep 2021 12:34:29 +0800 Subject: [PATCH 06/64] Trigger GitHub actions manually. (#829) --- .github/workflows/build-cpu.yml | 4 ++-- .github/workflows/build.yml | 4 ++-- .github/workflows/run-tests-cpu.yml | 4 ++-- .github/workflows/run-tests.yml | 4 ++-- .github/workflows/windows.yml | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 1ab3b620b..018020b3f 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -23,14 +23,14 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] env: BUILD_TYPE: Release jobs: build-cpu: + if: ${{ github.event.label.name == 'ready' }} runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f15af6452..38781b432 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -23,14 +23,14 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] env: BUILD_TYPE: Release jobs: build: + if: ${{ github.event.label.name == 'ready' }} runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/run-tests-cpu.yml b/.github/workflows/run-tests-cpu.yml index e0e260c7d..e730a4dcc 100644 --- a/.github/workflows/run-tests-cpu.yml +++ b/.github/workflows/run-tests-cpu.yml @@ -23,11 +23,11 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] jobs: run-tests-cpu: + if: ${{ github.event.label.name == 'ready' }} runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index b145405d6..fbf8fb18c 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -23,11 +23,11 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] jobs: run-tests: + if: ${{ github.event.label.name == 'ready' }} runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index d3b723ab0..27a509c6f 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -22,8 +22,7 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] env: BUILD_TYPE: Release @@ -31,6 +30,7 @@ env: jobs: build-windows: # see https://github.com/actions/virtual-environments/blob/win19/20210525.0/images/win/Windows2019-Readme.md + if: ${{ github.event.label.name == 'ready' }} runs-on: ${{ matrix.os }} strategy: fail-fast: false From 646704e142438bcd1aaf4a6e32d95e5ccd93a174 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 16 Sep 2021 13:05:12 +0800 Subject: [PATCH 07/64] Run GitHub actions on merging. (#830) --- .github/workflows/build-cpu.yml | 2 +- .github/workflows/build.yml | 2 +- .github/workflows/run-tests-cpu.yml | 2 +- .github/workflows/run-tests.yml | 2 +- .github/workflows/windows.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 018020b3f..02c373799 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -30,7 +30,7 @@ env: jobs: build-cpu: - if: ${{ github.event.label.name == 'ready' }} + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 38781b432..756de8f54 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,7 +30,7 @@ env: jobs: build: - if: ${{ github.event.label.name == 'ready' }} + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/run-tests-cpu.yml b/.github/workflows/run-tests-cpu.yml index e730a4dcc..94e391a6f 100644 --- a/.github/workflows/run-tests-cpu.yml +++ b/.github/workflows/run-tests-cpu.yml @@ -27,7 +27,7 @@ on: jobs: run-tests-cpu: - if: ${{ github.event.label.name == 'ready' }} + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index fbf8fb18c..81a3ba933 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -27,7 +27,7 @@ on: jobs: run-tests: - if: ${{ github.event.label.name == 'ready' }} + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 27a509c6f..7890fb805 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -30,7 +30,7 @@ env: jobs: build-windows: # see https://github.com/actions/virtual-environments/blob/win19/20210525.0/images/win/Windows2019-Readme.md - if: ${{ github.event.label.name == 'ready' }} + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false From 8030001c9a002aa17e090a41de3f1146bdfe1e78 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Sep 2021 13:42:56 +0800 Subject: [PATCH 08/64] Support printing ragged tensors in a more compact way. (#831) * Support printing ragged tensors in a more compact way. * Disable support for torch 1.3.1 * Fix test failures. --- .github/workflows/build-doc.yml | 1 + .github/workflows/nightly-cpu.yml | 9 +-- docs/source/conf.py | 3 +- docs/source/installation/for_developers.rst | 47 +++++++++-- docs/source/installation/pip.rst | 77 +++++++++--------- docs/source/installation/pip_pypi.rst | 37 ++++----- k2/python/csrc/torch/v2/any.cu | 7 ++ k2/python/csrc/torch/v2/doc/any.h | 22 +++++ k2/python/csrc/torch/v2/ragged_any.cu | 19 +++-- k2/python/csrc/torch/v2/ragged_any.h | 16 +++- k2/python/tests/fsa_test.py | 12 +-- k2/python/tests/ragged_tensor_test.py | 89 ++++++++++++++------- k2/python/tests/random_paths_test.py | 14 ++-- 13 files changed, 233 insertions(+), 120 deletions(-) diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index 0a8d65607..e0242cb01 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -23,6 +23,7 @@ on: branches: - master - doc + - doc-test env: # debug is faster in terms of compilation time diff --git a/.github/workflows/nightly-cpu.yml b/.github/workflows/nightly-cpu.yml index 6d9e1c025..4a33a59f9 100644 --- a/.github/workflows/nightly-cpu.yml +++ b/.github/workflows/nightly-cpu.yml @@ -40,13 +40,10 @@ jobs: matrix: os: [ubuntu-18.04, macos-10.15] # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0 - # torch 1.3.1 supports only Python 3.5/6/7 python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.3.1", "1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] + torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] exclude: - - python-version: 3.9 # exclude Python 3.9 for [1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.3.1" - - python-version: 3.9 + - python-version: 3.9 # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.4.0" - python-version: 3.9 torch: "1.5.0" @@ -56,8 +53,6 @@ jobs: torch: "1.6.0" - python-version: 3.9 torch: "1.7.0" - - python-version: 3.8 # exclude Python 3.8 for [1.3.1] - torch: "1.3.1" steps: - uses: actions/checkout@v2 diff --git a/docs/source/conf.py b/docs/source/conf.py index c53403d2b..7f7c8957f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,7 +14,6 @@ import re import sys sys.path.insert(0, os.path.abspath('../../k2/python')) -sys.path.insert(0, os.path.abspath('../../build-ragged/lib')) sys.path.insert(0, os.path.abspath('../../build/lib')) import sphinx_rtd_theme @@ -22,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = 'k2' -copyright = '2020, k2 development team' +copyright = '2020-2021, k2 development team' author = 'k2 development team' diff --git a/docs/source/installation/for_developers.rst b/docs/source/installation/for_developers.rst index 83b885b02..5529191f8 100644 --- a/docs/source/installation/for_developers.rst +++ b/docs/source/installation/for_developers.rst @@ -9,7 +9,7 @@ First, you have to install CMake, CUDA toolkit (with cuDNN), and PyTorch: - CMake 3.11.0 and 3.18.0 are known to work. Other CMake versions may work but they are not tested. - - Install PyTorch. PyTorch 1.5.x and above are known to work. Other PyTorch + - Install PyTorch. PyTorch 1.4.x and above are known to work. Other PyTorch versions may work, but they are not tested. - Install CUDA toolkit and cuDNN. CUDA 10.1 and above are known to work. @@ -43,7 +43,7 @@ To build a release version, use: python3 -c "import k2; print(k2.__file__)" # It should print /some/path/k2/k2/python/k2/__init.py - python3 -c "import _k2; print(_k2.__file__)" + python3 -c "import torch; import _k2; print(_k2.__file__)" # It should print /some/path/k2/build_release/lib/_k2.cpython-38-x86_64-linux-gnu.so # (I assume that you're using Python 3.8, so there is a string 38 above) @@ -63,10 +63,45 @@ To build a debug version, use: python3 -c "import k2; print(k2.__file__)" # It should print /some/path/k2/k2/python/k2/__init.py - python3 -c "import _k2; print(_k2.__file__)" + python3 -c "import torch; import _k2; print(_k2.__file__)" # It should print /some/path/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so # (I assume that you're using Python 3.8, so there is a string 38 above) +.. HINT:: + + You can pass the option ``-DK2_WITH_CUDA=OFF`` to ``cmake`` to build + a CPU only version of k2. + + It is much faster to build a CPU version than that of building a CUDA + version. When you are adding new features to k2, we recommend you to + create a diretory to build a CPU version to test your code. Once it is + working on CPU, you can create a new directory to build a CUDA version + to test your code. + + That is, while adding and testing new features, use: + + .. code-block:: bash + + cd k2 + mkdir build-cpu + cd build-cpu + cmake -DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Debug .. + make -j5 + export PYTHONPATH=$PWD/../k2/python:$PWD/lib:$PYTHONPATH + # make test # to test your code + + After it is working for CPU, you can use: + + .. code-block:: bash + + cd k2 + mkdir build-cuda + cd build-cuda + cmake -DCMAKE_BUILD_TYPE=Debug .. + make -j5 + export PYTHONPATH=$PWD/../k2/python:$PWD/lib:$PYTHONPATH + # make test # to test your code + To run tests, use: .. code-block:: bash @@ -154,16 +189,16 @@ To run a specific Python test, use: To check whether you are using a release version or a debug version, run: - .. code-block:: + .. code-block:: bash - python3 -c "import _k2; print(_k2.__file__)" + python3 -c "import torch; import _k2; print(_k2.__file__)" It should print the directory where k2 was built. That is, the above output contains a string ``build_release`` or ``build_debug``. Alternatively, you can run: - .. code-block:: + .. code-block:: bash python3 -m k2.version diff --git a/docs/source/installation/pip.rst b/docs/source/installation/pip.rst index fb5e2cb1f..ff1f1fdfb 100644 --- a/docs/source/installation/pip.rst +++ b/docs/source/installation/pip.rst @@ -29,77 +29,80 @@ versions of Python, CUDA, and PyTorch. automagically. You don't need to pre-install PyTorch and cudatoolkit when using ``conda install``. -The following commands install k2 with different CUDA versions: +The following commands install k2 with different versions of CUDA and PyTorch: .. code-block:: bash - # Install k2 0.3.3 with CUDA 10.2 built on 20210509 + # Install k2 1.8 with CUDA 10.1 built on 20210916 # - # cu102 means CUDA 10.2 + # You don't need to specifiy the Python version # - pip install k2==0.3.3+cu102.dev20210509 -f https://k2-fsa.org/nightly/ + pip install k2==1.8.dev20210916+cuda10.1.torch1.7.1 -f https://k2-fsa.org/nightly/ - # Install k2 0.3.3 with CUDA 11.0 built on 20210509 + # Install k2 1.8 with CUDA 10.2 built on 20210916 # - # cu110 means CUDA 11.0 # - pip install k2==0.3.3+cu110.dev20210509 -f https://k2-fsa.org/nightly/ + pip install k2==1.8.dev20210916+cuda10.2.torch1.7.1 -f https://k2-fsa.org/nightly/ - # Install k2 0.3.3 with CUDA 10.1 built on 20210509 + # Install k2 1.8 with CUDA 11.0 built on 20210916 # - # CAUTION: you don't need to specify cu101 since CUDA 10.1 is the default - # CUDA version - # - pip install k2==0.3.3.dev20210509 -f https://k2-fsa.org/nightly/ + pip install k2==1.8.dev20210916+cuda11.0.torch1.7.1 -f https://k2-fsa.org/nightly/ - # - # dev20210509 means that version is built on 2021.05.09 - # # Please always select the latest version. That is, the version # with the latest date. +.. Caution:: + + We only provide pre-compiled versions of k2 with torch 1.7.1. If you need + other versions of PyTorch, please consider one of the following alternatives + to install k2: + + - :ref:`install using conda` + - :ref:`install k2 from source` + The following is the log for installing k2: .. code-block:: - $ pip install k2==0.3.3.dev20210509 -f https://k2-fsa.org/nightly/ - Looking in links: https://k2-fsa.org/nightly/ - Collecting k2==0.3.3.dev20210509 - Downloading https://k2-fsa.org/nightly/whl/k2-0.3.3.dev20210509-cp38-cp38-linux_x86_64.whl (54.4 MB) - |________________________________| 54.4 MB 487 kB/s - Requirement already satisfied: torch in ./py38/lib/python3.8/site-packages (from k2==0.3.3.dev20210509) (1.7.1+cu101) - Requirement already satisfied: graphviz in ./py38/lib/python3.8/site-packages (from k2==0.3.3.dev20210509) (0.15) - Requirement already satisfied: numpy in ./py38/lib/python3.8/site-packages (from torch->k2==0.3.3.dev20210509) (1.19.5) - Requirement already satisfied: typing-extensions in ./py38/lib/python3.8/site-packages (from torch->k2==0.3.3.dev20210509) (3.7.4.3) - Installing collected packages: k2 - Successfully installed k2-0.3.3.dev20210509 - WARNING: You are using pip version 21.0.1; however, version 21.1.1 is available. - You should consider upgrading via the '/xxx/bin/python3.8 -m pip install --upgrade pip' command. + $ pip install k2==1.8.dev20210916+cuda10.1.torch1.7.1 -f https://k2-fsa.org/nightly + + Looking in links: https://k2-fsa.org/nightly + Collecting k2==1.8.dev20210916+cuda10.1.torch1.7.1 + Downloading https://k2-fsa.org/nightly/whl/k2-1.8.dev20210916%2Bcuda10.1.torch1.7.1-cp38-cp38-linux_x86_64.whl (77.7 MB) + |________________________________| 77.7 MB 1.6 MB/s + Collecting torch==1.7.1 + Using cached torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl (776.8 MB) + Collecting graphviz + Using cached graphviz-0.17-py3-none-any.whl (18 kB) + Collecting typing-extensions + Downloading typing_extensions-3.10.0.2-py3-none-any.whl (26 kB) + Collecting numpy + Using cached numpy-1.21.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB) + Installing collected packages: typing-extensions, numpy, torch, graphviz, k2 + Successfully installed graphviz-0.17 k2-1.8.dev20210916+cuda10.1.torch1.7.1 numpy-1.21.2 torch-1.7.1 typing-extensions-3.10.0.2 To verify that k2 is installed successfully, run: .. code-block:: $ python3 -m k2.version - /xxx/lib/python3.8/runpy.py:127: RuntimeWarning: 'k2.version' found in sys.modules after import of package 'k2', but prior to execution of 'k2.version'; this may result in unpredictable behaviour - warn(RuntimeWarning(msg)) - Collecting environment information... - k2 version: 0.3.3 + k2 version: 1.8 Build type: Release - Git SHA1: 8e2fa82dca767782351fec57ec187aa04015dcf2 - Git date: Thu May 6 18:55:15 2021 + Git SHA1: 646704e142438bcd1aaf4a6e32d95e5ccd93a174 + Git date: Thu Sep 16 13:05:12 2021 Cuda used to build k2: 10.1 cuDNN used to build k2: 8.0.2 Python version used to build k2: 3.8 OS used to build k2: Ubuntu 18.04.5 LTS - CMake version: 3.20.2 + CMake version: 3.21.2 GCC version: 7.5.0 - CMAKE_CUDA_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 --expt-extended-lambda -gencode arch=compute_50,code=sm_50 --expt-extended-lambda -gencode arch=compute_60,code=sm_60 --expt-extended-lambda -gencode arch=compute_61,code=sm_61 --expt-extended-lambda -gencode arch=compute_70,code=sm_70 --expt-extended-lambda -gencode arch=compute_75,code=sm_75 --compiler-options -Wall --compiler-options -Wno-unknown-pragmas - CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 + CMAKE_CUDA_FLAGS: --expt-extended-lambda -gencode arch=compute_35,code=sm_35 --expt-extended-lambda -gencode arch=compute_50,code=sm_50 --expt-extended-lambda -gencode arch=compute_60,code=sm_60 --expt-extended-lambda -gencode arch=compute_61,code=sm_61 --expt-extended-lambda -gencode arch=compute_70,code=sm_70 --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-unknown-pragmas --compiler-options -Wno-strict-overflow + CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-strict-overflow PyTorch version used to build k2: 1.7.1+cu101 PyTorch is using Cuda: 10.1 NVTX enabled: True + With CUDA: True Disable debug: True Sync kernels : False Disable checks: False diff --git a/docs/source/installation/pip_pypi.rst b/docs/source/installation/pip_pypi.rst index ef3633bb3..57cf98eed 100644 --- a/docs/source/installation/pip_pypi.rst +++ b/docs/source/installation/pip_pypi.rst @@ -31,45 +31,40 @@ The following command installs k2 from PyPI: .. code-block:: bash - pip install --pre k2 + pip install k2 -The wheel packages on PyPI are built using `torch==1.7.1+cu101` on Ubuntu 18.04. -If you are using other Linux systems or a different PyTorch version, -the pre-built wheel packages may NOT work on your system, please install -k2 from source in this case. +.. Caution:: -.. CAUTION:: + The wheel packages on PyPI are built using `torch==1.7.1+cu101` on Ubuntu 18.04. + If you are using other Linux systems or a different PyTorch version, the + pre-built wheel packages may NOT work on your system, please consider one of + the following alternatives to install k2: - k2 is still under active development and we are trying to keep - the packages on PyPI up to date. Please use ``--pre`` in ``pip install``. - - If you want to try the latest version, please refer to - :ref:`install k2 from source`. + - :ref:`install using conda` + - :ref:`install k2 from source` To verify that k2 is installed successfully, run: .. code-block:: $ python3 -m k2.version - /xxx/lib/python3.8/runpy.py:127: RuntimeWarning: 'k2.version' found in sys.modules after import of package 'k2', but prior to execution of 'k2.version'; this may result in unpredictable behaviour - warn(RuntimeWarning(msg)) - Collecting environment information... - k2 version: 0.3.3 + k2 version: 1.8 Build type: Release - Git SHA1: d66cad5067563bb87710a40cf401af35cae816ff - Git date: Fri Apr 30 13:33:47 2021 + Git SHA1: 646704e142438bcd1aaf4a6e32d95e5ccd93a174 + Git date: Thu Sep 16 13:05:12 2021 Cuda used to build k2: 10.1 cuDNN used to build k2: 8.0.2 Python version used to build k2: 3.8 OS used to build k2: Ubuntu 18.04.5 LTS - CMake version: 3.20.1 - GCC version: 5.5.0 - CMAKE_CUDA_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 --expt-extended-lambda -gencode arch=compute_50,code=sm_50 --expt-extended-lambda -gencode arch=compute_60,code=sm_60 --expt-extended-lambda -gencode arch=compute_61,code=sm_61 --expt-extended-lambda -gencode arch=compute_70,code=sm_70 --expt-extended-lambda -gencode arch=compute_75,code=sm_75 --compiler-options -Wall --compiler-options -Wno-unknown-pragmas - CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 + CMake version: 3.21.2 + GCC version: 7.5.0 + CMAKE_CUDA_FLAGS: --expt-extended-lambda -gencode arch=compute_35,code=sm_35 --expt-extended-lambda -gencode arch=compute_50,code=sm_50 --expt-extended-lambda -gencode arch=compute_60,code=sm_60 --expt-extended-lambda -gencode arch=compute_61,code=sm_61 --expt-extended-lambda -gencode arch=compute_70,code=sm_70 --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-unknown-pragmas --compiler-options -Wno-strict-overflow + CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-strict-overflow PyTorch version used to build k2: 1.7.1+cu101 PyTorch is using Cuda: 10.1 NVTX enabled: True + With CUDA: True Disable debug: True Sync kernels : False Disable checks: False diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 8a7a11194..5f41dc624 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -70,6 +70,13 @@ void PybindRaggedAny(py::module &m) { [](const RaggedAny &self) -> std::string { return self.ToString(); }, kRaggedAnyStrDoc); + any.def( + "to_str_simple", + [](const RaggedAny &self) -> std::string { + return self.ToString(/*compact*/ true); + }, + kRaggedAnyToStrSimpleDoc); + any.def( "__repr__", [](const RaggedAny &self) -> std::string { return self.ToString(); }, diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index 5c995fc9d..f92bd6571 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -595,6 +595,28 @@ RaggedTensor([[1], RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.int32) )doc"; +static constexpr const char *kRaggedAnyToStrSimpleDoc = R"doc( +Convert a ragged tensor to a string representation, which +is more compact than ``self.__str__``. + +An example output is given below:: + + RaggedTensor([[[1, 2, 3], [], [0]], [[2], [3, 10.5]]], dtype=torch.float32) + +>>> import k2.ragged as k2r +>>> a = k2r.RaggedTensor([ [[1, 2, 3], [], [0]], [[2], [3, 10.5]] ]) +>>> a +RaggedTensor([[[1, 2, 3], + [], + [0]], + [[2], + [3, 10.5]]], dtype=torch.float32) +>>> str(a) +'RaggedTensor([[[1, 2, 3],\n [],\n [0]],\n [[2],\n [3, 10.5]]], dtype=torch.float32)' +>>> a.to_str_simple() +'RaggedTensor([[[1, 2, 3], [], [0]], [[2], [3, 10.5]]], dtype=torch.float32)' +)doc"; + static constexpr const char *kRaggedAnyGetItemDoc = R"doc( Select the i-th sublist along axis 0. diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index 987d3e598..6ae7278bd 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -43,7 +43,7 @@ static void PrintSpaces(std::ostream &os, int32_t num_spaces) { template void RaggedAnyToStringIter(std::ostream &os, const Ragged ragged, int32_t axis, int32_t begin_pos, int32_t end_pos, - int32_t num_indent) { + int32_t num_indent, bool compact) { const auto &shape = ragged.shape; K2_CHECK(axis >= 0 && axis < shape.NumAxes() && begin_pos >= 0 && begin_pos <= end_pos && end_pos <= shape.TotSize(axis)); @@ -58,7 +58,7 @@ void RaggedAnyToStringIter(std::ostream &os, const Ragged ragged, K2_DCHECK_LE(d, shape.RowSplits(axis + 1).Dim()); int32_t row_start = row_splits[d], row_end = row_splits[d + 1]; - if (!is_first_row) { + if (!compact && !is_first_row) { PrintSpaces(os, num_indent + 1); } is_first_row = false; @@ -66,9 +66,14 @@ void RaggedAnyToStringIter(std::ostream &os, const Ragged ragged, os << "["; RaggedAnyToStringIter(os, ragged, axis + 1, row_start, row_end, - num_indent + 1); + num_indent + 1, compact); os << "]"; - if (d != end_pos - 1) os << ",\n"; + if (d != end_pos - 1) { + if (compact) + os << ", "; + else + os << ",\n"; + } } } } @@ -321,7 +326,8 @@ const torch::Tensor &RaggedAny::Data() const { return data; } -std::string RaggedAny::ToString(int32_t device_id /*=-1*/) const { +std::string RaggedAny::ToString(bool compact /*=false*/, + int32_t device_id /*=-1*/) const { ContextPtr context = any.Context(); if (context->GetDeviceType() != kCpu) { return To("cpu").ToString(context->GetDeviceId()); @@ -342,7 +348,8 @@ std::string RaggedAny::ToString(int32_t device_id /*=-1*/) const { FOR_REAL_AND_INT32_TYPES(t, T, { os << "RaggedTensor(["; // 13 is strlen("RaggedTensor(") - RaggedAnyToStringIter(os, any.Specialize(), 0, 0, any.shape.Dim0(), 13); + RaggedAnyToStringIter(os, any.Specialize(), 0, 0, any.shape.Dim0(), 13, + compact); os << "]"; if (device_id != -1) os << ", device='cuda:" << device_id << "'"; os << ", dtype=" << dtype; diff --git a/k2/python/csrc/torch/v2/ragged_any.h b/k2/python/csrc/torch/v2/ragged_any.h index 2ddec79cd..e3f3bf513 100644 --- a/k2/python/csrc/torch/v2/ragged_any.h +++ b/k2/python/csrc/torch/v2/ragged_any.h @@ -117,10 +117,24 @@ struct RaggedAny { /** Convert a ragged tensor to a string. + An example output for ``compact==false``: + + RaggedTensor([[[1, 2, 3], + [], + [0]], + [[2], + [3, 10.5]]], dtype=torch.float32) + + An example output for ``compact==true``: + + RaggedTensor([[[1, 2, 3], [], [0]], [[2], [3, 10.5]]], dtype=torch.float32) + @param device_id -1 for CPU. 0 and above is for CUDA. + @param compact If false, each sublist occupies a row. If true, all sublists + occupies only one row. @return Return a string representation of this tensor. */ - std::string ToString(int device_id = -1) const; + std::string ToString(bool compact = false, int device_id = -1) const; /* Move a ragged tensor to a given device. diff --git a/k2/python/tests/fsa_test.py b/k2/python/tests/fsa_test.py index 4349fbffc..5cf32848a 100644 --- a/k2/python/tests/fsa_test.py +++ b/k2/python/tests/fsa_test.py @@ -1078,8 +1078,8 @@ def test_clone(self): torch.eq(cloned.tensor_attr1, torch.tensor([10, 20, 30]).to(device))) - assert str(cloned.ragged_attr1) == str( - k2.RaggedTensor('[[100] [] [-1]]')) + assert cloned.ragged_attr1 == \ + k2.RaggedTensor('[[100] [] [-1]]', device=device) def test_detach_more_attributes(self): for device in self.devices: @@ -1123,15 +1123,15 @@ def test_convert_attr_to_ragged(self): dtype=torch.int32, device=device)[::2] fsa.convert_attr_to_ragged_(name='tensor_attr1', remove_eps=False) - expected = k2.RaggedTensor('[ [1] [3] [0] ]') - assert str(fsa.tensor_attr1) == str(expected) + expected = k2.RaggedTensor('[ [1] [3] [0] ]', device=device) + assert fsa.tensor_attr1 == expected fsa.tensor_attr2 = torch.tensor([1, 0, -1], dtype=torch.int32, device=device) fsa.convert_attr_to_ragged_(name='tensor_attr2', remove_eps=True) - expected = k2.RaggedTensor('[ [1] [] [-1] ]') - assert str(fsa.tensor_attr2) == str(expected) + expected = k2.RaggedTensor('[ [1] [] [-1] ]', device=device) + assert fsa.tensor_attr2 == expected def test_invalidate_cache(self): s = ''' diff --git a/k2/python/tests/ragged_tensor_test.py b/k2/python/tests/ragged_tensor_test.py index 2ae9063c0..2e9971292 100644 --- a/k2/python/tests/ragged_tensor_test.py +++ b/k2/python/tests/ragged_tensor_test.py @@ -44,25 +44,45 @@ def setUpClass(cls): def test_create_ragged_tensor(self): funcs = [k2r.create_ragged_tensor, k2r.RaggedTensor] for func in funcs: - a = func([[1000, 2], [3]]) - assert isinstance(a, k2r.RaggedTensor) - assert a.dtype == torch.int32 - - a = func([[1000, 2], [3]], dtype=torch.float32) - assert a.dtype == torch.float32 - - a = func([[1000, 2], [3]], dtype=torch.float64) - assert a.dtype == torch.float64 + for device in self.devices: + a = func([[1000, 2], [3]], device=device) + assert isinstance(a, k2r.RaggedTensor) + assert a.dtype == torch.int32 + assert a.device == device + + a = func([[1000, 2], [3]], dtype=torch.float32, device=device) + assert a.dtype == torch.float32 + assert a.device == device + + a = func([[1000, 2], [3]], dtype=torch.float64, device=device) + assert a.dtype == torch.float64 + assert a.device == device + for dtype in self.dtypes: + a = func([[1000, 2], [3]], dtype=dtype, device=device) + assert a.dtype == dtype + assert a.device == device def test_create_ragged_tensor_from_string(self): - a = k2r.RaggedTensor([[1], [2, 3, 4, 5], []]) - b = k2r.RaggedTensor("[[1] [2 3 4 5] []]") - assert a == b - assert b.dim0 == 3 - - b = k2r.RaggedTensor("[[[1] [2 3] []] [[10]]]") - assert b.num_axes == 3 - assert b.dim0 == 2 + funcs = [k2r.create_ragged_tensor, k2r.RaggedTensor] + for func in funcs: + for device in self.devices: + for dtype in self.dtypes: + a = func( + [[1], [2, 3, 4, 5], []], dtype=dtype, device=device + ) + b = func("[[1] [2 3 4 5] []]", dtype=dtype, device=device) + assert a == b + assert b.dim0 == 3 + assert a.dtype == dtype + assert a.device == device + + b = k2r.RaggedTensor( + "[[[1] [2 3] []] [[10]]]", dtype=dtype, device=device + ) + assert b.num_axes == 3 + assert b.dim0 == 2 + assert b.dtype == dtype + assert b.device == device def test_create_ragged_tensor_from_torch_tensor(self): for device in self.devices: @@ -109,22 +129,33 @@ def test_create_ragged_tensor_from_torch_tensor(self): assert c[1] != -100 def test_property_values(self): - a = k2r.RaggedTensor([[1], [2], [], [3, 4]]) - assert torch.all(torch.eq(a.values, torch.tensor([1, 2, 3, 4]))) + for device in self.devices: + for dtype in self.dtypes: + a = k2r.RaggedTensor( + [[1], [2], [], [3, 4]], device=device, dtype=dtype + ) + assert torch.all( + torch.eq( + a.values, + torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), + ) + ) - with self.assertRaises(AttributeError): - # the `values` attribute is const. You cannot rebind it - a.values = 10 + with self.assertRaises(AttributeError): + # the `values` attribute is const. You cannot rebind it + a.values = 10 - # However, we can change the elements of a.values - a.values[0] = 10 - a.values[-1] *= 2 + # However, we can change the elements of a.values + a.values[0] = 10 + a.values[-1] *= 2 - expected = k2r.RaggedTensor([[10], [2], [], [3, 8]]) - assert a == expected + expected = k2r.RaggedTensor( + [[10], [2], [], [3, 8]], dtype=dtype, device=device + ) + assert a == expected - a.values[0] = 1 - assert a != expected + a.values[0] = 1 + assert a != expected def test_clone(self): a = k2r.RaggedTensor([[1, 2], [], [3]]) diff --git a/k2/python/tests/random_paths_test.py b/k2/python/tests/random_paths_test.py index f753c79ad..595c4500f 100644 --- a/k2/python/tests/random_paths_test.py +++ b/k2/python/tests/random_paths_test.py @@ -157,8 +157,8 @@ def test_fsa_vec(self): 4 ''' - fsa1 = k2.Fsa.from_str(s1) - fsa2 = k2.Fsa.from_str(s2) + fsa1 = k2.Fsa.from_str(s1).to(device) + fsa2 = k2.Fsa.from_str(s2).to(device) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, @@ -175,7 +175,9 @@ def test_fsa_vec(self): # iter 3, p is 0, select arc 9 # path 1 is [2, 4, 7, 9] + 6 = [8, 10, 13, 15] assert path == k2.RaggedTensor( - '[ [ [ 1 2 4 ] ] [ [ 8 10 13 15 ] ] ]', device=device) + '[ [ [ 1 2 4 ] ] [ [ 8 10 13 15 ] ] ]', + device=device, + dtype=path.dtype) path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, @@ -195,7 +197,8 @@ def test_fsa_vec(self): # path 1 is [3, 4, 7, 9] + 6 = [9, 10, 13, 15] assert path == k2.RaggedTensor( '[ [ [ 0 3 4 ] [ 1 3 4 ] ] [ [ 7 10 13 15 ] [ 9 10 13 15 ] ] ]', # noqa - device=device) # noqa + device=device, + dtype=path.dtype) # noqa path = k2.random_paths(fsa_vec, use_double_scores=use_double_scores, num_paths=4) @@ -238,7 +241,8 @@ def test_fsa_vec(self): # path 3 is [3, 5, 7, 9] + 6 = [9, 11, 13, 15] assert path == k2.RaggedTensor( '[ [ [ 0 2 5 ] [ 0 3 5 ] [ 1 2 5 ] [ 1 3 5 ] ] [ [ 6 11 13 15 ] [ 7 11 13 15 ] [ 8 11 13 15 ] [ 9 11 13 15 ] ] ]', # noqa - device=device) + device=device, + dtype=path.dtype) if __name__ == '__main__': From d73a5b52905c552300ae596d557ad5989f39a30f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 19 Sep 2021 09:39:14 +0800 Subject: [PATCH 09/64] Add levenshtein alignment (#828) * Add levenshtein graph * Contruct k2.RaggedTensor in python part * Fix review comments, return aux_labels in ctc_graph * Fix tests * Fix bug of accessing symbols * Fix bug of accessing symbols * Change argument name, add levenshtein_distance interface * Fix test error, add tests for levenshtein_distance * Fix review comments and add unit test for c++ side * update the interface of levenshtein alignment * Fix review comments --- k2/csrc/fsa_algo.cu | 141 +++++++++++++++-- k2/csrc/fsa_algo.h | 30 +++- k2/csrc/fsa_algo_test.cu | 59 +++++-- k2/python/csrc/torch/fsa_algo.cu | 81 ++++------ k2/python/csrc/torch/v2/ragged_any.cu | 5 + k2/python/k2/__init__.py | 2 + k2/python/k2/fsa_algo.py | 144 +++++++++++++++--- k2/python/tests/CMakeLists.txt | 2 + k2/python/tests/levenshtein_alignment_test.py | 129 ++++++++++++++++ k2/python/tests/levenshtein_graph_test.py | 102 +++++++++++++ 10 files changed, 601 insertions(+), 94 deletions(-) create mode 100644 k2/python/tests/levenshtein_alignment_test.py create mode 100644 k2/python/tests/levenshtein_graph_test.py diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index ff699b657..9cb3ae023 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -453,9 +453,129 @@ FsaVec LinearFsas(const Ragged &symbols) { arcs); } +FsaVec LevenshteinGraphs(const Ragged &symbols, + float ins_del_score /* = -0.501 */, + Array1 *aux_labels /*= nullptr*/, + Array1 *score_offsets /*= nullptr*/) { + NVTX_RANGE(K2_FUNC); + K2_CHECK_EQ(symbols.NumAxes(), 2); + ContextPtr &c = symbols.Context(); + + // For each fsa, the number of states will be number of symbols plus 2, we + // plus 2 because we need an extra super final arc for each fsa. + RaggedShape fsa_to_states = ChangeSublistSize(symbols.shape, 2); + + int32_t num_states = fsa_to_states.NumElements(); + Array1 num_arcs_for(c, num_states + 1); + int32_t *num_arcs_for_data = num_arcs_for.Data(); + // "fts" is short for fsa to states + const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(), + *fts_row_ids1_data = fsa_to_states.RowIds(1).Data(); + // set the arcs number for each state + K2_EVAL( + c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { + int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], + final_state = fts_row_splits1_data[fsa_idx0 + 1] - 1, + current_num_arcs = 3; // normally there are three arcs, + // self-loop and two arcs pointing to + // the next state. + if (state_idx01 == final_state - 1) + current_num_arcs = 2; + else if (state_idx01 == final_state) + current_num_arcs = 0; + num_arcs_for_data[state_idx01] = current_num_arcs; + }); + ExclusiveSum(num_arcs_for, &num_arcs_for); + Array1 &states_to_arcs_row_splits = num_arcs_for; + int32_t num_arcs = symbols.NumElements() * 3 + symbols.Dim0() * 2; + RaggedShape states_to_arcs = + RaggedShape2(&states_to_arcs_row_splits, nullptr, num_arcs); + + // shape with a index of [fsa][state][arc] + RaggedShape shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); + Array1 arcs(c, num_arcs); + Arc *arcs_data = arcs.Data(); + const int32_t *row_splits1_data = shape.RowSplits(1).Data(), + *row_ids1_data = shape.RowIds(1).Data(), + *row_splits2_data = shape.RowSplits(2).Data(), + *row_ids2_data = shape.RowIds(2).Data(), + *symbols_data = symbols.values.Data(); + + int32_t *aux_labels_data = nullptr; + if (aux_labels != nullptr) { + *aux_labels = Array1(c, num_arcs); + aux_labels_data = aux_labels->Data(); + } + float *score_offsets_data = nullptr; + if (score_offsets != nullptr) { + *score_offsets = Array1(c, num_arcs); + score_offsets_data = score_offsets->Data(); + } + + K2_EVAL( + c, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { + int32_t state_idx01 = row_ids2_data[arc_idx012], + fsa_idx0 = row_ids1_data[state_idx01], + state_idx0x = row_splits1_data[fsa_idx0], + final_state_idx01 = row_splits1_data[fsa_idx0 + 1] - 1, + state_idx1 = state_idx01 - state_idx0x, + arc_idx01x = row_splits2_data[state_idx01], + arc_idx2 = arc_idx012 - arc_idx01x, + sym_state_idx01 = state_idx01 - 2 * fsa_idx0, + current_symbol = 0, + aux_labels_value = 0; + + if (state_idx01 != final_state_idx01 - 1 && + state_idx01 != final_state_idx01) { + current_symbol = symbols_data[sym_state_idx01]; + K2_CHECK((current_symbol != 0) && (current_symbol != -1)) + << "0 and -1 are not expected to be a symbol."; + } + + float score_offset_value = 0; + Arc arc; + arc.src_state = state_idx1; + + switch (arc_idx2) { + case 0: // the self loop arc + arc.label = 0; + arc.dest_state = state_idx1; + arc.score = ins_del_score; + aux_labels_value = 0; + score_offset_value = ins_del_score - (-0.5); + break; + case 1: // the arc pointing to next state with blank + if (state_idx01 == final_state_idx01 - 1) { // the arc pointing to + // final state + arc.label = -1; + arc.score = 0; + aux_labels_value = -1; + } else { + arc.label = 0; + arc.score = -0.5; + aux_labels_value = current_symbol; + } + arc.dest_state = state_idx1 + 1; + break; + case 2: // the arc pointing to the next state with symbol + arc.label = current_symbol; + arc.dest_state = state_idx1 + 1; + arc.score = 0; + aux_labels_value = current_symbol; + break; + default: + K2_LOG(FATAL) << "Arc index must be less than 3"; + } + + arcs_data[arc_idx012] = arc; + if (aux_labels) aux_labels_data[arc_idx012] = aux_labels_value; + if (score_offsets) score_offsets_data[arc_idx012] = score_offset_value; + }); + return Ragged(shape, arcs); +} FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, - Array1 *arc_map /*= nullptr*/) { + Array1 *aux_labels /*= nullptr*/) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(symbols.NumAxes(), 2); ContextPtr &c = symbols.Context(); @@ -542,10 +662,10 @@ FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, *ctc_row_ids1_data = ctc_shape.RowIds(1).Data(), *ctc_row_splits2_data = ctc_shape.RowSplits(2).Data(), *ctc_row_ids2_data = ctc_shape.RowIds(2).Data(); - int32_t *arc_map_data = nullptr; - if (arc_map != nullptr) { - *arc_map = Array1(c, num_arcs); - arc_map_data = arc_map->Data(); + int32_t *aux_labels_data = nullptr; + if (aux_labels != nullptr) { + *aux_labels = Array1(c, num_arcs); + aux_labels_data = aux_labels->Data(); } K2_EVAL( @@ -565,7 +685,7 @@ FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, Arc arc; arc.score = 0; arc.src_state = state_idx1; - int32_t arc_map_value = -1; + int32_t aux_labels_value = 0; if (remainder) { if (final_state) return; int32_t next_symbol = (sym_state_idx01 + 1) == sym_final_state ? @@ -588,8 +708,8 @@ FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, break; case 2: // the arc pointing to the next symbol state arc.label = next_symbol; - arc_map_value = sym_state_idx01 + 1 == sym_final_state ? - -1 : sym_state_idx01 + 1; + aux_labels_value = sym_state_idx01 + 1 == sym_final_state ? + 0 : next_symbol; arc.dest_state = state_idx1 + 2; break; default: @@ -600,10 +720,11 @@ FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, K2_CHECK_LT(arc_idx2, 2); arc.label = arc_idx2 == 0 ? 0 : current_symbol; arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1; - arc_map_value = (arc_idx2 == 0 || final_state) ? -1 : sym_state_idx01; + aux_labels_value = (arc_idx2 == 0 || final_state) ? + 0 : current_symbol; } arcs_data[arc_idx012] = arc; - if (arc_map) arc_map_data[arc_idx012] = arc_map_value; + if (aux_labels) aux_labels_data[arc_idx012] = aux_labels_value; }); return Ragged(ctc_shape, arcs); } diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index d9ebf3e58..2e14e1c77 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -526,14 +526,36 @@ FsaVec LinearFsas(const Ragged &symbols); @param [in] modified Option to specify the type of CTC topology: "standard" or "modified", where the "standard" one makes the blank mandatory between a pair of identical symbols. - @param [out] It maps the arcs of output fsa to the symbols(idx01), the - olabel of the `arc[i]` would be `symbols[arc_map[i]]`, - and -1 for epsilon olabel. + @param [out] The olabels of the graphs. @return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`. */ FsaVec CtcGraphs(const Ragged &symbols, bool modified = false, - Array1 *arc_map = nullptr); + Array1 *aux_labels = nullptr); + +/* + Create an FasVec containing levenshtein graph FSAs, given a list of sequences + of symbols. See https://github.com/k2-fsa/k2/pull/828 for more details about + the levenshtein graph. + + @param [in] symbols Input symbol sequences (must not contain + kFinalSymbol == -1 and blank == 0). Its num_axes is 2. + @param [in] ins_del_score Specify the score of the self loops in the + graphs, the main idea of this score is to set + insertion and deletion penalty, which will + affect the shortest path searching produre. + @param [out] aux_labels If not null, it will contain the aux_labels of the + graphs. + @param [out] score_offsets The score offset of arcs, for self loop arcs, it + will be `ins_del_score - (-0.5)`, for other arcs, + it will be zeros. The purpose of this + score_offsets is to calculate the levenshtein + distance. + */ +FsaVec LevenshteinGraphs(const Ragged &symbols, + float ins_del_score = -0.501, + Array1 *aux_labels = nullptr, + Array1 *score_offsets = nullptr); /* Create ctc topology from max token id. diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index 091abed13..38970edf2 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1281,8 +1281,8 @@ TEST(FsaAlgo, TestReplaceRandom) { TEST(FsaAlgo, TestCtcGraph) { for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { Ragged symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]"); - Array1 arc_map; - FsaVec graph = CtcGraphs(symbols, false, &arc_map); + Array1 aux_labels; + FsaVec graph = CtcGraphs(symbols, false, &aux_labels); FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 ] " " [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] " @@ -1292,19 +1292,19 @@ TEST(FsaAlgo, TestCtcGraph) { " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] " " [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] " " [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]"); - Array1 arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 -1 2 -1 -1 3 " - " -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 -1 5 " - " -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]"); + Array1 aux_labels_ref(c, "[ 0 1 0 0 2 0 2 0 0 0 2 0 0 3 " + " 0 3 0 0 0 0 0 0 1 0 0 2 0 2 " + " 0 0 3 0 3 0 0 0 0 0 ]"); K2_CHECK(Equal(graph, graph_ref)); - K2_CHECK(Equal(arc_map, arc_map_ref)); + K2_CHECK(Equal(aux_labels, aux_labels_ref)); } } TEST(FsaAlgo, TestCtcGraphSimplified) { for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { Ragged symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]"); - Array1 arc_map; - FsaVec graph = CtcGraphs(symbols, true, &arc_map); + Array1 aux_labels; + FsaVec graph = CtcGraphs(symbols, true, &aux_labels); FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 2 0] " " [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] " @@ -1314,11 +1314,11 @@ TEST(FsaAlgo, TestCtcGraphSimplified) { " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] " " [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] " " [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]"); - Array1 arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 2 -1 2 -1 " - " -1 3 -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 " - " -1 5 -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]"); + Array1 aux_labels_ref(c, "[ 0 1 0 0 2 0 2 0 0 2 0 2 0 " + " 0 3 0 3 0 0 0 0 0 0 1 0 0 2 " + " 0 2 0 0 3 0 3 0 0 0 0 0 ]"); K2_CHECK(Equal(graph, graph_ref)); - K2_CHECK(Equal(arc_map, arc_map_ref)); + K2_CHECK(Equal(aux_labels, aux_labels_ref)); } } @@ -1348,4 +1348,39 @@ TEST(FsaAlgo, TestCtcTopo) { K2_CHECK(Equal(aux_label, aux_label_ref)); } } + +TEST(FsaAlgo, TestLevenshteinGraph) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + Ragged symbols(c, "[ [ 1 2 3 ] [ 4 5 6 ] ]"); + Array1 aux_labels; + Array1 score_offsets; + FsaVec graph = LevenshteinGraphs( + symbols, -0.51, &aux_labels, &score_offsets); + FsaVec graph_ref(c, "[ [ [ 0 0 0 -0.51 0 1 0 -0.5 0 1 1 0 ] " + " [ 1 1 0 -0.51 1 2 0 -0.5 1 2 2 0 ] " + " [ 2 2 0 -0.51 2 3 0 -0.5 2 3 3 0 ] " + " [ 3 3 0 -0.51 3 4 -1 0 ] [ ] ] " + " [ [ 0 0 0 -0.51 0 1 0 -0.5 0 1 4 0 ] " + " [ 1 1 0 -0.51 1 2 0 -0.5 1 2 5 0 ] " + " [ 2 2 0 -0.51 2 3 0 -0.5 2 3 6 0 ] " + " [ 3 3 0 -0.51 3 4 -1 0 ] [ ] ] ]"); + Array1 aux_labels_ref(c, "[ 0 1 1 0 2 2 0 3 3 0 -1 " + " 0 4 4 0 5 5 0 6 6 0 -1 ]"); + Array1 score_offsets_ref("[ -0.01 0 0 -0.01 0 0 -0.01 0 0" + " -0.01 0 -0.01 0 0 -0.01 0 0 " + " -0.01 0 0 -0.01 0 ]"); + K2_CHECK(Equal(graph, graph_ref)); + K2_CHECK(Equal(aux_labels, aux_labels_ref)); + + K2_CHECK_EQ(score_offsets.Dim(), score_offsets_ref.Dim()); + score_offsets = score_offsets.To(GetCpuContext()); + const float *score_offsets_data = score_offsets.Data(), + *score_offsets_ref_data = score_offsets_ref.Data(); + for (int32_t i = 0; i < score_offsets.Dim(); ++i) { + K2_CHECK_LT( + fabs(score_offsets_data[i] - score_offsets_ref_data[i]), 0.0001); + } + } +} + } // namespace k2 diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 7e8d55847..2af279240 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -703,61 +703,16 @@ static void PybindReplaceFsa(py::module &m) { static void PybindCtcGraph(py::module &m) { m.def( "ctc_graph", - [](RaggedAny &symbols, torch::optional = {}, - bool modified = false, - bool need_arc_map = - true) -> std::pair> { + [](RaggedAny &symbols, bool modified = false) + -> std::pair { DeviceGuard guard(symbols.any.Context()); - Array1 arc_map; + Array1 aux_labels; FsaVec graph = CtcGraphs(symbols.any.Specialize(), modified, - need_arc_map ? &arc_map : nullptr); - torch::optional tensor; - if (need_arc_map) tensor = ToTorch(arc_map); - return std::make_pair(graph, tensor); - }, - py::arg("symbols"), py::arg("device") = py::none(), - py::arg("modified") = false, py::arg("need_arc_map") = true); - - m.def( - "ctc_graph", - [](const std::vector> &symbols, - torch::optional device = {}, bool modified = false, - bool need_arc_map = - true) -> std::pair> { - ContextPtr context = - GetContext(device.value_or(torch::Device(torch::kCPU))); - - DeviceGuard guard(context); - Ragged ragged = CreateRagged2(symbols).To(context); - Array1 arc_map; - FsaVec graph = - CtcGraphs(ragged, modified, need_arc_map ? &arc_map : nullptr); - torch::optional tensor; - if (need_arc_map) tensor = ToTorch(arc_map); - return std::make_pair(graph, tensor); - }, - py::arg("symbols"), py::arg("device") = py::none(), - py::arg("modified") = false, py::arg("need_arc_map") = true); - - m.def( - "ctc_graph", - [](const std::vector> &symbols, - torch::optional device = {}, bool modified = false, - bool need_arc_map = - true) -> std::pair> { - ContextPtr context = GetContext(torch::Device(device.value_or("cpu"))); - - DeviceGuard guard(context); - Ragged ragged = CreateRagged2(symbols).To(context); - Array1 arc_map; - FsaVec graph = - CtcGraphs(ragged, modified, need_arc_map ? &arc_map : nullptr); - torch::optional tensor; - if (need_arc_map) tensor = ToTorch(arc_map); + &aux_labels); + torch::Tensor tensor = ToTorch(aux_labels); return std::make_pair(graph, tensor); }, - py::arg("symbols"), py::arg("device") = py::none(), - py::arg("modified") = false, py::arg("need_arc_map") = true); + py::arg("symbols"), py::arg("modified") = false); } static void PybindCtcTopo(py::module &m) { @@ -789,6 +744,29 @@ static void PybindCtcTopo(py::module &m) { py::arg("max_token"), py::arg("device") = py::none(), py::arg("modified") = false); } + +static void PybindLevenshteinGraph(py::module &m) { + m.def( + "levenshtein_graph", + [](RaggedAny &symbols, float ins_del_score = -0.501, + bool need_score_offset = + true) -> std::tuple> { + DeviceGuard guard(symbols.any.Context()); + Array1 aux_labels; + Array1 score_offsets; + FsaVec graph = LevenshteinGraphs(symbols.any.Specialize(), + ins_del_score, &aux_labels, + need_score_offset ? &score_offsets : nullptr); + torch::Tensor aux_labels_tensor = ToTorch(aux_labels); + torch::optional score_offsets_tensor; + if (need_score_offset) score_offsets_tensor = ToTorch(score_offsets); + return std::make_tuple(graph, aux_labels_tensor, score_offsets_tensor); + }, + py::arg("symbols"), py::arg("ins_del_score") = -0.501, + py::arg("need_score_offset") = true); +} + } // namespace k2 void PybindFsaAlgo(py::module &m) { @@ -806,6 +784,7 @@ void PybindFsaAlgo(py::module &m) { k2::PybindIntersectDensePruned(m); k2::PybindIntersectDevice(m); k2::PybindInvert(m); + k2::PybindLevenshteinGraph(m); k2::PybindLinearFsa(m); k2::PybindRemoveEpsilon(m); k2::PybindRemoveEpsilonSelfLoops(m); diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index 6ae7278bd..ad26333e7 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -186,6 +186,9 @@ static Ragged RaggedAnyFromList(py::list data) { RaggedAny::RaggedAny(const RaggedShape &shape, torch::Tensor value) : data(value) { + ContextPtr context = GetContext(value); + DeviceGuard guard(context); + Dtype t = ScalarTypeToDtype(value.scalar_type()); FOR_REAL_AND_INT32_TYPES(t, T, { Array1 array = FromTorch(value); @@ -205,6 +208,7 @@ RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/, } ContextPtr context = GetContext(device); + DeviceGuard guard(context); if (dtype.is_none()) { try { @@ -241,6 +245,7 @@ RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/, } ContextPtr context = GetContext(device); + DeviceGuard guard(context); if (dtype.is_none()) { try { diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 640dc91da..8a5c9ded7 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -30,6 +30,8 @@ from .fsa_algo import intersect from .fsa_algo import intersect_device from .fsa_algo import invert +from .fsa_algo import levenshtein_alignment +from .fsa_algo import levenshtein_graph from .fsa_algo import linear_fsa from .fsa_algo import linear_fst from .fsa_algo import prune_on_arc_post diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 9efa6251a..f51477432 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -983,7 +983,7 @@ def replace_fsa( def ctc_graph(symbols: Union[List[List[int]], k2.RaggedTensor], modified: bool = False, - device: Optional[Union[torch.device, str]] = None) -> Fsa: + device: Optional[Union[torch.device, str]] = "cpu") -> Fsa: '''Construct ctc graphs from symbols. Note: @@ -1004,29 +1004,18 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedTensor], device: Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a torch.device. - If it is None, then the returned FSA is on CPU. It has to be None - if `symbols` is an instance of :class:`k2.RaggedTensor`, the returned + By default, the returned FSA is on CPU. + If `symbols` is an instance of :class:`k2.RaggedTensor`, the returned FSA will on the same device as `k2.RaggedTensor`. Returns: An FsaVec containing the returned ctc graphs, with "Dim0()" the same as "len(symbols)"(List[List[int]]) or "dim0"(k2.RaggedTensor) ''' - symbol_values = None - if isinstance(symbols, k2.RaggedTensor): - assert device is None - assert symbols.num_axes == 2 - symbol_values = symbols.values - else: - symbol_values = torch.tensor( - [it for symbol in symbols for it in symbol], - dtype=torch.int32, - device=device) + if not isinstance(symbols, k2.RaggedTensor): + symbols = k2.RaggedTensor(symbols, device=device) - need_arc_map = True - ragged_arc, arc_map = _k2.ctc_graph(symbols, device, modified, - need_arc_map) - aux_labels = k2.index_select(symbol_values, arc_map) + ragged_arc, aux_labels = _k2.ctc_graph(symbols, modified) fsa = Fsa(ragged_arc, aux_labels=aux_labels) return fsa @@ -1069,3 +1058,124 @@ def ctc_topo(max_token: int, ragged_arc, aux_labels = _k2.ctc_topo(max_token, device, modified) fsa = Fsa(ragged_arc, aux_labels=aux_labels) return fsa + + +def levenshtein_graph( + symbols: Union[k2.RaggedTensor, List[List[int]]], + ins_del_score: float = -0.501, + device: Optional[Union[torch.device, str]] = "cpu" +) -> Fsa: + '''Construct levenshtein graphs from symbols. + + See https://github.com/k2-fsa/k2/pull/828 for more details about levenshtein + graph. + + Args: + symbols: + It can be one of the following types: + + - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]` + - An instance of :class:`k2.RaggedTensor`. + Must have `num_axes == 2` and with dtype `torch.int32`. + + ins_del_score: + The score on the self loops arcs in the graphs, the main idea of this + score is to set insertion and deletion penalty, which will affect the + shortest path searching produre. + device: + Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a + torch.device. + By default, the returned FSA is on CPU. + If `symbols` is an instance of :class:`k2.RaggedTensor`, the returned + FSA will on the same device as `k2.RaggedTensor`. + + Returns: + An FsaVec containing the returned levenshtein graphs, with "Dim0()" + the same as "len(symbols)"(List[List[int]]) or "dim0"(k2.RaggedTensor). + ''' + if not isinstance(symbols, k2.RaggedTensor): + symbols = k2.RaggedTensor(symbols, device=device) + + ragged_arc, aux_labels, score_offsets = _k2.levenshtein_graph( + symbols, ins_del_score, True) + fsa = Fsa(ragged_arc, aux_labels=aux_labels) + # Use the complicated name to avoid conflicts with user defined + # attribute names + setattr(fsa, "__ins_del_score_offset_internal_attr_", score_offsets) + return fsa + + +def levenshtein_alignment( + refs: Fsa, + hyps: Fsa, + hyp_to_ref_map: torch.Tensor, + sorted_match_ref: bool = False, +) -> Fsa: + '''Get the levenshtein alignment of two FsaVecs + + This function supports both CPU and GPU. But it is very slow on CPU. + + Args: + refs: + An FsaVec (must have 3 axes, i.e., `len(refs.shape) == 3`. It is the + output Fsa of the :func:`levenshtein_graph`. + hyps: + An FsaVec (must have 3 axes) on the same device as `refs`. It is the + output Fsa of the :func:`levenshtein_graph`. + hyp_to_ref_map: + A 1-D torch.Tensor with dtype torch.int32 on the same device + as `refs`. Map from FSA-id in `hpys` to the corresponding + FSA-id in `refs` that we want to get levenshtein alignment with. + E.g. might be an identity map, or all-to-zero, or something the + user chooses. + + Requires + - `hyp_to_ref_map.shape[0] == hyps.shape[0]` + - `0 <= hyp_to_ref_map[i] < refs.shape[0]` + sorted_match_ref: + If true, the arcs of refs must be sorted by label (checked by + calling code via properties), and we'll use a matching approach + that requires this. + + Returns: + Returns an FsaVec containing the alignment information and satisfing + `ans.Dim0() == hyps.Dim0()`. Two attributes named `ref_labels` and + `hyp_labels` will be added to the returned FsaVec. `ref_labels` contains + the aligned sequences of refs and `hyp_labels` contains the aligned + sequences of hyps. You can get the levenshtein distance by calling + `get_tot_scores` on the returned FsaVec. + + Examples: + >>> hyps = k2.levenshtein_graph([[1, 2, 3], [1, 3, 3, 2]]) + >>> refs = k2.levenshtein_graph([[1, 2, 4]]) + >>> alignment = k2.levenshtein_alignment( + refs, hyps, + hyp_to_ref_map=torch.tensor([0, 0], dtype=torch.int32), + sorted_match_ref=True) + >>> alignment.labels + tensor([ 1, 2, 0, -1, 1, 0, 0, 0, -1], dtype=torch.int32) + >>> alignment.ref_labels + tensor([ 1, 2, 4, -1, 1, 2, 4, 0, -1], dtype=torch.int32) + >>> alignment.hyp_labels + tensor([ 1, 2, 3, -1, 1, 3, 3, 2, -1], dtype=torch.int32) + >>> -alignment.get_tot_scores( + use_double_scores=False, log_semiring=False)) + tensor([1., 3.]) + ''' + assert hasattr(refs, "aux_labels") + assert hasattr(hyps, "aux_labels") + + hyps.rename_tensor_attribute_("aux_labels", "hyp_labels") + + lattice = k2.intersect_device( + refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref) + lattice = k2.remove_epsilon_self_loops(lattice) + + alignment = k2.shortest_path(lattice, use_double_scores=True).invert_() + alignment.rename_tensor_attribute_("labels", "ref_labels") + alignment.rename_tensor_attribute_("aux_labels", "labels") + + alignment.scores -= getattr( + alignment, "__ins_del_score_offset_internal_attr_") + + return alignment diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 77562841e..8feb2dfc8 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -48,6 +48,8 @@ set(py_test_files intersect_device_test.py intersect_test.py invert_test.py + levenshtein_alignment_test.py + levenshtein_graph_test.py linear_fsa_test.py linear_fst_test.py multi_gpu_test.py diff --git a/k2/python/tests/levenshtein_alignment_test.py b/k2/python/tests/levenshtein_alignment_test.py new file mode 100644 index 000000000..f6e69ace4 --- /dev/null +++ b/k2/python/tests/levenshtein_alignment_test.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (authors: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R levenshtein_alignment_test_py + +from typing import List + +import random +import unittest + +import k2 +import torch + + +def levenshtein_distance(arr1: List[int], arr2: List[int]) -> int: + m = len(arr1) + 1 + n = len(arr2) + 1 + dp = [[0] * n for _ in range(m)] + for r in range(1, m): + dp[r][0] = r + for c in range(1, n): + dp[0][c] = c + for i in range(1, m): + for j in range(1, n): + if arr1[i - 1] == arr2[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + else: + dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + return dp[-1][-1] + + +class TestLevenshteinDistance(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + + def test(self): + for device in self.devices: + refs_vec = [[1, 2, 3, 4, 5]] + hyps_vec = [[1, 2, 3, 3, 5], [1, 2, 4, 5], [1, 2, 3, 4, 5, 6]] + refs = k2.levenshtein_graph(refs_vec, device=device) + hyps = k2.levenshtein_graph(hyps_vec, device=device) + + alignment = k2.levenshtein_alignment( + refs, + hyps, + hyp_to_ref_map=torch.tensor( + [0, 0, 0], dtype=torch.int32, device=device + ), + sorted_match_ref=True, + ) + + labels_refs = torch.tensor( + [1, 2, 3, 0, 5, -1, 1, 2, 0, 4, 5, -1, 1, 2, 3, 4, 5, 0, -1], + dtype=torch.int32, + ) + ref_labels_refs = torch.tensor( + [1, 2, 3, 4, 5, -1, 1, 2, 3, 4, 5, -1, 1, 2, 3, 4, 5, 0, -1], + dtype=torch.int32, + ) + hyp_labels_refs = torch.tensor( + [1, 2, 3, 3, 5, -1, 1, 2, 0, 4, 5, -1, 1, 2, 3, 4, 5, 6, -1], + dtype=torch.int32, + ) + assert torch.all(torch.eq(alignment.labels.to("cpu"), labels_refs)) + assert torch.all( + torch.eq(alignment.ref_labels.to("cpu"), ref_labels_refs) + ) + assert torch.all( + torch.eq(alignment.hyp_labels.to("cpu"), hyp_labels_refs) + ) + + def test_distance(self): + for device in self.devices: + refs_num = random.randint(2, 10) + hyps_num = random.randint(2, 10) + refs_vec = [[random.randint(1, 100) for i in range(refs_num)]] + hyps_vec = [ + [random.randint(1, 100) for i in range(hyps_num)] + for j in range(hyps_num) + ] + refs = k2.levenshtein_graph(refs_vec, device=device) + hyps = k2.levenshtein_graph(hyps_vec, device=device) + + alignment = k2.levenshtein_alignment( + refs, + hyps, + hyp_to_ref_map=torch.tensor( + [0] * hyps_num, dtype=torch.int32, device=device + ), + sorted_match_ref=True, + ) + distance_vec = [] + for i in range(hyps_num): + distance_vec.append( + levenshtein_distance(refs_vec[0], hyps_vec[i]) + ) + + distance_refs = torch.tensor(distance_vec, dtype=torch.float32) + distance = -alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + assert torch.allclose(distance.to("cpu"), distance_refs) + + +if __name__ == "__main__": + unittest.main() diff --git a/k2/python/tests/levenshtein_graph_test.py b/k2/python/tests/levenshtein_graph_test.py new file mode 100644 index 000000000..f05647171 --- /dev/null +++ b/k2/python/tests/levenshtein_graph_test.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (authors: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R levenshtein_graph_test_py + +import unittest + +import k2 +import torch + + +class TestLevenshteinGraph(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + + def test(self): + for device in self.devices: + for score in [-0.5, -0.501, -0.502]: + s = ''' + [ [1 2 3] [ ] [4 5 6] ] + ''' + ragged_int = k2.RaggedTensor(s).to(device) + fsa_vec_ragged = k2.levenshtein_graph( + ragged_int, ins_del_score=score) + + fsa_vec = k2.levenshtein_graph( + [[1, 2, 3], [], [4, 5, 6]], device=device, + ins_del_score=score) + + expected_str0 = '\n'.join([ + f'0 0 0 0 {score}', '0 1 0 1 -0.5', '0 1 1 1 0', + f'1 1 0 0 {score}', '1 2 0 2 -0.5', '1 2 2 2 0', + f'2 2 0 0 {score}', '2 3 0 3 -0.5', '2 3 3 3 0', + f'3 3 0 0 {score}', '3 4 -1 -1 0', '4' + ]) + expected_str1 = '\n'.join([ + f'0 0 0 0 {score}', '0 1 -1 -1 0', '1' + ]) + expected_str2 = '\n'.join([ + f'0 0 0 0 {score}', '0 1 0 4 -0.5', '0 1 4 4 0', + f'1 1 0 0 {score}', '1 2 0 5 -0.5', '1 2 5 5 0', + f'2 2 0 0 {score}', '2 3 0 6 -0.5', '2 3 6 6 0', + f'3 3 0 0 {score}', '3 4 -1 -1 0', '4' + ]) + actual_str_ragged0 = k2.to_str_simple( + fsa_vec_ragged[0].to('cpu')) + actual_str_ragged1 = k2.to_str_simple( + fsa_vec_ragged[1].to('cpu')) + actual_str_ragged2 = k2.to_str_simple( + fsa_vec_ragged[2].to('cpu')) + actual_str0 = k2.to_str_simple(fsa_vec[0].to('cpu')) + actual_str1 = k2.to_str_simple(fsa_vec[1].to('cpu')) + actual_str2 = k2.to_str_simple(fsa_vec[2].to('cpu')) + assert actual_str0.strip() == expected_str0 + assert actual_str1.strip() == expected_str1 + assert actual_str2.strip() == expected_str2 + assert actual_str_ragged0.strip() == expected_str0 + assert actual_str_ragged1.strip() == expected_str1 + assert actual_str_ragged2.strip() == expected_str2 + + offset_value = score - (-0.5) + expected_offset = torch.tensor([ + offset_value, 0, 0, offset_value, 0, 0, offset_value, 0, 0, + offset_value, 0, offset_value, 0, + offset_value, 0, 0, offset_value, 0, 0, offset_value, 0, 0, + offset_value, 0], dtype=torch.float32) + + offset_ragged = getattr( + fsa_vec_ragged, "__ins_del_score_offset_internal_attr_") + offset_ragged = offset_ragged.to('cpu') + offset = getattr( + fsa_vec, "__ins_del_score_offset_internal_attr_").to('cpu') + assert torch.allclose(expected_offset, offset_ragged) + assert torch.allclose(expected_offset, offset) + + +if __name__ == '__main__': + unittest.main() From f2fd997f752ed11bbef4c306652c433e83f9cf12 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 19 Sep 2021 09:41:46 +0800 Subject: [PATCH 10/64] Release v1.9 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 75af4772f..1ba68c478 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,7 +42,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.8") +set(K2_VERSION "1.9") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) From 601d663fa11dbab0788f41013aee08db62c7ac18 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 25 Sep 2021 23:33:46 +0800 Subject: [PATCH 11/64] Support a[b[i]] where both a and b are ragged tensors. (#833) --- k2/python/csrc/torch/v2/any.cu | 20 ++++++++++ k2/python/csrc/torch/v2/doc/any.h | 53 +++++++++++++++++++++++++++ k2/python/tests/ragged_tensor_test.py | 40 +++++++++++++++++++- 3 files changed, 112 insertions(+), 1 deletion(-) diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 5f41dc624..0c9f07b4a 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -89,6 +89,7 @@ void PybindRaggedAny(py::module &m) { RaggedAny ragged = self.Index(/*axis*/ 0, i); return py::cast(ragged); } else { + DeviceGuard guard(self.any.Context()); K2_CHECK_EQ(self.any.NumAxes(), 2); Array1 row_split = self.any.RowSplits(1).To(GetCpuContext()); const int32_t *row_split_data = row_split.Data(); @@ -122,6 +123,25 @@ void PybindRaggedAny(py::module &m) { }, py::arg("key"), kRaggedAnyGetItemSliceDoc); + any.def( + "__getitem__", + [](RaggedAny &self, torch::Tensor key) -> RaggedAny { + // key is a 1-d torch tensor with dtype torch.int32 + DeviceGuard guard(self.any.Context()); + Array1 indexes = FromTorch(key); + Dtype t = self.any.GetDtype(); + FOR_REAL_AND_INT32_TYPES(t, T, { + Ragged ans = + k2::Index(self.any.Specialize(), /*axis*/ 0, indexes, + /*value_indexes*/ nullptr); + + return RaggedAny(ans.Generic()); + }); + // Unreachable code + return {}; + }, + py::arg("key"), kRaggedAnyGetItem1DTensorDoc); + any.def("index", static_cast(&RaggedAny::Index), py::arg("indexes"), kRaggedAnyRaggedIndexDoc); diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index f92bd6571..764cb8b04 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -692,6 +692,59 @@ RaggedTensor([[[8]]], dtype=torch.int32) only contains the sublists within the range. )doc"; +static constexpr const char *kRaggedAnyGetItem1DTensorDoc = R"doc( +Slice a ragged tensor along axis 0 using a 1-D torch.int32 tensor. + +**Example 1**: + + >>> import k2 + >>> a = k2.RaggedTensor([[1, 2, 0], [0, 1], [2, 3]]) + >>> b = k2.RaggedTensor([[10, 20], [300], [-10, 0, -1], [-2, 4, 5]]) + >>> a[0] + tensor([1, 2, 0], dtype=torch.int32) + >>> b[a[0]] + RaggedTensor([[300], + [-10, 0, -1], + [10, 20]], dtype=torch.int32) + >>> a[1] + tensor([0, 1], dtype=torch.int32) + >>> b[a[1]] + RaggedTensor([[10, 20], + [300]], dtype=torch.int32) + >>> a[2] + tensor([2, 3], dtype=torch.int32) + >>> b[a[2]] + RaggedTensor([[-10, 0, -1], + [-2, 4, 5]], dtype=torch.int32) + +**Example 2**: + + >>> import torch + >>> import k2 + >>> a = k2.RaggedTensor([ [[1], [2, 3], [0]], [[], [2]], [[10, 20]] ]) + >>> i = torch.tensor([0, 2, 1, 0], dtype=torch.int32) + >>> a[i] + RaggedTensor([[[1], + [2, 3], + [0]], + [[10, 20]], + [[], + [2]], + [[1], + [2, 3], + [0]]], dtype=torch.int32) + +Args: + key: + A 1-D torch.int32 tensor containing the indexes to select along + axis 0. + +Return: + Return a new ragged tensor with the same number of axes as ``self`` but + only contains the specified sublists. + +)doc"; + static constexpr const char *kRaggedAnyCloneDoc = R"doc( Return a copy of this tensor. diff --git a/k2/python/tests/ragged_tensor_test.py b/k2/python/tests/ragged_tensor_test.py index 2e9971292..da8849155 100644 --- a/k2/python/tests/ragged_tensor_test.py +++ b/k2/python/tests/ragged_tensor_test.py @@ -227,7 +227,7 @@ def test_sum_no_grad(self): assert torch.all(torch.eq(b, expected_sum)) - def test_getitem(self): + def test_getitem_scalar(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor( @@ -244,6 +244,44 @@ def test_getitem(self): expected = k2r.RaggedTensor("[[3] [5]]", dtype=dtype).to(device) assert b == expected + def test_getitem_1d_tensor(self): + for device in self.devices: + for dtype in self.dtypes: + a = k2r.RaggedTensor([[1, 2, 0], [0, 1], [2, 3]], device=device) + b = k2.RaggedTensor( + # 0 1 2 3 + [[10, 20], [300], [-10, 0, -1], [-2, 4, 5]], + dtype=dtype, + device=device, + ) + + # for a[0] + index = torch.tensor( + [1, 2, 0], dtype=torch.int32, device=device + ) + assert torch.all(torch.eq(a[0], index)) + expected = k2.RaggedTensor( + [[300], [-10, 0, -1], [10, 20]], dtype=dtype, device=device + ) + + assert b[a[0]] == expected + + # for a[1] + index = torch.tensor([0, 1], dtype=torch.int32, device=device) + assert torch.all(torch.eq(a[1], index)) + expected = k2.RaggedTensor( + [[10, 20], [300]], dtype=dtype, device=device + ) + assert b[a[1]] == expected + + # for a[2] + index = torch.tensor([2, 3], dtype=torch.int32, device=device) + assert torch.all(torch.eq(a[2], index)) + expected = k2.RaggedTensor( + [[-10, 0, -1], [-2, 4, 5]], dtype=dtype, device=device + ) + assert b[a[2]] == expected + def test_getstate_2axes(self): for device in self.devices: for dtype in self.dtypes: From 8694fee66f564cf750792cb30c639d3cc404c18b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 30 Sep 2021 11:35:28 -0400 Subject: [PATCH 12/64] Display import error solution message on MacOS (#837) --- k2/python/k2/__init__.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 8a5c9ded7..d3f9207be 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -1,6 +1,15 @@ import torch # noqa -from _k2 import DeterminizeWeightPushingType -from _k2 import simple_ragged_index_select +try: + from _k2 import DeterminizeWeightPushingType + from _k2 import simple_ragged_index_select +except ImportError as e: + import sys + major_v, minor_v = sys.version_info[:2] + raise ImportError( + str(e) + "\nNote: If you're using anaconda and importing k2 on MacOS," + "\n you can probably fix this by setting the environment variable:" + f"\n export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python{major_v}.{minor_v}/site-packages:$DYLD_LIBRARY_PATH" + ) from .ragged import RaggedShape from .ragged import RaggedTensor From 86e54797ec02881a5642d231eaf2c2534e48ef16 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Oct 2021 22:55:31 +0800 Subject: [PATCH 13/64] Fix installation doc. (#841) * Fix installation doc. Remove Windows support. Will fix it later. * Fix style issues. --- docs/source/installation/conda.rst | 17 +++++++++++++++-- k2/python/k2/__init__.py | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/docs/source/installation/conda.rst b/docs/source/installation/conda.rst index 9afb539f3..4c50b8c9d 100644 --- a/docs/source/installation/conda.rst +++ b/docs/source/installation/conda.rst @@ -1,8 +1,12 @@ .. _install using conda: -Install using conda (Linux/macOS/Windows) -========================================= +Install using conda (Linux/macOS) +================================= + +.. HINT:: + + Windows is currently not supported. All you need is the following line @@ -38,6 +42,15 @@ To Install a CPU version, use: export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.8/site-packages:$DYLD_LIBRARY_PATH python3 -m k2.version # now it should work +.. HINT:: + + If you encounter the following error:: + + ModuleNotFoundError: no module named graphviz + + Please run:: + + conda install -c anaconda graphviz Read the following if you want to learn more. diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index d3f9207be..c60194d58 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -8,7 +8,7 @@ raise ImportError( str(e) + "\nNote: If you're using anaconda and importing k2 on MacOS," "\n you can probably fix this by setting the environment variable:" - f"\n export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python{major_v}.{minor_v}/site-packages:$DYLD_LIBRARY_PATH" + f"\n export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python{major_v}.{minor_v}/site-packages:$DYLD_LIBRARY_PATH" # noqa ) from .ragged import RaggedShape from .ragged import RaggedTensor From b72589cbce2dc1311a6b4ead96b67b8f8652d356 Mon Sep 17 00:00:00 2001 From: "Jan \"yenda\" Trmal" Date: Wed, 13 Oct 2021 18:19:02 -0400 Subject: [PATCH 14/64] fix typos in the install instructions (#844) --- INSTALL.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/INSTALL.rst b/INSTALL.rst index 241397d80..b3b761720 100644 --- a/INSTALL.rst +++ b/INSTALL.rst @@ -71,8 +71,8 @@ To build a release version, use: cd build_release cmake -DCMAKE_BUILD_TYPE=Release .. make -j - export PYHONPATH=$PWD/../k2/python:$PYHONPATH # for `import k2` - export PYHONPATH=$PWD/lib:$PYHONPATH # for `import _k2` + export PYTHONPATH=$PWD/../k2/python:$PYTHONPATH # for `import k2` + export PYTHONPATH=$PWD/lib:$PYTHONPATH # for `import _k2` # To test that your build is successful, run python3 -c "import k2; print(k2.__file__)" @@ -91,8 +91,8 @@ To build a debug version, use: cd build_debug cmake -DCMAKE_BUILD_TYPE=Debug .. make -j - export PYHONPATH=$PWD/../k2/python:$PYHONPATH # for `import k2` - export PYHONPATH=$PWD/lib:$PYHONPATH # for `import _k2` + export PYTHONPATH=$PWD/../k2/python:$PYTHONPATH # for `import k2` + export PYTHONPATH=$PWD/lib:$PYTHONPATH # for `import _k2` # To test that your build is successful, run python3 -c "import k2; print(k2.__file__)" @@ -141,8 +141,8 @@ To run a specific Python test, use: cd /some/path/k2/build_release # or switch to build_debug - export PYHONPATH=$PWD/../k2/python:$PYHONPATH # for `import k2` - export PYHONPATH=$PWD/lib:$PYHONPATH # for `import _k2` + export PYTHONPATH=$PWD/../k2/python:$PYTHONPATH # for `import k2` + export PYTHONPATH=$PWD/lib:$PYTHONPATH # for `import _k2` python3 ../k2/python/tests/index_test.py From 6ac97950bd51776c9323276b64f598cf81394f7b Mon Sep 17 00:00:00 2001 From: "Jan \"yenda\" Trmal" Date: Wed, 13 Oct 2021 19:18:33 -0400 Subject: [PATCH 15/64] make cmake adhere to the modernized way of finding packages outside default dirs (#845) --- CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ba68c478..b667178e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,9 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR) if(POLICY CMP0111) cmake_policy(SET CMP0111 OLD) endif() +if(POLICY CMP0074) + cmake_policy(SET CMP0074 NEW) +endif() set(languages CXX) set(_K2_WITH_CUDA ON) From 2537a3fa927671faec5e4ca56b8b151806356324 Mon Sep 17 00:00:00 2001 From: "Jan \"yenda\" Trmal" Date: Thu, 14 Oct 2021 19:40:41 -0400 Subject: [PATCH 16/64] import torch first in the smoke tests to preven SEGFAULT (#846) --- INSTALL.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/INSTALL.rst b/INSTALL.rst index b3b761720..dd5fd8ef1 100644 --- a/INSTALL.rst +++ b/INSTALL.rst @@ -78,7 +78,7 @@ To build a release version, use: python3 -c "import k2; print(k2.__file__)" # It should print /some/path/k2/k2/python/k2/__init.py - python3 -c "import _k2; print(_k2.__file__)" + python3 -c "import torch; import _k2; print(_k2.__file__)" # It should print /some/path/k2/build_release/lib/_k2.cpython-38-x86_64-linux-gnu.so # (I assume that you're using Python 3.8, so there is a string 38 above) @@ -98,7 +98,7 @@ To build a debug version, use: python3 -c "import k2; print(k2.__file__)" # It should print /some/path/k2/k2/python/k2/__init.py - python3 -c "import _k2; print(_k2.__file__)" + python3 -c "import torch; import _k2; print(_k2.__file__)" # It should print /some/path/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so # (I assume that you're using Python 3.8, so there is a string 38 above) @@ -191,7 +191,7 @@ To run a specific Python test, use: .. code-block:: - python3 -c "import _k2; print(_k2.__file__)" + python3 -c "import torch; import _k2; print(_k2.__file__)" It should print the directory where k2 was built. That is, the above output contains a string ``build_release`` or ``build_debug``. From cae610a97aab52944c00792c8878d06f0257e7a7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 23 Oct 2021 17:35:57 +0800 Subject: [PATCH 17/64] Add doc about how to install a CPU version of k2. (#850) * Add doc about how to install a CPU version of k2. * Remove property setter of Fsa.labels * Update Ubuntu version in GitHub CI since 16.04 reaches end-of-life. --- .github/workflows/build_conda.yml | 2 +- INSTALL.rst | 14 +++++++++++++- docs/source/installation/pip.rst | 14 ++++++++++++++ k2/python/k2/fsa.py | 13 ------------- 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build_conda.yml b/.github/workflows/build_conda.yml index cdb80d71d..4fba8ca5f 100644 --- a/.github/workflows/build_conda.yml +++ b/.github/workflows/build_conda.yml @@ -32,7 +32,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-16.04] + os: [ubuntu-18.04] # anaconda does not support 3.9 as of 2021.05.08 python-version: [3.6, 3.7, 3.8, 3.9] # python-version: [3.6, 3.7, 3.8] diff --git a/INSTALL.rst b/INSTALL.rst index dd5fd8ef1..e4bbf7515 100644 --- a/INSTALL.rst +++ b/INSTALL.rst @@ -17,7 +17,19 @@ Use pip .. code-block:: bash - pip install --pre k2 + # Install a CUDA version compiled using CUDA 10.1 and PyTorch 1.7.1 + # + pip install k2 + + # Install a CPU version compiled against PyTorch 1.8.1 on 2021.10.22 + # + pip install k2==1.9.dev20211022+cpu.torch1.8.1 -f https://k2-fsa.org/nightly/ + + # Install a CPU version compiled against PyTorch 1.9.0 on 2021.10.22 + # + pip install k2==1.9.dev20211022+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/ + + # Please visit https://k2-fsa.org/nightly/ for more versions of k2 Read the following two pages to learn more: diff --git a/docs/source/installation/pip.rst b/docs/source/installation/pip.rst index ff1f1fdfb..20db2a886 100644 --- a/docs/source/installation/pip.rst +++ b/docs/source/installation/pip.rst @@ -51,6 +51,20 @@ The following commands install k2 with different versions of CUDA and PyTorch: # Please always select the latest version. That is, the version # with the latest date. +To install a version for CPU only, please use: + +.. code-block:: bash + + # Install a CPU version compiled against PyTorch 1.8.1 on 2021.10.22 + # + pip install k2==1.9.dev20211022+cpu.torch1.8.1 -f https://k2-fsa.org/nightly/ + + # Install a CPU version compiled against PyTorch 1.9.0 on 2021.10.22 + # + pip install k2==1.9.dev20211022+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/ + + # Please visit https://k2-fsa.org/nightly/ for more versions of k2 + .. Caution:: We only provide pre-compiled versions of k2 with torch 1.7.1. If you need diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index 619f3a0e0..aabbc8dd0 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -437,19 +437,6 @@ def labels(self) -> torch.Tensor: traceback.print_exc() raise e - @labels.setter - def labels(self, values) -> None: - '''Set labels. - - Args: - values: - A 1-D `torch.tensor` with dtype `torch.int32`. - ''' - assert values.dtype == torch.int32 - self.arcs.values()[:, 2] = values - # Invalidate the properties since we changed the labels. - self.__dict__['_properties'] = None - @property def properties(self) -> int: # instead of accessing self._properties, we use From d061bc600be10a665a7108a6b9c6c12e703070a8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 24 Oct 2021 17:28:09 +0800 Subject: [PATCH 18/64] Support PyTorch 1.10. (#851) --- .github/workflows/build-cpu.yml | 4 +-- .github/workflows/build.yml | 47 +++++++++++++++++++++---- .github/workflows/build_conda.yml | 27 +++++++++----- .github/workflows/build_conda_cpu.yml | 28 +++++++-------- .github/workflows/nightly-cpu.yml | 4 +-- .github/workflows/nightly.yml | 7 ++++ .github/workflows/run-tests.yml | 4 +++ .github/workflows/wheel-stable.yml | 4 +++ .github/workflows/wheel.yml | 4 +++ CMakeLists.txt | 11 ++++-- scripts/github_actions/install_cuda.sh | 4 +++ scripts/github_actions/install_cudnn.sh | 27 ++++++-------- scripts/github_actions/install_torch.sh | 19 +++++++++- 13 files changed, 138 insertions(+), 52 deletions(-) diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 02c373799..90e7f6114 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -36,8 +36,8 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04, macos-10.15] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0, + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10 python-version: [3.6, 3.7, 3.8, 3.9] exclude: - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 756de8f54..6f86b6844 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,19 +37,42 @@ jobs: matrix: os: [ubuntu-18.04] # from https://download.pytorch.org/whl/torch_stable.html - # 1.9.0 supports: cuda10.2 (default), 11.1 + # Note: There are no torch versions for CUDA 11.2 + # + # 1.10 supports: cuda10.2 (default), 11.1, 11.3 + # 1.9.x supports: cuda10.2 (default), 11.1 # PyTorch 1.8.x supports: cuda 10.1, 10.2 (default), 11.1 # PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0 # PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default) # PyTorch 1.5.x supports: cuda 10.1, 10.2 (default) # Other PyTorch versions are not tested - cuda: ["10.1", "10.2", "11.0", "11.1"] + # CUDA 11.3 is for torch 1.10 + cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"] gcc: ["7"] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] - # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.0 + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + # + # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10 python-version: [3.6, 3.7, 3.8, 3.9] exclude: - - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0] + - cuda: "11.3" # exclude 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1] + torch: "1.5.0" + - cuda: "11.3" + torch: "1.5.1" + - cuda: "11.3" + torch: "1.6.0" + - cuda: "11.3" + torch: "1.7.0" + - cuda: "11.3" + torch: "1.7.1" + - cuda: "11.3" + torch: "1.8.0" + - cuda: "11.3" + torch: "1.8.1" + - cuda: "11.3" + torch: "1.9.0" + - cuda: "11.3" + torch: "1.9.1" + - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10] torch: "1.5.0" - cuda: "11.0" torch: "1.5.1" @@ -61,6 +84,10 @@ jobs: torch: "1.8.1" - cuda: "11.0" torch: "1.9.0" + - cuda: "11.0" + torch: "1.9.1" + - cuda: "11.0" + torch: "1.10" - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1] torch: "1.5.0" - cuda: "11.1" @@ -71,8 +98,12 @@ jobs: torch: "1.7.0" - cuda: "11.1" torch: "1.7.1" - - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0] + - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10] torch: "1.9.0" + - cuda: "10.1" + torch: "1.9.1" + - cuda: "10.1" + torch: "1.10" - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - python-version: 3.9 @@ -117,6 +148,10 @@ jobs: echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: diff --git a/.github/workflows/build_conda.yml b/.github/workflows/build_conda.yml index 4fba8ca5f..91e703f2b 100644 --- a/.github/workflows/build_conda.yml +++ b/.github/workflows/build_conda.yml @@ -33,20 +33,19 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04] - # anaconda does not support 3.9 as of 2021.05.08 python-version: [3.6, 3.7, 3.8, 3.9] - # python-version: [3.6, 3.7, 3.8] - cuda: ["10.1", "10.2", "11.0", "11.1"] + cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"] # from https://download.pytorch.org/whl/torch_stable.html # - # PyTorch 1.9.0 supports: 10.2 (default), 11.1 + # PyTorch 1.10 supports: 10.2 (default), 11.1, 11.3 + # PyTorch 1.9.x supports: 10.2 (default), 11.1 # PyTorch 1.8.1 supports: cuda 10.1, 10.2 (default), 11.1 # PyTorch 1.8.0 supports: cuda 10.1, 10.2 (default), 11.1 # PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0, 9.2 (not included in this setup) # PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup) # PyTorch 1.5.x supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup) # - # PyTorch 1.8.x and 1.7.1 support 3.6, 3.7, 3.8, 3.9 + # PyTorch 1.7.1, 1.8.x, 1.9.x, and 1.10 support 3.6, 3.7, 3.8, 3.9 # PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8 # # Other PyTorch versions are not tested @@ -57,9 +56,9 @@ jobs: # https://github.com/csukuangfj/k2/runs/2533830771?check_suite_focus=true # and # https://github.com/NVIDIA/apex/issues/805 - torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] + torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] exclude: - # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0] + # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10] # torch: "1.5.0" # - cuda: "11.0" # torch: "1.5.1" @@ -71,6 +70,10 @@ jobs: torch: "1.8.1" - cuda: "11.0" torch: "1.9.0" + - cuda: "11.0" + torch: "1.9.1" + - cuda: "11.0" + torch: "1.10" # - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1] # torch: "1.5.0" # - cuda: "11.1" @@ -81,8 +84,12 @@ jobs: torch: "1.7.0" - cuda: "11.1" torch: "1.7.1" - - cuda: "10.1" # exclude 10.1 for [1.9.0] + - cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10] torch: "1.9.0" + - cuda: "10.1" + torch: "1.9.1" + - cuda: "10.1" + torch: "1.10" - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - python-version: 3.9 @@ -142,6 +149,10 @@ jobs: conda info nproc + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + - name: Download cudnn 8.0 shell: bash -l {0} env: diff --git a/.github/workflows/build_conda_cpu.yml b/.github/workflows/build_conda_cpu.yml index b7d5fb72f..b7f2dcc4a 100644 --- a/.github/workflows/build_conda_cpu.yml +++ b/.github/workflows/build_conda_cpu.yml @@ -42,27 +42,25 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-16.04, macos-10.15] - # anaconda does not support 3.9 as of 2021.05.08 - # python-version: [3.6, 3.7, 3.8, 3.9] - python-version: [3.6, 3.7, 3.8] + os: [ubuntu-18.04, macos-10.15] + python-version: [3.6, 3.7, 3.8, 3.9] # from https://download.pytorch.org/whl/torch_stable.html # - # PyTorch 1.9.0, 1.8.x, and 1.7.1 support 3.6, 3.7, 3.8, 3.9 + # PyTorch 1.10, 1.9.x, 1.8.x, and 1.7.1 support 3.6, 3.7, 3.8, 3.9 # PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8 # # Other PyTorch versions are not tested # - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] - # exclude: - # - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - # torch: "1.5.0" - # - python-version: 3.9 - # torch: "1.5.1" - # - python-version: 3.9 - # torch: "1.6.0" - # - python-version: 3.9 - # torch: "1.7.0" + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + exclude: + - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] + torch: "1.5.0" + - python-version: 3.9 + torch: "1.5.1" + - python-version: 3.9 + torch: "1.6.0" + - python-version: 3.9 + torch: "1.7.0" steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/nightly-cpu.yml b/.github/workflows/nightly-cpu.yml index 4a33a59f9..249fd215d 100644 --- a/.github/workflows/nightly-cpu.yml +++ b/.github/workflows/nightly-cpu.yml @@ -39,9 +39,9 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04, macos-10.15] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0 + # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10 python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] + torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] exclude: - python-version: 3.9 # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.4.0" diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 084c8569d..0cc6cb1be 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -1,6 +1,9 @@ name: nightly on: + push: + branches: + - nightly schedule: # minute (0-59) # hour (0-23) @@ -80,6 +83,10 @@ jobs: ./scripts/github_actions/install_torch.sh python3 -c "import torch; print('torch version:', torch.__version__)" + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + - name: Download cudnn 8.0 env: cuda: ${{ matrix.cuda }} diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 81a3ba933..dbe73a0c8 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -91,6 +91,10 @@ jobs: ./scripts/github_actions/install_torch.sh python3 -c "import torch; print('torch version:', torch.__version__)" + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + - name: Download cudnn 8.0 env: cuda: ${{ matrix.cuda }} diff --git a/.github/workflows/wheel-stable.yml b/.github/workflows/wheel-stable.yml index f8c52dfe4..f142c2910 100644 --- a/.github/workflows/wheel-stable.yml +++ b/.github/workflows/wheel-stable.yml @@ -70,6 +70,10 @@ jobs: ./scripts/github_actions/install_torch.sh python3 -c "import torch; print('torch version:', torch.__version__)" + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + - name: Download cudnn 8.0 env: cuda: ${{ matrix.cuda }} diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index becfa40fc..74c46595a 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -70,6 +70,10 @@ jobs: ./scripts/github_actions/install_torch.sh python3 -c "import torch; print('torch version:', torch.__version__)" + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + - name: Download cudnn 8.0 env: cuda: ${{ matrix.cuda }} diff --git a/CMakeLists.txt b/CMakeLists.txt index b667178e8..918aa079f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,10 +203,17 @@ if(K2_WITH_CUDA) cuda_select_nvcc_arch_flags(K2_COMPUTE_ARCH_FLAGS) message(STATUS "K2_COMPUTE_ARCH_FLAGS: ${K2_COMPUTE_ARCH_FLAGS}") -# set(K2_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72) - message(WARNING "arch 62/72 are not supported for now") + # set(K2_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72) + # message(WARNING "arch 62/72 are not supported for now") + # see https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/ + # https://www.myzhar.com/blog/tutorials/tutorial-nvidia-gpu-cuda-compute-capability/ set(K2_COMPUTE_ARCH_CANDIDATES 35 50 60 61 70 75) + if(CUDA_VERSION VERSION_GREATER "11.0") + list(APPEND K2_COMPUTE_ARCH_CANDIDATES 80 86) + endif() + message(STATUS "K2_COMPUTE_ARCH_CANDIDATES ${K2_COMPUTE_ARCH_CANDIDATES}") + set(K2_COMPUTE_ARCHS) foreach(COMPUTE_ARCH IN LISTS K2_COMPUTE_ARCH_CANDIDATES) diff --git a/scripts/github_actions/install_cuda.sh b/scripts/github_actions/install_cuda.sh index 7d023b960..358c05b45 100755 --- a/scripts/github_actions/install_cuda.sh +++ b/scripts/github_actions/install_cuda.sh @@ -36,6 +36,10 @@ case "$cuda" in # url=https://developer.download.nvidia.com/compute/cuda/11.1.0/local_installers/cuda_11.1.0_455.23.05_linux.run url=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run ;; + 11.3) + # url=https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run + url=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run + ;; *) echo "Unknown cuda version: $cuda" exit 1 diff --git a/scripts/github_actions/install_cudnn.sh b/scripts/github_actions/install_cudnn.sh index 853eba568..ceead2d36 100755 --- a/scripts/github_actions/install_cudnn.sh +++ b/scripts/github_actions/install_cudnn.sh @@ -17,23 +17,21 @@ case $cuda in 10.0) filename=cudnn-10.0-linux-x64-v7.6.5.32.tgz - url=http://www.mediafire.com/file/1037lb1vmj9qdtq/cudnn-10.0-linux-x64-v7.6.5.32.tgz/file ;; 10.1) filename=cudnn-10.1-linux-x64-v8.0.2.39.tgz - url=http://www.mediafire.com/file/fnl2wg0h757qhd7/cudnn-10.1-linux-x64-v8.0.2.39.tgz/file ;; 10.2) filename=cudnn-10.2-linux-x64-v8.0.2.39.tgz - url=http://www.mediafire.com/file/sc2nvbtyg0f7ien/cudnn-10.2-linux-x64-v8.0.2.39.tgz/file ;; 11.0) filename=cudnn-11.0-linux-x64-v8.0.5.39.tgz - url=https://www.mediafire.com/file/abyhnls106ko9kp/cudnn-11.0-linux-x64-v8.0.5.39.tgz/file ;; 11.1) - filename=cudnn-11.1-linux-x64-v8.0.5.39.tgz - url=https://www.mediafire.com/file/qx55zd65773xonv/cudnn-11.1-linux-x64-v8.0.5.39.tgz/file + filename=cudnn-11.1-linux-x64-v8.0.4.30.tgz + ;; + 11.3) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz ;; *) echo "Unsupported cuda version: $cuda" @@ -41,18 +39,15 @@ case $cuda in ;; esac -function retry() { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} +command -v git-lfs >/dev/null 2>&1 || { echo >&2 "\nPlease install 'git-lfs' first."; exit 2; } -# It is forked from https://github.com/Juvenal-Yescas/mediafire-dl -# https://github.com/Juvenal-Yescas/mediafire-dl/pull/2 changes the filename and breaks the CI. -# We use a separate fork to keep the link fixed. -retry wget https://raw.githubusercontent.com/csukuangfj/mediafire-dl/master/mediafire_dl.py +git clone https://huggingface.co/csukuangfj/cudnn +cd cudnn +git lfs pull --include="$filename" -sed -i 's/quiet=False/quiet=True/' mediafire_dl.py -retry python3 mediafire_dl.py "$url" sudo tar xf ./$filename -C /usr/local -rm -v ./$filename + +# save disk space +git lfs prune && cd .. && rm -rf cudnn sudo sed -i '59i#define CUDNN_MAJOR 8' /usr/local/cuda/include/cudnn.h diff --git a/scripts/github_actions/install_torch.sh b/scripts/github_actions/install_torch.sh index 3ad1717bc..b0b822a13 100755 --- a/scripts/github_actions/install_torch.sh +++ b/scripts/github_actions/install_torch.sh @@ -78,7 +78,7 @@ case ${torch} in ;; esac ;; - 1.9.0) + 1.9.*) case ${cuda} in 10.2) package="torch==${torch}" @@ -91,6 +91,23 @@ case ${torch} in ;; esac ;; + 1.10) + case ${cuda} in + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.1) + package="torch==${torch}+cu111" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.3) + package="torch==${torch}+cu113" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; *) echo "Unsupported PyTorch version: ${torch}" exit 1 From 7178d67e594bc7fa89c2b331ad7bd1c62a6a9eb4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Oct 2021 22:12:54 +0800 Subject: [PATCH 19/64] Fix test cases for k2.union() (#853) --- k2/python/k2/__init__.py | 2 +- k2/python/k2/autograd.py | 67 ----------------------------------- k2/python/k2/fsa_algo.py | 21 +++++++++++ k2/python/tests/union_test.py | 31 ++++++++++++++++ 4 files changed, 53 insertions(+), 68 deletions(-) diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index c60194d58..c3fdf0cbd 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -21,7 +21,6 @@ # from .autograd import intersect_dense from .autograd import intersect_dense_pruned -from .autograd import union from .ctc_loss import CtcLoss from .ctc_loss import ctc_loss from .dense_fsa_vec import DenseFsaVec @@ -51,6 +50,7 @@ from .fsa_algo import replace_fsa from .fsa_algo import shortest_path from .fsa_algo import top_sort +from .fsa_algo import union from .fsa_properties import to_str as properties_to_str from .nbest import Nbest from .ops import cat diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index 96e6a9f72..2c12c8975 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -645,53 +645,6 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \ ) -class _UnionFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, fsas: Fsa, out_fsa: List[Fsa], - unused_fsas_scores: torch.Tensor) -> torch.Tensor: - '''Compute the union of all fsas in a FsaVec. - - Args: - fsas: - The input FsaVec. Caution: We require that each fsa in the FsaVec - is non-empty (i.e., with at least two states). - out_fsa: - A list containing one entry. Since this function can only return - values of type `torch.Tensor`, we return the union result in the - list. - unused_fsas_scores: - It is the same as `fsas.scores`, whose sole purpose is for autograd. - It is not used in this function. - ''' - need_arc_map = True - ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) - out_fsa[0] = Fsa(ragged_arc) - - for name, value in fsas.named_tensor_attr(include_scores=False): - value = k2.index(value, arc_map) - setattr(out_fsa[0], name, value) - - for name, value in fsas.named_non_tensor_attr(): - setattr(out_fsa[0], name, value) - ctx.arc_map = arc_map - ctx.save_for_backward(unused_fsas_scores) - - return out_fsa[0].scores # the return value will be discarded - - @staticmethod - def backward(ctx, out_fsa_grad: torch.Tensor - ) -> Tuple[None, None, torch.Tensor]: # noqa - arc_map = ctx.arc_map - fsas_scores, = ctx.saved_tensors - ans = torch.zeros(fsas_scores.size(0), - dtype=torch.float32, - device=fsas_scores.device, - requires_grad=False) - _k2.index_add(arc_map, out_fsa_grad, ans) - return None, None, ans - - def intersect_dense_pruned(a_fsas: Fsa, b_fsas: DenseFsaVec, search_beam: float, @@ -843,23 +796,3 @@ def intersect_dense(a_fsas: Fsa, a_fsas.scores, b_fsas.scores, a_to_b_map, seqframe_idx_name, frame_idx_name) return out_fsa[0] - - -def union(fsas: Fsa) -> Fsa: - '''Compute the union of a FsaVec. - - Caution: - We require that every fsa in fsas is non-empty, i.e., - contains at least two states - - Args: - fsas: - A FsaVec. That is, len(fsas.shape) == 3. - - Returns: - A single Fsa that is the union of the input fsas. - ''' - - out_fsa = [0] # as a placeholder - _UnionFunction.apply(fsas, out_fsa, fsas.scores) - return out_fsa[0] diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index f51477432..146e11cb9 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1179,3 +1179,24 @@ def levenshtein_alignment( alignment, "__ins_del_score_offset_internal_attr_") return alignment + + +def union(fsas: Fsa) -> Fsa: + '''Compute the union of a FsaVec. + + Caution: + We require that every fsa in fsas is non-empty, i.e., + contains at least two states + + Args: + fsas: + A FsaVec. That is, len(fsas.shape) == 3. + + Returns: + A single Fsa that is the union of the input fsas. + ''' + need_arc_map = True + ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) + + out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map) + return out_fsa diff --git a/k2/python/tests/union_test.py b/k2/python/tests/union_test.py index d608d0d16..86c6155b9 100644 --- a/k2/python/tests/union_test.py +++ b/k2/python/tests/union_test.py @@ -63,9 +63,40 @@ def test(self): fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) + fsa0.tensor_attr = torch.tensor([1, 2, 3, 4, 5, 6], + dtype=torch.int32, + device=device) + fsa0.ragged_tensor_attr = k2.RaggedTensor( + fsa0.tensor_attr.unsqueeze(-1)) + + fsa1.tensor_attr = torch.tensor([7], + dtype=torch.int32, + device=device) + + fsa1.ragged_tensor_attr = k2.RaggedTensor( + fsa1.tensor_attr.unsqueeze(-1)) + + fsa2.tensor_attr = torch.tensor([8, 9, 10, 11], + dtype=torch.int32, + device=device) + + fsa2.ragged_tensor_attr = k2.RaggedTensor( + fsa2.tensor_attr.unsqueeze(-1)) + fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]).to(device) fsa = k2.union(fsa_vec) + + expected_tensor_attr = torch.tensor( + [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11]).to(fsa.tensor_attr) + assert torch.all(torch.eq(fsa.tensor_attr, expected_tensor_attr)) + + expected_ragged_tensor_attr = k2.RaggedTensor( + expected_tensor_attr.unsqueeze(-1)).remove_values_eq(0) + assert str(expected_ragged_tensor_attr) == str( + fsa.ragged_tensor_attr) + assert torch.allclose( fsa.arcs.values()[:, :3], torch.tensor([ From e6db5dcc0017d2788cbf5d57b7967d9f46c73601 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 2 Nov 2021 21:23:57 +0800 Subject: [PATCH 20/64] Fix out-of-boundary access (read). (#859) --- CMakeLists.txt | 2 +- k2/csrc/ragged_utils.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 918aa079f..47e0355c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.9") +set(K2_VERSION "1.10") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) diff --git a/k2/csrc/ragged_utils.cu b/k2/csrc/ragged_utils.cu index db48acab6..e038357dc 100644 --- a/k2/csrc/ragged_utils.cu +++ b/k2/csrc/ragged_utils.cu @@ -156,7 +156,7 @@ RaggedShape IntersperseRaggedLayer(int32_t layer, row_splits_ptrs_data[src][pos]; row_splits_data[i] = this_size; }; - EvalDevice(c, new_num_rows + 1, lambda_get_sizes); + EvalDevice(c, new_num_rows, lambda_get_sizes); ExclusiveSum(row_splits, &row_splits); } } From e8c589a47e8acb19c9c08df5ceaba7d19d78238e Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Thu, 4 Nov 2021 22:10:40 +0800 Subject: [PATCH 21/64] Update all the example codes in the docs (#861) * Update all the example codes in the docs I have run all the modified codes with the newest version k2. * do some changes --- docs/source/core_concepts/index.rst | 14 +++++++------- .../python_tutorials/fsa_algo/code/invert1.py | 2 +- .../python_tutorials/fsa_algo/code/invert2.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/core_concepts/index.rst b/docs/source/core_concepts/index.rst index 1fde45269..e67a3d569 100644 --- a/docs/source/core_concepts/index.rst +++ b/docs/source/core_concepts/index.rst @@ -193,13 +193,13 @@ In k2, you would use the following code to compute it: fsa = k2.Fsa.from_str(s) fsa.draw('fsa2.svg') fsa = k2.create_fsa_vec([fsa]) - total_scores = k2.get_tot_scores(fsa, log_semiring=False, use_double_scores=False) + total_scores = fsa.get_tot_scores(log_semiring=False, use_double_scores=False) print(total_scores) # It prints: tensor([0.2000]) .. HINT:: - :func:`k2.get_tot_scores` takes a vector of FSAs as input, + :func:`k2.Fsa.get_tot_scores` takes a vector of FSAs as input, so we use :func:`k2.create_fsa_vec` to turn an FSA into a vector of FSAs. Most operations in k2 take a vector of FSAs as input and process them @@ -230,7 +230,7 @@ The code in k2 looks like: ''' fsa = k2.Fsa.from_str(s) fsa = k2.create_fsa_vec([fsa]) - total_scores = k2.get_tot_scores(fsa, log_semiring=True, use_double_scores=False) + total_scores = fsa.get_tot_scores(log_semiring=True, use_double_scores=False) print(total_scores) # It prints: tensor([0.8444]) @@ -319,7 +319,7 @@ the FSA given in :numref:`autograd example`: fsa.scores = nnet_output fsa.draw('autograd_tropical.svg') fsa_vec = k2.create_fsa_vec([fsa]) - total_scores = k2.get_tot_scores(fsa_vec, log_semiring=False, use_double_scores=False) + total_scores = fsa.get_tot_scores(log_semiring=False, use_double_scores=False) total_scores.backward() print(nnet_output.grad) @@ -366,11 +366,11 @@ Example 2: Autograd in log semiring For the log semiring, we just change:: - total_scores = k2.get_tot_scores(fsa_vec, log_semiring=False, use_double_scores=False) + total_scores = fsa.get_tot_scores(log_semiring=False, use_double_scores=False) to:: - total_scores = k2.get_tot_scores(fsa_vec, log_semiring=True, use_double_scores=False) + total_scores = fsa.get_tot_scores(log_semiring=True, use_double_scores=False) For completeness and ease of reference, we repost the code below. @@ -392,7 +392,7 @@ For completeness and ease of reference, we repost the code below. fsa.scores = nnet_output fsa.draw('autograd_log.svg') fsa_vec = k2.create_fsa_vec([fsa]) - total_scores = k2.get_tot_scores(fsa_vec, log_semiring=True, use_double_scores=False) + total_scores = fsa.get_tot_scores(log_semiring=True, use_double_scores=False) total_scores.backward() print(nnet_output.grad) diff --git a/docs/source/python_tutorials/fsa_algo/code/invert1.py b/docs/source/python_tutorials/fsa_algo/code/invert1.py index c7333c3f6..4d045034c 100644 --- a/docs/source/python_tutorials/fsa_algo/code/invert1.py +++ b/docs/source/python_tutorials/fsa_algo/code/invert1.py @@ -4,7 +4,7 @@ 1 2 -1 -1 0.2 2 ''' -fsa = k2.Fsa.from_str(s) +fsa = k2.Fsa.from_str(s, acceptor=False) inverted_fsa = k2.invert(fsa) fsa.draw('before_invert.svg', title='before invert') inverted_fsa.draw('after_invert.svg', title='after invert') diff --git a/docs/source/python_tutorials/fsa_algo/code/invert2.py b/docs/source/python_tutorials/fsa_algo/code/invert2.py index 86e80b30d..fb96ea655 100644 --- a/docs/source/python_tutorials/fsa_algo/code/invert2.py +++ b/docs/source/python_tutorials/fsa_algo/code/invert2.py @@ -5,7 +5,7 @@ 2 ''' fsa = k2.Fsa.from_str(s) -fsa.aux_labels = k2.RaggedInt('[ [10 20] [-1] ]') +fsa.aux_labels = k2.RaggedTensor('[ [10 20] [-1] ]') inverted_fsa = k2.invert(fsa) fsa.draw('before_invert_aux.svg', title='before invert with ragged tensors as aux_labels') From fd5565d32ffa8274ff9700453b1e543f34343ed1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Nov 2021 21:31:12 +0800 Subject: [PATCH 22/64] Fix compilation errors with CUB 1.15. (#865) --- cmake/cub.cmake | 5 +++-- k2/csrc/cub.h | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/cmake/cub.cmake b/cmake/cub.cmake index c65f8f6d2..c4d6e1c94 100644 --- a/cmake/cub.cmake +++ b/cmake/cub.cmake @@ -20,8 +20,9 @@ function(download_cub) include(FetchContent) - set(cub_URL "https://github.com/NVlabs/cub/archive/1.10.0.tar.gz") - set(cub_HASH "SHA256=8531e09f909aa021125cffa70a250761dfc247f960d7a1a12f65e6651ffb6477") + set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz") + set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685") + FetchContent_Declare(cub URL ${cub_URL} diff --git a/k2/csrc/cub.h b/k2/csrc/cub.h index 55331f8b2..d1df56f32 100644 --- a/k2/csrc/cub.h +++ b/k2/csrc/cub.h @@ -30,14 +30,40 @@ // that k2 and PyTorch use a different copy // of CUB. +#ifdef CUB_NS_PREFIX +#undef CUB_NS_PREFIX +#endif + +#ifdef CUB_NS_POSTFIX +#undef CUB_NS_POSTFIX +#endif + +#ifdef CUB_NS_QUALIFIER +#undef CUB_NS_QUALIFIER +#endif + +// see +// https://github.com/NVIDIA/cub/commit/6631c72630f10e370d93814a59146b12f7620d85 +// The above commit replaced "thrust" with "THRUST_NS_QUALIFIER" +#ifndef THRUST_NS_QUALIFIER +#define THRUST_NS_QUALIFIER thrust +#endif + #define CUB_NS_PREFIX namespace k2 { #define CUB_NS_POSTFIX } +// See +// https://github.com/NVIDIA/cub/commit/6631c72630f10e370d93814a59146b12f7620d85 +// and +// https://github.com/NVIDIA/cub/pull/350 +#define CUB_NS_QUALIFIER ::k2::cub + #ifdef K2_WITH_CUDA #include "cub/cub.cuh" // NOLINT #endif #undef CUB_NS_PREFIX #undef CUB_NS_POSTFIX +#undef CUB_NS_QUALIFIER #endif // K2_CSRC_CUB_H_ From bdcaaf828212913366110a0c45aab4ea40b1cbf9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 12 Nov 2021 15:41:51 +0800 Subject: [PATCH 23/64] Update README. (#873) * Update README. * Fix typos. --- README.md | 45 ++++++++++----------------------------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index c7768d163..f554ca2cf 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ speech recognition system with multiple decoding passes including lattice rescoring and confidence estimation. We hope k2 will have many other applications as well. -One of the key algorithms that we want to make efficient in the short term is +One of the key algorithms that we have implemented is pruned composition of a generic FSA with a "dense" FSA (i.e. one that corresponds to log-probs of symbols at the output of a neural network). This can be used as a fast implementation of decoding for ASR, and for CTC and @@ -78,46 +78,21 @@ general and extensible framework to allow further development of ASR technology. ## Current state of the code - A lot of the code is still unfinished (Sep 11, 2020). - We finished the CPU versions of many algorithms and this code is in `k2/csrc/host/`; - however, after that we figured out how to implement things on the GPU and decided - to change the interfaces so the CPU and GPU code had a more unified interface. - Currently in `k2/csrc/` we have more GPU-oriented implementations (although - these algorithms will also work on CPU). We had almost finished the Python - wrapping for the older code, in the `k2/python/` subdirectory, but we decided not to - release code with that wrapping because it would have had to be reworked to be compatible - with our GPU algorithms. Instead we will use the interfaces drafted in `k2/csrc/` - e.g. the Context object (which encapsulates things like memory managers from external - toolkits) and the Tensor object which can be used to wrap tensors from external toolkits; - and wrap those in Python (using pybind11). The code in host/ will eventually - be either deprecated, rewritten or wrapped with newer-style interfaces. - -## Plans for initial release - - We hope to get the first version working in early October. The current - short-term aim is to finish the GPU implementation of pruned composition of a - normal FSA with a dense FSA, which is the same as decoder search in speech - recognition and can be used to implement CTC training and lattice-free MMI (LF-MMI) training. The - proof-of-concept that we will release initially is something that's like CTC - but allowing more general supervisions (general FSAs rather than linear - sequences). This will work on GPU. The same underlying code will support - LF-MMI so that would be easy to implement soon after. We plan to put - example code in a separate repository. + We have wrapped all the C++ code to Python with [pybind11](https://github.com/pybind/pybind11) + and have finished the integration with [PyTorch](https://github.com/pytorch/pytorch). + + We are currently writing speech recognition recipes using k2, which are hosted in a + separate repository. Please see . ## Plans after initial release - We will then gradually implement more algorithms in a way that's compatible - with the interfaces in `k2/csrc/`. Some of them will be CPU-only to start - with. The idea is to eventually have very rich capabilities for operating on - collections of sequences, including methods to convert from a lattice to a - collection of linear sequences and back again (for purposes of neural language - model rescoring, neural confidence estimation and the like). + We are currently trying to make k2 ready for production use (see the branch + [v2.0-pre](https://github.com/k2-fsa/k2/tree/v2.0-pre)). ## Quick start Want to try it out without installing anything? We have setup a [Google Colab][1]. - -Caution: k2 is not nearly ready for actual use! We are still coding the core -algorithms, and hope to have an early version working by early October. +You can find more Colab notebooks using k2 in speech recognition at +. [1]: https://colab.research.google.com/drive/1qbHUhNZUX7AYEpqnZyf29Lrz2IPHBGlX?usp=sharing From 31e1307c3131d03384ef83c33472309d22d2e79d Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 19 Nov 2021 14:47:09 +0800 Subject: [PATCH 24/64] Fix ctc graph (make aux_labels of final arcs -1) (#877) --- k2/csrc/fsa_algo.cu | 6 +++--- k2/csrc/fsa_algo_test.cu | 8 ++++---- k2/python/tests/ctc_graph_test.py | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 9cb3ae023..c7106bc0c 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -709,7 +709,7 @@ FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, case 2: // the arc pointing to the next symbol state arc.label = next_symbol; aux_labels_value = sym_state_idx01 + 1 == sym_final_state ? - 0 : next_symbol; + -1 : next_symbol; arc.dest_state = state_idx1 + 2; break; default: @@ -720,8 +720,8 @@ FsaVec CtcGraphs(const Ragged &symbols, bool modified /*= false*/, K2_CHECK_LT(arc_idx2, 2); arc.label = arc_idx2 == 0 ? 0 : current_symbol; arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1; - aux_labels_value = (arc_idx2 == 0 || final_state) ? - 0 : current_symbol; + aux_labels_value = arc_idx2 == 0 ? 0 : current_symbol; + if (final_state && arc_idx2 != 0) aux_labels_value = -1; } arcs_data[arc_idx012] = arc; if (aux_labels) aux_labels_data[arc_idx012] = aux_labels_value; diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index 38970edf2..903514318 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1293,8 +1293,8 @@ TEST(FsaAlgo, TestCtcGraph) { " [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] " " [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]"); Array1 aux_labels_ref(c, "[ 0 1 0 0 2 0 2 0 0 0 2 0 0 3 " - " 0 3 0 0 0 0 0 0 1 0 0 2 0 2 " - " 0 0 3 0 3 0 0 0 0 0 ]"); + " 0 3 0 0 -1 0 -1 0 1 0 0 2 0 2 " + " 0 0 3 0 3 0 0 -1 0 -1 ]"); K2_CHECK(Equal(graph, graph_ref)); K2_CHECK(Equal(aux_labels, aux_labels_ref)); } @@ -1315,8 +1315,8 @@ TEST(FsaAlgo, TestCtcGraphSimplified) { " [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] " " [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]"); Array1 aux_labels_ref(c, "[ 0 1 0 0 2 0 2 0 0 2 0 2 0 " - " 0 3 0 3 0 0 0 0 0 0 1 0 0 2 " - " 0 2 0 0 3 0 3 0 0 0 0 0 ]"); + " 0 3 0 3 0 0 -1 0 -1 0 1 0 0 2 " + " 0 2 0 0 3 0 3 0 0 -1 0 -1 ]"); K2_CHECK(Equal(graph, graph_ref)); K2_CHECK(Equal(aux_labels, aux_labels_ref)); } diff --git a/k2/python/tests/ctc_graph_test.py b/k2/python/tests/ctc_graph_test.py index 569c1d8da..285a55989 100644 --- a/k2/python/tests/ctc_graph_test.py +++ b/k2/python/tests/ctc_graph_test.py @@ -50,14 +50,14 @@ def test(self): '0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', '4 4 0 0 0', '4 5 2 2 0', '5 6 0 0 0', - '5 5 2 0 0', '5 7 -1 0 0', '6 6 0 0 0', '6 7 -1 0 0', '7' + '5 5 2 0 0', '5 7 -1 -1 0', '6 6 0 0 0', '6 7 -1 -1 0', '7' ]) expected_str1 = '\n'.join([ '0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0', - '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', '6 6 0 0 0', - '6 7 -1 0 0', '7' + '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 -1 0', '6 6 0 0 0', + '6 7 -1 -1 0', '7' ]) actual_str_ragged0 = k2.to_str_simple(fsa_vec_ragged[0].to('cpu')) actual_str_ragged1 = k2.to_str_simple(fsa_vec_ragged[1].to('cpu')) @@ -81,15 +81,15 @@ def test_simplified(self): '0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', '3 5 2 2 0', '4 4 0 0 0', '4 5 2 2 0', - '5 6 0 0 0', '5 5 2 0 0', '5 7 -1 0 0', '6 6 0 0 0', - '6 7 -1 0 0', '7' + '5 6 0 0 0', '5 5 2 0 0', '5 7 -1 -1 0', '6 6 0 0 0', + '6 7 -1 -1 0', '7' ]) expected_str1 = '\n'.join([ '0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0', - '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', '6 6 0 0 0', - '6 7 -1 0 0', '7' + '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 -1 0', '6 6 0 0 0', + '6 7 -1 -1 0', '7' ]) actual_str_ragged0 = k2.to_str_simple(fsa_vec_ragged[0].to('cpu')) actual_str_ragged1 = k2.to_str_simple(fsa_vec_ragged[1].to('cpu')) From 12f591526a7afa3a3d55cc37c67576328cf926e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ludwig=20K=C3=BCrzinger?= Date: Thu, 25 Nov 2021 00:54:45 +0100 Subject: [PATCH 25/64] Fix LICENSE location to k2 folder (#880) --- MANIFEST.in | 2 ++ setup.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..19ba5c47a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE + diff --git a/setup.py b/setup.py index c45995b27..c0c4f52e0 100644 --- a/setup.py +++ b/setup.py @@ -215,7 +215,6 @@ def get_short_description(): packages=['k2', 'k2.ragged', 'k2.sparse', 'k2.version'], install_requires=install_requires, extras_require={'dev': dev_requirements}, - data_files=[('', ['LICENSE'])], ext_modules=[cmake_extension('_k2')], cmdclass={ 'build_ext': BuildExtension, From a0d75c8222a768adc1c68164f7c0730b621fff24 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 29 Nov 2021 19:00:22 +0800 Subject: [PATCH 26/64] Release v1.11. (#881) It contains bugfixes. --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 47e0355c0..9f90d1d10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.10") +set(K2_VERSION "1.11") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) From 2cb3eeaa49b61d23f2917e5cb23ca03ef8ca709d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Dec 2021 18:00:48 +0800 Subject: [PATCH 27/64] Update documentation for hash.h (#887) * Update documentation for hash.h * Typo fix --- k2/csrc/hash.h | 63 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/k2/csrc/hash.h b/k2/csrc/hash.h index 48459accb..ae651c6de 100644 --- a/k2/csrc/hash.h +++ b/k2/csrc/hash.h @@ -55,14 +55,43 @@ unsigned long long int __forceinline__ __host__ __device__ AtomicCAS( How class Hash works: - It can function as a map from key=uint32_t to value=uint32_t, or from - key=uint64_t to value=uint64_t where you choose NUM_KEY_BITS and - `key` must have only up to NUM_KEY_BITS set and `value` must have - only up to (64-NUM_KEY_BITS) set. You decide NUM_KEY_BITS when - you call Hash::Accessor() - - You can store any (key,value) pair except the pair where all the bits of + key=uint64_t to value=uint64_t, but you cannot use all 64 bits in the + key and value because we compress both of them into a single 64-bit + integer. There are several different modes of using this hash, + depending which accessor objects you use. The modes are: + + - Use Accessor with num_key_bits known at compile time; + the number of values bits will be 64 - NUM_KEY_BITS. + - Use GenericAccessor, which is like Accessor but the number of + key bits is not known at compile time; and they both must still + sum to 64. + - Use PackedAccessor, which allows you to have the number of key + plus value bits greater than 64; the rest of the bits are + implicit in groups of buckets (the number of buckets must + be >= 32 * 1 << (num_key_bits + num_value_bits - 64). + + - You must decide the number of key and value bits, and the number of + buckets, when you create the hash, but you can resize it (manually) + and when you resize it you can change the number of key and value bits. + + Some constraints: + - You can store any (key,value) pair allowed by the number of key and value + bits, except the pair where all the bits of both and key and value are set [that is used to mean "nothing here"] - - The number of buckets is a power of 2 provided by the user to the constructor; - currently no resizing is supported. + - The number of buckets must always be a power of 2. + - When deleting values from the hash you must delete them all at + once (necessary because there is no concept of a "tombstone". + + Some notes on usage: + + You use it by: constructing it, obtaining its Accessor with GetAccessor() + with appropriate template args depending on your chosen accessor type; and + inside kernels (or host code), calling functions Insert(), Find() or Delete() + of the Accessor object. Resizing is not automatic; it is the user's + responsibility to make sure the hash does not get too full (which could cause + assertion failures in kernels, and will be very slow). + + Some implementation notes: - When accessing hash[key], we use bucket_index == key % num_buckets, bucket_inc = 1 | (((key * 2) / num_buckets) ^ key). - If the bucket at `bucket_index` is occupied, we look in locations @@ -72,15 +101,7 @@ unsigned long long int __forceinline__ __host__ __device__ AtomicCAS( being odd ensures we eventually try all locations (of course for reasonable hash occupancy levels, we shouldn't ever have to try more than two or three). - - When deleting values from the hash you must delete them all at - once (necessary because there is no concept of a "tombstone". - You use it by: constructing it, obtaining its Accessor with - GetAccessor(), and inside kernels (or host code), calling - functions Insert(), Find() or Delete() of the Accessor object. There is no - resizing; sizing it correctly is the caller's responsibility and if the hash - gets full the code will just loop forever (of course it will get extremely - slow before it reaches that point). */ class Hash { public: @@ -94,10 +115,14 @@ class Hash { @param [in] num_key_bits Number of bits in the key of the hash; must satisfy 0 < num_key_bits < 64, and keys used must be less than (1< or GenericAccessor, + we require that num_key_bits + num_value_bits == 64. + For PackedAccessor we allow that num_key_bits + num_value_bits > 64, + but with the constraint that + (num_buckets >> (64 - num_key_bits - num_value_bits)) >= 32 */ Hash(ContextPtr c, int32_t num_buckets, From aab2dd77d1fa8e31b153c9ff95e10295811a8b12 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 14 Dec 2021 12:34:46 +0800 Subject: [PATCH 28/64] Wrap MonotonicLowerBound (#883) * Wrap MonotonicLowerBound * Add unit tests * Support int64; update documents --- k2/csrc/dtype.h | 32 ++++++++++++ k2/python/csrc/torch.cu | 2 + k2/python/csrc/torch/CMakeLists.txt | 1 + k2/python/csrc/torch/array_ops.cu | 73 ++++++++++++++++++++++++++ k2/python/csrc/torch/array_ops.h | 30 +++++++++++ k2/python/k2/__init__.py | 1 + k2/python/k2/utils.py | 68 +++++++++++++++++++++++- k2/python/tests/CMakeLists.txt | 1 + k2/python/tests/array_ops_test.py | 81 +++++++++++++++++++++++++++++ 9 files changed, 287 insertions(+), 2 deletions(-) create mode 100644 k2/python/csrc/torch/array_ops.cu create mode 100644 k2/python/csrc/torch/array_ops.h create mode 100644 k2/python/tests/array_ops_test.py diff --git a/k2/csrc/dtype.h b/k2/csrc/dtype.h index 7f532516b..a36f75fcf 100644 --- a/k2/csrc/dtype.h +++ b/k2/csrc/dtype.h @@ -277,6 +277,38 @@ struct DtypeOf { } \ } while (0) +#define FOR_REAL_AND_INT_TYPES(DtypeValue, TypeName, ...) \ + do { \ + switch (DtypeValue) { \ + case kFloatDtype: { \ + using TypeName = float; \ + __VA_ARGS__; \ + break; \ + } \ + case kDoubleDtype: { \ + using TypeName = double; \ + __VA_ARGS__; \ + break; \ + } \ + case kInt32Dtype: { \ + using TypeName = int32_t; \ + __VA_ARGS__; \ + break; \ + } \ + case kInt64Dtype: { \ + using TypeName = int64_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + K2_LOG(FATAL) \ + << "Dtype " << TraitsOf(DtypeValue).Name() \ + << " not covered in switch statement. Op not supported for " \ + "this type?"; \ + break; \ + } \ + } while (0) + #define FOR_REAL_TYPES(DtypeValue, TypeName, ...) \ do { \ switch (DtypeValue) { \ diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index 4bf281e8b..ab941ff23 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -25,6 +25,7 @@ #if defined(K2_USE_PYTORCH) #include "k2/python/csrc/torch/arc.h" +#include "k2/python/csrc/torch/array_ops.h" #include "k2/python/csrc/torch/discounted_cum_sum.h" #include "k2/python/csrc/torch/fsa.h" #include "k2/python/csrc/torch/fsa_algo.h" @@ -37,6 +38,7 @@ void PybindTorch(py::module &m) { PybindArc(m); + PybindArrayOps(m); PybindDiscountedCumSum(m); PybindFsa(m); PybindFsaAlgo(m); diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 02783920b..4213807c7 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -1,6 +1,7 @@ # please keep the list sorted set(torch_srcs arc.cu + array_ops.cu discounted_cum_sum.cu fsa.cu fsa_algo.cu diff --git a/k2/python/csrc/torch/array_ops.cu b/k2/python/csrc/torch/array_ops.cu new file mode 100644 index 000000000..1042adf41 --- /dev/null +++ b/k2/python/csrc/torch/array_ops.cu @@ -0,0 +1,73 @@ +/** + * @brief python wrappers for array_ops.h + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/array_ops.h" +#include "k2/csrc/device_guard.h" +#include "k2/python/csrc/torch/array_ops.h" +#include "k2/python/csrc/torch/torch_util.h" + +namespace k2 { + +static void PybindMonotonicLowerBound(py::module &m) { + m.def( + "monotonic_lower_bound", + [](torch::Tensor src, bool inplace = false) -> torch::Tensor { + Dtype t = ScalarTypeToDtype(src.scalar_type()); + ContextPtr c = GetContext(src); + DeviceGuard guard(c); + FOR_REAL_AND_INT_TYPES(t, T, { + if (src.dim() == 1) { + Array1 src_array = FromTorch(src); + Array1 dest_array = src_array; + if (!inplace) { + dest_array = Array1(c, src_array.Dim()); + } + MonotonicLowerBound(src_array, &dest_array); + return ToTorch(dest_array); + } else if (src.dim() == 2) { + Array2 src_array = FromTorch(src, Array2Tag{}); + Array2 dest_array = src_array; + if (!inplace) { + dest_array = Array2(c, src_array.Dim0(), src_array.Dim1()); + } + for (int32_t i = 0; i < src_array.Dim0(); ++i) { + Array1 row = dest_array.Row(i); + MonotonicLowerBound(src_array.Row(i), &row); + } + return ToTorch(dest_array); + } else { + K2_LOG(FATAL) + << "Only support 1 dimension and 2 dimensions tensor, given " + "dimension : " + << src.dim(); + return torch::Tensor(); + } + }); + // Unreachable code, to make compiler happy + return torch::Tensor(); + }, + py::arg("src"), py::arg("inplace") = false); +} + +} // namespace k2 + +void PybindArrayOps(py::module &m) { k2::PybindMonotonicLowerBound(m); } diff --git a/k2/python/csrc/torch/array_ops.h b/k2/python/csrc/torch/array_ops.h new file mode 100644 index 000000000..8af3cec38 --- /dev/null +++ b/k2/python/csrc/torch/array_ops.h @@ -0,0 +1,30 @@ +/** + * @brief python wrappers for array_ops.h + * + * @copyright + * Copyright 2021 Xiaomi Corp. (author: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_ARRAY_OPS_H_ +#define K2_PYTHON_CSRC_TORCH_ARRAY_OPS_H_ + +#include "k2/python/csrc/torch.h" + +void PybindArrayOps(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_ARRAY_OPS_H_ diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index c3fdf0cbd..bff893967 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -64,6 +64,7 @@ from .utils import create_sparse from .utils import is_rand_equivalent from .utils import get_best_matching_stats +from .utils import monotonic_lower_bound from .utils import to_dot from .utils import to_str from .utils import to_str_simple diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index d24780240..abc28c21d 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -694,8 +694,8 @@ def random_fsa_vec(min_num_fsas: int = 1, def get_best_matching_stats( - tokens: k2.RaggedTensor, scores: torch.Tensor, counts: torch.Tensor, - eos: int, min_token: int, max_token: int, max_order: int + tokens: k2.RaggedTensor, scores: torch.Tensor, counts: torch.Tensor, + eos: int, min_token: int, max_token: int, max_order: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # noqa '''For "query" sentences, this function gets the mean and variance of scores from the best matching words-in-context in a set of provided "key" @@ -781,3 +781,67 @@ def get_best_matching_stats( ''' return _k2.get_best_matching_stats(tokens, scores, counts, eos, min_token, max_token, max_order) + + +def monotonic_lower_bound(src: torch.Tensor, + inplace: bool = False) -> torch.Tensor: + """Compute a monotonically increasing lower bound on the array `src`. + The basic idea is: we traverse the array in reverse order, and update + current element with the following statement, + + min_value = min(src[i], min_value) + dest[i] = min_value + + we initialize the min_value with `inf`, so the last element always keeps + the same. See the examples below, if the input tensor is + `[0, 2, 1, 3, 6, 5, 8]`, the output tensor will be `[0, 1, 1, 3, 5, 5, 8]`, + i.e. we traverse it in reverse order and guarantee that + `dest[i] <= dest[i+1]`. + + Note: Only support 1 dimension and 2 dimensions tentor with dtype equals to + `torch.int32`,`torch.int64`, `torch.float` or `torch.float64`. + + >>> import k2 + >>> import torch + >>> src = torch.tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32) + >>> k2.monotonic_lower_bound(src) + tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32) + >>> src + tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32) + >>> k2.monotonic_lower_bound(src, inplace=True) + tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32) + >>> src + tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32) + >>> src = torch.randint(20, (3, 6), dtype=torch.int32) + >>> src + tensor([[12, 18, 5, 4, 18, 17], + [11, 14, 14, 3, 10, 4], + [19, 3, 8, 13, 7, 19]], dtype=torch.int32) + >>> k2.monotonic_lower_bound(src) + tensor([[ 4, 4, 4, 4, 17, 17], + [ 3, 3, 3, 3, 4, 4], + [ 3, 3, 7, 7, 7, 19]], dtype=torch.int32) + >>> k2.monotonic_lower_bound(src, inplace=True) + tensor([[ 4, 4, 4, 4, 17, 17], + [ 3, 3, 3, 3, 4, 4], + [ 3, 3, 7, 7, 7, 19]], dtype=torch.int32) + >>> src + tensor([[ 4, 4, 4, 4, 17, 17], + [ 3, 3, 3, 3, 4, 4], + [ 3, 3, 7, 7, 7, 19]], dtype=torch.int32) + + Args: + src: + The source tensor, MUST be a 1 dimension or 2 dimensions tensor with + dtype equals to `torch.int32`,`torch.int64`,`torch.float` or + `torch.float64`. + inplace: + True to modify the source tensor inplace, Fasle to return another + tensor. + + Returns: + Returns a tensor which is monotonic(i.e. satisfiy `dest[i] <= dest[i+1]`), + the returned tensor shares the same underlying memory with the source + tensor if inplace is True. + """ + return _k2.monotonic_lower_bound(src, inplace) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 8feb2dfc8..56fb59114 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -19,6 +19,7 @@ endfunction() set(py_test_files add_epsilon_self_loops_test.py arc_sort_test.py + array_ops_test.py cat_test.py compose_arc_maps_test.py closure_test.py diff --git a/k2/python/tests/array_ops_test.py b/k2/python/tests/array_ops_test.py new file mode 100644 index 000000000..60e06af5f --- /dev/null +++ b/k2/python/tests/array_ops_test.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (author: Wei kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R array_ops_test_py + +import unittest + +import random +import torch +import k2 + + +class TestArrayOps(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + + cls.dtypes = [torch.int32, torch.int64, torch.float32, torch.float64] + + def test_monotonic_lower_bound(self): + for device in self.devices: + for dtype in self.dtypes: + # simple case + src = torch.tensor( + [2, 1, 3, 7, 5, 8, 20, 15], dtype=dtype, device=device + ) + expected = torch.tensor( + [1, 1, 3, 5, 5, 8, 15, 15], dtype=dtype, device=device + ) + dest = k2.monotonic_lower_bound(src) + assert torch.allclose(dest, expected) + assert torch.allclose( + src, + torch.tensor( + [2, 1, 3, 7, 5, 8, 20, 15], dtype=dtype, device=device + ), + ) + k2.monotonic_lower_bound(src, inplace=True) + assert torch.allclose(src, expected) + + # random case + src = torch.randint(100, (10, 100), dtype=dtype, device=device) + dest = k2.monotonic_lower_bound(src) + expected = torch.zeros_like(src, device=torch.device("cpu")) + dest = dest.to("cpu") + for i in range(src.shape[0]): + min_value = 101 + for j in range(src.shape[1] - 1, -1, -1): + min_value = min(dest[i][j], min_value) + expected[i][j] = min_value + assert torch.allclose(dest, expected) + + k2.monotonic_lower_bound(src, inplace=True) + src = src.to("cpu") + assert torch.allclose(src, expected) + + +if __name__ == "__main__": + unittest.main() From 5517b3e99fcc544016289c26335b845f37dc26f9 Mon Sep 17 00:00:00 2001 From: drawfish Date: Sat, 25 Dec 2021 08:31:52 +0800 Subject: [PATCH 29/64] Remove extra commas after 'TOPSORTED' properity and fix RaggedTensor constructer parameter 'byte_offset' out-of-range bug. (#892) Co-authored-by: gzchenduisheng --- k2/csrc/tensor.cu | 2 +- k2/csrc/tensor.h | 2 +- k2/python/k2/fsa_properties.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/k2/csrc/tensor.cu b/k2/csrc/tensor.cu index 27fcd1c6b..e4dd4fb69 100644 --- a/k2/csrc/tensor.cu +++ b/k2/csrc/tensor.cu @@ -147,7 +147,7 @@ Tensor::Tensor(ContextPtr c, Dtype type, const std::vector &dims) } Tensor::Tensor(Dtype type, const Shape &shape, RegionPtr region, - int32_t byte_offset) + size_t byte_offset) : impl_(std::make_shared()) { int64_t begin_elem, end_elem; shape.GetReachableElems(&begin_elem, &end_elem); diff --git a/k2/csrc/tensor.h b/k2/csrc/tensor.h index 8fbb28df2..726cab209 100644 --- a/k2/csrc/tensor.h +++ b/k2/csrc/tensor.h @@ -176,7 +176,7 @@ class Tensor { Tensor(ContextPtr c, Dtype type, const std::vector &dims); // Create Tensor backed by existing memory. - Tensor(Dtype type, const Shape &shape, RegionPtr region, int32_t byte_offset); + Tensor(Dtype type, const Shape &shape, RegionPtr region, size_t byte_offset); Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; diff --git a/k2/python/k2/fsa_properties.py b/k2/python/k2/fsa_properties.py index cdd9db8e2..7b99aebee 100644 --- a/k2/python/k2/fsa_properties.py +++ b/k2/python/k2/fsa_properties.py @@ -23,7 +23,7 @@ VALID = 0x01 # Valid from a formatting perspective NONEMPTY = 0x02 # Nonempty as in, has at least one arc. -TOPSORTED = 0x04, # FSA is top-sorted, but possibly with +TOPSORTED = 0x04 # FSA is top-sorted, but possibly with # self-loops, dest_state >= src_state TOPSORTED_AND_ACYCLIC = 0x08 # Fsa is topsorted, dest_state > src_state ARC_SORTED = 0x10 # Fsa is arc-sorted: arcs leaving a state are are sorted by From 5f4cc790d319e93589cbc8a50c3f3d475a504871 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Jan 2022 12:08:33 +0800 Subject: [PATCH 30/64] Fix small typos (#896) --- k2/csrc/intersect_dense_pruned.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 61b0ea9c0..506396dcc 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -34,8 +34,8 @@ namespace intersect_pruned_internal { /* Information associated with a state active on a particular frame.. */ struct StateInfo { - /* abs_state_id is the state-index in a_fsas_. Note: the ind0 in here - won't necessarily match the ind0 within FrameInfo::state if + /* abs_state_id is the state-index in a_fsas_. Note: the idx0 in here + won't necessarily match the idx0 within FrameInfo::state if a_fsas_stride_ == 0. */ int32_t a_fsas_state_idx01; @@ -649,7 +649,7 @@ class MultiGraphDenseIntersectPruned { const int32_t *a_fsas_row_splits2 = a_fsas_.shape.RowSplits(2).Data(); const Arc *arcs = a_fsas_.values.Data(); - // fsa_idx0 to ind0x (into b_fsas_), which gives the 1st row for this + // fsa_idx0 to idx0x (into b_fsas_), which gives the 1st row for this // sequence. const int32_t *b_fsas_row_ids1 = b_fsas_.shape.RowIds(1).Data(); const int32_t *b_fsas_row_splits1 = b_fsas_.shape.RowSplits(1).Data(); @@ -673,7 +673,7 @@ class MultiGraphDenseIntersectPruned { Arc arc = arcs[a_fsas_arc_idx012]; int32_t scores_idx0x = b_fsas_row_splits1[ai_fsa_idx0], - scores_idx01 = scores_idx0x + t, // t == ind1 into 'scores' + scores_idx01 = scores_idx0x + t, // t == idx1 into 'scores' scores_idx2 = arc.label + 1; // the +1 is so that -1 can be handled K2_DCHECK_LT(static_cast(scores_idx2), From e799928bd17f5bde3a37c28021f6ed4f87f58687 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 13 Jan 2022 11:04:38 +0800 Subject: [PATCH 31/64] Fix k2.ragged.create_ragged_shape2 (#901) Before the fix, we have to specify both `row_splits` and `row_ids` while calling `k2.create_ragged_shape2` even if one of them is `None`. After this fix, we only need to specify one of them. --- k2/python/csrc/torch/v2/doc/ragged_shape.h | 6 ++++-- k2/python/csrc/torch/v2/ragged_shape.cu | 6 +++--- k2/python/tests/ragged_shape_test.py | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/k2/python/csrc/torch/v2/doc/ragged_shape.h b/k2/python/csrc/torch/v2/doc/ragged_shape.h index aa64cf26a..4e63d4018 100644 --- a/k2/python/csrc/torch/v2/doc/ragged_shape.h +++ b/k2/python/csrc/torch/v2/doc/ragged_shape.h @@ -509,9 +509,11 @@ the overall concepts, please see comments in k2/csrc/utils.h. Args: row_splits: - Optionally, a torch.Tensor with dtype=torch.int32 and one axis + Optional. A 1-D torch.Tensor with dtype torch.int32. + If ``None``, you have to specify ``row_ids``. row_ids: - Optionally, a torch.Tensor with dtype=torch.int32 and one axis. + Optional. A 1-D torch.Tensor with dtype torch.int32. + If ``None``, you have to specify ``row_splits``. cached_tot_size: The number of elements (length of row_ids, even if row_ids is not provided); would be identical to the last element of row_splits, diff --git a/k2/python/csrc/torch/v2/ragged_shape.cu b/k2/python/csrc/torch/v2/ragged_shape.cu index 81c3118ce..cb3bc8c13 100644 --- a/k2/python/csrc/torch/v2/ragged_shape.cu +++ b/k2/python/csrc/torch/v2/ragged_shape.cu @@ -232,8 +232,8 @@ void PybindRaggedShape(py::module &m) { m.def( "create_ragged_shape2", - [](torch::optional row_splits, - torch::optional row_ids, + [](torch::optional row_splits = torch::nullopt, + torch::optional row_ids = torch::nullopt, int32_t cached_tot_size = -1) -> RaggedShape { if (!row_splits.has_value() && !row_ids.has_value()) K2_LOG(FATAL) << "Both row_splits and row_ids are None"; @@ -257,7 +257,7 @@ void PybindRaggedShape(py::module &m) { row_splits.has_value() ? &array_row_splits : nullptr, row_ids.has_value() ? &array_row_ids : nullptr, cached_tot_size); }, - py::arg("row_splits"), py::arg("row_ids"), + py::arg("row_splits") = py::none(), py::arg("row_ids") = py::none(), py::arg("cached_tot_size") = -1, kCreateRaggedShape2Doc); m.def("random_ragged_shape", &RandomRaggedShape, "RandomRaggedShape", diff --git a/k2/python/tests/ragged_shape_test.py b/k2/python/tests/ragged_shape_test.py index 1e6d41e4b..15906c98c 100644 --- a/k2/python/tests/ragged_shape_test.py +++ b/k2/python/tests/ragged_shape_test.py @@ -121,6 +121,22 @@ def test_compose_ragged_shape(self): prod2 = k2.RaggedTensor(abshape2, b.values) assert prod == prod2 + def test_create_ragged_shape2_with_row_splits(self): + for device in self.devices: + row_splits = torch.tensor([0, 1, 3], + dtype=torch.int32, + device=device) + shape = k2.ragged.create_ragged_shape2(row_splits=row_splits) + expected_shape = k2.RaggedShape('[[x] [x x]]').to(device) + assert shape == expected_shape + + def test_create_ragged_shape2_with_row_ids(self): + for device in self.devices: + row_ids = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + shape = k2.ragged.create_ragged_shape2(row_ids=row_ids) + expected_shape = k2.RaggedShape('[[x] [x x]]').to(device) + assert shape == expected_shape + if __name__ == '__main__': unittest.main() From d6323d5fe6ae70eb36432e2e26018928e669b7ac Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 17 Jan 2022 10:33:23 +0800 Subject: [PATCH 32/64] Add rnnt loss (#891) * Add cpp code of mutual information * mutual information working * Add rnnt loss * Add pruned rnnt loss * Minor Fixes * Minor fixes & fix code style * Fix cpp style * Fix code style * Fix s_begin values in padding positions * Fix bugs related to boundary; Fix s_begin padding value; Add more tests * Minor fixes * Fix comments * Add boundary to pruned loss tests --- .flake8 | 5 + k2/python/csrc/torch.cu | 2 + k2/python/csrc/torch/CMakeLists.txt | 6 + k2/python/csrc/torch/mutual_information.cu | 68 ++ k2/python/csrc/torch/mutual_information.h | 107 ++ .../csrc/torch/mutual_information_cpu.cu | 214 ++++ .../csrc/torch/mutual_information_cuda.cu | 831 ++++++++++++++ k2/python/k2/__init__.py | 15 +- k2/python/k2/mutual_information.py | 302 +++++ k2/python/k2/rnnt_loss.py | 1012 +++++++++++++++++ k2/python/tests/CMakeLists.txt | 2 + k2/python/tests/mutual_information_test.py | 271 +++++ k2/python/tests/rnnt_loss_test.py | 424 +++++++ 13 files changed, 3258 insertions(+), 1 deletion(-) create mode 100644 k2/python/csrc/torch/mutual_information.cu create mode 100644 k2/python/csrc/torch/mutual_information.h create mode 100644 k2/python/csrc/torch/mutual_information_cpu.cu create mode 100644 k2/python/csrc/torch/mutual_information_cuda.cu create mode 100644 k2/python/k2/mutual_information.py create mode 100644 k2/python/k2/rnnt_loss.py create mode 100644 k2/python/tests/mutual_information_test.py create mode 100644 k2/python/tests/rnnt_loss_test.py diff --git a/.flake8 b/.flake8 index 49ba543a5..6cd23ad39 100644 --- a/.flake8 +++ b/.flake8 @@ -2,6 +2,11 @@ show-source=true statistics=true max-line-length=80 +per-file-ignores = + # line too long E501 + # line break before operator W503 + k2/python/k2/rnnt_loss.py: E501, W503 + k2/python/tests/rnnt_loss_test.py: W503 exclude = .git, setup.py, diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index ab941ff23..132065699 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -31,6 +31,7 @@ #include "k2/python/csrc/torch/fsa_algo.h" #include "k2/python/csrc/torch/index_add.h" #include "k2/python/csrc/torch/index_select.h" +#include "k2/python/csrc/torch/mutual_information.h" #include "k2/python/csrc/torch/nbest.h" #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" @@ -44,6 +45,7 @@ void PybindTorch(py::module &m) { PybindFsaAlgo(m); PybindIndexAdd(m); PybindIndexSelect(m); + PybindMutualInformation(m); PybindNbest(m); PybindRagged(m); PybindRaggedOps(m); diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 4213807c7..94f005ba8 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -7,6 +7,8 @@ set(torch_srcs fsa_algo.cu index_add.cu index_select.cu + mutual_information.cu + mutual_information_cpu.cu nbest.cu ragged.cu ragged_ops.cu @@ -19,6 +21,10 @@ set(torch_srcs v2/ragged_shape.cu ) +if (K2_WITH_CUDA) + list(APPEND torch_srcs mutual_information_cuda.cu) +endif() + set(torch_srcs_with_prefix) foreach(src IN LISTS torch_srcs) list(APPEND torch_srcs_with_prefix "torch/${src}") diff --git a/k2/python/csrc/torch/mutual_information.cu b/k2/python/csrc/torch/mutual_information.cu new file mode 100644 index 000000000..9d2c620a9 --- /dev/null +++ b/k2/python/csrc/torch/mutual_information.cu @@ -0,0 +1,68 @@ +/** + * @copyright + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/device_guard.h" +#include "k2/python/csrc/torch/mutual_information.h" +#include "k2/python/csrc/torch/torch_util.h" + +void PybindMutualInformation(py::module &m) { + m.def( + "mutual_information_forward", + [](torch::Tensor px, torch::Tensor py, + torch::optional boundary, + torch::Tensor p) -> torch::Tensor { + k2::DeviceGuard guard(k2::GetContext(px)); + if (px.device().is_cpu()) { + return k2::MutualInformationCpu(px, py, boundary, p); + } else { +#ifdef K2_WITH_CUDA + return k2::MutualInformationCuda(px, py, boundary, p); +#else + K2_LOG(FATAL) << "Failed to find native CUDA module, make sure " + << "that you compiled the code with K2_WITH_CUDA."; + return torch::Tensor(); +#endif + } + }, + py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p")); + + m.def( + "mutual_information_backward", + [](torch::Tensor px, torch::Tensor py, + torch::optional boundary, torch::Tensor p, + torch::Tensor ans_grad) -> std::vector { + k2::DeviceGuard guard(k2::GetContext(px)); + if (px.device().is_cpu()) { + return k2::MutualInformationBackwardCpu(px, py, boundary, p, + ans_grad); + } else { +#ifdef K2_WITH_CUDA + return k2::MutualInformationBackwardCuda(px, py, boundary, p, + ans_grad, true); +#else + K2_LOG(FATAL) << "Failed to find native CUDA module, make sure " + << "that you compiled the code with K2_WITH_CUDA."; + return std::vector(); +#endif + } + }, + py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"), + py::arg("ans_grad")); +} diff --git a/k2/python/csrc/torch/mutual_information.h b/k2/python/csrc/torch/mutual_information.h new file mode 100644 index 000000000..efcdccaa3 --- /dev/null +++ b/k2/python/csrc/torch/mutual_information.h @@ -0,0 +1,107 @@ +/** + * @copyright + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_ +#define K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_ + +#include + +#include + +#include "k2/python/csrc/torch.h" + +namespace k2 { +/* + Forward of mutual_information. See also comment of `mutual_information` + in mutual_information.py. This is the core recursion + in the sequence-to-sequence mutual information computation. + + @param px Tensor of shape [B][S][T + 1]; contains the log-odds ratio of + generating the next x in the sequence, i.e. + xy[b][s][t] is the log of + p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), + i.e. the log-prob of generating x_s given subsequences of + lengths (s, t), divided by the prior probability of generating + x_s. (See mutual_information.py for more info). + @param py The log-odds ratio of generating the next y in the sequence. + Shape [B][S + 1][T] + @param p This function writes to p[b][s][t] the mutual information between + sub-sequences of x and y of length s and t respectively, from the + b'th sequences in the batch. Its shape is [B][S + 1][T + 1]. + Concretely, this function implements the following recursion, + in the case where s_begin == t_begin == 0: + + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + if s > 0 or t > 0, + treating values with any -1 index as -infinity. + .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. + @param boundary If set, a tensor of shape [B][4] of type int64_t, which + contains, where for each batch element b, boundary[b] + equals [s_begin, t_begin, s_end, t_end] + which are the beginning and end (i.e. one-past-the-last) + of the x and y sequences that we should process. + Alternatively, may be a tensor of shape [0][0] and type + int64_t; the elements will default to (0, 0, S, T). + @return A tensor `ans` of shape [B], where this function will set + ans[b] = p[b][s_end][t_end], + with s_end and t_end being (S, T) if `boundary` was specified, + and (boundary[b][2], boundary[b][3]) otherwise. + `ans` represents the mutual information between each pair of + sequences (i.e. x[b] and y[b], although the sequences are not + supplied directy to this function). + + The block-dim and grid-dim must both be 1-dimensional, and the block-dim must + be at least 128. +*/ +torch::Tensor MutualInformationCpu( + torch::Tensor px, // [B][S][T+1] + torch::Tensor py, // [B][S+1][T] + torch::optional boundary, // [B][4], int64_t. + torch::Tensor p); // [B][S+1][T+1]; an output + +torch::Tensor MutualInformationCuda( + torch::Tensor px, // [B][S][T+1] + torch::Tensor py, // [B][S+1][T] + torch::optional boundary, // [B][4], int64_t. + torch::Tensor p); // [B][S+1][T+1]; an output + +/* + backward of mutual_information; returns (grad_px, grad_py) + + if overwrite_ans_grad == true, this function will overwrite ans_grad with a + value that, if the computation worked correctly, should be identical to or + very close to the value of ans_grad at entry. This can be used + to validate the correctness of this code. +*/ +std::vector MutualInformationBackwardCpu( + torch::Tensor px, torch::Tensor py, torch::optional boundary, + torch::Tensor p, torch::Tensor ans_grad); + +std::vector MutualInformationBackwardCuda( + torch::Tensor px, torch::Tensor py, torch::optional boundary, + torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad); + +} // namespace k2 + +void PybindMutualInformation(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_ diff --git a/k2/python/csrc/torch/mutual_information_cpu.cu b/k2/python/csrc/torch/mutual_information_cpu.cu new file mode 100644 index 000000000..bcbe6ed27 --- /dev/null +++ b/k2/python/csrc/torch/mutual_information_cpu.cu @@ -0,0 +1,214 @@ +/** + * @copyright + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/utils.h" // for LogAdd +#include "k2/python/csrc/torch/mutual_information.h" + +namespace k2 { + +// forward of mutual_information. See also comment of `mutual_information` +// in k2/python/k2/mutual_information.py for documentation of the +// behavior of this function. +torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py, + torch::optional opt_boundary, + torch::Tensor p) { + TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); + TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); + TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); + TORCH_CHECK( + px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(), + "inputs must be CPU tensors"); + + auto scalar_t = px.scalar_type(); + auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); + + const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); + TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); + + auto boundary = opt_boundary.value_or( + torch::tensor({0, 0, S, T}, + torch::dtype(torch::kInt64).device(torch::kCPU)) + .reshape({1, 4}) + .expand({B, 4})); + TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional."); + TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); + TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64); + + torch::Tensor ans = torch::empty({B}, opts); + + AT_DISPATCH_FLOATING_TYPES( + px.scalar_type(), "mutual_information_cpu_loop", ([&] { + auto px_a = px.accessor(), + py_a = py.accessor(), p_a = p.accessor(); + auto boundary_a = boundary.accessor(); + auto ans_a = ans.accessor(); + + for (int b = 0; b < B; b++) { + int s_begin = boundary_a[b][0]; + int t_begin = boundary_a[b][1]; + int s_end = boundary_a[b][2]; + int t_end = boundary_a[b][3]; + p_a[b][s_begin][t_begin] = 0.0; + for (int s = s_begin + 1; s <= s_end; ++s) + p_a[b][s][t_begin] = + p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; + for (int t = t_begin + 1; t <= t_end; ++t) + p_a[b][s_begin][t] = + p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1]; + for (int s = s_begin + 1; s <= s_end; ++s) { + scalar_t p_s_t1 = p_a[b][s][t_begin]; + for (int t = t_begin + 1; t <= t_end; ++t) { + // The following statement is a small optimization of: + // p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t], + // p_a[b][s][t - 1] + py_a[b][s][t - 1]); + // .. which obtains p_a[b][s][t - 1] from a register. + p_a[b][s][t] = p_s_t1 = + LogAdd()(p_a[b][s - 1][t] + px_a[b][s - 1][t], + p_s_t1 + py_a[b][s][t - 1]); + } + } + ans_a[b] = p_a[b][s_end][t_end]; + } + })); + return ans; +} + +// backward of mutual_information. Returns (px_grad, py_grad). +// p corresponds to what we computed in the forward pass. +std::vector MutualInformationBackwardCpu( + torch::Tensor px, torch::Tensor py, + torch::optional opt_boundary, torch::Tensor p, + torch::Tensor ans_grad) { + TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); + TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); + TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); + TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); + + TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && + p.device().is_cpu() && ans_grad.device().is_cpu(), + "inputs must be CPU tensors"); + + auto scalar_t = px.scalar_type(); + auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); + + const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); + TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); + + auto boundary = opt_boundary.value_or( + torch::tensor({0, 0, S, T}, + torch::dtype(torch::kInt64).device(torch::kCPU)) + .reshape({1, 4}) + .expand({B, 4})); + TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional."); + TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); + TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64); + + bool has_boundary = opt_boundary.has_value(); + torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts), + px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) + : torch::empty({B, S, T + 1}, opts)), + py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) + : torch::empty({B, S + 1, T}, opts)); + + AT_DISPATCH_FLOATING_TYPES( + px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] { + auto px_a = px.accessor(), p_a = p.accessor(), + p_grad_a = p_grad.accessor(), + px_grad_a = px_grad.accessor(), + py_grad_a = py_grad.accessor(); + + auto ans_grad_a = ans_grad.accessor(); + auto boundary_a = boundary.accessor(); + + for (int b = 0; b < B; b++) { + int s_begin = boundary_a[b][0]; + int t_begin = boundary_a[b][1]; + int s_end = boundary_a[b][2]; + int t_end = boundary_a[b][3]; + // Backprop for: ans_a[b] = p_a[b][s_end][t_end]; + p_grad_a[b][s_end][t_end] = ans_grad_a[b]; + + for (int s = s_end; s > s_begin; --s) { + for (int t = t_end; t > t_begin; --t) { + // The s,t indexes correspond to + // The statement we are backpropagating here is: + // p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t], + // p_a[b][s][t - 1] + py_a[b][s][t - 1]); + // .. which obtains p_a[b][s][t - 1] from a register. + scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t], + // term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not + // actually needed.. + total = p_a[b][s][t]; + if (total - total != 0) total = 0; + scalar_t term1_deriv = exp(term1 - total), + term2_deriv = 1.0 - term1_deriv, + grad = p_grad_a[b][s][t]; + scalar_t term1_grad, term2_grad; + if (term1_deriv - term1_deriv == 0.0) { + term1_grad = term1_deriv * grad; + term2_grad = term2_deriv * grad; + } else { + // could happen if total == -inf + term1_grad = term2_grad = 0.0; + } + px_grad_a[b][s - 1][t] = term1_grad; + p_grad_a[b][s - 1][t] = term1_grad; + py_grad_a[b][s][t - 1] = term2_grad; + p_grad_a[b][s][t - 1] += term2_grad; + } + } + for (int t = t_end; t > t_begin; --t) { + // Backprop for: + // p_a[b][s_begin][t] = + // p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1]; + scalar_t this_p_grad = p_grad_a[b][s_begin][t]; + p_grad_a[b][s_begin][t - 1] += this_p_grad; + py_grad_a[b][s_begin][t - 1] = this_p_grad; + } + for (int s = s_end; s > s_begin; --s) { + // Backprop for: + // p_a[b][s][t_begin] = + // p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; + scalar_t this_p_grad = p_grad_a[b][s][t_begin]; + p_grad_a[b][s - 1][t_begin] += this_p_grad; + px_grad_a[b][s - 1][t_begin] = this_p_grad; + } + // There is no backprop for: + // p_a[b][s_begin][t_begin] = 0.0; + // .. but we can use this for a check, that the grad at the beginning + // of the sequence is equal to the grad at the end of the sequence. + if (ans_grad_a[b] != 0.0) { + float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b]; + if (fabs(grad_ratio - 1.0) > 0.01) { + K2_LOG(WARNING) + << "Warning: mutual_information backprop: expected these " + << "numbers to be the same:" + << static_cast(p_grad_a[b][s_begin][t_begin]) << " vs " + << static_cast(ans_grad_a[b]); + } + } + } + })); + + return std::vector({px_grad, py_grad}); +} +} // namespace k2 diff --git a/k2/python/csrc/torch/mutual_information_cuda.cu b/k2/python/csrc/torch/mutual_information_cuda.cu new file mode 100644 index 000000000..c858d4d7b --- /dev/null +++ b/k2/python/csrc/torch/mutual_information_cuda.cu @@ -0,0 +1,831 @@ +/** + * @copyright + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // for getCurrentCUDAStream() +#include + +#include "k2/csrc/utils.h" // for LogAdd +#include "k2/python/csrc/torch/mutual_information.h" + +namespace k2 { + +/* + Forward of mutual_information. Each thread block computes blocks of the 'p' + array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). + Thread-blocks loop over such blocks, but they might loop only once if there is + not that much data to process. We sequentially launch thread groups in + such a way that thread-blocks within a group do not depend on each other + (see the "iter" parameter). The blocks of the 'image' (i.e. of the p matrix) + that each group handles are arranged in a diagonal. + + Template args: + scalar_t: the floating-point type, e.g. float, double; maybe eventually + half, although I think we don't support LogAdd for half yet. + BLOCK_SIZE: an integer power of two no greater than 32 (this limitation + is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading + code). + Args: + px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of + generating the next x in the sequence, i.e. + xy[b][s][t] is the log of + p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), + i.e. the log-prob of generating x_s given subsequences of lengths + (s, t), divided by the prior probability of generating x_s. (See + mutual_information.py for more info). + py: The log-odds ratio of generating the next y in the sequence. + Shape [B][S + 1][T] + p: This function writes to p[b][s][t] the mutual information between + sub-sequences of x and y of length s and t respectively, from the + b'th sequences in the batch. Its shape is [B][S + 1][T + 1]. + Concretely, this function implements the following recursion, + in the case where s_begin == t_begin == 0: + + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) (eq. 0) + if s > 0 or t > 0, + treating values with any -1 index as -infinity. + .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. + boundary: If set, a tensor of shape [B][4] of type int64_t, which + contains, where for each batch element b, boundary[b] equals + [s_begin, t_begin, s_end, t_end] + which are the beginning and end (i.e. one-past-the-last) of the + x and y sequences that we should process. Otherwise, must be + a tensor of shape [0][0] of type int64_t; the values will + default to (0, 0, S, T). + ans: a tensor `ans` of shape [B], where this function will set + ans[b] = p[b][s_end][t_end], + with s_end and t_end being (S, T) if `boundary` was specified, + and (boundary[b][2], boundary[b][3]) otherwise. + `ans` represents the mutual information between each pair of + sequences (i.e. x[b] and y[b], although the sequences are not + supplied directy to this function). + + The block-dim and grid-dim must both be 1-dimensional, and the block-dim must + be at least 128. +*/ +template // e.g. BLOCK_SIZE == 16 or 32. +__global__ void mutual_information_kernel( + // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 + torch::PackedTensorAccessor32 px, + torch::PackedTensorAccessor32 py, // B, S + 1, T. + // B, S + 1, T + 1. This is an output. + torch::PackedTensorAccessor32 p, + // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) + torch::PackedTensorAccessor32 boundary, + torch::PackedTensorAccessor32 ans, // [B] + int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and + // so on, up to num_iters - 1 where num_iters = num_s_blocks + + // num_t_blocks - 1 num_s_blocks = S / BLOCK_SIZE + 1 + // num_t_blocks = T / BLOCK_SIZE + 1 + // so that each group depends on the previous group... + const int B = px.size(0), S = px.size(1), T = py.size(2); + // num_s_blocks and num_t_blocks are the number of blocks we need to cover the + // array of size (S, T) with blocks of this size, in the s and t directions + // respectively. + // You can read the following expressions as simplifications of, for example, + // num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE, + // i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + + // 1). + const int num_s_blocks = S / BLOCK_SIZE + 1; + //, num_t_blocks = T / BLOCK_SIZE + 1; + + // num_blocks_this_iter is an upper bound on the number of blocks of size + // (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`). + // These iterations start from the bottom left of the image so that on iter == + // 0 we process only one block with block-index (0, 0) then on iter == 1 we + // process block-indexes (1, 0) and (0, 1); and then on iter==2 we process (2, + // 0), (1, 1) and (0, 2); and so on. We also will never have more than + // `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but + // the numbering we use corresponds to s and not t, so when we hit the + // num_t_blocks limit, the blocks with the lowest s indexes would just not be + // active and we'll 'continue' in the loop below). + int num_blocks_this_iter = min(iter + 1, num_s_blocks); + + // For the block with s_block_begin == 0 and t_block_begin == 0 (for + // easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0 + // for out-of-range indexes into px. + // Likewise, py_buf[s][t] will contain exp(py[s][t - 1]). + __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], + py_buf[BLOCK_SIZE][BLOCK_SIZE]; + + // p_buf[s][t] == p[s+s_block_begin-1][t+t_block_begin-1] + // 1st row/col of p_buf correspond to the previously computed blocks (lower + // `iter`), or to negative indexes into p. So, for the origin block, + // p_buf[s][t] corresponds to p[s - 1][t - 1]; or -inf for + // out-of-range values. + __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; + + // boundary_buf will be used to store the b'th row of `boundary` if we have + // boundary information supplied; or (0, 0, S, T) otherwise. + __shared__ int64_t boundary_buf[4]; + + if (threadIdx.x == 0) { + boundary_buf[0] = 0; + boundary_buf[1] = 0; + boundary_buf[2] = S; + boundary_buf[3] = T; + } + + // batch_block_iter iterates over batch elements (index b) and block + // indexes in the range [0..num_blocks_this_iter-1], combining both + // batch and block indexes. + for (int batch_block_iter = blockIdx.x; + batch_block_iter < B * num_blocks_this_iter; + batch_block_iter += gridDim.x) { + int block = batch_block_iter / B, + b = batch_block_iter % B; // b is the index into the batch + + // Note: `block` can be no greater than `iter` because num_blocks_this_iter + // <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and + // block < num_blocks_this_iter, so iter - block >= 0. + int s_block_begin = block * BLOCK_SIZE, + t_block_begin = (iter - block) * BLOCK_SIZE; + bool is_origin_block = (s_block_begin + t_block_begin == 0); + + __syncthreads(); + + if (threadIdx.x < 4) boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; + + __syncthreads(); + + int s_begin = boundary_buf[0], t_begin = boundary_buf[1], + s_end = boundary_buf[2], t_end = boundary_buf[3]; + + s_block_begin += s_begin; + t_block_begin += t_begin; + + // block_S and block_T are the actual sizes of this block (the block of `p` + // that we will write), no greater than (BLOCK_SIZE, BLOCK_SIZE) but + // possibly less than that if we are towards the end of the sequence. The + // last element in the output matrix p that we need to write is (s_end, + // t_end), i.e. the one-past-the-end index is (s_end + 1, t_end + 1). + int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin), + block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin); + + if (block_S <= 0 || block_T <= 0) continue; + + // Load px_buf and py_buf. We exponentiate; the assumption is that they + // most likely won't overflow or underflow, but if they do overflow we'll + // detect it later; we'll also detect certain kinds of underflow. + for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { + int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + // comparing as unsigned int makes sure the index is nonnegative. + // Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px + // and py values that are outside the proper boundaries that we need, but + // the corresponding p_buf values will end up being 0 so this won't + // matter. + scalar_t this_px = -INFINITY; + if (s > s_begin && s <= s_end && t <= t_end) this_px = px[b][s - 1][t]; + px_buf[s_in_block][t_in_block] = this_px; + + scalar_t this_py = -INFINITY; + if (t > t_begin && t <= t_end && s <= s_end) this_py = py[b][s][t - 1]; + py_buf[s_in_block][t_in_block] = this_py; + } + + // Load the 1st row and 1st column of p_buf (except element[0][0] is not + // needed). This is the context from previously computed blocks of the + // image. Remember: p_buf[s][t] will correspond to p[s + s_block_begin - + // 1][t + t_block_begin - 1] + if (threadIdx.x <= BLOCK_SIZE) { + // s_in_p_buf are simply the indexes into p_buf + int s_in_p_buf = threadIdx.x, t_in_p_buf = 0, + s = s_in_p_buf + s_block_begin - 1, + t = t_in_p_buf + t_block_begin - 1; + + scalar_t this_p = -INFINITY; + if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) + this_p = p[b][s][t]; + /*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, + t, (float)this_p, (int)threadIdx.x, + (float)px_buf[s_in_p_buf][t_in_p_buf], + (float)py_buf[s_in_p_buf][t_in_p_buf]); */ + p_buf[s_in_p_buf][t_in_p_buf] = this_p; + } else if (static_cast(static_cast(threadIdx.x) - 64) <= + static_cast(BLOCK_SIZE)) { + // Another warp handles the other leg. Checking as unsigned + // tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE + int s_in_p_buf = 0, t_in_p_buf = static_cast(threadIdx.x) - 64, + s = s_in_p_buf + s_block_begin - 1, + t = t_in_p_buf + t_block_begin - 1; + + scalar_t this_p = -INFINITY; + if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) + this_p = p[b][s][t]; + /*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, + t, (float)this_p, (int)threadIdx.x, + (float)px_buf[s_in_p_buf][t_in_p_buf], + (float)py_buf[s_in_p_buf][t_in_p_buf]);*/ + p_buf[s_in_p_buf][t_in_p_buf] = this_p; + } + + __syncthreads(); + + // from here to the next __syncthreads(), only the 1st warp should be active + // so we shouldn't need to synchronize. (implicit within-warp + // synchronization). + + if (threadIdx.x == 0) { + // This if-statement is an optimization and modification of the loop below + // for the value i == 0, i.e. inner-iteration == 0. The modification is + // to set p_buf to 1.0 = exp(0.0) if this is the "origin block", + // i.e. s == s_begin, t == t_begin. This corresponds to the + // probability of the pair of sequences of length (0, 0). + p_buf[1][1] = + (is_origin_block ? 0.0 + : LogAdd()(p_buf[0][1] + px_buf[0][0], + p_buf[1][0] + py_buf[0][0])); + } + + scalar_t p_buf_s1_t; // This is for an optimization to avoid one + // shared-memory read/write in the loop below. it + // represents p_buf[s + 1][t]; the first time we + // access this, it will be for t == 0, except for + // thread 0 when we first need it for t == 1. + if (threadIdx.x < BLOCK_SIZE) { + int s = threadIdx.x; + p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0]; + } + + int s = threadIdx.x; + for (int i = 1; i < block_S + block_T - 1; ++i) { + __syncwarp(); + // i is the inner iteration, which corresponds to the (s + t) indexes of + // the elements within the block that we write. So i == 0 writes + // positions (s, t) == (0, 0) (but we treated i == 0 as a special case + // above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1) + // and (2, 1); and so on. Note: not many threads participate in this + // part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure + // out a very meaningful way for more threads to do work, that looked like + // it would really spead things up. + // So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot, + // but we do at least do the I/O in an efficient way and keep the + // inner loop simple and fast (e.g. no exp() or log()). + int t = i - s; + if (s < block_S && + static_cast(t) < static_cast(block_T)) { + // p_buf is indexed by s + 1 and t + 1 because it has an extra initial + // row and column for context from previous blocks. Taking into account + // the way these buffers relate to the tensors p, px and py, + // can be interpreted as follows, + // writing sbb for s_block_begin and tbb for t_block_begin: + // + // p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + + // px[s+sbb-1][t+tbb], + // p[b][s+sbb][t+tbb-1] + + // py[s+sbb][t+tbb-1] + // + // where you can see that apart from the offsets of tbb and sbb, this is + // the same as the recursion defined for p in + // mutual_information.py:mutual_information_recursion(); and (eq. 0) + // above. +#if 0 + p_buf[s + 1][t + 1] = LogAdd()( + p_buf[s][t + 1] + px_buf[s][t], p_buf[s + 1][t] + py_buf[s][t]); + + /*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = + %f, p_buf[s][t+1] = %f, " "px_buf[s][t] = %f, p_buf[s + 1][t] = %f, + py_buf[s][t] = %f\n", (int)threadIdx.x, i, s, t, + (float)p_buf[s+1][t+1], (float)p_buf[s][t+1], (float)px_buf[s][t], + (float)p_buf[s+1][t], (float)py_buf[s][t]);*/ +#else + // This is an optimization of the statement above (the other half of + // this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid + // the need for a load from shared memory. + p_buf_s1_t = LogAdd()(p_buf[s][t + 1] + px_buf[s][t], + p_buf_s1_t + py_buf[s][t]); + // The next time this thread reads p_buf_s1_t, t will be one greater, + // so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this + // thread uses p_buf_s1_t is when t == 0, except for thread 0 where + // the 1st item accessed is for s == 0, t == 1. + p_buf[s + 1][t + 1] = p_buf_s1_t; +#endif + // We don't need to do __syncthreads() in this loop because all the + // threads that are active are in the same warp. (However, in future, + // if NVidia changes some things, we might need to sync here). + } + } + __syncthreads(); + + // Write out the data to p; check that nothing has gone out of numerical + // range, and write 'panic' flag if it has. + for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { + int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + if (s_in_block < block_S && t_in_block < block_T) { + scalar_t this_p = p_buf[s_in_block + 1][t_in_block + 1]; + p[b][s][t] = this_p; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + // Write `ans`, if this is the final (top-right) block in its sequence + // Logically, the following equation corresponds to: + // ans[b] = p[b][s_end][t_end] + if (s_block_begin + block_S - 1 == s_end && + t_block_begin + block_T - 1 == t_end) { + // you could read block_S below as block_S - 1 + 1, meaning, + // it's the last index in a block of size block_S, but the indexes into + // p_buf have a "+ 1". Likewise for block_T. + ans[b] = p_buf[block_S][block_T]; + } + } + } +} + +/* + Backward of mutual_information. + + If we were to write the forward pass in non-log space, it would be (ignoring + edge cases), as follows... we'll prefix all the variable names with e, e.g. +ep, to clarify that it's the exp of the actual argument p: + + ep[b][s][t] = ep[b][s - 1][t] * epx[b][s - 1][t] + + ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1) + +(A) + First we consider the part of the backprop that requires recursion or +iteration, i.e. the part involving only gradients of ep. + This is: ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t] + ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1]. + + .. and if we add 1 to the s index of the first equation above and 1 to the + t index of the second equation, we can see that: + + ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] + + ep_grad[b][s][t + 1] * epy[b][s][t]. + + Now, if ep = exp(p), and y is the loss function we are backprop'ing, + then ep_grad == dy/dep == dy/dp + dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p) + == dy/dp / ep. == p_grad / ep. + I.e. ep_grad = p_grad / ep. + + So we can write the above as: + p_grad[b][s][t] / ep[b][s][t] + = p_grad[b][s + 1][t] / ep[b][s + 1][t] * epx[b][s][t] + + p_grad[b][s][t + 1] / ep[b][s][t + 1] * epy[b][s][t]. + + Or, rearranging: + p_grad[b][s][t] = + p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + + p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]). + (eq. 2) + + (B) The following is the backprop for epx and epy from (eq. 1): + + epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t] + epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1] + + .. adding 1 to the s indexes in the 1st equation and to the t indexes in the +2nd: + + epx_grad[b][s][t] = ep_grad[b][s + 1][t] * ep[b][s][t] + epy_grad[b][s][t] = ep_grad[b][s][t + 1] * ep[b][s][t] + + Using, similar to the above, ep_grad = p_grad / ep, and similarly, + epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) +for p and so on, the above becomes: + + px_grad[b][s][t] / exp(px[b][s][t]) = + p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t]) + py_grad[b][s][t] / exp(py[b][s][t]) = + p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t]) + Rearranging: + px_grad[b][s][t] = + p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + (eq. 3a) + py_grad[b][s][t] = + p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) + (eq. 3b) + + + Defining terms that are common to (eq. 2) and (eqs. 3a,3b), write: + + xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 4) + yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5) + + .. and note that these quantities are <= 1 so there is no problem doing + the exponentiation. So the recursion can be simplified as from eqs. (2, 3a, +3b), as: + + p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + + p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6) + px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] (eq. 7) + py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8) + + (It might seem like we could just reuse px_grad and py_grad for (eq. 6), but +it's not clear to me that this is the best strategy since that would require an +extra write to shared memory within the loop that's the limiting factor.) + + The backward pass will be slightly different from the forward pass in terms of + how we store and index p (and p_grad), because for writing a particular block + of p_grad, we need context on the top and right instead of the bottom and + left. So there are offsets of 1. + */ +template +__global__ void mutual_information_backward_kernel( + // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 + torch::PackedTensorAccessor32 px, + torch::PackedTensorAccessor32 py, // B, S + 1, T. + // B, S + 1, T + 1. Produced in forward pass. + torch::PackedTensorAccessor32 p, + // [B]. This is an input. + torch::PackedTensorAccessor32 ans_grad, + // B, S + 1, T + 1. This is a temporary. + torch::PackedTensorAccessor32 p_grad, + torch::PackedTensorAccessor32 px_grad, // B, S, T + 1. + torch::PackedTensorAccessor32 py_grad, // B, S + 1, T. + // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) + torch::PackedTensorAccessor32 boundary, + int iter, // This kernel is sequentially called with 'iter' = num_iters + // - 1, num_iters - 2, .. 0, where num_iters can be taken to + // be any sufficiently large number but will actually be: + // num_s_blocks + num_t_blocks - 1 where num_s_blocks = S / + // BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1 + bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function + // will overwrite ans_grad with a value which, + // if everything is working correctly, should be + // identical or very close to the value of + // ans_grad that was passed in. + const int B = px.size(0), S = px.size(1), T = py.size(2); + + // For statements that are the same as the forward pass, we are omitting some + // comments. We'll focus, in the comments, on differences from the forward + // pass. + const int num_s_blocks = S / BLOCK_SIZE + 1, + // num_t_blocks = T / BLOCK_SIZE + 1, + num_blocks_this_iter = min(iter + 1, num_s_blocks); + + // px_buf and py_buf are used temporarily to store the px and py values, + // but then modified to store the "xderiv" and "yderiv" values defined + // in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0 + // here. + // Initially (before xderiv/yderiv are written): + // px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin]; + // py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin]. + // Later (see eq. 4 and eq. 5): + // px_buf[s][t] contains + // exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]), + // py_buf[s][t] contains + // exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1] + // where ss == s + s_block_begin, tt = t + t_block_begin. + // Unlike in the forward code, there is no offset of 1 in the indexes. + __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], + py_buf[BLOCK_SIZE][BLOCK_SIZE]; + + // p_buf is initially used to store p, and then (after we are done putting + // xderiv and yderiv into px_buf and py_buf) it is repurposed to store + // p_grad. + // + // Unlike in the forward pass, p_buf has the same numbering as px_buf and + // py_buf, it's not offset by 1: e.g., for the origin block, p_buf[0][0] + // refers to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than + // the block for px_buf and py_buf; unlike in the forward pass, we store + // context on the top and right, not the bottom and left, i.e. the elements at + // (one past the largest indexes in the block). + // + // For out-of-range elements of p_buf, we'll put zero. + __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; + + // boundary_buf will be used to store the b'th row of `boundary` if we have + // boundary information supplied; or (0, 0, S, T) if not. + __shared__ int64_t boundary_buf[4]; + + if (threadIdx.x == 0) { + boundary_buf[0] = 0; + boundary_buf[1] = 0; + boundary_buf[2] = S; + boundary_buf[3] = T; + } + + // batch_block_iter iterates over both batch elements (index b), and block + // indexes in the range [0..num_blocks_this_iter-1]. The order here + // doesn't matter, since there are no interdependencies between these + // blocks (they are on a diagonal). + for (int batch_block_iter = blockIdx.x; + batch_block_iter < B * num_blocks_this_iter; + batch_block_iter += gridDim.x) { + int block = batch_block_iter / B, b = batch_block_iter % B; + int s_block_begin = block * BLOCK_SIZE, + t_block_begin = (iter - block) * BLOCK_SIZE; + + if (threadIdx.x < 4) boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; + __syncthreads(); + + int s_begin = boundary_buf[0], t_begin = boundary_buf[1], + s_end = boundary_buf[2], t_end = boundary_buf[3]; + s_block_begin += s_begin; + t_block_begin += t_begin; + + // block_S and block_T are the actual sizes of this block, no greater than + // (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards + // the end of the sequence. + // The last element of the output matrix p_grad we write is (s_end, t_end), + // i.e. the one-past-the-end index of p_grad is (s_end + 1, t_end + 1). + int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin), + block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin); + + if (block_S <= 0 || block_T <= 0) continue; + + // Load px_buf and py_buf. At this point we just set them to the px and py + // for this block. + for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { + int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + // We let px and py default to -infinity if they are out of range, which + // will cause xderiv and yderiv for out-of-range values to be zero, and + // cause correct behavior in edge cases (for the top and right blocks). + // The issue is that p and p_grad are of larger size than px and py. + scalar_t this_px = -INFINITY; + if (s < s_end && t <= t_end) this_px = px[b][s][t]; + px_buf[s_in_block][t_in_block] = this_px; + scalar_t this_py = -INFINITY; + if (s <= s_end && t < t_end) this_py = py[b][s][t]; + py_buf[s_in_block][t_in_block] = this_py; + } + __syncthreads(); + + // load p. + for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); + i += blockDim.x) { + int s_in_block = i / (BLOCK_SIZE + 1), t_in_block = i % (BLOCK_SIZE + 1), + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + // Setting 0.0 for out-of-bounds elements of p, together with setting + // -INFINITY for out-of-bounds elements of px_buf and py_buf, will + // ensure that we do the right thing in top and right edge cases, + // i.e. that no derivatives will be propagated from out-of-bounds points + // because the corresponding xderiv and yderiv values will be zero. + scalar_t this_p = 0.0; + if (s <= s_end && t <= t_end) this_p = p[b][s][t]; + // if this_p is -inf, replace with large finite negative value, to avoid + // NaN's below. + // TODO: use a value that would work correctly in half precision + if (this_p < -1.0e+30) this_p = -1.0e+30; + p_buf[s_in_block][t_in_block] = this_p; + } + __syncthreads(); + + // Set xderiv and yderiv; see (eq. 4) and (eq. 5). + for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { + // We can apply this formula to the entire block even if we are processing + // a partial block; we have ensured that x_buf and y_buf contain + // -infinity, and p contains 0, for out-of-range elements, so we'll get + // x_buf and y_buf containing 0 after applying the followin formulas. + int s = i / BLOCK_SIZE, t = i % BLOCK_SIZE; + // Mathematically the following is doing: + // xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + // (with an offset on the s and t indexes) + px_buf[s][t] = exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]); + // Mathematically the following is doing: + // yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) + // (with an offset on the s and t indexes) + py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); + } + + __syncthreads(); + + // Load p_grad for the top and right elements in p_buf: i.e. for elements + // p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't + // need to load the top-right corner [block_S][block_T]; that location will + // never be accessed. + // These are the p_grad values computed by previous instances of this kernel + // If this is one of the top or right blocks, some or all of the p_grad + // values we'd be reading here will be out of range, and we use zeros + // to ensure no gradient gets propagated from those positions. + if (threadIdx.x < block_S) { + int s_in_block = threadIdx.x, t_in_block = block_T, + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + p_buf[s_in_block][t_in_block] = + (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0); + } else if (static_cast(static_cast(threadIdx.x) - 64) < + static_cast(block_T)) { + // casting to unsigned before the comparison tests for both negative and + // out-of-range values of (int)threadIdx.x - 64. + int s_in_block = block_S, t_in_block = static_cast(threadIdx.x) - 64, + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + p_buf[s_in_block][t_in_block] = + (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0); + } + + __syncthreads(); + + // The highest-numbered value in p_buf that we need (corresponding, + // of course, to p_grad), is: + // p_buf[block_S - 1][block_T - 1], + // and the inner iteration number (i) on which we set this is the sum of + // these indexes, i.e. (block_S - 1) + (block_T - 1). + bool is_final_block = (s_block_begin + block_S == s_end + 1 && + t_block_begin + block_T == t_end + 1); + + int first_iter = block_S + block_T - 2; + if (is_final_block) { + // The following statement corresponds to: + // p_grad[b][s_end][t_end] = ans_grad[b] + // Normally this element of p_buf would be set by the first iteration of + // the loop below, so if it's set this way we have to decrement first_iter + // to prevent it from being overwritten. + p_buf[block_S - 1][block_T - 1] = ans_grad[b]; + --first_iter; + } + + { + int s = threadIdx.x; + for (int i = first_iter; i >= 0; --i) { + __syncwarp(); + int t = i - s; + if (s < block_S && + static_cast(t) < static_cast(block_T)) { + // The following statement is really operating on the gradients; + // it corresponds, with offsets of s_block_begin and t_block_begin + // on the indexes, to (eq. 6) defined above, i.e.: + // p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + + // p_grad[b][s][t + 1] * yderiv[b][s][t] + p_buf[s][t] = + (p_buf[s + 1][t] * px_buf[s][t] + p_buf[s][t + 1] * py_buf[s][t]); + } + } + } + + __syncthreads(); + + // Write out p_grad, px_grad and py_grad. + for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { + int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + // s_end and t_end are the one-past-the-end of the (x,y) sequences, but + // the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1). + if (t <= t_end && s <= s_end) { + p_grad[b][s][t] = p_buf[s_in_block][t_in_block]; + + if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1] + // From (eq. 7): + // px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] * + px_buf[s_in_block][t_in_block]); + } + if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T] + // from (eq. 8): + // py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] + py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] * + py_buf[s_in_block][t_in_block]); + } + } + } + + if (threadIdx.x == 0 && s_block_begin == s_begin && + t_block_begin == t_begin && overwrite_ans_grad) + ans_grad[b] = p_buf[0][0]; + } +} + +// forward of mutual_information. See """... """ comment of +// `mutual_information` in mutual_information.py for documentation of the +// behavior of this function. +torch::Tensor MutualInformationCuda(torch::Tensor px, torch::Tensor py, + torch::optional opt_boundary, + torch::Tensor p) { + TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); + TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); + TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); + TORCH_CHECK( + px.device().is_cuda() && py.device().is_cuda() && p.device().is_cuda(), + "inputs must be CUDA tensors"); + + auto scalar_t = px.scalar_type(); + auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); + + const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); + TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); + + auto boundary = opt_boundary.value_or( + torch::tensor({0, 0, S, T}, + torch::dtype(torch::kInt64).device(px.device())) + .reshape({1, 4}) + .expand({B, 4})); + TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); + TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64); + + torch::Tensor ans = torch::empty({B}, opts); + + // num_threads and num_blocks and BLOCK_SIZE can be tuned. + // (however, num_threads may not be less than 128). + const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32; + + // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), + // so dividing by BLOCK_SIZE rounding up we get e.g. + // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 + const int num_s_blocks = S / BLOCK_SIZE + 1, + num_t_blocks = T / BLOCK_SIZE + 1, + num_iters = num_s_blocks + num_t_blocks - 1; + + AT_DISPATCH_FLOATING_TYPES( + px.scalar_type(), "mutual_information_cuda_stub", ([&] { + for (int iter = 0; iter < num_iters; ++iter) { + mutual_information_kernel + <<>>( + px.packed_accessor32(), + py.packed_accessor32(), + p.packed_accessor32(), + boundary.packed_accessor32(), + ans.packed_accessor32(), iter); + } + })); + return ans; +} + +// backward of mutual_information; returns (grad_px, grad_py) +// If overwrite_ans_grad == true, will overwrite ans_grad with a value which +// should be identical to the original ans_grad if the computation worked +// as it should. +std::vector MutualInformationBackwardCuda( + torch::Tensor px, torch::Tensor py, + torch::optional opt_boundary, torch::Tensor p, + torch::Tensor ans_grad, bool overwrite_ans_grad) { + TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); + TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); + TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); + TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); + + TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && + p.device().is_cuda() && ans_grad.device().is_cuda() && + "inputs must be CUDA tensors"); + + auto scalar_t = px.scalar_type(); + auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); + + const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + + TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); + TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); + + auto boundary = opt_boundary.value_or( + torch::tensor({0, 0, S, T}, + torch::dtype(torch::kInt64).device(px.device())) + .reshape({1, 4}) + .expand({B, 4})); + TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); + TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64); + TORCH_CHECK(ans_grad.size(0) == B); + + bool has_boundary = opt_boundary.has_value(); + + torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), + px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) + : torch::empty({B, S, T + 1}, opts)), + py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) + : torch::empty({B, S + 1, T}, opts)); + + // num_threads and num_blocks and BLOCK_SIZE can be tuned. + // (however, num_threads may not be less than 128). + const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32; + + // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), + // so dividing by BLOCK_SIZE rounding up we get e.g. + // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 + const int num_s_blocks = S / BLOCK_SIZE + 1, + num_t_blocks = T / BLOCK_SIZE + 1, + num_iters = num_s_blocks + num_t_blocks - 1; + + AT_DISPATCH_FLOATING_TYPES( + px.scalar_type(), "mutual_information_backward_stub", ([&] { + for (int iter = num_iters - 1; iter >= 0; --iter) { + mutual_information_backward_kernel + <<>>( + px.packed_accessor32(), + py.packed_accessor32(), + p.packed_accessor32(), + ans_grad.packed_accessor32(), + p_grad.packed_accessor32(), + px_grad.packed_accessor32(), + py_grad.packed_accessor32(), + boundary.packed_accessor32(), iter, + overwrite_ans_grad); + } + })); + return std::vector({px_grad, py_grad}); +} +} // namespace k2 diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index bff893967..0f9d4e51b 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -52,13 +52,26 @@ from .fsa_algo import top_sort from .fsa_algo import union from .fsa_properties import to_str as properties_to_str +from .mutual_information import joint_mutual_information_recursion +from .mutual_information import mutual_information_recursion from .nbest import Nbest from .ops import cat from .ops import compose_arc_maps from .ops import index_add from .ops import index_fsa from .ops import index_select -# + +from .rnnt_loss import do_rnnt_pruning +from .rnnt_loss import get_rnnt_logprobs +from .rnnt_loss import get_rnnt_logprobs_joint +from .rnnt_loss import get_rnnt_logprobs_pruned +from .rnnt_loss import get_rnnt_logprobs_smoothed +from .rnnt_loss import get_rnnt_prune_ranges +from .rnnt_loss import rnnt_loss +from .rnnt_loss import rnnt_loss_pruned +from .rnnt_loss import rnnt_loss_simple +from .rnnt_loss import rnnt_loss_smoothed + from .symbol_table import SymbolTable from .utils import create_fsa_vec from .utils import create_sparse diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py new file mode 100644 index 000000000..00123f806 --- /dev/null +++ b/k2/python/k2/mutual_information.py @@ -0,0 +1,302 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import _k2 +from torch import Tensor +from typing import Tuple, Optional, Sequence, Union + + +class MutualInformationRecursionFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + px: torch.Tensor, + py: torch.Tensor, + boundary: Optional[torch.Tensor] = None, + return_grad: bool = False, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + (B, S, T1) = px.shape + T = T1 - 1 + assert py.shape == (B, S + 1, T) + if boundary is not None: + assert boundary.shape == (B, 4) + + # p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the + # the mutual information of the pair of subsequences of x and y that + # are of length s and t respectively. p[0][0] will be 0.0 and p[S][T] + # is the mutual information of the entire pair of sequences, + # i.e. of lengths S and T respectively. + # It is computed as follows (in C++ and CUDA): + # p[b,0,0] = 0.0 + # p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + # p[b,s,t-1] + py[b,s,t-1]) + # if s > 0 or t > 0, + # treating values with any -1 index as -infinity. + # .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. + + p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype) + + ans = _k2.mutual_information_forward(px, py, boundary, p) + + px_grad, py_grad = None, None + if return_grad or px.requires_grad or py.requires_grad: + ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) + (px_grad, py_grad) = _k2.mutual_information_backward( + px, py, boundary, p, ans_grad + ) + ctx.save_for_backward(px_grad, py_grad) + return ans, px_grad, py_grad + + @staticmethod + def backward( + ctx, ans_grad: Tensor, dummy_px_grad: Tensor, dummy_py_grad: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: + (px_grad, py_grad) = ctx.saved_tensors + (B,) = ans_grad.shape + ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1) + px_grad *= ans_grad + py_grad *= ans_grad + return (px_grad, py_grad, None, None) + + +def mutual_information_recursion( + px: Tensor, + py: Tensor, + boundary: Optional[Tensor] = None, + return_grad: bool = False, +) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: + """A recursion that is useful in computing mutual information between two + sequences of real vectors, but may be useful more generally in + sequence-to-sequence tasks where monotonic alignment between pairs of + sequences is desired. The definitions of the arguments are definitions that + would be used when computing this type of mutual information, but you can + also view them as arbitrary quantities and just make use of the formula + computed by this function. + + Args: + px: + A torch.Tensor of some floating point type, with shape [B][S][T+1], + where B is the batch size, S is the length of the 'x' sequence + (including representations of EOS symbols but not BOS symbols), and + S is the length of the 'y' sequence (including representations of + EOS symbols but not BOS symbols). In the mutual information + application, px[b][s][t] would represent the following log odds + ratio; ignoring the b index on the right to make the notation more + compact, + + px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] + + This expression also implicitly includes the log-probability of + choosing to generate an x value as opposed to a y value. In + practice it might be computed as a + b, where a is the log + probability of choosing to extend the sequence of length (s,t) + with an x as opposed to a y value; and b might in practice be + of the form: + log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) + where N is the number of terms that the sum over t' included, which + might include some or all of the other sequences as well as this one. + + Note: we don't require px and py to be contiguous, but the + code assumes for optimization purposes that the T axis has + stride 1. + + py: + A torch.Tensor of the same dtype as px, with shape [B][S+1][T], + representing + + py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] + + This function does not treat x and y differently; the only difference + is that for optimization purposes we assume the last axis + (the t axis) has stride of 1; this is true if px and py are + contiguous. + + boundary: + If supplied, a torch.LongTensor of shape [B][4], where each + row contains [s_begin, t_begin, s_end, t_end], + with 0 <= s_begin <= s_end < S and 0 <= t_begin <= t_end < T + (this implies that empty sequences are allowed). + If not supplied, the values [0, 0, S, T] will be assumed. + These are the beginning and one-past-the-last positions in the x and + y sequences respectively, and can be used if not all sequences are + of the same length. + + return_grad: + Whether to return grads of px and py, this grad standing for the + occupation probability is the output of the backward with a + `fake gradient` input (all ones) This is useful to implement the + pruned version of rnnt loss. + + Returns: + Returns a torch.Tensor of shape [B], containing the log of the mutual + information between the b'th pair of sequences. This is defined by + the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]), + representing a mutual information between sub-sequences of lengths s + and t: + + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + (if s > 0 or t > 0) + + where we handle edge cases by treating quantities with negative indexes + as -infinity. The extension to cases where the boundaries are specified + should be obvious; it just works on shorter sequences with offsets into + px and py. + """ + assert px.ndim == 3 + B, S, T1 = px.shape + T = T1 - 1 + assert py.shape == (B, S + 1, T) + assert px.dtype == py.dtype + (B, S, T) = px.shape + if boundary is not None: + assert boundary.dtype == torch.int64 + assert boundary.shape == (B, 4) + for s_begin, t_begin, s_end, t_end in boundary.tolist(): + assert 0 <= s_begin <= s_end <= S + assert 0 <= t_begin <= t_end <= T + # The following assertions are for efficiency + assert px.is_contiguous() + assert py.is_contiguous() + + m, px_grad, py_grad = MutualInformationRecursionFunction.apply( + px, py, boundary, return_grad + ) + return (m, (px_grad, py_grad)) if return_grad else m + + +def _inner_product(a: Tensor, b: Tensor) -> Tensor: + """ + Does inner product on the last dimension, with expected broadcasting, + i.e. equivalent to (a * b).sum(dim=-1) + without creating a large temporary. + """ + assert a.shape[-1] == b.shape[-1] # The last dim must be equal + a = a.unsqueeze(-2) # (..., 1, K) + b = b.unsqueeze(-1) # (..., K, 1) + c = torch.matmul(a, b) # (..., 1, 1) + return c.squeeze(-1).squeeze(-1) + + +def joint_mutual_information_recursion( + px: Sequence[Tensor], + py: Sequence[Tensor], + boundary: Optional[Tensor] = None, +) -> Sequence[Tensor]: + """A recursion that is useful for modifications of RNN-T and similar loss + functions, where the recursion probabilities have a number of terms and you + want them reported separately. See mutual_information_recursion() for more + documentation of the basic aspects of this. + + Args: + px: + a sequence of Tensors, each of the same shape [B][S][T+1] + py: + a sequence of Tensor, each of the same shape [B][S+1][T], + the sequence must be the same length as px. + boundary: + optionally, a LongTensor of shape [B][4] containing rows + [s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S + and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T]. + These are the beginning and one-past-the-last positions in the x + and y sequences respectively, and can be used if not all + sequences are of the same length. + Returns: + a Tensor of shape (len(px), B), + whose sum over dim 0 is the total log-prob of the recursion mentioned + below, per sequence. The first element of the sequence of length len(px) + is "special", in that it has an offset term reflecting the difference + between sum-of-log and log-of-sum; for more interpretable loss values, + the "main" part of your loss function should be first. + + The recursion below applies if boundary == None, when it defaults + to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px + and py: + + p = tensor of shape (B, S+1, T+1), containing -infinity + p[b,0,0] = 0.0 + # do the following in loop over s and t: + p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t], + p[b,s,t-1] + py_sum[b,s,t-1]) + (if s > 0 or t > 0) + return b[:][S][T] + + This function lets you implement the above recursion efficiently, except + that it gives you a breakdown of the contribution from all the elements of + px and py separately. As noted above, the first element of the + sequence is "special". + """ + N = len(px) + assert len(py) == N and N > 0 + B, S, T1 = px[0].shape + T = T1 - 1 + assert py[0].shape == (B, S + 1, T) + assert px[0].dtype == py[0].dtype + + px_cat = torch.stack(px, dim=0) # (N, B, S, T+1) + py_cat = torch.stack(py, dim=0) # (N, B, S+1, T) + px_tot = px_cat.sum(dim=0) # (B, S, T+1) + py_tot = py_cat.sum(dim=0) # (B, S+1, T) + + if boundary is not None: + assert boundary.dtype == torch.int64 + assert boundary.shape == (B, 4) + for s_begin, t_begin, s_end, t_end in boundary.tolist(): + assert 0 <= s_begin <= s_end <= S + assert 0 <= t_begin <= t_end <= T + + px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous() + # The following assertions are for efficiency + assert px_tot.ndim == 3 + assert py_tot.ndim == 3 + + p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype) + + # note, tot_probs is without grad. + tot_probs = _k2.mutual_information_forward(px_tot, py_tot, boundary, p) + + # this is a kind of "fake gradient" that we use, in effect to compute + # occupation probabilities. The backprop will work regardless of the + # actual derivative w.r.t. the total probs. + ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype) + + (px_grad, py_grad) = _k2.mutual_information_backward( + px_tot, py_tot, boundary, p, ans_grad + ) + + px_grad = px_grad.reshape(1, B, -1) + py_grad = py_grad.reshape(1, B, -1) + px_cat = px_cat.reshape(N, B, -1) + py_cat = py_cat.reshape(N, B, -1) + # get rid of -inf, would generate nan on product with 0 + px_cat = px_cat.clamp(min=torch.finfo(px_cat.dtype).min) + py_cat = py_cat.clamp(min=torch.finfo(py_cat.dtype).min) + + x_prods = _inner_product(px_grad, px_cat) # (N, B) + y_prods = _inner_product(py_grad, py_cat) # (N, B) + + # If all the occupation counts were exactly 1.0 (i.e. no partial counts), + # "prods" should be equal to "tot_probs"; however, in general, "tot_probs" + # will be more positive due to the difference between log-of-sum and + # sum-of-log + prods = x_prods + y_prods # (N, B) + with torch.no_grad(): + offset = tot_probs - prods.sum(dim=0) # (B,) + prods[0] += offset + return prods # (N, B) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py new file mode 100644 index 000000000..9823fac87 --- /dev/null +++ b/k2/python/k2/rnnt_loss.py @@ -0,0 +1,1012 @@ +# Copyright 2021 Xiaomi Corp. (author: Daniel Povey, Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import k2 +import torch +from torch import Tensor +from typing import Optional, Tuple, Union +from .mutual_information import mutual_information_recursion + + +def get_rnnt_logprobs( + lm: Tensor, + am: Tensor, + symbols: Tensor, + termination_symbol: int, + boundary: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """ + Reduces RNN-T problem (the simple case, where joiner network is just + addition), to a compact, standard form that can then be given + (with boundaries) to mutual_information_recursion(). + This function is called from rnnt_loss_simple(), but may be useful for + other purposes. + + Args: + lm: + Language model part of un-normalized logprobs of symbols, to be added to + acoustic model part before normalizing. Of shape: + [B][S+1][C] + where B is the batch size, S is the maximum sequence length of + the symbol sequence, possibly including the EOS symbol; and + C is size of the symbol vocabulary, including the termination/next-frame + symbol. + Conceptually, lm[b][s] is a vector of length [C] representing the + "language model" part of the un-normalized logprobs of symbols, + given all symbols *earlier than* s in the sequence. The reason + we still need this for position S is that we may still be emitting + the termination/next-frame symbol at this point. + am: + Acoustic-model part of un-normalized logprobs of symbols, to be added + to language-model part before normalizing. Of shape: + [B][T][C] + where B is the batch size, T is the maximum sequence length of + the acoustic sequences (in frames); and C is size of the symbol + vocabulary, including the termination/next-frame symbol. It reflects + the "acoustic" part of the probability of any given symbol appearing + next on this frame. + symbols: + A LongTensor of shape [B][S], containing the symbols at each position + of the sequence, possibly including EOS + termination_symbol: + The identity of the termination symbol, must be in {0..C-1} + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + Returns: + (px, py) (the names are quite arbitrary). + px: logprobs, of shape [B][S][T+1] + py: logprobs, of shape [B][S+1][T] + in the recursion: + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences + of length s and t respectively. px[b][s][t] represents the + probability of extending the subsequences of length (s,t) by one in + the s direction, given the particular symbol, and py[b][s][t] + represents the probability of extending the subsequences of length + (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame + we cannot emit any symbols. This is simply a way of incorporating + the probability of the termination symbol on the last frame. + """ + assert lm.ndim == 3 + assert am.ndim == 3 + assert lm.shape[0] == am.shape[0] + assert lm.shape[2] == am.shape[2] + + (B, T, C) = am.shape + S = lm.shape[1] - 1 + assert symbols.shape == (B, S) + + # subtracting am_max and lm_max is to ensure the probs are in a good range + # to do exp() without causing underflow or overflow. + am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1] + lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1] + am_probs = (am - am_max).exp() + lm_probs = (lm - lm_max).exp() + # normalizers: [B][S+1][T] + normalizers = ( + torch.matmul(lm_probs, am_probs.transpose(1, 2)) + + torch.finfo(am_probs.dtype).tiny + ).log() + + # add lm_max and am_max to normalizers, to make it as if we had not + # subtracted am_max and lm_max above. + normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T] + + # px is the probs of the actual symbols.. + px_am = torch.gather( + am.unsqueeze(1).expand(B, S, T, C), + dim=3, + index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1), + ).squeeze( + -1 + ) # [B][S][T] + + px_am = torch.cat( + ( + px_am, + torch.full( + (B, S, 1), float("-inf"), device=px_am.device, dtype=px_am.dtype + ), + ), + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. + + if boundary is not None: + assert boundary.shape == (B, 4) + mask = ( + torch.arange(0, T + 1, device=px_am.device) + .reshape(1, T + 1) + .expand(B, T + 1) + ) + mask = mask < boundary[:, 3].reshape(B, 1) + mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) + px_am = torch.where( + mask, + px_am, + torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device), + ) + + px_lm = torch.gather( + lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + + px = px_am + px_lm # [B][S][T+1], last slice with indexes out of + # boundary is -inf + px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] + + # py is the probs of termination symbols, of shape [B][S+1][T] + py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T] + py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] + py = py_am + py_lm - normalizers + + return (px, py) + + +def rnnt_loss_simple( + lm: Tensor, + am: Tensor, + symbols: Tensor, + termination_symbol: int, + boundary: Optional[Tensor] = None, + return_grad: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: + """A simple case of the RNN-T loss, where the 'joiner' network is just + addition. Returns negated total loss value. + + Args: + lm: + language-model part of unnormalized log-probs of symbols, with shape + (B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes + am: + acoustic-model part of unnormalized log-probs of symbols, with shape + (B, T, C), i.e. batch, frame, num_classes + symbols: + the symbol sequences, a LongTensor of shape [B][S], and elements in + {0..C-1}. + termination_symbol: + the termination symbol, with 0 <= termination_symbol < C + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + return_grad: + Whether to return grads of px and py, this grad standing for the + occupation probability is the output of the backward with a + `fake gradient` input (all ones) This is useful to implement the + pruned version of rnnt loss. + Returns: + If return_grad is False, returns a Tensor of shape (B,), containing the + NEGATED total RNN-T loss values for each element of the batch + (like log-probs of sequences). + If return_grad is True, the grads of px and py, which is the output of + backward with a `fake gradient` input, will be returned too. And the + returned value will be a tuple like (loss, (px_grad, py_grad)). + """ + px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol, boundary) + return mutual_information_recursion(px, py, boundary, return_grad) + + +def get_rnnt_logprobs_joint( + joint: Tensor, + symbols: Tensor, + termination_symbol: int, + boundary: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """Reduces RNN-T problem to a compact, standard form that can then be given + (with boundaries) to mutual_information_recursion(). + This function is called from rnnt_loss(). + + Args: + joint: + The output of joiner network, with shape (B, T, S + 1, C), + i.e. batch, time_seq_len, symbol_seq_len+1, num_classes + symbols: + A LongTensor of shape [B][S], containing the symbols at each position + of the sequence, possibly including EOS + termination_symbol: + The identity of the termination symbol, must be in {0..C-1} + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + Returns: + (px, py) (the names are quite arbitrary). + px: logprobs, of shape [B][S][T+1] + py: logprobs, of shape [B][S+1][T] + in the recursion: + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences of + length s and t respectively. px[b][s][t] represents the probability of + extending the subsequences of length (s,t) by one in the s direction, + given the particular symbol, and py[b][s][t] represents the probability + of extending the subsequences of length (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame + we cannot emit any symbols. This is simply a way of incorporating + the probability of the termination symbol on the last frame. + """ + assert joint.ndim == 4 + (B, T, S1, C) = joint.shape + S = S1 - 1 + assert symbols.shape == (B, S) + + normalizers = torch.logsumexp(joint, dim=3) + normalizers = normalizers.permute((0, 2, 1)) + + px = torch.gather( + joint, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) + ).squeeze(-1) + px = px.permute((0, 2, 1)) + px = torch.cat( + ( + px, + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), + ), + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. + + if boundary is not None: + assert boundary.shape == (B, 4) + mask = ( + torch.arange(0, T + 1, device=px.device) + .reshape(1, T + 1) + .expand(B, T + 1) + ) + mask = mask < boundary[:, 3].reshape(B, 1) + mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) + px = torch.where( + mask, + px, + torch.tensor(float("-inf"), dtype=px.dtype, device=px.device), + ) + + px[:, :, :T] -= normalizers[:, :S, :] + + py = ( + joint[:, :, :, termination_symbol].permute((0, 2, 1)).clone() + ) # [B][S+1][T] + py -= normalizers + px = px.contiguous() + py = py.contiguous() + + return (px, py) + + +def rnnt_loss( + joint: Tensor, + symbols: Tensor, + termination_symbol: int, + boundary: Optional[Tensor] = None, +) -> Tensor: + """A normal RNN-T loss, which uses a 'joiner' network output as input, + i.e. a 4 dimensions tensor. + + Args: + joint: + The output of joiner network, with shape (B, T, S + 1, C), + i.e. batch, time_seq_len, symbol_seq_len+1, num_classes + symbols: + The symbol sequences, a LongTensor of shape [B][S], and elements + in {0..C-1}. + termination_symbol: + the termination symbol, with 0 <= termination_symbol < C + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + + Returns: + A Tensor of shape (B,), containing the total RNN-T loss values for each + element of the batch (like log-probs of sequences). + """ + px, py = get_rnnt_logprobs_joint( + joint, symbols, termination_symbol, boundary + ) + return mutual_information_recursion(px, py, boundary) + + +def _adjust_pruning_lower_bound( + s_begin: torch.Tensor, s_range: int +) -> torch.Tensor: + """Adjust s_begin (pruning lower bound) to make it satisfied the following + constrains + + - monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1] + - start with symbol 0 at first frame. + - s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip + any symbols. + + To make it monotonic increasing, we can use `monotonic_lower_bound` function + in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is: + traverse the array in reverse order and update the elements by + `min_value = min(a_begin[i], min_value)`, the initial `min_value` set to + `inf`. + + The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range` + constrain is a little tricky. We first transform `s_begin` with + `s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))` + then we make the transformed `s_begin` monotonic increasing, after that, + we transform back `s_begin` with the same formula as the previous + transformation. The idea is: if we want to make + `s_begin[i + 1] - s_begin[i] < s_range` we only need to make + `-(s_begin[i] - i * (s_range - 1))` a non-decreasing array. Proof: + + -(s_begin[i] - i * (s_range - 1)) <= -(s_begin[i + 1] - (i + 1) * (s_range - 1)) + -s_begin[i] <= -s_begin[i + 1] + (i + 1) * (s_range - 1) - i * (s_range - 1) + -s_begin[i] <= -s_begin[i + 1] + s_range - 1 + s_begin[i + 1] - s_begin[i] <= s_range - 1 + s_begin[i + 1] - s_begin[i] < s_range + + The above transformation can not guarantee the start symbol to be 0, so we + have to make all the elements that less than 0 to be 0 before transforming + back the `s_begin`. + """ + # s_begin (B, T) + (B, T) = s_begin.shape + s_begin = k2.monotonic_lower_bound(s_begin) + # do the magic transformation + s_begin = -( + s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) + ) + # make the transformed tensor to be non-decreasing + s_begin = k2.monotonic_lower_bound(s_begin) + # make start symbol to be zero. + s_begin = torch.where(s_begin < 0, 0, s_begin) + # do the magic transformation again to recover s_begin + s_begin = -( + s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) + ) + return s_begin + + +def get_rnnt_prune_ranges( + px_grad: torch.Tensor, + py_grad: torch.Tensor, + boundary: torch.Tensor, + s_range: int, +) -> torch.Tensor: + """Get the pruning ranges of normal rnnt loss according to the grads + of px and py returned by mutual_information_recursion. + + For each sequence with T frames, we will generate a tensor with the shape of + (T, s_range) containing the information that which symbols will be token + into consideration for each frame. For example, here is a sequence with 10 + frames and the corresponding symbols are `[A B C D E F]`, if the s_range + equals 3, one possible ranges tensor will be: + + [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [1, 2, 3], + [1, 2, 3], [1, 2, 3], [3, 4, 5], [3, 4, 5], [3, 4, 5]] + + which means we only consider `[A B C]` at frame 0, 1, 2, 3, and `[B C D]` + at frame 4, 5, 6, `[D E F]` at frame 7, 8, 9. + + We can only consider limited number of symbols because frames and symbols + are monotonic aligned, theoretically it can only generate particular range + of symbols given a particular frame. + + Note: + For the generated tensor ranges, ranges[:, 0] is a monotonic increasing + tensor from 0 to `len(symbols)` and it satisfies + `ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any + symbols. + + Args: + px_grad: + The gradient of px, see docs in `mutual_information_recursion` for more + details of px. + py_grad: + The gradient of py, see docs in `mutual_information_recursion` for more + details of py. + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] + s_range: + How many symbols to keep for each frame. + Returns: + A tensor contains the kept symbols indexes for each frame, with shape + (B, T, s_range). + """ + (B, S, T1) = px_grad.shape + T = T1 - 1 + assert py_grad.shape == (B, S + 1, T) + assert boundary.shape == (B, 4) + assert s_range >= 1 + if s_range > S: + s_range = S + + px_pad = torch.zeros( + (B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device + ) + py_pad = torch.zeros( + (B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device + ) + tot_grad = torch.cat((px_grad, px_pad), dim=1) + torch.cat( + (py_grad, py_pad), dim=2 + ) # (B, S + 1, T + 1) + + tot_grad = torch.cat( + ( + torch.zeros( + (B, 1, T + 1), dtype=tot_grad.dtype, device=tot_grad.device + ), + tot_grad, + ), + dim=1, + ) + tot_grad = torch.cumsum(tot_grad, dim=1) + diff_grad = tot_grad[:, s_range:, :] - tot_grad[:, 0:-s_range, :] + s_begin = torch.argmax(diff_grad, dim=1) + s_begin = s_begin[:, :T] + + # handle the values of s_begin in padding positions. + # set the s_begin in paddding positions to `len(symbols) - s_range + 1` + mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T) + mask = mask < boundary[:, 3].reshape(B, 1) + + s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 + # handle the cases when `len(symbols) < s_range` + s_begin_padding = torch.where(s_begin_padding >= 0, s_begin_padding, 0) + + s_begin = torch.where(mask, s_begin, s_begin_padding) + + # adjusting lower bound to make it satisfied constrains, see docs in + # `adjust_pruning_lower_bound` for more details of these constrains. + s_begin = _adjust_pruning_lower_bound(s_begin, s_range) + ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange( + s_range, device=px_grad.device + ) + return ranges + + +def do_rnnt_pruning( + am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Prune the output of encoder(am) output and prediction network(lm) + output of RNNT. + + Args: + am: + The encoder output, with shape (B, T, C) + lm: + The prediction network output, with shape (B, S + 1, C) + ranges: + A tensor containing the symbol indexes for each frame that we want to + keep. Its shape is (B, T, s_range), see the docs in + `get_rnnt_prune_ranges` for more details of this tensor. + + Returns: + Return the pruned am and lm with shape (B, T, s_range, C) + """ + # am (B, T, C) + # lm (B, S + 1, C) + # ranges (B, T, s_range) + assert ranges.shape[0] == am.shape[0] + assert ranges.shape[0] == lm.shape[0] + assert am.shape[1] == ranges.shape[1] + (B, T, s_range) = ranges.shape + (B, S1, C) = lm.shape + S = S1 - 1 + + # (B, T, s_range, C) + am_pruning = am.unsqueeze(2).expand((B, T, s_range, C)) + + # (B, T, s_range, C) + lm_pruning = torch.gather( + lm.unsqueeze(1).expand((B, T, S + 1, C)), + dim=2, + index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)), + ) + return am_pruning, lm_pruning + + +def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): + """Roll tensor with different shifts for each row. + + Note: + We assume the src is a 3 dimensions tensor and roll the last dimension. + + Example: + + >>> src = torch.arange(15).reshape((1,3,5)) + >>> src + tensor([[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]]]) + >>> shift = torch.tensor([[1, 2, 3]]) + >>> shift + tensor([[1, 2, 3]]) + >>> _roll_by_shifts(src, shift) + tensor([[[ 4, 0, 1, 2, 3], + [ 8, 9, 5, 6, 7], + [12, 13, 14, 10, 11]]]) + """ + assert src.dim() == 3 + (B, T, S) = src.shape + assert shifts.shape == (B, T) + + index = ( + torch.arange(S, device=src.device) + .view((1, S)) + .repeat((T, 1)) + .repeat((B, 1, 1)) + ) + index = (index - shifts.reshape(B, T, 1)) % S + return torch.gather(src, 2, index) + + +def get_rnnt_logprobs_pruned( + joint: Tensor, + symbols: Tensor, + ranges: Tensor, + termination_symbol: int, + boundary: Tensor, +) -> Tuple[Tensor, Tensor]: + """Construct px, py for mutual_information_recursion with pruned output. + + Args: + joint: + The pruned output of joiner network, with shape (B, T, s_range, C) + symbols: + The symbol sequences, a LongTensor of shape [B][S], and elements in + {0..C-1}. + ranges: + A tensor containing the symbol ids for each frame that we want to keep. + termination_symbol: + the termination symbol, with 0 <= termination_symbol < C + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + Returns: + Return the px (B, S, T + 1) and py (B, S + 1, T) needed by + mutual_information_recursion. + """ + # joint (B, T, s_range, C) + # symbols (B, S) + # ranges (B, T, s_range) + assert joint.ndim == 4 + (B, T, s_range, C) = joint.shape + assert ranges.shape == (B, T, s_range) + (B, S) = symbols.shape + + normalizers = torch.logsumexp(joint, dim=3) + + symbols_with_terminal = torch.cat( + ( + symbols, + torch.tensor( + [termination_symbol] * B, + dtype=torch.int64, + device=symbols.device, + ).reshape((B, 1)), + ), + dim=1, + ) + + # (B, T, s_range) + pruning_symbols = torch.gather( + symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)), + dim=2, + index=ranges, + ) + + # (B, T, s_range) + px = torch.gather( + joint, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) + ).squeeze(-1) + px = px - normalizers + + # (B, T, S) with index larger than s_range in dim 2 fill with -inf + px = torch.cat( + ( + px, + torch.full( + (B, T, S + 1 - s_range), + float("-inf"), + device=px.device, + dtype=px.dtype, + ), + ), + dim=2, + ) + + # (B, T, S) with index out of s_range in dim 2 fill with -inf + px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S] + + px = px.permute((0, 2, 1)) + px = torch.cat( + ( + px, + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), + ), + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. + + if boundary is not None: + assert boundary.shape == (B, 4) + mask = ( + torch.arange(0, T + 1, device=px.device) + .reshape(1, T + 1) + .expand(B, T + 1) + ) + mask = mask < boundary[:, 3].reshape(B, 1) + mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) + px = torch.where( + mask, + px, + torch.tensor(float("-inf"), dtype=px.dtype, device=px.device), + ) + + py = joint[:, :, :, termination_symbol] # (B, T, s_range) + py = py - normalizers + + # (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf + py = torch.cat( + ( + py, + torch.full( + (B, T, S + 1 - s_range), + float("-inf"), + device=py.device, + dtype=py.dtype, + ), + ), + dim=2, + ) + + # (B, T, S + 1) with index out of s_range in dim 2 fill with -inf + py = _roll_by_shifts(py, ranges[:, :, 0]) + # (B, S + 1, T) + py = py.permute((0, 2, 1)) + + px = px.contiguous() + py = py.contiguous() + return (px, py) + + +def rnnt_loss_pruned( + joint: Tensor, + symbols: Tensor, + ranges: Tensor, + termination_symbol: int, + boundary: Tensor = None, +) -> Tensor: + """A RNN-T loss with pruning, which uses a pruned 'joiner' network output + as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), + s_range means the symbols number kept for each frame. + + Args: + joint: + The pruned output of joiner network, with shape (B, T, s_range, C), + i.e. batch, time_seq_len, prune_range, num_classes + symbols: + A LongTensor of shape [B][S], containing the symbols at each position + of the sequence, possibly including EOS + ranges: + A tensor containing the symbol ids for each frame that we want to keep. + termination_symbol: + The identity of the termination symbol, must be in {0..C-1} + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + Returns: + A Tensor of shape (B,), containing the total RNN-T loss values for each + element of the batch (like log-probs of sequences). + """ + px, py = get_rnnt_logprobs_pruned( + joint, symbols, ranges, termination_symbol, boundary + ) + return mutual_information_recursion(px, py, boundary) + + +def get_rnnt_logprobs_smoothed( + lm: Tensor, + am: Tensor, + symbols: Tensor, + termination_symbol: int, + lm_only_scale: float = 0.1, + am_only_scale: float = 0.1, + boundary: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """Reduces RNN-T problem (the simple case, where joiner network is just + addition), to a compact, standard form that can then be given + (with boundaries) to mutual_information_recursion(). + This version allows you to make the loss-function one of the form: + lm_only_scale * lm_probs + + am_only_scale * am_probs + + (1-lm_only_scale-am_only_scale) * combined_probs + where lm_probs and am_probs are the probabilities given the lm and acoustic + model independently. + + This function is called from + rnnt_loss_smoothed(), but may be useful for other purposes. + + Args: + lm: + Language model part of un-normalized logprobs of symbols, to be added to + acoustic model part before normalizing. Of shape: + [B][S+1][C] + where B is the batch size, S is the maximum sequence length of + the symbol sequence, possibly including the EOS symbol; and + C is size of the symbol vocabulary, including the termination/next-frame + symbol. + Conceptually, lm[b][s] is a vector of length [C] representing the + "language model" part of the un-normalized logprobs of symbols, + given all symbols *earlier than* s in the sequence. The reason + we still need this for position S is that we may still be emitting + the termination/next-frame symbol at this point. + am: + Acoustic-model part of un-normalized logprobs of symbols, to be added + to language-model part before normalizing. Of shape: + [B][T][C] + where B is the batch size, T is the maximum sequence length of + the acoustic sequences (in frames); and C is size of the symbol + vocabulary, including the termination/next-frame symbol. It reflects + the "acoustic" part of the probability of any given symbol appearing + next on this frame. + symbols: + A LongTensor of shape [B][S], containing the symbols at each position + of the sequence, possibly including EOS + termination_symbol: + The identity of the termination symbol, must be in {0..C-1} + lm_only_scale: + the scale on the "LM-only" part of the loss. + am_only_scale: + the scale on the "AM-only" part of the loss, for which we use + an "averaged" LM (averaged over all histories, so effectively unigram). + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + Returns: + (px, py) (the names are quite arbitrary). + px: logprobs, of shape [B][S][T+1] + py: logprobs, of shape [B][S+1][T] + in the recursion: + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences + of length s and t respectively. px[b][s][t] represents the + probability of extending the subsequences of length (s,t) by one in + the s direction, given the particular symbol, and py[b][s][t] + represents the probability of extending the subsequences of length + (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame + we cannot emit any symbols. This is simply a way of incorporating + the probability of the termination symbol on the last frame. + """ + assert lm.ndim == 3 + assert am.ndim == 3 + assert lm.shape[0] == am.shape[0] + assert lm.shape[2] == am.shape[2] + (B, T, C) = am.shape + S = lm.shape[1] - 1 + assert symbols.shape == (B, S) + + # Caution: some parts of this code are a little less clear than they could + # be due to optimizations. In particular it may not be totally obvious that + # all of the logprobs here are properly normalized. We test that + # this code is invariant to adding constants in the appropriate ways. + + # subtracting am_max and lm_max is to ensure the probs are in a good range + # to do exp() without causing underflow or overflow. + am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1] + lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1] + am_probs = (am - am_max).exp() # [B][T][C] + lm_probs = (lm - lm_max).exp() # [B][S+1][C] + # normalizers: [B][S+1][T] + normalizers = ( + torch.matmul(lm_probs, am_probs.transpose(1, 2)) + + torch.finfo(lm_probs.dtype).tiny + ).log() + + # normalizer per frame, if we take only the LM probs by themselves + lmonly_normalizers = lm_probs.sum( + dim=2, keepdim=True + ) # lmonly_normalizers: [B][S+1][1] + unigram_lm = ( + torch.mean(lm_probs / lmonly_normalizers, dim=(0, 1), keepdim=True) + + torch.finfo(lm_probs.dtype).tiny + ) # [1][1][C] + amonly_normalizers = ( + torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)) + .reshape(B, T, 1) + .log() + + am_max + ) # [B][T][1] + amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T] + unigram_lm = unigram_lm.log() + lmonly_normalizers = ( + lmonly_normalizers.log() + lm_max + ) # [B][S+1][1], log-normalizer, used for LM-only part of prob. + + # add lm_max and am_max to normalizers, to make it as if we had not + # subtracted am_max and lm_max above. + normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T] + + # px is the probs of the actual symbols (not yet normalized).. + px_am = torch.gather( + am.unsqueeze(1).expand(B, S, T, C), + dim=3, + index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1), + ).squeeze( + -1 + ) # [B][S][T] + px_am = torch.cat( + ( + px_am, + torch.full( + (B, S, 1), float("-inf"), device=px_am.device, dtype=px_am.dtype + ), + ), + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. + + if boundary is not None: + assert boundary.shape == (B, 4) + mask = ( + torch.arange(0, T + 1, device=px_am.device) + .reshape(1, T + 1) + .expand(B, T + 1) + ) + mask = mask < boundary[:, 3].reshape(B, 1) + mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) + px_am = torch.where( + mask, + px_am, + torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device), + ) + + px_lm = torch.gather( + lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + px_lm_unigram = torch.gather( + unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + + px = px_am + px_lm # [B][S][T+1], last slice indexed [:,:,T] is -inf + px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] + + px_amonly = px_am + px_lm_unigram # [B][S][T+1] + px_amonly[:, :, :T] -= amonly_normalizers + px_lmonly = px_lm - lmonly_normalizers[:, :S, :] + + # py is the probs of termination symbols, of shape [B][S+1][T] + py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T] + py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] + py = py_am + py_lm - normalizers + + py_lm_unigram = unigram_lm[0][0][termination_symbol] # scalar, normalized.. + py_amonly = py_am + py_lm_unigram - amonly_normalizers # [B][S+1][T] + py_lmonly = py_lm - lmonly_normalizers # [B][S+1][T] + + combined_scale = 1.0 - lm_only_scale - am_only_scale + + # We need to avoid exact zeros in the scales because otherwise multiplying + # -inf by zero generates nan. + if lm_only_scale == 0.0: + lm_only_scale = 1.0e-20 + if am_only_scale == 0.0: + am_only_scale = 1.0e-20 + + px_interp = ( + px * combined_scale + + px_lmonly * lm_only_scale + + px_amonly * am_only_scale + ) + py_interp = ( + py * combined_scale + + py_lmonly * lm_only_scale + + py_amonly * am_only_scale + ) + + return (px_interp, py_interp) + + +def rnnt_loss_smoothed( + lm: Tensor, + am: Tensor, + symbols: Tensor, + termination_symbol: int, + lm_only_scale: float = 0.1, + am_only_scale: float = 0.1, + boundary: Optional[Tensor] = None, + return_grad: bool = False, +) -> Tensor: + """A simple case of the RNN-T loss, where the 'joiner' network is just + addition. Returns negated total loss value. + + Args: + lm: + language-model part of unnormalized log-probs of symbols, with shape + (B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes. + These are assumed to be well-normalized, in the sense that we could + use them as probabilities separately from the am scores + am: + acoustic-model part of unnormalized log-probs of symbols, with shape + (B, T, C), i.e. batch, frame, num_classes + symbols: + the symbol sequences, a LongTensor of shape [B][S], and elements in + {0..C-1}. + termination_symbol: + the termination symbol, with 0 <= termination_symbol < C + lm_only_scale: + the scale on the "LM-only" part of the loss. + am_only_scale: + the scale on the "AM-only" part of the loss, for which we use + an "averaged" LM (averaged over all histories, so effectively unigram). + boundary: + a LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + return_grad: + Whether to return grads of px and py, this grad standing for the + occupation probability is the output of the backward with a + `fake gradient` input (all ones) This is useful to implement the + pruned version of rnnt loss. + + Returns: + If return_grad is False, returns a Tensor of shape (B,), containing the + NEGATED total RNN-T loss values for each element of the batch + (like log-probs of sequences). + If return_grad is True, the grads of px and py, which is the output of + backward with a `fake gradient` input, will be returned too. And the + returned value will be a tuple like (loss, (px_grad, py_grad)). + """ + px, py = get_rnnt_logprobs_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale, + am_only_scale, + boundary, + ) + return mutual_information_recursion(px, py, boundary, return_grad) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 56fb59114..0cb2c45fe 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -54,6 +54,7 @@ set(py_test_files linear_fsa_test.py linear_fst_test.py multi_gpu_test.py + mutual_information_test.py nbest_test.py numerical_gradient_check_test.py ragged_ops_test.py @@ -63,6 +64,7 @@ set(py_test_files random_paths_test.py remove_epsilon_self_loops_test.py remove_epsilon_test.py + rnnt_loss_test.py shortest_path_test.py sparse_abs_test.py symbol_table_test.py diff --git a/k2/python/tests/mutual_information_test.py b/k2/python/tests/mutual_information_test.py new file mode 100644 index 000000000..fc48cd1da --- /dev/null +++ b/k2/python/tests/mutual_information_test.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey, +# Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R mutual_information_test_py + +import random +import unittest + +import k2 +import torch + + +# Caution: this will fail occasionally due to cutoffs not being quite large +# enough. As long as it passes most of the time, it's OK. +class TestMutualInformation(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + cls.dtypes = [torch.float32, torch.float64] + + def test_mutual_information_basic(self): + for _iter in range(100): + (B, S, T) = (random.randint(1, 10), random.randint(1, 16), + random.randint(1, 500)) + random_px = (random.random() < 0.2) + random_py = (random.random() < 0.2) + random_boundary = (random.random() < 0.7) + big_px = (random.random() < 0.2) + big_py = (random.random() < 0.2) + + for dtype in self.dtypes: + for device in self.devices: + if random_boundary: + + def get_boundary_row(): + s_begin = random.randint(0, S - 1) + t_begin = random.randint(0, T - 1) + # allow empty sequence + s_end = random.randint(s_begin, S) + # allow empty sequence + t_end = random.randint(t_begin, T) + return [s_begin, t_begin, s_end, t_end] + + if device == torch.device('cpu'): + boundary = torch.tensor( + [get_boundary_row() for _ in range(B)], + dtype=torch.int64, + device=device) + else: + boundary = boundary.to(device) + else: + # Use default boundary, but either specified directly + # or not. + if random.random() < 0.5: + boundary = torch.tensor( + [0, 0, S, T], + dtype=torch.int64).unsqueeze(0).expand( + B, 4).to(device) + else: + boundary = None + + if device == torch.device('cpu'): + if random_px: + # log of an odds ratio + px = torch.randn(B, S, T + 1, + dtype=dtype).to(device) + if S > 1 and not random_boundary: + px[:, :, -1:] = float('-inf') + else: + # log of an odds ratio + px = torch.zeros(B, S, T + 1, + dtype=dtype).to(device) + # px and py get exponentiated, and then multiplied + # together up to 32 times (BLOCK_SIZE in the CUDA code), + # so 15 is actually a big number that could lead to + # overflow. + if big_px: + px += 15.0 + if random_py: + # log of an odds ratio + py = torch.randn(B, S + 1, T, + dtype=dtype).to(device) + else: + # log of an odds ratio + py = torch.zeros(B, S + 1, T, + dtype=dtype).to(device) + if big_py: + py += 15.0 + + else: + px = px.to(device).detach() + py = py.to(device).detach() + px.requires_grad = True + py.requires_grad = True + + m = k2.mutual_information_recursion(px, py, boundary) + + m2 = k2.joint_mutual_information_recursion((px,), (py,), + boundary) + + m3 = k2.joint_mutual_information_recursion( + (px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary) + + # it is supposed to be identical only after + # summing over dim 0, corresponding to the + # sequence dim + m3 = m3.sum(dim=0) + + assert torch.allclose(m, m2) + assert torch.allclose(m, m3) + + # the loop this is in checks that the CPU and CUDA versions + # give the same derivative; + # by randomizing which of m, m2 or m3 we backprop, we also + # ensure that the joint version of the code gives the same + # derivative as the regular version + scale = 3 + if random.random() < 0.5: + (m.sum() * scale).backward() + elif random.random() < 0.5: + (m2.sum() * scale).backward() + else: + (m3.sum() * scale).backward() + + if device == torch.device("cpu"): + expected_px_grad = px.grad + expected_py_grad = py.grad + expected_m = m + assert torch.allclose(px.grad, + expected_px_grad.to(device), + atol=1.0e-02, + rtol=1.0e-02) + assert torch.allclose(py.grad, + expected_py_grad.to(device), + atol=1.0e-02, + rtol=1.0e-02) + assert torch.allclose(m, + expected_m.to(device), + atol=1.0e-02, + rtol=1.0e-02) + + def test_mutual_information_deriv(self): + for _iter in range(100): + (B, S, T) = (random.randint(1, 10), random.randint(1, 200), + random.randint(1, 200)) + random_px = (random.random() < 0.2) + random_py = (random.random() < 0.2) + random_boundary = (random.random() < 0.7) + big_px = (random.random() < 0.2) + big_py = (random.random() < 0.2) + + for dtype in self.dtypes: + for device in self.devices: + + if random_boundary: + + def get_boundary_row(): + s_begin = random.randint(0, S - 1) + t_begin = random.randint(0, T - 1) + s_end = random.randint(s_begin + 1, S) + t_end = random.randint(t_begin + 1, T) + return [s_begin, t_begin, s_end, t_end] + + if device == torch.device('cpu'): + boundary = torch.tensor( + [get_boundary_row() for _ in range(B)], + dtype=torch.int64, + device=device) + else: + boundary = boundary.to(device) + else: + # Use default boundary, but either specified directly + # or not. + if random.random() < 0.5: + boundary = torch.tensor( + [0, 0, S, T], + dtype=torch.int64).unsqueeze(0).expand( + B, 4).to(device) + else: + boundary = None + + if device == torch.device('cpu'): + if random_px: + # log of an odds ratio + px = torch.randn(B, S, T + 1, + dtype=dtype).to(device) + else: + # log of an odds ratio + px = torch.zeros(B, S, T + 1, + dtype=dtype).to(device) + # px and py get exponentiated, and then multiplied + # together up to 32 times (BLOCK_SIZE in the CUDA code), + # so 15 is actually a big number that could lead to + # overflow. + if big_px: + px += 15.0 + if random_py: + # log of an odds ratio + py = torch.randn(B, S + 1, T, + dtype=dtype).to(device) + else: + # log of an odds ratio + py = torch.zeros(B, S + 1, T, + dtype=dtype).to(device) + if big_py: + py += 15.0 + else: + px = px.to(device).detach() + py = py.to(device).detach() + px.requires_grad = True + py.requires_grad = True + + m = k2.mutual_information_recursion(px, py, boundary) + + m_grad = torch.randn(B, dtype=dtype, device=device) + m.backward(gradient=m_grad) + delta = 1.0e-04 + delta_px = delta * torch.randn_like(px) + m2 = k2.mutual_information_recursion( + px + delta_px, py, boundary) + delta_m = m2 - m + observed_delta = (delta_m * m_grad).sum().to('cpu') + predicted_delta = (delta_px * px.grad).sum().to('cpu') + + atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 + rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 + + assert torch.allclose(observed_delta, + predicted_delta, + atol=atol, + rtol=rtol) + + delta_py = delta * torch.randn_like(py) + m2 = k2.mutual_information_recursion( + px, py + delta_py, boundary) + delta_m = m2 - m + observed_delta = (delta_m * m_grad).sum().to('cpu') + predicted_delta = (delta_py * py.grad).sum().to('cpu') + + assert torch.allclose(observed_delta, + predicted_delta, + atol=atol, + rtol=rtol) + + +if __name__ == "__main__": + unittest.main() diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py new file mode 100644 index 000000000..16cc7875f --- /dev/null +++ b/k2/python/tests/rnnt_loss_test.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey, +# Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R rnnt_loss_test_py + +import unittest + +import k2 +import random +import torch + + +class TestRnntLoss(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + try: + import torchaudio + import torchaudio.functional + + if hasattr(torchaudio.functional, "rnnt_loss"): + cls.has_torch_rnnt_loss = True + else: + cls.has_torch_rnnt_loss = False + print( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Skipping the tests of comparing rnnt loss with torch " + "one, to enable these tests please install a " + "version >= 0.10.0" + ) + except ImportError as e: + cls.has_torch_rnnt_loss = False + print( + f"Import torchaudio error, error message: {e}\n" + "Skipping the tests of comparing rnnt loss with torch " + "one, to enable these tests, please install torchaudio " + "with version >= 0.10.0" + ) + + def test_rnnt_loss_basic(self): + B = 1 + S = 3 + T = 4 + # C = 3 + for device in self.devices: + # lm: [B][S+1][C] + lm = torch.tensor( + [[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], + dtype=torch.float, + device=device, + ) + # am: [B][T][C] + am = torch.tensor( + [[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]], + dtype=torch.float, + device=device, + ) + termination_symbol = 2 + symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) + + px, py = k2.get_rnnt_logprobs(lm, am, symbols, termination_symbol) + assert px.shape == (B, S, T + 1) + assert py.shape == (B, S + 1, T) + assert symbols.shape == (B, S) + m = k2.mutual_information_recursion(px, py) + + if device == torch.device("cpu"): + expected = m + assert torch.allclose(m, expected.to(device)) + + # test rnnt_loss_simple + m = k2.rnnt_loss_simple(lm, am, symbols, termination_symbol, None) + assert torch.allclose(m, expected.to(device)) + + # test rnnt_loss_smoothed + m = k2.rnnt_loss_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.0, + boundary=None, + ) + assert torch.allclose(m, expected.to(device)) + + probs = am.unsqueeze(2) + lm.unsqueeze(1) + + # test rnnt_loss + m = k2.rnnt_loss(probs, symbols, termination_symbol, None) + assert torch.allclose(m, expected.to(device)) + + # compare with torchaudio rnnt_loss + if self.has_torch_rnnt_loss: + import torchaudio.functional + + m = torchaudio.functional.rnnt_loss( + logits=probs, + targets=symbols.int(), + logit_lengths=torch.tensor( + [T] * B, dtype=torch.int32, device=device + ), + target_lengths=torch.tensor( + [S] * B, dtype=torch.int32, device=device + ), + blank=termination_symbol, + reduction="none", + ) + assert torch.allclose(-m, expected.to(device)) + + # should be invariant to adding a constant for any frame. + lm += torch.randn(B, S + 1, 1, device=device) + am += torch.randn(B, T, 1, device=device) + + m = k2.rnnt_loss_simple(lm, am, symbols, termination_symbol, None) + assert torch.allclose(m, expected.to(device)) + + m = k2.rnnt_loss_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.0, + boundary=None, + ) + assert torch.allclose(m, expected.to(device)) + + probs = am.unsqueeze(2) + lm.unsqueeze(1) + m = k2.rnnt_loss(probs, symbols, termination_symbol, None) + assert torch.allclose(m, expected.to(device)) + + def test_rnnt_loss_random(self): + B = 5 + S = 20 + T = 300 + C = 100 + frames = torch.randint(S, T, (B,)) + seq_length = torch.randint(3, S - 1, (B,)) + T = torch.max(frames) + S = torch.max(seq_length) + + am_ = torch.randn((B, T, C), dtype=torch.float32) + lm_ = torch.randn((B, S + 1, C), dtype=torch.float32) + symbols_ = torch.randint(0, C - 1, (B, S)) + termination_symbol = C - 1 + + boundary_ = torch.zeros((B, 4), dtype=torch.int64) + boundary_[:, 2] = seq_length + boundary_[:, 3] = frames + + for device in self.devices: + + # lm: [B][S+1][C] + lm = lm_.to(device) + # am: [B][T][C] + am = am_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + + px, py = k2.get_rnnt_logprobs( + lm, am, symbols, termination_symbol, boundary + ) + assert px.shape == (B, S, T + 1) + assert py.shape == (B, S + 1, T) + assert symbols.shape == (B, S) + m = k2.mutual_information_recursion(px, py, boundary) + + if device == torch.device("cpu"): + expected = m + assert torch.allclose(m, expected.to(device)) + + m = k2.rnnt_loss_simple( + lm, am, symbols, termination_symbol, boundary + ) + assert torch.allclose(m, expected.to(device)) + + m = k2.rnnt_loss_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.0, + boundary=boundary, + ) + assert torch.allclose(m, expected.to(device)) + + probs = am.unsqueeze(2) + lm.unsqueeze(1) + m = k2.rnnt_loss(probs, symbols, termination_symbol, boundary) + assert torch.allclose(m, expected.to(device)) + + # compare with torchaudio rnnt_loss + if self.has_torch_rnnt_loss: + import torchaudio.functional + + m = torchaudio.functional.rnnt_loss( + logits=probs, + targets=symbols.int(), + logit_lengths=boundary[:, 3].int(), + target_lengths=boundary[:, 2].int(), + blank=termination_symbol, + reduction="none", + ) + assert torch.allclose(-m, expected.to(device)) + + # should be invariant to adding a constant for any frame. + lm += torch.randn(B, S + 1, 1, device=device) + am += torch.randn(B, T, 1, device=device) + + m = k2.rnnt_loss_simple( + lm, am, symbols, termination_symbol, boundary + ) + assert torch.allclose(m, expected.to(device)) + + probs = am.unsqueeze(2) + lm.unsqueeze(1) + m = k2.rnnt_loss(probs, symbols, termination_symbol, boundary) + assert torch.allclose(m, expected.to(device)) + + m = k2.rnnt_loss_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.0, + boundary=boundary, + ) + assert torch.allclose(m, expected.to(device)) + + def test_rnnt_loss_gradient(self): + if self.has_torch_rnnt_loss: + import torchaudio.functional + + B = 5 + S = 20 + T = 300 + C = 100 + frames = torch.randint(S, T, (B,)) + seq_length = torch.randint(3, S - 1, (B,)) + T = torch.max(frames) + S = torch.max(seq_length) + + am_ = torch.randn((B, T, C), dtype=torch.float32) + lm_ = torch.randn((B, S + 1, C), dtype=torch.float32) + symbols_ = torch.randint(0, C - 1, (B, S)) + termination_symbol = C - 1 + + boundary_ = torch.zeros((B, 4), dtype=torch.int64) + boundary_[:, 2] = seq_length + boundary_[:, 3] = frames + + for device in self.devices: + + # lm: [B][S+1][C] + lm = lm_.to(device) + # am: [B][T][C] + am = am_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + + logprobs = am.unsqueeze(2) + lm.unsqueeze(1) + logprobs.requires_grad_() + k2_loss = k2.rnnt_loss( + logprobs, symbols, termination_symbol, boundary + ) + k2_grad = torch.autograd.grad( + k2_loss, logprobs, -torch.ones_like(k2_loss) + ) + k2_grad = k2_grad[0] + + logprobs2 = logprobs.detach().clone().float() + logprobs2.requires_grad_() + torch_loss = torchaudio.functional.rnnt_loss( + logprobs2, + symbols.int(), + boundary[:, 3].int(), + boundary[:, 2].int(), + blank=termination_symbol, + reduction="none", + ) + torch_grad = torch.autograd.grad( + torch_loss, logprobs2, torch.ones_like(torch_loss) + ) + torch_grad = torch_grad[0] + + assert torch.allclose( + -k2_loss, torch_loss, atol=1e-2, rtol=1e-2 + ) + + assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2) + + def test_rnnt_loss_smoothed(self): + B = 1 + S = 3 + T = 4 + # C = 3 + for device in self.devices: + # lm: [B][S+1][C] + lm = torch.tensor( + [[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], + dtype=torch.float, + device=device, + ) + # am: [B][T][C] + am = torch.tensor( + [[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]], + dtype=torch.float, + device=device, + ) + + termination_symbol = 2 + symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) + + m = k2.rnnt_loss_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.333, + boundary=None, + ) + + if device == torch.device("cpu"): + expected = m + assert torch.allclose(m, expected.to(device)) + + # should be invariant to adding a constant for any frame. + lm += torch.randn(B, S + 1, 1, device=device) + am += torch.randn(B, T, 1, device=device) + + m = k2.rnnt_loss_smoothed( + lm, + am, + symbols, + termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.333, + boundary=None, + ) + assert torch.allclose(m, expected.to(device)) + + def test_rnnt_loss_pruned(self): + B = 4 + T = 300 + S = 50 + C = 10 + + frames = torch.randint(S, T, (B,)) + seq_length = torch.randint(3, S - 1, (B,)) + T = torch.max(frames) + S = torch.max(seq_length) + + am_ = torch.randn((B, T, C), dtype=torch.float64) + lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) + symbols_ = torch.randint(0, C - 1, (B, S)) + terminal_symbol = C - 1 + + boundary_ = torch.zeros((B, 4), dtype=torch.int64) + boundary_[:, 2] = seq_length + boundary_[:, 3] = frames + + for device in self.devices: + # normal rnnt + am = am_.to(device) + lm = lm_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + t_am = am.unsqueeze(2).float() + t_lm = lm.unsqueeze(1).float() + t_prob = t_am + t_lm + # nonlinear transform + t_prob = torch.sigmoid(t_prob) + k2_loss = k2.rnnt_loss(t_prob, symbols, terminal_symbol, boundary) + + print("unpruned rnnt loss: ", k2_loss) + + # pruning + k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( + lm, am, symbols, terminal_symbol, boundary, True + ) + + for r in range(2, 50, 5): + ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, r) + # (B, T, r, C) + am_p, lm_p = k2.do_rnnt_pruning(am, lm, ranges) + + t_prob_p = am_p + lm_p + + # nonlinear transform + t_prob_p = torch.sigmoid(t_prob_p) + + pruning_loss = k2.rnnt_loss_pruned( + t_prob_p, symbols, ranges, terminal_symbol, boundary + ) + print(f"pruning loss with range {r} : ", pruning_loss) + + +if __name__ == "__main__": + unittest.main() From d3fbb1b79ead4758fb4dc674ea5cd89f90f67411 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 25 Jan 2022 15:17:53 +0800 Subject: [PATCH 33/64] Use more efficient way to fix boundaries (#906) --- k2/python/k2/rnnt_loss.py | 85 ++++++++++++--------------------------- 1 file changed, 25 insertions(+), 60 deletions(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 9823fac87..ca4ffcf8f 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -23,6 +23,26 @@ from .mutual_information import mutual_information_recursion +def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: + """ + Insert -inf's into `px` in appropriate places if `boundary` is not + None. If boundary == None and modified == False, px[:,:,-1] will + be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] + to be -infinity. + Args: + px: a Tensor of of shape [B][S][T+1] (this function is only + called if modified == False, see other docs for `modified`) + px is modified in-place and returned. + boundary: None, or a Tensor of shape [B][3] containing + [s_begin, t_begin, s_end, t_end]; we need only t_end. + """ + if boundary is None: + return px + B, S, T1 = px.shape + boundary = boundary[:, 3].reshape(B, 1, 1).expand(B, S, T1) + return px.scatter_(dim=2, index=boundary, value=float("-inf")) + + def get_rnnt_logprobs( lm: Tensor, am: Tensor, @@ -135,21 +155,6 @@ def get_rnnt_logprobs( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px_am.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px_am = torch.where( - mask, - px_am, - torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device), - ) - px_lm = torch.gather( lm[:, :S], dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] @@ -163,6 +168,7 @@ def get_rnnt_logprobs( py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] py = py_am + py_lm - normalizers + px = fix_for_boundary(px, boundary) return (px, py) @@ -278,21 +284,6 @@ def get_rnnt_logprobs_joint( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px = torch.where( - mask, - px, - torch.tensor(float("-inf"), dtype=px.dtype, device=px.device), - ) - px[:, :, :T] -= normalizers[:, :S, :] py = ( @@ -302,6 +293,7 @@ def get_rnnt_logprobs_joint( px = px.contiguous() py = py.contiguous() + px = fix_for_boundary(px, boundary) return (px, py) @@ -660,21 +652,6 @@ def get_rnnt_logprobs_pruned( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px = torch.where( - mask, - px, - torch.tensor(float("-inf"), dtype=px.dtype, device=px.device), - ) - py = joint[:, :, :, termination_symbol] # (B, T, s_range) py = py - normalizers @@ -699,6 +676,8 @@ def get_rnnt_logprobs_pruned( px = px.contiguous() py = py.contiguous() + + px = fix_for_boundary(px, boundary) return (px, py) @@ -887,21 +866,6 @@ def get_rnnt_logprobs_smoothed( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px_am.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px_am = torch.where( - mask, - px_am, - torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device), - ) - px_lm = torch.gather( lm[:, :S], dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] @@ -945,6 +909,7 @@ def get_rnnt_logprobs_smoothed( + py_amonly * am_only_scale ) + px_interp = fix_for_boundary(px_interp, boundary) return (px_interp, py_interp) From 9a91ec66eb542de68ff873f1f0e5e58b6a247ef1 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 25 Jan 2022 15:26:27 +0800 Subject: [PATCH 34/64] Release v1.12 (#907) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f90d1d10..aa58d1b37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.11") +set(K2_VERSION "1.12") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) From 3367c7f1a1dbf69a6d493221c3ca91d211018f9e Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sat, 29 Jan 2022 15:35:49 +0800 Subject: [PATCH 35/64] Change the sign of the rnnt_loss and add reduction argument (#911) * Add right boundary constrains for s_begin * Minor fixes to the interface of rnnt_loss to make it return positive value * Fix comments * Release a new version * Minor fixes * Minor fixes to the docs --- CMakeLists.txt | 2 +- k2/python/k2/rnnt_loss.py | 227 +++++++++++++++++++++--------- k2/python/tests/rnnt_loss_test.py | 200 +++++++++++++++++--------- 3 files changed, 299 insertions(+), 130 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index aa58d1b37..386e412d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.12") +set(K2_VERSION "1.13") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index ca4ffcf8f..2150d4ed2 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -26,15 +26,14 @@ def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: """ Insert -inf's into `px` in appropriate places if `boundary` is not - None. If boundary == None and modified == False, px[:,:,-1] will - be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] + None. If boundary == None, px[:,:,-1] will be -infinity, + but if boundary is specified, we need px[b,:,boundary[b,3]] to be -infinity. Args: - px: a Tensor of of shape [B][S][T+1] (this function is only - called if modified == False, see other docs for `modified`) - px is modified in-place and returned. - boundary: None, or a Tensor of shape [B][3] containing - [s_begin, t_begin, s_end, t_end]; we need only t_end. + px: a Tensor of of shape [B][S][T+1], px is modified in-place + and returned. + boundary: None, or a Tensor of shape [B][3] containing + [s_begin, t_begin, s_end, t_end]; we need only t_end. """ if boundary is None: return px @@ -82,7 +81,7 @@ def get_rnnt_logprobs( next on this frame. symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. termination_symbol: The identity of the termination symbol, must be in {0..C-1} boundary: @@ -178,10 +177,11 @@ def rnnt_loss_simple( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: """A simple case of the RNN-T loss, where the 'joiner' network is just - addition. Returns negated total loss value. + addition. Args: lm: @@ -201,25 +201,53 @@ def rnnt_loss_simple( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` return_grad: Whether to return grads of px and py, this grad standing for the occupation probability is the output of the backward with a - `fake gradient` input (all ones) This is useful to implement the - pruned version of rnnt loss. + `fake gradient`, the `fake gradient` is the same as the gradient you'd + get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the + loss here is the loss with reduction "none". + This is useful to implement the pruned version of rnnt loss. Returns: - If return_grad is False, returns a Tensor of shape (B,), containing the - NEGATED total RNN-T loss values for each element of the batch - (like log-probs of sequences). + If return_grad is False, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch if reduction equals + to "none", otherwise a scalar with the reduction applied. If return_grad is True, the grads of px and py, which is the output of - backward with a `fake gradient` input, will be returned too. And the + backward with a `fake gradient`(see above), will be returned too. And the returned value will be a tuple like (loss, (px_grad, py_grad)). """ - px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol, boundary) - return mutual_information_recursion(px, py, boundary, return_grad) + px, py = get_rnnt_logprobs( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + ) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad + ) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads + if reduction == "none": + loss = -negated_loss + elif reduction == "mean": + loss = -torch.mean(negated_loss) + elif reduction == "sum": + loss = -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + return (loss, scores_and_grads[1]) if return_grad else loss def get_rnnt_logprobs_joint( - joint: Tensor, + logits: Tensor, symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, @@ -229,12 +257,12 @@ def get_rnnt_logprobs_joint( This function is called from rnnt_loss(). Args: - joint: + logits: The output of joiner network, with shape (B, T, S + 1, C), i.e. batch, time_seq_len, symbol_seq_len+1, num_classes symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. termination_symbol: The identity of the termination symbol, must be in {0..C-1} boundary: @@ -262,16 +290,16 @@ def get_rnnt_logprobs_joint( we cannot emit any symbols. This is simply a way of incorporating the probability of the termination symbol on the last frame. """ - assert joint.ndim == 4 - (B, T, S1, C) = joint.shape + assert logits.ndim == 4 + (B, T, S1, C) = logits.shape S = S1 - 1 assert symbols.shape == (B, S) - normalizers = torch.logsumexp(joint, dim=3) + normalizers = torch.logsumexp(logits, dim=3) normalizers = normalizers.permute((0, 2, 1)) px = torch.gather( - joint, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) + logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) ).squeeze(-1) px = px.permute((0, 2, 1)) px = torch.cat( @@ -287,7 +315,7 @@ def get_rnnt_logprobs_joint( px[:, :, :T] -= normalizers[:, :S, :] py = ( - joint[:, :, :, termination_symbol].permute((0, 2, 1)).clone() + logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() ) # [B][S+1][T] py -= normalizers px = px.contiguous() @@ -298,16 +326,17 @@ def get_rnnt_logprobs_joint( def rnnt_loss( - joint: Tensor, + logits: Tensor, symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, i.e. a 4 dimensions tensor. Args: - joint: + logits: The output of joiner network, with shape (B, T, S + 1, C), i.e. batch, time_seq_len, symbol_seq_len+1, num_classes symbols: @@ -320,15 +349,35 @@ def rnnt_loss( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` Returns: - A Tensor of shape (B,), containing the total RNN-T loss values for each - element of the batch (like log-probs of sequences). + If recursion is `none`, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch, otherwise a scalar + with the reduction applied. """ px, py = get_rnnt_logprobs_joint( - joint, symbols, termination_symbol, boundary + logits=logits, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) - return mutual_information_recursion(px, py, boundary) + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + if reduction == "none": + return -negated_loss + elif reduction == "mean": + return -torch.mean(negated_loss) + elif reduction == "sum": + return -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" def _adjust_pruning_lower_bound( @@ -464,10 +513,13 @@ def get_rnnt_prune_ranges( s_begin = torch.argmax(diff_grad, dim=1) s_begin = s_begin[:, :T] - # handle the values of s_begin in padding positions. - # set the s_begin in paddding positions to `len(symbols) - s_range + 1` + # Handle the values of s_begin in padding positions. + # -1 here means we fill the position of the last frame of real data with + # padding value which is `len(symbols) - s_range + 1`. + # This is to guarantee that we reach the last symbol at last frame of real + # data. mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T) - mask = mask < boundary[:, 3].reshape(B, 1) + mask = mask < boundary[:, 3].reshape(B, 1) - 1 s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 # handle the cases when `len(symbols) < s_range` @@ -561,7 +613,7 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): def get_rnnt_logprobs_pruned( - joint: Tensor, + logits: Tensor, symbols: Tensor, ranges: Tensor, termination_symbol: int, @@ -570,7 +622,7 @@ def get_rnnt_logprobs_pruned( """Construct px, py for mutual_information_recursion with pruned output. Args: - joint: + logits: The pruned output of joiner network, with shape (B, T, s_range, C) symbols: The symbol sequences, a LongTensor of shape [B][S], and elements in @@ -589,15 +641,15 @@ def get_rnnt_logprobs_pruned( Return the px (B, S, T + 1) and py (B, S + 1, T) needed by mutual_information_recursion. """ - # joint (B, T, s_range, C) + # logits (B, T, s_range, C) # symbols (B, S) # ranges (B, T, s_range) - assert joint.ndim == 4 - (B, T, s_range, C) = joint.shape + assert logits.ndim == 4 + (B, T, s_range, C) = logits.shape assert ranges.shape == (B, T, s_range) (B, S) = symbols.shape - normalizers = torch.logsumexp(joint, dim=3) + normalizers = torch.logsumexp(logits, dim=3) symbols_with_terminal = torch.cat( ( @@ -620,7 +672,7 @@ def get_rnnt_logprobs_pruned( # (B, T, s_range) px = torch.gather( - joint, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) + logits, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) ).squeeze(-1) px = px - normalizers @@ -652,7 +704,7 @@ def get_rnnt_logprobs_pruned( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - py = joint[:, :, :, termination_symbol] # (B, T, s_range) + py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range) py = py - normalizers # (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf @@ -682,23 +734,24 @@ def get_rnnt_logprobs_pruned( def rnnt_loss_pruned( - joint: Tensor, + logits: Tensor, symbols: Tensor, ranges: Tensor, termination_symbol: int, boundary: Tensor = None, + reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), s_range means the symbols number kept for each frame. Args: - joint: + logits: The pruned output of joiner network, with shape (B, T, s_range, C), i.e. batch, time_seq_len, prune_range, num_classes symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. ranges: A tensor containing the symbol ids for each frame that we want to keep. termination_symbol: @@ -708,14 +761,35 @@ def rnnt_loss_pruned( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` Returns: - A Tensor of shape (B,), containing the total RNN-T loss values for each - element of the batch (like log-probs of sequences). + If recursion is `none`, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch, otherwise a scalar + with the reduction applied. """ px, py = get_rnnt_logprobs_pruned( - joint, symbols, ranges, termination_symbol, boundary + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=termination_symbol, + boundary=boundary, ) - return mutual_information_recursion(px, py, boundary) + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + if reduction == "none": + return -negated_loss + elif reduction == "mean": + return -torch.mean(negated_loss) + elif reduction == "sum": + return -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" def get_rnnt_logprobs_smoothed( @@ -765,7 +839,7 @@ def get_rnnt_logprobs_smoothed( next on this frame. symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. termination_symbol: The identity of the termination symbol, must be in {0..C-1} lm_only_scale: @@ -921,10 +995,11 @@ def rnnt_loss_smoothed( lm_only_scale: float = 0.1, am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, + reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Tensor: """A simple case of the RNN-T loss, where the 'joiner' network is just - addition. Returns negated total loss value. + addition. Args: lm: @@ -951,27 +1026,49 @@ def rnnt_loss_smoothed( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` return_grad: Whether to return grads of px and py, this grad standing for the occupation probability is the output of the backward with a - `fake gradient` input (all ones) This is useful to implement the - pruned version of rnnt loss. + `fake gradient`, the `fake gradient` is the same as the gradient you'd + get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the + loss here is the loss with reduction "none". + This is useful to implement the pruned version of rnnt loss. Returns: - If return_grad is False, returns a Tensor of shape (B,), containing the - NEGATED total RNN-T loss values for each element of the batch - (like log-probs of sequences). + If return_grad is False, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch if reduction equals + to "none", otherwise a scalar with the reduction applied. If return_grad is True, the grads of px and py, which is the output of - backward with a `fake gradient` input, will be returned too. And the + backward with a `fake gradient`(see above), will be returned too. And the returned value will be a tuple like (loss, (px_grad, py_grad)). """ px, py = get_rnnt_logprobs_smoothed( - lm, - am, - symbols, - termination_symbol, - lm_only_scale, - am_only_scale, - boundary, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + lm_only_scale=lm_only_scale, + am_only_scale=am_only_scale, + boundary=boundary, + ) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad ) - return mutual_information_recursion(px, py, boundary, return_grad) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads + if reduction == "none": + loss = -negated_loss + elif reduction == "mean": + loss = -torch.mean(negated_loss) + elif reduction == "sum": + loss = -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + return (loss, scores_and_grads[1]) if return_grad else loss diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py index 16cc7875f..d619591a8 100644 --- a/k2/python/tests/rnnt_loss_test.py +++ b/k2/python/tests/rnnt_loss_test.py @@ -81,36 +81,55 @@ def test_rnnt_loss_basic(self): termination_symbol = 2 symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) - px, py = k2.get_rnnt_logprobs(lm, am, symbols, termination_symbol) + px, py = k2.get_rnnt_logprobs( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + ) assert px.shape == (B, S, T + 1) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) - m = k2.mutual_information_recursion(px, py) + m = k2.mutual_information_recursion(px=px, py=py, boundary=None) if device == torch.device("cpu"): - expected = m - assert torch.allclose(m, expected.to(device)) + expected = -m + assert torch.allclose(-m, expected.to(device)) # test rnnt_loss_simple - m = k2.rnnt_loss_simple(lm, am, symbols, termination_symbol, None) + m = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) # test rnnt_loss_smoothed m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=None, + reduction="none", ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) # test rnnt_loss - m = k2.rnnt_loss(probs, symbols, termination_symbol, None) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) # compare with torchaudio rnnt_loss @@ -129,28 +148,42 @@ def test_rnnt_loss_basic(self): blank=termination_symbol, reduction="none", ) - assert torch.allclose(-m, expected.to(device)) + assert torch.allclose(m, expected.to(device)) # should be invariant to adding a constant for any frame. lm += torch.randn(B, S + 1, 1, device=device) am += torch.randn(B, T, 1, device=device) - m = k2.rnnt_loss_simple(lm, am, symbols, termination_symbol, None) + m = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=None, + reduction="none", ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss(probs, symbols, termination_symbol, None) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) def test_rnnt_loss_random(self): @@ -182,27 +215,35 @@ def test_rnnt_loss_random(self): boundary = boundary_.to(device) px, py = k2.get_rnnt_logprobs( - lm, am, symbols, termination_symbol, boundary + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) assert px.shape == (B, S, T + 1) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) - m = k2.mutual_information_recursion(px, py, boundary) + m = k2.mutual_information_recursion(px=px, py=py, boundary=boundary) if device == torch.device("cpu"): - expected = m - assert torch.allclose(m, expected.to(device)) + expected = -torch.mean(m) + assert torch.allclose(-torch.mean(m), expected.to(device)) m = k2.rnnt_loss_simple( - lm, am, symbols, termination_symbol, boundary + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=boundary, @@ -210,7 +251,12 @@ def test_rnnt_loss_random(self): assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss(probs, symbols, termination_symbol, boundary) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + ) assert torch.allclose(m, expected.to(device)) # compare with torchaudio rnnt_loss @@ -223,28 +269,36 @@ def test_rnnt_loss_random(self): logit_lengths=boundary[:, 3].int(), target_lengths=boundary[:, 2].int(), blank=termination_symbol, - reduction="none", ) - assert torch.allclose(-m, expected.to(device)) + assert torch.allclose(m, expected.to(device)) # should be invariant to adding a constant for any frame. lm += torch.randn(B, S + 1, 1, device=device) am += torch.randn(B, T, 1, device=device) m = k2.rnnt_loss_simple( - lm, am, symbols, termination_symbol, boundary + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss(probs, symbols, termination_symbol, boundary) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=boundary, @@ -285,31 +339,27 @@ def test_rnnt_loss_gradient(self): logprobs = am.unsqueeze(2) + lm.unsqueeze(1) logprobs.requires_grad_() k2_loss = k2.rnnt_loss( - logprobs, symbols, termination_symbol, boundary - ) - k2_grad = torch.autograd.grad( - k2_loss, logprobs, -torch.ones_like(k2_loss) + logits=logprobs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) + k2_grad = torch.autograd.grad(k2_loss, logprobs) k2_grad = k2_grad[0] logprobs2 = logprobs.detach().clone().float() logprobs2.requires_grad_() torch_loss = torchaudio.functional.rnnt_loss( - logprobs2, - symbols.int(), - boundary[:, 3].int(), - boundary[:, 2].int(), + logits=logprobs2, + targets=symbols.int(), + logit_lengths=boundary[:, 3].int(), + target_lengths=boundary[:, 2].int(), blank=termination_symbol, - reduction="none", - ) - torch_grad = torch.autograd.grad( - torch_loss, logprobs2, torch.ones_like(torch_loss) ) + torch_grad = torch.autograd.grad(torch_loss, logprobs2) torch_grad = torch_grad[0] - assert torch.allclose( - -k2_loss, torch_loss, atol=1e-2, rtol=1e-2 - ) + assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2) @@ -336,10 +386,10 @@ def test_rnnt_loss_smoothed(self): symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.333, boundary=None, @@ -354,10 +404,10 @@ def test_rnnt_loss_smoothed(self): am += torch.randn(B, T, 1, device=device) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.333, boundary=None, @@ -395,19 +445,36 @@ def test_rnnt_loss_pruned(self): t_prob = t_am + t_lm # nonlinear transform t_prob = torch.sigmoid(t_prob) - k2_loss = k2.rnnt_loss(t_prob, symbols, terminal_symbol, boundary) + k2_loss = k2.rnnt_loss( + logits=t_prob, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + reduction="none", + ) print("unpruned rnnt loss: ", k2_loss) # pruning k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( - lm, am, symbols, terminal_symbol, boundary, True + lm=lm, + am=am, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + return_grad=True, + reduction="none", ) for r in range(2, 50, 5): - ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, r) + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=r, + ) # (B, T, r, C) - am_p, lm_p = k2.do_rnnt_pruning(am, lm, ranges) + am_p, lm_p = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) t_prob_p = am_p + lm_p @@ -415,9 +482,14 @@ def test_rnnt_loss_pruned(self): t_prob_p = torch.sigmoid(t_prob_p) pruning_loss = k2.rnnt_loss_pruned( - t_prob_p, symbols, ranges, terminal_symbol, boundary + logits=t_prob_p, + symbols=symbols, + ranges=ranges, + termination_symbol=terminal_symbol, + boundary=boundary, + reduction="none", ) - print(f"pruning loss with range {r} : ", pruning_loss) + print(f"pruned loss with range {r} : ", pruning_loss) if __name__ == "__main__": From 779a9bda1bfcd842c96b80eed4bf3c1bd8ce15b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 29 Jan 2022 15:36:43 +0800 Subject: [PATCH 36/64] Fix building doc. (#908) * Fix building doc. * Minor fixes. * Minor fixes. --- docs/source/conf.py | 9 ++- k2/python/csrc/torch/v2/doc/any.h | 4 +- k2/python/k2/fsa_algo.py | 11 ++-- k2/python/k2/mutual_information.py | 84 +++++++++++++------------- k2/python/k2/nbest.py | 2 +- k2/python/k2/rnnt_loss.py | 94 ++++++++++++++++++------------ k2/python/k2/utils.py | 7 ++- 7 files changed, 123 insertions(+), 88 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7f7c8957f..389a00086 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = 'k2' -copyright = '2020-2021, k2 development team' +copyright = '2020-2022, k2 development team' author = 'k2 development team' @@ -147,7 +147,12 @@ def find_source(): # Replace key with value in the generated doc REPLACE_PATTERN = { - '_k2.ragged': 'k2.ragged', + # somehow it results in errors + # Handler for event + # 'autodoc-process-docstring' threw an exception (exception: + # is a built-in module) + # + # '_k2.ragged': 'k2.ragged', 'at::Tensor': 'torch.Tensor' } diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index 764cb8b04..b99ae4bbf 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -969,7 +969,7 @@ equivalent to the property ``dim0``. )doc"; static constexpr const char *kRaggedAnyGetStateDoc = R"doc( -__getstate__(self: _k2.ragged.Tensor) -> tuple +__getstate__(self: k2.RaggedTensor) -> tuple Requires a tensor with 2 axes or 3 axes. Other number of axes are not implemented yet. @@ -990,7 +990,7 @@ You are not expected to call it by yourself. )doc"; static constexpr const char *kRaggedAnySetStateDoc = R"doc( -__setstate__(self: _k2.ragged.Tensor, arg0: tuple) -> None +__setstate__(self: k2.RaggedTensor, arg0: tuple) -> None Set the content of this class from ``arg0``. diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 146e11cb9..5c608cb77 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -950,11 +950,12 @@ def replace_fsa( labels outside this range are just copied. Labels on final-arcs in `src` (Which will be -1) would be set to 0(epsilon) in the result fsa. - Caution: Attributes of the result inherits from `index` and `src` via - `arc_map_index` and `arc_map_src`, But if there are attributes - with same name, only the attributes with dtype `torch.float32` - are supported, the other kinds of attributes are discarded. - See docs in `fsa_from_binary_function_tensor` for details. + Caution: + Attributes of the result inherits from `index` and `src` via + `arc_map_index` and `arc_map_src`, But if there are attributes + with same name, only the attributes with dtype `torch.float32` + are supported, the other kinds of attributes are discarded. + See docs in `fsa_from_binary_function_tensor` for details. Args: src: diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 00123f806..79dc20b9d 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -90,74 +90,78 @@ def mutual_information_recursion( Args: px: - A torch.Tensor of some floating point type, with shape [B][S][T+1], - where B is the batch size, S is the length of the 'x' sequence - (including representations of EOS symbols but not BOS symbols), and - S is the length of the 'y' sequence (including representations of - EOS symbols but not BOS symbols). In the mutual information - application, px[b][s][t] would represent the following log odds + A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``, + where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence + (including representations of ``EOS`` symbols but not ``BOS`` symbols), + and ``S`` is the length of the ``y`` sequence (including representations + of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information + application, ``px[b][s][t]`` would represent the following log odds ratio; ignoring the b index on the right to make the notation more - compact, + compact:: px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] This expression also implicitly includes the log-probability of - choosing to generate an x value as opposed to a y value. In - practice it might be computed as a + b, where a is the log - probability of choosing to extend the sequence of length (s,t) - with an x as opposed to a y value; and b might in practice be - of the form: + choosing to generate an ``x`` value as opposed to a ``y`` value. In + practice it might be computed as ``a + b``, where ``a`` is the log + probability of choosing to extend the sequence of length ``(s,t)`` + with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice + be of the form:: + log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) - where N is the number of terms that the sum over t' included, which - might include some or all of the other sequences as well as this one. - Note: we don't require px and py to be contiguous, but the - code assumes for optimization purposes that the T axis has - stride 1. + where ``N`` is the number of terms that the sum over ``t'`` included, + which might include some or all of the other sequences as well as this + one. + + Note: + we don't require ``px`` and py to be contiguous, but the + code assumes for optimization purposes that the ``T`` axis has + stride 1. py: - A torch.Tensor of the same dtype as px, with shape [B][S+1][T], - representing + A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``, + representing:: py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] - This function does not treat x and y differently; the only difference - is that for optimization purposes we assume the last axis - (the t axis) has stride of 1; this is true if px and py are + This function does not treat ``x`` and ``y`` differently; the only + difference is that for optimization purposes we assume the last axis + (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are contiguous. boundary: - If supplied, a torch.LongTensor of shape [B][4], where each - row contains [s_begin, t_begin, s_end, t_end], - with 0 <= s_begin <= s_end < S and 0 <= t_begin <= t_end < T + If supplied, a torch.LongTensor of shape ``[B][4]``, where each + row contains ``[s_begin, t_begin, s_end, t_end]``, + with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T`` (this implies that empty sequences are allowed). - If not supplied, the values [0, 0, S, T] will be assumed. - These are the beginning and one-past-the-last positions in the x and - y sequences respectively, and can be used if not all sequences are + If not supplied, the values ``[0, 0, S, T]`` will be assumed. + These are the beginning and one-past-the-last positions in the ``x`` and + ``y`` sequences respectively, and can be used if not all sequences are of the same length. return_grad: - Whether to return grads of px and py, this grad standing for the + Whether to return grads of ``px`` and ``py``, this grad standing for the occupation probability is the output of the backward with a - `fake gradient` input (all ones) This is useful to implement the + ``fake gradient`` input (all ones) This is useful to implement the pruned version of rnnt loss. Returns: - Returns a torch.Tensor of shape [B], containing the log of the mutual - information between the b'th pair of sequences. This is defined by - the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]), - representing a mutual information between sub-sequences of lengths s - and t: + Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual + information between the b'th pair of sequences. This is defined by + the following recursion on ``p[b,s,t]`` (where ``p`` is of shape + ``[B,S+1,T+1]``), representing a mutual information between sub-sequences + of lengths ``s`` and ``t``:: p[b,0,0] = 0.0 p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) (if s > 0 or t > 0) - where we handle edge cases by treating quantities with negative indexes - as -infinity. The extension to cases where the boundaries are specified - should be obvious; it just works on shorter sequences with offsets into - px and py. + where we handle edge cases by treating quantities with negative indexes + as **-infinity**. The extension to cases where the boundaries are + specified should be obvious; it just works on shorter sequences with offsets + into ``px`` and ``py``. """ assert px.ndim == 3 B, S, T1 = px.shape @@ -227,7 +231,7 @@ def joint_mutual_information_recursion( The recursion below applies if boundary == None, when it defaults to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px - and py: + and py:: p = tensor of shape (B, S+1, T+1), containing -infinity p[b,0,0] = 0.0 diff --git a/k2/python/k2/nbest.py b/k2/python/k2/nbest.py index c245f75e7..cb109694e 100644 --- a/k2/python/k2/nbest.py +++ b/k2/python/k2/nbest.py @@ -62,7 +62,7 @@ def intersect(self, lats: Fsa) -> 'Nbest': Returns: Return a new Nbest. This new Nbest shares the same shape with `self`, while its `fsa` is the 1-best path from intersecting `self.fsa` and - `lats. + `lats`. ''' assert self.fsa.device == lats.device, \ f'{self.fsa.device} vs {lats.device}' diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 2150d4ed2..6805ae2a4 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -59,8 +59,10 @@ def get_rnnt_logprobs( Args: lm: Language model part of un-normalized logprobs of symbols, to be added to - acoustic model part before normalizing. Of shape: + acoustic model part before normalizing. Of shape:: + [B][S+1][C] + where B is the batch size, S is the maximum sequence length of the symbol sequence, possibly including the EOS symbol; and C is size of the symbol vocabulary, including the termination/next-frame @@ -72,8 +74,10 @@ def get_rnnt_logprobs( the termination/next-frame symbol at this point. am: Acoustic-model part of un-normalized logprobs of symbols, to be added - to language-model part before normalizing. Of shape: + to language-model part before normalizing. Of shape:: + [B][T][C] + where B is the batch size, T is the maximum sequence length of the acoustic sequences (in frames); and C is size of the symbol vocabulary, including the termination/next-frame symbol. It reflects @@ -91,24 +95,28 @@ def get_rnnt_logprobs( if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. Returns: - (px, py) (the names are quite arbitrary). + (px, py) (the names are quite arbitrary):: + px: logprobs, of shape [B][S][T+1] py: logprobs, of shape [B][S+1][T] - in the recursion: + + in the recursion:: + p[b,0,0] = 0.0 p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) - .. where p[b][s][t] is the "joint score" of the pair of subsequences - of length s and t respectively. px[b][s][t] represents the - probability of extending the subsequences of length (s,t) by one in - the s direction, given the particular symbol, and py[b][s][t] - represents the probability of extending the subsequences of length - (s,t) by one in the t direction, - i.e. of emitting the termination/next-frame symbol. - - px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame - we cannot emit any symbols. This is simply a way of incorporating - the probability of the termination symbol on the last frame. + + where p[b][s][t] is the "joint score" of the pair of subsequences + of length s and t respectively. px[b][s][t] represents the + probability of extending the subsequences of length (s,t) by one in + the s direction, given the particular symbol, and py[b][s][t] + represents the probability of extending the subsequences of length + (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame + we cannot emit any symbols. This is simply a way of incorporating + the probability of the termination symbol on the last frame. """ assert lm.ndim == 3 assert am.ndim == 3 @@ -272,14 +280,18 @@ def get_rnnt_logprobs_joint( if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. Returns: - (px, py) (the names are quite arbitrary). + (px, py) (the names are quite arbitrary):: + px: logprobs, of shape [B][S][T+1] py: logprobs, of shape [B][S+1][T] - in the recursion: + + in the recursion:: + p[b,0,0] = 0.0 p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) - .. where p[b][s][t] is the "joint score" of the pair of subsequences of + + where p[b][s][t] is the "joint score" of the pair of subsequences of length s and t respectively. px[b][s][t] represents the probability of extending the subsequences of length (s,t) by one in the s direction, given the particular symbol, and py[b][s][t] represents the probability @@ -447,7 +459,7 @@ def get_rnnt_prune_ranges( (T, s_range) containing the information that which symbols will be token into consideration for each frame. For example, here is a sequence with 10 frames and the corresponding symbols are `[A B C D E F]`, if the s_range - equals 3, one possible ranges tensor will be: + equals 3, one possible ranges tensor will be:: [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [1, 2, 3], [1, 2, 3], [1, 2, 3], [3, 4, 5], [3, 4, 5], [3, 4, 5]] @@ -804,21 +816,25 @@ def get_rnnt_logprobs_smoothed( """Reduces RNN-T problem (the simple case, where joiner network is just addition), to a compact, standard form that can then be given (with boundaries) to mutual_information_recursion(). - This version allows you to make the loss-function one of the form: + This version allows you to make the loss-function one of the form:: + lm_only_scale * lm_probs + am_only_scale * am_probs + (1-lm_only_scale-am_only_scale) * combined_probs + where lm_probs and am_probs are the probabilities given the lm and acoustic model independently. This function is called from - rnnt_loss_smoothed(), but may be useful for other purposes. + :func:`rnnt_loss_smoothed`, but may be useful for other purposes. Args: lm: Language model part of un-normalized logprobs of symbols, to be added to - acoustic model part before normalizing. Of shape: + acoustic model part before normalizing. Of shape:: + [B][S+1][C] + where B is the batch size, S is the maximum sequence length of the symbol sequence, possibly including the EOS symbol; and C is size of the symbol vocabulary, including the termination/next-frame @@ -830,8 +846,10 @@ def get_rnnt_logprobs_smoothed( the termination/next-frame symbol at this point. am: Acoustic-model part of un-normalized logprobs of symbols, to be added - to language-model part before normalizing. Of shape: + to language-model part before normalizing. Of shape:: + [B][T][C] + where B is the batch size, T is the maximum sequence length of the acoustic sequences (in frames); and C is size of the symbol vocabulary, including the termination/next-frame symbol. It reflects @@ -854,24 +872,28 @@ def get_rnnt_logprobs_smoothed( if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. Returns: - (px, py) (the names are quite arbitrary). + (px, py) (the names are quite arbitrary):: + px: logprobs, of shape [B][S][T+1] py: logprobs, of shape [B][S+1][T] - in the recursion: + + in the recursion:: + p[b,0,0] = 0.0 p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) - .. where p[b][s][t] is the "joint score" of the pair of subsequences - of length s and t respectively. px[b][s][t] represents the - probability of extending the subsequences of length (s,t) by one in - the s direction, given the particular symbol, and py[b][s][t] - represents the probability of extending the subsequences of length - (s,t) by one in the t direction, - i.e. of emitting the termination/next-frame symbol. - - px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame - we cannot emit any symbols. This is simply a way of incorporating - the probability of the termination symbol on the last frame. + + where p[b][s][t] is the "joint score" of the pair of subsequences + of length s and t respectively. px[b][s][t] represents the + probability of extending the subsequences of length (s,t) by one in + the s direction, given the particular symbol, and py[b][s][t] + represents the probability of extending the subsequences of length + (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame + we cannot emit any symbols. This is simply a way of incorporating + the probability of the termination symbol on the last frame. """ assert lm.ndim == 3 assert am.ndim == 3 diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index abc28c21d..aec54dd02 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -715,9 +715,12 @@ def get_best_matching_stats( a collection of key and query sequences. If 3 axes, this represents a set of such collections. - 2-axis example: + 2-axis example:: + [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ] - 3-axis example: + + 3-axis example:: + [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ], [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ] From 47c4b754bb418b2a40c3ee0f24ca5ed12b08997f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sat, 29 Jan 2022 17:39:32 +0800 Subject: [PATCH 37/64] Fix building doc (#912) * Fix building doc * Fix flake8 --- docs/requirements.txt | 14 +++++++------- k2/python/k2/mutual_information.py | 13 +++++++------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b426623eb..86d0a4cb8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,8 +1,8 @@ -dataclasses -graphviz -recommonmark -sphinx -sphinx-autodoc-typehints -sphinx_rtd_theme -sphinxcontrib-bibtex +dataclasses==0.6 +graphviz==0.19.1 +recommonmark==0.7.1 +sphinx==4.3.2 +sphinx-autodoc-typehints==1.12.0 +sphinx_rtd_theme==1.0.0 +sphinxcontrib-bibtex==2.4.1 torch>=1.6.0 diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 79dc20b9d..6a61d8f1f 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -143,8 +143,9 @@ def mutual_information_recursion( return_grad: Whether to return grads of ``px`` and ``py``, this grad standing for the occupation probability is the output of the backward with a - ``fake gradient`` input (all ones) This is useful to implement the - pruned version of rnnt loss. + ``fake gradient`` the ``fake gradient`` is the same as the gradient + you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``. + This is useful to implement the pruned version of rnnt loss. Returns: Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual @@ -160,8 +161,8 @@ def mutual_information_recursion( where we handle edge cases by treating quantities with negative indexes as **-infinity**. The extension to cases where the boundaries are - specified should be obvious; it just works on shorter sequences with offsets - into ``px`` and ``py``. + specified should be obvious; it just works on shorter sequences with + offsets into ``px`` and ``py``. """ assert px.ndim == 3 B, S, T1 = px.shape @@ -179,10 +180,10 @@ def mutual_information_recursion( assert px.is_contiguous() assert py.is_contiguous() - m, px_grad, py_grad = MutualInformationRecursionFunction.apply( + scores, px_grad, py_grad = MutualInformationRecursionFunction.apply( px, py, boundary, return_grad ) - return (m, (px_grad, py_grad)) if return_grad else m + return (scores, (px_grad, py_grad)) if return_grad else scores def _inner_product(a: Tensor, b: Tensor) -> Tensor: From cf32e2d854101e9727f8efc4891185d9359535ef Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Feb 2022 08:00:45 +0800 Subject: [PATCH 38/64] Support torch 1.10.x (#914) * Support torch 1.10.x * Fix installing PyTorch. --- .github/workflows/build-cpu.yml | 4 ++-- .github/workflows/build.yml | 22 +++++++++++++++------- .github/workflows/build_conda.yml | 20 ++++++++++++++------ .github/workflows/build_conda_cpu.yml | 2 +- .github/workflows/nightly-cpu.yml | 2 +- scripts/github_actions/install_torch.sh | 2 +- 6 files changed, 34 insertions(+), 18 deletions(-) diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 90e7f6114..4283d8dc4 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -36,8 +36,8 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04, macos-10.15] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10 + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] + # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x python-version: [3.6, 3.7, 3.8, 3.9] exclude: - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6f86b6844..769bdf05a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,7 +39,7 @@ jobs: # from https://download.pytorch.org/whl/torch_stable.html # Note: There are no torch versions for CUDA 11.2 # - # 1.10 supports: cuda10.2 (default), 11.1, 11.3 + # 1.10.x supports: cuda10.2 (default), 11.1, 11.3 # 1.9.x supports: cuda10.2 (default), 11.1 # PyTorch 1.8.x supports: cuda 10.1, 10.2 (default), 11.1 # PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0 @@ -49,9 +49,9 @@ jobs: # CUDA 11.3 is for torch 1.10 cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"] gcc: ["7"] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] # - # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10 + # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10.x python-version: [3.6, 3.7, 3.8, 3.9] exclude: - cuda: "11.3" # exclude 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1] @@ -72,7 +72,7 @@ jobs: torch: "1.9.0" - cuda: "11.3" torch: "1.9.1" - - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10] + - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] torch: "1.5.0" - cuda: "11.0" torch: "1.5.1" @@ -87,7 +87,11 @@ jobs: - cuda: "11.0" torch: "1.9.1" - cuda: "11.0" - torch: "1.10" + torch: "1.10.0" + - cuda: "11.0" + torch: "1.10.1" + - cuda: "11.0" + torch: "1.10.2" - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1] torch: "1.5.0" - cuda: "11.1" @@ -98,12 +102,16 @@ jobs: torch: "1.7.0" - cuda: "11.1" torch: "1.7.1" - - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10] + - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10.0, 10.1, 10.2] torch: "1.9.0" - cuda: "10.1" torch: "1.9.1" - cuda: "10.1" - torch: "1.10" + torch: "1.10.0" + - cuda: "10.1" + torch: "1.10.1" + - cuda: "10.1" + torch: "1.10.2" - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - python-version: 3.9 diff --git a/.github/workflows/build_conda.yml b/.github/workflows/build_conda.yml index 91e703f2b..74ad2238c 100644 --- a/.github/workflows/build_conda.yml +++ b/.github/workflows/build_conda.yml @@ -37,7 +37,7 @@ jobs: cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"] # from https://download.pytorch.org/whl/torch_stable.html # - # PyTorch 1.10 supports: 10.2 (default), 11.1, 11.3 + # PyTorch 1.10.x supports: 10.2 (default), 11.1, 11.3 # PyTorch 1.9.x supports: 10.2 (default), 11.1 # PyTorch 1.8.1 supports: cuda 10.1, 10.2 (default), 11.1 # PyTorch 1.8.0 supports: cuda 10.1, 10.2 (default), 11.1 @@ -56,9 +56,9 @@ jobs: # https://github.com/csukuangfj/k2/runs/2533830771?check_suite_focus=true # and # https://github.com/NVIDIA/apex/issues/805 - torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] exclude: - # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10] + # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] # torch: "1.5.0" # - cuda: "11.0" # torch: "1.5.1" @@ -73,7 +73,11 @@ jobs: - cuda: "11.0" torch: "1.9.1" - cuda: "11.0" - torch: "1.10" + torch: "1.10.0" + - cuda: "11.0" + torch: "1.10.1" + - cuda: "11.0" + torch: "1.10.2" # - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1] # torch: "1.5.0" # - cuda: "11.1" @@ -84,12 +88,16 @@ jobs: torch: "1.7.0" - cuda: "11.1" torch: "1.7.1" - - cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10] + - cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] torch: "1.9.0" - cuda: "10.1" torch: "1.9.1" - cuda: "10.1" - torch: "1.10" + torch: "1.10.0" + - cuda: "10.1" + torch: "1.10.1" + - cuda: "10.1" + torch: "1.10.2" - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - python-version: 3.9 diff --git a/.github/workflows/build_conda_cpu.yml b/.github/workflows/build_conda_cpu.yml index b7f2dcc4a..aec3e114d 100644 --- a/.github/workflows/build_conda_cpu.yml +++ b/.github/workflows/build_conda_cpu.yml @@ -51,7 +51,7 @@ jobs: # # Other PyTorch versions are not tested # - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] exclude: - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" diff --git a/.github/workflows/nightly-cpu.yml b/.github/workflows/nightly-cpu.yml index 249fd215d..af052ecdf 100644 --- a/.github/workflows/nightly-cpu.yml +++ b/.github/workflows/nightly-cpu.yml @@ -41,7 +41,7 @@ jobs: os: [ubuntu-18.04, macos-10.15] # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10 python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"] + torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] exclude: - python-version: 3.9 # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.4.0" diff --git a/scripts/github_actions/install_torch.sh b/scripts/github_actions/install_torch.sh index b0b822a13..8024729f1 100755 --- a/scripts/github_actions/install_torch.sh +++ b/scripts/github_actions/install_torch.sh @@ -91,7 +91,7 @@ case ${torch} in ;; esac ;; - 1.10) + 1.10.*) case ${cuda} in 10.2) package="torch==${torch}" From 9e7b2a995622d06b85c090705887468d7a8833fe Mon Sep 17 00:00:00 2001 From: alexei-v-ivanov Date: Mon, 7 Feb 2022 22:43:06 -0800 Subject: [PATCH 39/64] Update INSTALL.rst (#915) * Update INSTALL.rst Setting a few additional env variables to enable compilation from source *with CUDA GPU computation support enabled* --- INSTALL.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/INSTALL.rst b/INSTALL.rst index e4bbf7515..c94e2aebe 100644 --- a/INSTALL.rst +++ b/INSTALL.rst @@ -44,6 +44,18 @@ From source git clone https://github.com/k2-fsa/k2.git cd k2 python3 setup.py install + +From source (with CUDA support) +========================= + +.. code-block:: bash + + git clone https://github.com/k2-fsa/k2.git + cd k2 + export K2_CMAKE_ARGS="-DK2_WITH_CUDA=ON -DCMAKE_BUILD_TYPE=Release" + export LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/lib:$LD_LIBRARY_PATH + export PATH=$PATH:/usr/local/cuda/bin + python3 setup.py install Read ``_ to learn more From 43ed45012d8b1bf9a04aff4272f5994a45bb1f67 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 10 Feb 2022 12:55:46 +0800 Subject: [PATCH 40/64] Fix torch/cuda/python versions in the doc. (#918) * Fix torch/cuda/python versions in the doc. * Minor fixes. --- docs/source/installation/conda.rst | 6 ++--- docs/source/installation/from_source.rst | 6 ++--- docs/source/installation/images/README.md | 14 +++--------- .../cuda-10.1_10.2_11.0_11.1-orange.svg | 1 - .../images/cuda_ge_10.1-orange.svg | 1 + .../images/pip_cuda-10.1_10.2_11.0-orange.svg | 1 - .../images/pip_python-3.6_3.7_3.8-blue.svg | 1 - .../images/pip_pytorch-1.7.1-green.svg | 1 - .../images/python-3.6_3.7_3.8-blue.svg | 1 - .../images/python_ge_3.6-blue.svg | 1 + .../images/pytorch_ge_1.5.0-green.svg | 1 + ...a-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg | 1 - .../source_python-3.6_3.7_3.8_3.9-blue.svg | 1 - ...ch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg | 1 - docs/source/installation/index.rst | 22 +++++++++---------- docs/source/installation/pip.rst | 6 ++--- 16 files changed, 26 insertions(+), 39 deletions(-) delete mode 100644 docs/source/installation/images/cuda-10.1_10.2_11.0_11.1-orange.svg create mode 100644 docs/source/installation/images/cuda_ge_10.1-orange.svg delete mode 100644 docs/source/installation/images/pip_cuda-10.1_10.2_11.0-orange.svg delete mode 100644 docs/source/installation/images/pip_python-3.6_3.7_3.8-blue.svg delete mode 100644 docs/source/installation/images/pip_pytorch-1.7.1-green.svg delete mode 100644 docs/source/installation/images/python-3.6_3.7_3.8-blue.svg create mode 100644 docs/source/installation/images/python_ge_3.6-blue.svg create mode 100644 docs/source/installation/images/pytorch_ge_1.5.0-green.svg delete mode 100644 docs/source/installation/images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg delete mode 100644 docs/source/installation/images/source_python-3.6_3.7_3.8_3.9-blue.svg delete mode 100644 docs/source/installation/images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg diff --git a/docs/source/installation/conda.rst b/docs/source/installation/conda.rst index 4c50b8c9d..d4ef221e2 100644 --- a/docs/source/installation/conda.rst +++ b/docs/source/installation/conda.rst @@ -57,13 +57,13 @@ Read the following if you want to learn more. Supported versions ------------------ -.. |conda_python_versions| image:: ./images/python-3.6_3.7_3.8-blue.svg +.. |conda_python_versions| image:: ./images/python_ge_3.6-blue.svg :alt: Supported python versions -.. |conda_cuda_versions| image:: ./images/cuda-10.1_10.2_11.0_11.1-orange.svg +.. |conda_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |conda_pytorch_versions| image:: ./images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg +.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg :alt: Supported pytorch versions - |conda_python_versions| diff --git a/docs/source/installation/from_source.rst b/docs/source/installation/from_source.rst index 9d78a864b..5aac6d406 100644 --- a/docs/source/installation/from_source.rst +++ b/docs/source/installation/from_source.rst @@ -9,13 +9,13 @@ The following versions of Python, CUDA, and PyTorch are known to work. - |source_cuda_versions| - |source_pytorch_versions| -.. |source_python_versions| image:: ./images/source_python-3.6_3.7_3.8_3.9-blue.svg +.. |source_python_versions| image:: ./images/python_ge_3.6-blue.svg :alt: Supported python versions -.. |source_cuda_versions| image:: ./images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg +.. |source_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |source_pytorch_versions| image:: ./images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg +.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg :alt: Supported pytorch versions Before compiling k2, some preparation work has to be done: diff --git a/docs/source/installation/images/README.md b/docs/source/installation/images/README.md index 9562b5ba7..a63890a06 100644 --- a/docs/source/installation/images/README.md +++ b/docs/source/installation/images/README.md @@ -3,18 +3,10 @@ is used to create the following files: -- python-3.6_3.7_3.8-blue.svg -- cuda-10.1_10.2_11.0_11.1-orange.svg -- pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg +- python_ge_3.6-blue.svg +- cuda_ge_10.1-orange.svg +- pytorch_ge_1.5.0-green.svg - pypi_python-3.6_3.7_3.8-blue.svg - pypi_cuda-10.1-orange.svg - pypi_pytorch-1.7.1-green.svg - -- pip_python-3.6_3.7_3.8-blue.svg -- pip_cuda-10.1_10.2_11.0-orange.svg -- pip_pytorch-1.7.1-green.svg - -- source_python-3.6_3.7_3.8_3.9-blue.svg -- source_cuda-10.1_10.2_11.0_11.1-orange.svg -- source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg diff --git a/docs/source/installation/images/cuda-10.1_10.2_11.0_11.1-orange.svg b/docs/source/installation/images/cuda-10.1_10.2_11.0_11.1-orange.svg deleted file mode 100644 index f4043e4d6..000000000 --- a/docs/source/installation/images/cuda-10.1_10.2_11.0_11.1-orange.svg +++ /dev/null @@ -1 +0,0 @@ -cuda: 10.1 | 10.2 | 11.0 | 11.1cuda10.1 | 10.2 | 11.0 | 11.1 \ No newline at end of file diff --git a/docs/source/installation/images/cuda_ge_10.1-orange.svg b/docs/source/installation/images/cuda_ge_10.1-orange.svg new file mode 100644 index 000000000..6aa534287 --- /dev/null +++ b/docs/source/installation/images/cuda_ge_10.1-orange.svg @@ -0,0 +1 @@ +cuda: >= 10.1cuda>= 10.1 \ No newline at end of file diff --git a/docs/source/installation/images/pip_cuda-10.1_10.2_11.0-orange.svg b/docs/source/installation/images/pip_cuda-10.1_10.2_11.0-orange.svg deleted file mode 100644 index 72db8624a..000000000 --- a/docs/source/installation/images/pip_cuda-10.1_10.2_11.0-orange.svg +++ /dev/null @@ -1 +0,0 @@ -cuda: 10.1 | 10.2 | 11.0cuda10.1 | 10.2 | 11.0 \ No newline at end of file diff --git a/docs/source/installation/images/pip_python-3.6_3.7_3.8-blue.svg b/docs/source/installation/images/pip_python-3.6_3.7_3.8-blue.svg deleted file mode 100644 index 275a432ec..000000000 --- a/docs/source/installation/images/pip_python-3.6_3.7_3.8-blue.svg +++ /dev/null @@ -1 +0,0 @@ -python: 3.6 | 3.7 | 3.8python3.6 | 3.7 | 3.8 \ No newline at end of file diff --git a/docs/source/installation/images/pip_pytorch-1.7.1-green.svg b/docs/source/installation/images/pip_pytorch-1.7.1-green.svg deleted file mode 100644 index 1e8abdbe6..000000000 --- a/docs/source/installation/images/pip_pytorch-1.7.1-green.svg +++ /dev/null @@ -1 +0,0 @@ -pytorch: 1.7.1pytorch1.7.1 \ No newline at end of file diff --git a/docs/source/installation/images/python-3.6_3.7_3.8-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8-blue.svg deleted file mode 100644 index 275a432ec..000000000 --- a/docs/source/installation/images/python-3.6_3.7_3.8-blue.svg +++ /dev/null @@ -1 +0,0 @@ -python: 3.6 | 3.7 | 3.8python3.6 | 3.7 | 3.8 \ No newline at end of file diff --git a/docs/source/installation/images/python_ge_3.6-blue.svg b/docs/source/installation/images/python_ge_3.6-blue.svg new file mode 100644 index 000000000..4254dc58a --- /dev/null +++ b/docs/source/installation/images/python_ge_3.6-blue.svg @@ -0,0 +1 @@ +python: >= 3.6python>= 3.6 \ No newline at end of file diff --git a/docs/source/installation/images/pytorch_ge_1.5.0-green.svg b/docs/source/installation/images/pytorch_ge_1.5.0-green.svg new file mode 100644 index 000000000..ce4826490 --- /dev/null +++ b/docs/source/installation/images/pytorch_ge_1.5.0-green.svg @@ -0,0 +1 @@ +pytorch: >= 1.5.0pytorch>= 1.5.0 \ No newline at end of file diff --git a/docs/source/installation/images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg b/docs/source/installation/images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg deleted file mode 100644 index 74818a756..000000000 --- a/docs/source/installation/images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg +++ /dev/null @@ -1 +0,0 @@ -cuda: 10.1 | 10.2 | 11.0 | 11.1 | 11.2 | 11.3cuda10.1 | 10.2 | 11.0 | 11.1 | 11.2 | 11.3 \ No newline at end of file diff --git a/docs/source/installation/images/source_python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/source_python-3.6_3.7_3.8_3.9-blue.svg deleted file mode 100644 index 676feba2c..000000000 --- a/docs/source/installation/images/source_python-3.6_3.7_3.8_3.9-blue.svg +++ /dev/null @@ -1 +0,0 @@ -python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 \ No newline at end of file diff --git a/docs/source/installation/images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg b/docs/source/installation/images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg deleted file mode 100644 index dc940a51a..000000000 --- a/docs/source/installation/images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg +++ /dev/null @@ -1 +0,0 @@ -pytorch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1pytorch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 \ No newline at end of file diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index b64a8db5b..5b6537ba2 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -14,7 +14,7 @@ below: - From conda (**recommended**) - |conda_python_versions| - - |conda_cuda_versions|- + - |conda_cuda_versions| - |conda_pytorch_versions| - From pip (k2-fsa.org) @@ -26,7 +26,7 @@ below: - From pypi (pypi.org) - |pypi_python_versions| - - |pypi_cuda_versions|- + - |pypi_cuda_versions| - |pypi_pytorch_versions| - From source (**for advanced users**) @@ -44,22 +44,22 @@ below: from_source for_developers -.. |conda_python_versions| image:: ./images/python-3.6_3.7_3.8-blue.svg +.. |conda_python_versions| image:: ./images/python_ge_3.6-blue.svg :alt: Supported python versions -.. |conda_cuda_versions| image:: ./images/cuda-10.1_10.2_11.0_11.1-orange.svg +.. |conda_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |conda_pytorch_versions| image:: ./images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg +.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg :alt: Supported pytorch versions -.. |pip_python_versions| image:: ./images/pip_python-3.6_3.7_3.8-blue.svg +.. |pip_python_versions| image:: ./images/python_ge_3.6-blue.svg :alt: Supported python versions -.. |pip_cuda_versions| image:: ./images/pip_cuda-10.1_10.2_11.0-orange.svg +.. |pip_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |pip_pytorch_versions| image:: ./images/pip_pytorch-1.7.1-green.svg +.. |pip_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg :alt: Supported pytorch versions .. |pypi_python_versions| image:: ./images/pypi_python-3.6_3.7_3.8-blue.svg @@ -71,13 +71,13 @@ below: .. |pypi_pytorch_versions| image:: ./images/pypi_pytorch-1.7.1-green.svg :alt: Supported pytorch versions -.. |source_python_versions| image:: ./images/source_python-3.6_3.7_3.8_3.9-blue.svg +.. |source_python_versions| image:: ./images/python_ge_3.6-blue.svg :alt: Supported python versions -.. |source_cuda_versions| image:: ./images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg +.. |source_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |source_pytorch_versions| image:: ./images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg +.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg :alt: Supported pytorch versions Reporting issues diff --git a/docs/source/installation/pip.rst b/docs/source/installation/pip.rst index 20db2a886..b756263b7 100644 --- a/docs/source/installation/pip.rst +++ b/docs/source/installation/pip.rst @@ -1,13 +1,13 @@ Install using pip (k2-fsa.org) ============================== -.. |pip_python_versions| image:: ./images/pip_python-3.6_3.7_3.8-blue.svg +.. |pip_python_versions| image:: ./images/python_ge_3.6-blue.svg :alt: Supported python versions -.. |pip_cuda_versions| image:: ./images/pip_cuda-10.1_10.2_11.0-orange.svg +.. |pip_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |pip_pytorch_versions| image:: ./images/pip_pytorch-1.7.1-green.svg +.. |pip_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg :alt: Supported pytorch versions You can find a list of nightly pre-built From f4fefe4882bc0ae59af951da3f47335d5495ef71 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 10 Feb 2022 15:16:02 +0800 Subject: [PATCH 41/64] Fix building for CUDA 11.6 (#917) * Fix building for CUDA 11.6 * Minor fixes. --- k2/csrc/CMakeLists.txt | 5 +++++ k2/csrc/cub.h | 42 ------------------------------------------ 2 files changed, 5 insertions(+), 42 deletions(-) diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index a97e291e8..1759244ea 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -92,6 +92,11 @@ add_library(context ${context_srcs}) target_compile_definitions(context PUBLIC K2_TORCH_VERSION_MAJOR=${K2_TORCH_VERSION_MAJOR}) target_compile_definitions(context PUBLIC K2_TORCH_VERSION_MINOR=${K2_TORCH_VERSION_MINOR}) +# see https://github.com/NVIDIA/thrust/issues/1401 +# and https://github.com/k2-fsa/k2/pull/917 +target_compile_definitions(context PUBLIC CUB_WRAPPED_NAMESPACE=k2) +target_compile_definitions(context PUBLIC THRUST_NS_QUALIFIER=thrust) + set_target_properties(context PROPERTIES CUDA_SEPARABLE_COMPILATION ON) set_target_properties(context PROPERTIES OUTPUT_NAME "k2context") diff --git a/k2/csrc/cub.h b/k2/csrc/cub.h index d1df56f32..aeb524ec6 100644 --- a/k2/csrc/cub.h +++ b/k2/csrc/cub.h @@ -20,50 +20,8 @@ #ifndef K2_CSRC_CUB_H_ #define K2_CSRC_CUB_H_ -// See -// https://github.com/k2-fsa/k2/issues/698 -// and -// https://github.com/pytorch/pytorch/issues/54245#issuecomment-805707551 -// for why we need the following two macros -// -// NOTE: We define the following two macros so -// that k2 and PyTorch use a different copy -// of CUB. - -#ifdef CUB_NS_PREFIX -#undef CUB_NS_PREFIX -#endif - -#ifdef CUB_NS_POSTFIX -#undef CUB_NS_POSTFIX -#endif - -#ifdef CUB_NS_QUALIFIER -#undef CUB_NS_QUALIFIER -#endif - -// see -// https://github.com/NVIDIA/cub/commit/6631c72630f10e370d93814a59146b12f7620d85 -// The above commit replaced "thrust" with "THRUST_NS_QUALIFIER" -#ifndef THRUST_NS_QUALIFIER -#define THRUST_NS_QUALIFIER thrust -#endif - -#define CUB_NS_PREFIX namespace k2 { -#define CUB_NS_POSTFIX } - -// See -// https://github.com/NVIDIA/cub/commit/6631c72630f10e370d93814a59146b12f7620d85 -// and -// https://github.com/NVIDIA/cub/pull/350 -#define CUB_NS_QUALIFIER ::k2::cub - #ifdef K2_WITH_CUDA #include "cub/cub.cuh" // NOLINT #endif -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX -#undef CUB_NS_QUALIFIER - #endif // K2_CSRC_CUB_H_ From 56edc82769221fa016f2436595ac5a4bcb6691f8 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 20 Feb 2022 11:10:24 +0800 Subject: [PATCH 42/64] Implement Unstack (#920) * Implement unstack * Remove code does not relate to this PR * Remove for loop on output dim; add Unstack ragged * Add more docs * Fix comments * Fix docs & unit tests --- k2/csrc/ragged_ops.cu | 328 ++++++++++++++++++++++++++++++++++++++- k2/csrc/ragged_ops.h | 121 ++++++++++++++- k2/csrc/ragged_ops_inl.h | 56 +++++++ k2/csrc/ragged_test.cu | 299 +++++++++++++++++++++++++++++++---- 4 files changed, 767 insertions(+), 37 deletions(-) diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index efd546a59..6120428cc 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -235,11 +235,18 @@ RaggedShape RaggedShapeFromTotSizes(ContextPtr c, int32_t num_axes, NVTX_RANGE(K2_FUNC); K2_CHECK_GE(num_axes, 2); std::vector axes(num_axes - 1); - // In future we might choose to allocate everything in one big array, to avoid - // multiple allocations, but for now just do it the simple way. + int32_t tot_size = 0; for (int32_t axis = 1; axis < num_axes; ++axis) { - axes[axis - 1].row_splits = Array1(c, tot_sizes[axis - 1] + 1); - axes[axis - 1].row_ids = Array1(c, tot_sizes[axis]); + tot_size += tot_sizes[axis - 1] + 1 + tot_sizes[axis]; + } + Array1 buf(c, tot_size); + int32_t start = 0; + for (int32_t axis = 1; axis < num_axes; ++axis) { + axes[axis - 1].row_splits = buf.Arange(start, + start + tot_sizes[axis - 1] + 1); + start += tot_sizes[axis - 1] + 1; + axes[axis - 1].row_ids = buf.Arange(start, start + tot_sizes[axis]); + start += tot_sizes[axis]; axes[axis - 1].cached_tot_size = tot_sizes[axis]; } // Not check here as we did not set the values of row_splits and row_ids @@ -418,7 +425,6 @@ static RaggedShape IndexAxis0(RaggedShape &src, const Array1 &new2old, Array1 *elem_indexes /*=nullptr*/) { NVTX_RANGE(K2_FUNC); ContextPtr &c = src.Context(); - bool is_cpu = (c->GetDeviceType() == kCpu); K2_CHECK(IsCompatible(src, new2old)); int32_t num_axes = src.NumAxes(), src_dim0 = src.Dim0(), ans_dim0 = new2old.Dim(); @@ -470,7 +476,6 @@ static RaggedShape IndexAxis0(RaggedShape &src, const Array1 &new2old, tot_sizes.data[i]); } - int32_t *elem_indexes_data = (elem_indexes != nullptr ? elem_indexes->Data() : nullptr); @@ -575,6 +580,7 @@ RaggedShape Index(RaggedShape &src, int32_t axis, Array1 last_row_splits(last_row_ids.Context(), src.TotSize(num_axes - 2) + 1); RowIdsToRowSplits(last_row_ids, &last_row_splits); + if (elem_indexes) *elem_indexes = indexes; @@ -689,7 +695,6 @@ static RaggedShape StackAxis0(int32_t num_srcs, RaggedShape **src, int32_t num_axes_in = src[0]->NumAxes(), num_axes_out = num_axes_in + 1; ContextPtr c = src[0]->Context(); - bool is_cpu = (c->GetDeviceType() == kCpu); // Check if they have same num-axes and compatible context for (int32_t i = 1; i < num_srcs; ++i) { @@ -702,7 +707,6 @@ static RaggedShape StackAxis0(int32_t num_srcs, RaggedShape **src, Array2 offsets = GetOffsets(num_srcs, src); auto offsets_acc = offsets.Accessor(); - SmallVec tot_sizes_out; K2_CHECK(num_axes_out <= 6); int32_t max_tot_size = 0; @@ -1100,6 +1104,314 @@ RaggedShape Stack(int32_t axis, int32_t num_srcs, RaggedShape **src, return RaggedShape(ans_layers); } +/* + Select ragged tensor's shape on axis 0 with a two axes ragged index. + + @param [in] src Source RaggedShape to select. + @param [in] indexes A **TWO** axes ragged tensor containing the indexes + into the axis 0 of src. we also support -1 as an index, + which will result in the empty list (as if it were the + index into a position in `src` that had an empty list) + i.e. with `-1 <= indexes[i] < src.TotSize(0)`. + @param [out] out The container where the output RaggedShape will write to, + MUST NOT be a nullptr. Will be reallocated and the final + size of `out` would equal to `indexes.TotSize(0)`. + Note, The `NumAxes()` of output RaggedShape is the same + as the `NumAxes()` of src. + @param [out] split_map If not nullptr will store the element-index within + src telling where the elements of split RaggedShape + come from. Will be reallocated and the final size of + `split_map` would equal to `indexes.TotSize(0)`. + + Suppose indexes is `[ [ 0 3 5 ] [ 1 2 4] [ 6 -1 ] ]`, it means that we will + select elements 0,3,5 of src's axis 0 to construct the first output + RaggedShape, 1,2,4 to construct the second output RaggedShape, 6 and a empty + list to construct the third output RaggedShape. + */ +static void SelectAxis0(RaggedShape &src, const Ragged &indexes, + std::vector *out, std::vector> *split_map) { + NVTX_RANGE(K2_FUNC); + ContextPtr &c = src.Context(); + K2_CHECK(IsCompatible(src, indexes)); + K2_CHECK_EQ(indexes.NumAxes(), 2); + K2_CHECK(out != nullptr); + int32_t num_axes = src.NumAxes(), + out_size = indexes.Dim0(), + tot_elems = indexes.NumElements(); + if (out_size == 0) { + *out = std::vector(); + if (split_map) { + *split_map = std::vector>(); + } + return; + } + + Array2 old_offsets, // num_axes by tot_elems + new_offsets; // num_axes by (tot_elems + 1). + GetOldAndNewOffsets(src, indexes.values, &old_offsets, &new_offsets); + + const int32_t *indexes_row_split1_data = indexes.RowSplits(1).Data(), + *indexes_row_ids1_data = indexes.RowIds(1).Data(); + + // Contains the `TotSize` of each axes of each output RaggedShape + Array2 tot_sizes(c, out_size, num_axes); + Array2Accessor tot_sizes_acc = tot_sizes.Accessor(); + Array2Accessor new_offsets_acc = new_offsets.Accessor(); + + K2_EVAL2(c, out_size, num_axes, lambda_set_tot_sizes, + (int32_t i, int32_t j) -> void { + int32_t idx0 = indexes_row_split1_data[i], + idx0_next = indexes_row_split1_data[i + 1]; + tot_sizes_acc(i, j) = + new_offsets_acc(j, idx0_next) - new_offsets_acc(j, idx0); + }); + + auto tot_sizes_cpu = tot_sizes.To(GetCpuContext()); + auto tot_sizes_cpu_acc = tot_sizes_cpu.Accessor(); + out->resize(out_size); + if (split_map != nullptr) split_map->resize(out_size); + // We can not avoid this for loop on dim0, as we want to allocate memory + // seperately, may consider using a ThreadPool later. + for (int32_t i = 0; i < out_size; ++i) { + out->at(i) = RaggedShapeFromTotSizes(c, + num_axes, tot_sizes_cpu.Row(i).Data()); + if (split_map != nullptr) { + split_map->at(i) = + Array1(c, tot_sizes_cpu_acc(i, num_axes - 1)); + }; + } + + // Caution: e.g. old_row_splits_acc(i) == src.RowSplits(i+1). + RowSplitsAccessor<5> old_row_splits_acc(src); + RowIdsAccessor<5> old_row_ids_acc(src); + auto old_offsets_acc = old_offsets.Accessor(); + + // axes_elems contains the elements number of each axes before splitting into + // different RaggedShape, it should equal to the Col sum of `tot_sizes` above. + Array1 axes_elems = + Array1(new_offsets.Col(tot_elems)).To(GetCpuContext()); + + for (int32_t axis = 0; axis < num_axes; axis++) { + // Contains the RowSplits & RowIds pointer for current layer, + // has a dimension of dim0 * 2, the layout is splits_pointer0, ids_pointer0, + // splits_pointer1, ids_pointer1, ... + Array1 splits_ids_ptr(GetCpuContext(), out_size * 2); + int32_t **splits_ids_ptr_data = splits_ids_ptr.Data(); + + // Contains the pointers for split_map + Array1 split_map_ptr; + int32_t **split_map_ptr_data; + + if (axis == num_axes - 1 && split_map != nullptr) { + split_map_ptr = Array1(GetCpuContext(), out_size); + split_map_ptr_data = split_map_ptr.Data(); + } + + for (int32_t i = 0; i < out_size; ++i) { + splits_ids_ptr_data[2 * i] = axis == num_axes - 1 ? nullptr : + out->at(i).RowSplits(axis + 1).Data(); + + splits_ids_ptr_data[2 * i + 1] = + axis == 0 ? nullptr : out->at(i).RowIds(axis).Data(); + + if (axis == num_axes - 1 && split_map != nullptr) { + split_map_ptr_data[i] = split_map->at(i).Data(); + } + } + // transfer to GPU if we're using a GPU + splits_ids_ptr = splits_ids_ptr.To(c); + splits_ids_ptr_data = splits_ids_ptr.Data(); + + // set row split1 + if (axis == 0) { + K2_EVAL(c, tot_elems, lambda_set_row_split1, (int32_t idx01) { + int32_t index_idx0 = indexes_row_ids1_data[idx01], + idx0x = indexes_row_split1_data[index_idx0]; + splits_ids_ptr_data[2 * index_idx0][idx01 - idx0x] + = new_offsets_acc(axis + 1, idx01) - + new_offsets_acc(axis + 1, idx0x); + + // Set the last elements of row_splits1 of each output shape + if (idx01 == tot_elems - 1 || + index_idx0 != indexes_row_ids1_data[idx01 + 1]) { + splits_ids_ptr_data[2 * index_idx0][idx01 - idx0x + 1] + = new_offsets_acc(axis + 1, idx01 + 1) - + new_offsets_acc(axis + 1, idx0x); + } + }); + continue; + } + + // set last element of each row_splits + // TODO: Integrate this kernel into the kernel below. + if (axis < num_axes - 1) { + K2_EVAL(c, out_size, lambda_set_last_row_splits, (int32_t idx0) { + int32_t idx0x = indexes_row_split1_data[idx0], + idx0x_next = indexes_row_split1_data[idx0 + 1], + value = new_offsets_acc(axis + 1, idx0x_next) - + new_offsets_acc(axis + 1, idx0x), + pos = tot_sizes_acc(idx0, axis); + splits_ids_ptr_data[2 * idx0][pos] = value; + }); + } + + if (axis == num_axes - 1 && split_map != nullptr) { + split_map_ptr = split_map_ptr.To(c); + split_map_ptr_data = split_map_ptr.Data(); + } + + int32_t num_elems = axes_elems[axis]; + + // composed_row_ids maps current idx to idx01 of indexes + Array1 composed_row_ids(c, num_elems); + RowSplitsToRowIds(new_offsets.Row(axis), &composed_row_ids); + + const int32_t *composed_row_ids_data = composed_row_ids.Data(); + + K2_EVAL(c, num_elems, lambda_set_row_splits_and_ids, (int32_t i) { + // tot_elems = indexes.NumElements(), so tot_idx0 can be interpreted as + // index_idx01 + int32_t tot_idx0 = composed_row_ids_data[i], + index_idx0 = indexes_row_ids1_data[tot_idx0], + index_idx0x = indexes_row_split1_data[index_idx0], + + begin_base = new_offsets_acc(axis, index_idx0x), + begin = new_offsets_acc(axis, tot_idx0), + this_idx0 = i - begin, + this_idx01 = i - begin_base; + + K2_CHECK_GE(this_idx0, 0); + K2_CHECK_GE(this_idx01, 0); + + // "prev" means for axis - 1 + int32_t new_prev_offset = new_offsets_acc(axis - 1, tot_idx0), + old_prev_offset = old_offsets_acc(axis - 1, tot_idx0), + old_offset = old_offsets_acc(axis, tot_idx0), + old_idx = old_offset + this_idx0; + + if (split_map != nullptr && axis == num_axes - 1) + split_map_ptr_data[index_idx0][this_idx01] = old_idx; + + // set row ids + const int32_t *this_old_row_ids = old_row_ids_acc(axis - 1); + int32_t old_row_id = this_old_row_ids[old_idx], + new_row_id = old_row_id + new_prev_offset - old_prev_offset, + new_pre_offset_idx0x = new_offsets_acc(axis - 1, index_idx0x); + + splits_ids_ptr_data[2 * index_idx0 + 1][this_idx01] = + new_row_id - new_pre_offset_idx0x; + + // set row splits + if (axis + 1 < num_axes) { + int32_t new_next_offset = new_offsets_acc(axis + 1, tot_idx0), + old_next_offset = old_offsets_acc(axis + 1, tot_idx0), + next_offset_diff = new_next_offset - old_next_offset; + const int32_t *old_row_splits_data = old_row_splits_acc(axis); + int32_t row_split_value = + next_offset_diff + old_row_splits_data[old_idx], + new_next_offset_idx0x = new_offsets_acc(axis + 1, index_idx0x); + splits_ids_ptr_data[2 * index_idx0][this_idx01] + = row_split_value - new_next_offset_idx0x; + } + }); + } +} + +void Unstack(RaggedShape &src, int32_t axis, std::vector *out, + std::vector> *split_map) { + ContextPtr &c = src.Context(); + if (axis == 0) { + if (src.NumAxes() == 2) { + auto new_src = ComposeRaggedShapes( + TrivialShape(c, src.TotSize(0)), src); + return Unstack(new_src, 1, out, split_map); + } + auto indexes = Ragged(RegularRaggedShape(c, src.Dim0(), 1), + Arange(c, 0, src.Dim0())); + + SelectAxis0(src, indexes, out, split_map); + for (size_t i = 0; i < out->size(); ++i) { + out->at(i) = RemoveAxis(out->at(i), 0); + } + } else { + int32_t tot_size_axis_minus1 = src.TotSize(axis - 1), + tot_size_axis = src.TotSize(axis); + const int32_t *row_splits_axis = src.RowSplits(axis).Data(), + *row_ids_axis = src.RowIds(axis).Data(); + + // Get the number of elements of current axis on each sublist + Array1 sublists_size(c, tot_size_axis_minus1); + int32_t *sublists_size_data = sublists_size.Data(); + K2_EVAL(c, tot_size_axis_minus1, lambda_get_sublists_size, (int32_t i) { + sublists_size_data[i] = row_splits_axis[i + 1] - row_splits_axis[i]; + }); + + // Each sublist contains the elements of axis `axis`, unstack operation will + // split all these elements in a sublist to different RaggedShapes, so the + // number of output RaggedShapes is the size of the sublist with max + // elements. + int32_t num_out = MaxValue(sublists_size); + + out->resize(num_out); + if (split_map != nullptr) split_map->resize(num_out); + + // We will select the elements of axis `axis` on each sublist, the number + // of sublits equals to `src.TotSize(axis - 1)`. + // Initialize with -1 here, because not all the sublists have the same size, + // -1s here mean that we don't select anything on those positions + Array1 indexes(c, num_out * tot_size_axis_minus1, -1); + int32_t *indexes_data = indexes.Data(); + + // Decide the elements of axis `axis` will go to which output RaggedShape + K2_EVAL(c, tot_size_axis, lambda_set_indexes, (int32_t idx01) { + int32_t idx0 = row_ids_axis[idx01], + idx0x = row_splits_axis[idx0], + idx1 = idx01 - idx0x; + indexes_data[idx1 * tot_size_axis_minus1 + idx0] = idx01; + }); + + // To make `DecomposeRaggedShape` work, we add a RegularRaggedShape + // layer after axis `axis` if axis equals to `src.NumAxes() - 1`. + // Of course, we have to remove the added layer finally. + bool remove_last_axis = false; + if (axis == src.NumAxes() - 1) { + src = ComposeRaggedShapes(src, + RegularRaggedShape(c, src.NumElements(), 1)); + remove_last_axis = true; + } + + RaggedShape top, bottom; + DecomposeRaggedShape(src, axis, &top, &bottom); + + // Unstack will remove current axis (the last axis of top after decomposing + // on axis), to make `RemoveAxis` work, we add a TrivialShape layer before + // axix 0, finally we will remove the added layer. + bool remove_axis0 = false; + if (top.NumAxes() == 2) { + top = ComposeRaggedShapes( + TrivialShape(c, top.TotSize(0)), top); + remove_axis0 = true; + } + top = RemoveAxis(top, top.NumAxes() - 1); + + auto ragged_indexes = Ragged(RegularRaggedShape(c, + num_out, tot_size_axis_minus1), indexes); + + // Select elements according to indexes into corresponding RaggedShape + SelectAxis0(bottom, ragged_indexes, out, split_map); + + for (int32_t i = 0; i < num_out; ++i) { + out->at(i) = ComposeRaggedShapes(top, out->at(i)); + if (remove_axis0 && !remove_last_axis) + out->at(i) = RemoveAxis(out->at(i), 0); + if (remove_last_axis) { + out->at(i) = RemoveEmptyLists(out->at(i), out->at(i).NumAxes() - 2); + out->at(i) = RemoveAxis(out->at(i), out->at(i).NumAxes() - 1); + } + } + } +} + RaggedShape Merge(int32_t num_srcs, RaggedShape **src, const Array1 &merge_map, Array1 *merge_map_out) { diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 441434819..62e8b640c 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -186,6 +186,66 @@ void OrPerSublist(Ragged &src, T initial_value, Array1 *or_values) { RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, Array1 *merge_map = nullptr); + +/* + Unstack a RaggedShape to a list of RaggedShapes, all the output RaggedShapes + have one less axis. + This function tries to do the opposite of Stack(), i.e. to generate an array + out such that `Equal(src, Stack(axis, out->size(), out->data()))`. But notes + that `Stack` needs a pointer of RaggedShape pointer, Unstack produces only a + pointer of RaggedShape, you should do some convertion before using Stack. + + @param [in] src The shape to unstack. + @param [in] axis The axis to be removed, all the elements of this axis will + be rearranged into output RaggedShapes. + @param [out] out The container where the output RaggedShapes would write + to. MUST NOT be a nullptr, will be reallocated. + @param [out] split_map If not nullptr will store the element-index within + `src` telling where the elements of each split RaggedShapes + come from. It has the same size of `out`, see notes below + for the dimension of it. For Array1 in each of the + `split_map`, It satifies + `split_map[i].Dim() == out[i].NumElements()`, and + `0 <= split_map[i][j] < src.NumElements()`. + `split_map` will be reallocated by this function. + + Caution: If `src.NumAxes() == 2`, the output shapes will only have one + dimension, to make it a RaggedShape, we will add a TrivialShape on + each of the output tensors. + + Note: The output RaggedShape may contain empty lists on axis `axis`, you can + remove them by RemoveEmptyLists if needed. + + Note: The number of output RaggedShape is decided by the size of sublist + with max number of elements along axis `axis`, for `axis == 0`, it has + only one sublist along `axis == 0`(i.e. the src itself), so the number + of output RaggedShape will be equal to `src.Dim0()`. + + A small example of unstacking a 3 axes RaggedShape: + + src: [ [ [ x x ] [ x ] ] [ [ x ] ] ] + unstack on axis 0: + src.Dim0() == 2, will produce 2 RaggedShape. + + out[0] : [ [ x x ] [ x ] ] split_map[0] : [0, 1, 2] + out[1] : [ [ x ] ] split_map[1] : [3] + + unstack on axis 1: + two sublists along axis 1, the sizes are [2, 1], will produce 2 RaggedShape + + out[0] : [ [ x x ] [ x ] ] split_map[0] : [0, 1, 3] + out[1] : [ [ x ] [ ] ] split_map[1] : [2] + + unstack on axis 2: + three sublists along axis 2, the sizes are [2, 1, 1], will produce 2 + RaggedShape. + + out[0] : [ [ x x ] [ x ] ] split_map[0] : [0, 2, 3] + out[1] : [ [ x ] [ ] ] split_map[1] : [1] + */ +void Unstack(RaggedShape &src, int32_t axis, std::vector *out, + std::vector> *split_map = nullptr); + /* Return a modified version of `src` in which all sub-lists on the last axis of the tenor have size modified by `size_delta`. `size_delta` may have either @@ -708,7 +768,6 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, */ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); - /* Return ragged shape with only a subset of the elements on the last and one-before-last axes kept. @@ -804,6 +863,8 @@ RaggedShape RemoveEmptyListsAxis0(RaggedShape &src_shape, RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, Renumbering &renumbering); + + /* Return ragged array with only a subset of the bottom-level elements kept. Require renumbering.NumOldElems() == src.NumElements(). Note: all @@ -852,6 +913,64 @@ template Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, Array1 *merge_map = nullptr); +/* + Unstack a Ragged tensor to a list of Ragged, tensors all the output Ragged + tensors have one less axis. Similar to TF's Unstack (or unbind in Pytorch). + + @param [in] src The ragged tensor to be unstacked. + @param [in] axis The axis to be removed, all the elements of this axis will + be rearranged into output Raggeds. + @param [out] out The container where the output ragged tensors would write + to. MUST NOT be a nullptr, will be reallocated. + @param [out] split_map If not nullptr will store the element-index within + the `src` telling where the elements of each split Raggeds + comes from. It has same size as `out`, see notes below for + the dimension of `out`. For Array1 in each of the + `split_map`, It satifies + `split_map[i].Dim() == out[i].values.Dim()`, and it contains + the element-index within `src`. + (i.e.`out[i].values[j] == src.values[split_map[i][j]]`) + `split_map` will be reallocated by this function. + + Caution: If `src.NumAxes() == 2`, the output shapes will only have one + dimension, to make it a ragged tensor, we will add a TrivialShape on + each of the output tensors. + + Note: The output ragged tensors may contain empty lists on axis `axis`, + you can remove them by RemoveEmptyLists if needed. + + Note: The number of output ragged tensors is decided by the size of sublist + with max number of elements along axis `axis`, for `axis == 0`, it has + only one sublist along `axis == 0`(i.e. the src itself), so the number + of output ragged will be equal to `src.Dim0()`. + + A small example of unstacking a 3 axes Ragged: + + src: [ [ [ 1 2 ] [ 3 ] ] [ [ 4 ] ] ] + unstack on axis 0: + src.Dim0() == 2, will produce 2 ragged tensors. + + out[0] : [ [ 1 2 ] [ 3 ] ] split_map[0] : [0, 1, 2] + out[1] : [ [ 4 ] ] split_map[1] : [3] + + unstack on axis 1: + two sublists along axis 1, the sizes are [2, 1], will produce 2 ragged tensors + + out[0] : [ [ 1 2 ] [ 4 ] ] split_map[0] : [0, 1, 3] + out[1] : [ [ 3 ] [ ] ] split_map[1] : [2] + + unstack on axis 2: + three sublists along axis 2, the sizes are [2, 1, 1], will produce 2 + ragged tensors. + + out[0] : [ [ 1 3 ] [ 4 ] ] split_map[0] : [0, 2, 3] + out[1] : [ [ 2 ] [ ] ] split_map[1] : [1] + */ + +template +void Unstack(Ragged src, int32_t axis, std::vector> *out, + std::vector> *split_map = nullptr); + /* Concatenate a list of Ragged to form a single Ragged. diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index 4d44761cd..bee7efb21 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -169,6 +169,62 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, return Stack(axis, num_srcs, temp.data(), merge_map); } +template +void Unstack(Ragged src, int32_t axis, std::vector> *out, + std::vector> *split_map /* = nullptr */) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(out != nullptr); + ContextPtr &c = src.Context(); + std::vector> split_map_tmp; + std::vector> *split_map_ptr = + (split_map != nullptr ? split_map : &split_map_tmp); + std::vector shape_out; + + Unstack(src.shape, axis, &shape_out, split_map_ptr); + + out->resize(shape_out.size()); + // +1 here because we need to do ExclusiveSum on this Array1 later + Array1 elem_nums(GetCpuContext(), shape_out.size() + 1); + Array1 values_ptr(GetCpuContext(), shape_out.size()); + Array1 map_ptr(GetCpuContext(), shape_out.size()); + int32_t *elem_nums_data = elem_nums.Data(); + T **values_ptr_data = values_ptr.Data(); + int32_t **map_ptr_data = map_ptr.Data(); + + int32_t tot_elems = 0; + // Can not avoid this for loop as we want to allocate memory separately. + for (size_t i = 0; i < shape_out.size(); ++i) { + int32_t elem_num = shape_out[i].NumElements(); + out->at(i) = Ragged(shape_out[i], Array1(c, elem_num)); + elem_nums_data[i] = elem_num; + tot_elems += elem_num; + values_ptr_data[i] = out->at(i).values.Data(); + map_ptr_data[i] = split_map_ptr->at(i).Data(); + } + + Array1 row_splits(c, shape_out.size() + 1); + ExclusiveSum(elem_nums.To(c), &row_splits); + + Array1 row_ids(c, tot_elems); + RowSplitsToRowIds(row_splits, &row_ids); + + const int32_t *row_splits_data = row_splits.Data(), + *row_ids_data = row_ids.Data(); + const T *src_value_data = src.values.Data(); + // Transfer to GPU if we are using a GPU + map_ptr = map_ptr.To(c); + map_ptr_data = map_ptr.Data(); + values_ptr = values_ptr.To(c); + values_ptr_data = values_ptr.Data(); + + K2_EVAL(c, tot_elems, lambda_set_values, (int32_t idx01) { + int32_t idx0 = row_ids_data[idx01], + idx0x = row_splits_data[idx0], + idx1 = idx01 - idx0x; + values_ptr_data[idx0][idx1] = src_value_data[map_ptr_data[idx0][idx1]]; + }); +} + template Ragged Cat(int32_t axis, int32_t num_srcs, Ragged **src, Array1 *merge_map /* = nullptr*/) { diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index a8fdd1286..7fe9c5fae 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -44,27 +44,27 @@ namespace k2 { TEST(RaggedShapeOpsTest, CatMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape1 = - RaggedShape("[ [ [ [ x x ] ] [ [x ] ] ] [[[x]]]]").To(c), - shape2 = - RaggedShape("[ [ [ [x ] ] [ [x ] ] ] [[[x x]]]]").To(c), - shape3 = RaggedShape("[ [ [ [ ] ] [ [ x ] ] ] [[[]]]]").To(c); + RaggedShape shape1 = RaggedShape(c, "[ [ [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x ] ] ] ]"), + shape2 = RaggedShape(c, "[ [ [ [ x ] ] [ [ x ] ] ]" + " [ [ [ x x ] ] ] ]"), + shape3 = RaggedShape(c, "[ [ [ [ ] ] [ [ x ] ] ]" + " [ [ [ ] ] ] ]"); RaggedShape cat_axis2_ref = - RaggedShape("[ [ [[ x x ][ x ][]] [[x ][x][ x ]] ] [[[x ][ x x][]]]]") - .To(c); + RaggedShape(c, "[ [ [ [ x x ] [ x ] [ ] ] [ [ x ] [ x ] [ x ] ] ]" + " [ [ [ x ] [ x x ] [ ] ] ] ]"); RaggedShape cat_axis3_ref = - RaggedShape("[ [ [[ x x x ]] [[x x x ]] ] [[[x x x]]]]").To(c); + RaggedShape(c, "[ [ [ [ x x x ] ] [ [ x x x ] ] ]" + " [ [ [ x x x ] ] ] ]"); RaggedShape *srcs[] = {&shape1, &shape2, &shape3}; Array1 merge_map2; Array1 merge_map3; RaggedShape cat_axis2 = Cat(2, 3, srcs, &merge_map2); RaggedShape cat_axis3 = Cat(3, 3, srcs, &merge_map3); - K2_LOG(INFO) << "cat_axis2 = " << cat_axis2; - K2_LOG(INFO) << "cat_axis3 = " << cat_axis3; K2_CHECK(Equal(cat_axis2, cat_axis2_ref)); - K2_CHECK(Equal(cat_axis2, cat_axis2_ref)); + K2_CHECK(Equal(cat_axis3, cat_axis3_ref)); std::vector merge_values = {0, 3, 1, 6, 4, 2, 9, 7, 10}; CheckArrayData(merge_map2, merge_values); @@ -74,27 +74,29 @@ TEST(RaggedShapeOpsTest, CatMoreAxes) { TEST(RaggedShapeOpsTest, StackMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape1 = - RaggedShape("[ [ [ [ x x ] ] [ [x ] ] ] [[[x]]]]").To(c), - shape2 = - RaggedShape("[ [ [ [x ] ] [ [x ] ] ] [[[x x]]]]").To(c), - shape3 = RaggedShape("[ [ [ [ ] ] [ [ x ] ] ] [[[]]]]").To(c); - - RaggedShape stacked_ref = - RaggedShape( - "[ [ [[[ x x ]][[ x ]][[]]] [[[x ]][[x]][[ x ]]] ] " - "[[[[x ]][[ x x]][[]]]]]") - .To(c); + RaggedShape shape1 = RaggedShape(c, "[ [ [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x ] ] ] ]"), + shape2 = RaggedShape(c, "[ [ [ [ x ] ] [ [ x ] ] ]" + " [ [ [ x x ] ] ] ]"), + shape3 = RaggedShape(c, "[ [ [ [ ] ] [ [ x ] ] ]" + " [ [ [ ] ] ] ]"); + + RaggedShape stacked2_ref = + RaggedShape(c, "[ [ [ [ [ x x ] ] [ [ x ] ] [ [ ] ] ]" + " [ [ [ x ] ] [ [ x ] ] [ [ x ] ] ] ]" + " [ [ [ [ x ] ] [ [ x x ] ] [ [ ] ] ] ] ]"); + RaggedShape stacked3_ref = + RaggedShape(c, "[ [ [ [ [ x x ] [ x ] [ ] ] ]" + " [ [ [ x ] [ x ] [ x ] ] ] ]" + " [ [ [ [ x ] [ x x ] [ ] ] ] ] ]"); RaggedShape *srcs[] = {&shape1, &shape2, &shape3}; Array1 merge_map2; Array1 merge_map3; RaggedShape stacked_axis2 = Stack(2, 3, srcs, &merge_map2); RaggedShape stacked_axis3 = Stack(3, 3, srcs, &merge_map3); - K2_LOG(INFO) << "stacked_axis2 = " << stacked_axis2; - K2_LOG(INFO) << "stacked_axis3 = " << stacked_axis3; - K2_CHECK(Equal(stacked_axis2, stacked_ref)); - K2_CHECK(Equal(stacked_axis2, stacked_ref)); + K2_CHECK(Equal(stacked_axis2, stacked2_ref)); + K2_CHECK(Equal(stacked_axis3, stacked3_ref)); std::vector merge_values = {0, 3, 1, 6, 4, 2, 9, 7, 10}; CheckArrayData(merge_map2, merge_values); @@ -102,6 +104,185 @@ TEST(RaggedShapeOpsTest, StackMoreAxes) { } } +TEST(RaggedShapeOpsTest, Unstack2Axes) { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + auto shape = RaggedShape(c, "[ [ x x ] [ x x x ] [ x ] ]"); + + std::vector out; + std::vector> out_map; + + // axis = 0 + Unstack(shape, 0, &out, &out_map); + K2_CHECK(Equal(out[0], + RaggedShape(c, "[ [ x x ] ]"))); + K2_CHECK(Equal(out_map[0], + Array1(c, std::vector{0, 1}))); + K2_CHECK(Equal(out[1], + RaggedShape(c, "[ [ x x x ] ]"))); + K2_CHECK(Equal(out_map[1], + Array1(c, std::vector{2, 3, 4}))); + K2_CHECK(Equal(out[2], + RaggedShape(c, "[ [ x ] ]"))); + K2_CHECK(Equal(out_map[2], + Array1(c, std::vector{5}))); + + std::vector out_ptr; + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); + auto dest = Stack(0, out.size(), out_ptr.data()); + dest = RemoveAxis(dest, 1); + K2_CHECK(Equal(dest, shape)); + + // axis = 1 + Unstack(shape, 1, &out, &out_map); + K2_CHECK(Equal(out[0], + RaggedShape(c, "[ [ x x x ] ]"))); + K2_CHECK(Equal(out_map[0], + Array1(c, std::vector{0, 2, 5}))); + K2_CHECK(Equal(out[1], + RaggedShape(c, "[ [ x x ] ]"))); + K2_CHECK(Equal(out_map[1], + Array1(c, std::vector{1, 3}))); + K2_CHECK(Equal(out[2], + RaggedShape(c, "[ [ x ] ]"))); + K2_CHECK(Equal(out_map[2], + Array1(c, std::vector{4}))); + // can not test Stack here, because the element numbers of axis 1 is not + // the same + } +} + +TEST(RaggedShapeOpsTest, Unstack) { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + RaggedShape shape(c, "[ [ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]" + " [ [ [ x x x ] ] ] ]"); + std::vector out; + std::vector> out_map; + Unstack(shape, 0, &out, &out_map); + + // axis = 0 + K2_CHECK(Equal(out[0], + RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"))); + K2_CHECK(Equal(out_map[0], + Array1(c, std::vector{0, 1, 2, 3, 4, 5}))); + K2_CHECK(Equal(out[1], + RaggedShape(c, "[ [ [ x x x ] ] ]"))); + K2_CHECK(Equal(out_map[1], + Array1(c, std::vector{6, 7, 8}))); + + std::vector out_ptr; + for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); + auto dest = Stack(0, out.size(), out_ptr.data()); + K2_CHECK(Equal(dest, shape)); + + // axis = 1 + Unstack(shape, 1, &out, &out_map); + K2_CHECK(Equal(out[0], + RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); + K2_CHECK(Equal(out_map[0], + Array1(c, std::vector{0, 1, 2, 6, 7, 8}))); + K2_CHECK(Equal(out[1], + RaggedShape(c, "[ [ [ x ] [ x x ] ] [ ] ]"))); + K2_CHECK(Equal(out_map[1], + Array1(c, std::vector{3, 4, 5}))); + + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); + dest = Stack(1, out.size(), out_ptr.data()); + dest = RemoveEmptyLists(dest, 1); + K2_CHECK(Equal(dest, shape)); + + // axis = 2 + Unstack(shape, 2, &out, &out_map); + K2_CHECK(Equal(out[0], + RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); + K2_CHECK(Equal(out_map[0], + Array1(c, std::vector{0, 1, 3, 6, 7, 8}))); + K2_CHECK(Equal(out[1], + RaggedShape(c, "[ [ [ x ] [ x x ] ] [ [ ] ] ]"))); + K2_CHECK(Equal(out_map[1], + Array1(c, std::vector{2, 4, 5}))); + + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); + dest = Stack(2, out.size(), out_ptr.data()); + dest = RemoveEmptyLists(dest, 2); + K2_CHECK(Equal(dest, shape)); + + // axis = 3 + Unstack(shape, 3, &out, &out_map); + K2_CHECK(Equal(out[0], + RaggedShape(c, "[ [ [ x x ] [ x x ] ] [ [ x ] ] ]"))); + K2_CHECK(Equal(out_map[0], + Array1(c, std::vector{0, 2, 3, 4, 6}))); + K2_CHECK(Equal(out[1], + RaggedShape(c, "[ [ [ x ] [ x ] ] [ [ x ] ] ]"))); + K2_CHECK(Equal(out_map[1], + Array1(c, std::vector{1, 5, 7}))); + K2_CHECK(Equal(out[2], + RaggedShape(c, "[ [ [ ] [ ] ] [ [ x ] ] ]"))); + K2_CHECK(Equal(out_map[2], + Array1(c, std::vector{8}))); + // can not test Stack here, because the element numbers of axis 3 is not + // the same + } +} + +TEST(RaggedShapeOpsTest, UnstackMoreAxes) { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + RaggedShape shape(c, "[ [ [ [ [ x ] [ ] ] [ [ x x x ] ] ] ]" + " [ [ [ [ x x x ] ] [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x x ] [ x ] [ ] [ x ] ] ]" + " [ [ [ x ] ] [ [ x ] [ x x x x ] ] ] ]" + " [ [ [ [ x ] ] [ ] ]" + " [ [ [ x x ] ] ] ] ]"); + + std::vector out; + std::vector> out_map; + std::vector out_ptr; + + for (int32_t axis = 0; axis < 4; axis++) { + Unstack(shape, axis, &out, &out_map); + + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); + auto dest = Stack(axis, out.size(), out_ptr.data()); + dest = RemoveEmptyLists(dest, axis); + K2_CHECK(Equal(dest, RemoveEmptyLists(shape, axis))); + } + } +} + +TEST(RaggedShapeOpsTest, UnstackRandom) { + RaggedShape random_shape_ = RandomRaggedShape(true, // set_row_ids + 5, // min_num_axes + 5, // max_num_axes + 1, // min_num_elements + 100); // max_num_elements + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + auto random_shape0 = random_shape_.To(c); + std::vector out; + std::vector out_ptr; + for (int32_t axis = 0; axis < 4; axis++) { + auto random_shape = RemoveEmptyLists(random_shape0, axis); + + Unstack(random_shape, axis, &out, nullptr); + + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) { + out_ptr.emplace_back(&(out[i])); + } + // There is a bug in `Stack` for stacking a shape itself, + // not urgent, so skipping here. + // TODO: Remove this line when the bug fixed. + if (out.size() == 1) continue; + auto dest = Stack(axis, out.size(), out_ptr.data()); + dest = RemoveEmptyLists(dest, axis); + + K2_CHECK(Equal(dest, random_shape)); + } + } +} class RaggedShapeOpsSuiteTest : public ::testing::Test { protected: @@ -1549,7 +1730,6 @@ TEST(RaggedShapeOpsTest, TestIndex) { } } - TEST(RaggedShapeOpsTest, TestIndexAxis1) { for (auto &context : {GetCpuContext(), GetCudaContext()}) { { @@ -2041,6 +2221,70 @@ TEST(RaggedTest, TestStackRagged) { TestStackRagged(); } +template +void TestUnstackRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + // two axes + auto ragged = Ragged(c, "[ [ 10 20 ] [ 30 40 50 ] [ 60 ] ]"); + std::vector> out; + + // axis = 0 + Unstack(ragged, 0, &out); + K2_CHECK(Equal(out[0], Ragged(c, "[ [ 10 20 ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ 30 40 50 ] ]"))); + K2_CHECK(Equal(out[2], Ragged(c, "[ [ 60 ] ]"))); + + // axis = 1 + Unstack(ragged, 1, &out); + K2_CHECK(Equal(out[0], Ragged(c, "[ [ 10 30 60 ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ 20 40 ] ]"))); + K2_CHECK(Equal(out[2], Ragged(c, "[ [ 50 ] ]"))); + + // more axes + ragged = Ragged(c, "[ [ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ]" + " [ [ 41 ] [ 51 ] ] ]" + " [ [ [ 61 62 63 ] ] ] ]"); + + // axis = 0 + Unstack(ragged, 0, &out); + K2_CHECK(Equal(out[0], Ragged(c, + "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 41 ] [ 51 ] ] ]"))); + K2_CHECK(Equal(out[1], + Ragged(c, "[ [ [ 61 62 63 ] ] ]"))); + + // axis = 1 + Unstack(ragged, 1, &out); + K2_CHECK(Equal(out[0], Ragged(c, + "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal(out[1], + Ragged(c, "[ [ [ 41 ] [ 51 ] ] [ ] ]"))); + + // axis = 2 + Unstack(ragged, 2, &out); + K2_CHECK(Equal(out[0], + Ragged(c, "[ [ [ 1 11 21 ] [ 41 ] ] [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal(out[1], + Ragged(c, "[ [ [ 21 22 ] [ 51 ] ] [ [ ] ] ]"))); + K2_CHECK(Equal(out[2], + Ragged(c, "[ [ [ 31 ] [ ] ] [ [ ] ] ]"))); + + // axis = 3 + Unstack(ragged, 3, &out); + K2_CHECK(Equal(out[0], + Ragged(c, "[ [ [ 1 21 31 ] [ 41 51 ] ] [ [ 61 ] ] ]"))); + K2_CHECK(Equal(out[1], + Ragged(c, "[ [ [ 11 22 ] [ ] ] [ [ 62 ] ] ]"))); + K2_CHECK(Equal(out[2], + Ragged(c, "[ [ [ 21 ] [ ] ] [ [ 63 ] ] ]"))); + } +} + +TEST(RaggedTest, TestUnstack) { + TestUnstackRagged(); + TestUnstackRagged(); + TestUnstackRagged(); +} + TEST(RaggedTest, TestMaxSize) { for (int32_t i = 0; i <= 10; i++) { ContextPtr c = (i % 2 == 0 ? GetCpuContext() : GetCudaContext()); @@ -2795,5 +3039,4 @@ TEST(RaggedTest, TestPadRagged) { TestPadRagged(); TestPadRagged(); } - } // namespace k2 From 854b792368214a2adb4e89cd83f6bc09ddbbcdae Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 20 Feb 2022 11:22:31 +0800 Subject: [PATCH 43/64] SubsetRagged & PruneRagged (#919) * Extend interface of SubsampleRagged. * Add interface for pruning ragged tensor. * Draft of new RNN-T decoding method * Implements SubsampleRaggedShape * Implements PruneRagged * Rename subsample-> subset * Minor fixes * Fix comments Co-authored-by: Daniel Povey --- k2/csrc/algorithms.h | 31 ++- k2/csrc/algorithms_test.cu | 9 + k2/csrc/ragged_ops.cu | 25 +-- k2/csrc/ragged_ops.h | 120 +++++++++-- k2/csrc/ragged_ops_inl.h | 139 ++++++++++++- k2/csrc/ragged_shape_test.cu | 59 +++++- k2/csrc/ragged_test.cu | 380 ++++++++++++++++++++++------------- k2/csrc/rm_epsilon.cu | 4 +- 8 files changed, 580 insertions(+), 187 deletions(-) diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index 439ac5ba3..6e11a31cb 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -111,9 +111,7 @@ class Renumbering { (pre-renumbering) indexes. Its dimension is the number of new indexes (i.e. the number of 1 in keep_), but internally it has one extra element which contains the number of old - elements, so it's OK to read one past the end. (We may - later make it possible to access the array with the one-larger - dimension). + elements, so it's OK to read one past the end. */ Array1 &New2Old() { NVTX_RANGE(K2_FUNC); @@ -121,17 +119,40 @@ class Renumbering { return new2old_; } + /* Return a mapping from new index to old index, with one extra element + containing the total number of kept elements if extra_element == true. + If Keep() can be interpreted as a tails vector, i.e. with 1 at the end + of sub-lists of elements, then New2Old(true) would corresponds to a + row-splits array and Old2New(false) would correspond to a row-ids + array. + */ + Array1 New2Old(bool extra_element) { + Array1 &new2old_part = New2Old(); + if (!extra_element) { + return new2old_part; + } else { + // This is a little perverse, using low-level interfaces to increase the + // dimension of the array; but we know it does have one more element. + // Because we normally use New2Old() with no arg (equivalent to false), + // the overloaded version of this function returns a reference for + // efficiency. + return Array1(new2old_part.Dim() + 1, + new2old_part.GetRegion(), 0); + } + } + /* Return a mapping from old index to new index. This is created on demand (must only be called after the Keep() array has been populated). @param [in] extra_element If true, will return the array of size NumOldElems() + 1, which includes one more element; otherwise it will return an array of size NumOldElems(). + + + @return Returns an array mapping the old indexes to the new indexes. This array is just the exclusive sum of Keep(). It gives the mapping for indexes that are kept; element i is kept if `Old2New()[i+1] > Old2New()[i]`. - - @return Returns an array mapping the old indexes to the new indexes. */ Array1 Old2New(bool extra_element = false) { NVTX_RANGE(K2_FUNC); diff --git a/k2/csrc/algorithms_test.cu b/k2/csrc/algorithms_test.cu index 5edc01cac..bbf310d12 100644 --- a/k2/csrc/algorithms_test.cu +++ b/k2/csrc/algorithms_test.cu @@ -45,6 +45,9 @@ TEST(AlgorithmsTest, TestRenumbering) { Array1 new2old = numbering.New2Old(); EXPECT_EQ(new2old.Dim(), 0); EXPECT_EQ(numbering.NumNewElems(), 0); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 1); + EXPECT_EQ(new2old.Back(), 0); } { @@ -67,6 +70,9 @@ TEST(AlgorithmsTest, TestRenumbering) { Array1 new2old = numbering.New2Old(); EXPECT_EQ(new2old.Dim(), 0); EXPECT_EQ(numbering.NumNewElems(), 0); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 1); + EXPECT_EQ(new2old.Back(), 5); } { @@ -93,6 +99,9 @@ TEST(AlgorithmsTest, TestRenumbering) { std::vector cpu_new2old(new2old.Data(), new2old.Data() + new2old.Dim()); EXPECT_THAT(cpu_new2old, ::testing::ElementsAre(0, 2, 3, 6)); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 5); + EXPECT_EQ(new2old.Back(), 7); } } } diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 6120428cc..808ad41e2 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -569,7 +569,7 @@ RaggedShape Index(RaggedShape &src, int32_t axis, if (axis == 0) { return IndexAxis0(src, indexes, elem_indexes); } else if (axis == src.NumAxes() - 1) { - // This code is related to SubsampleRaggedShape(). `indexes` corresponds + // This code is related to SubsetRaggedShape(). `indexes` corresponds // to `new2old`. Array1 last_row_ids = src.RowIds(num_axes - 1)[indexes]; #ifndef NDEBUG @@ -1944,21 +1944,16 @@ Ragged AddPrefixToRagged(Ragged &src, return Ragged(dst_shape, dst_values); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering) { +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering, + int32_t axis, Array1 *elems_new2old) { NVTX_RANGE(K2_FUNC); - K2_CHECK_EQ(renumbering.NumOldElems(), src.NumElements()); - - // Make sure final row-ids are populated. - src.RowIds(src.NumAxes() - 1); - std::vector axes = src.Layers(); - axes.back().row_ids = axes.back().row_ids[renumbering.New2Old()]; - axes.back().row_splits = renumbering.Old2New()[axes.back().row_splits]; - axes.back().cached_tot_size = axes.back().row_ids.Dim(); - return RaggedShape(axes); + axis = axis < 0 ? src.NumAxes() + axis : axis; + K2_CHECK_EQ(renumbering.NumOldElems(), src.TotSize(axis)); + return Index(src, axis, renumbering.New2Old(), elems_new2old); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &r_before_last, - Renumbering &r_last) { +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &r_before_last, + Renumbering &r_last) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(r_before_last.NumOldElems(), src.TotSize(src.NumAxes() - 2)); K2_CHECK_EQ(r_last.NumOldElems(), src.NumElements()); @@ -2103,7 +2098,7 @@ RaggedShape RemoveEmptyLists(RaggedShape &src_shape, int32_t axis, Renumbering r_temp; if (!renumbering_out) renumbering_out = &r_temp; bottom_shape = RemoveEmptyListsAxis0(bottom_shape, renumbering_out); - top_shape = SubsampleRaggedShape(top_shape, *renumbering_out); + top_shape = SubsetRaggedShape(top_shape, *renumbering_out); return ComposeRaggedShapes(top_shape, bottom_shape); } @@ -2117,7 +2112,7 @@ RaggedShape RemoveSomeEmptyLists(RaggedShape &src_shape, int32_t axis, DecomposeRaggedShape(src_shape, axis, &top_shape, &bottom_shape); bottom_shape = RenumberAxis0Simple(bottom_shape, renumbering); - top_shape = SubsampleRaggedShape(top_shape, renumbering); + top_shape = SubsetRaggedShape(top_shape, renumbering); return ComposeRaggedShapes(top_shape, bottom_shape); } diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 62e8b640c..83c2ca238 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -759,14 +759,32 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, int32_t max_num_elements = 2000); /* - Return ragged shape with only a subset of the bottom-level elements kept. - Require renumbering.NumOldElems() == src.NumElements(). Note: all - dimensions and tot-sizes preceding the final axis will remain the same, which - might give rise to empty lists. + Return ragged shape with only a subset of the elements or sub-lists + on the specified axis kept. (This is not regular sampling, it is + irregular subsampling with specified elements kept). + + @param [in] src The ragged shape that we are subsampling + @param [in] renumbering The renumbering object that dictates + which elements of `src` we keep; we require + renumbering.NumOldElems() == src.TotSize(axis2) + where axis2 = (axis < 0 ? src.NumAxes() + axis : axis). + @param [in] axis The axis to subsample; if negative, will be + interpreted as an offset from src.NumAxes(). + @param [out] elems_new2old If supplied, this function will + output to this location a new2old vector that + dictates how the elements of a ragged tensor + with shape `src` would be renumbered. + @return Returns the subsampled shape. All dimensions and tot-sizes + preceding the axis `axis` will remain the same, which might give + rise to empty lists on those axes; these can be removed if + necessary with RemoveEmptyLists(). Notice the other version of this function below. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); +RaggedShape SubsetRaggedShape(RaggedShape &src, + Renumbering &renumbering, + int32_t axis = -1, + Array1 *elems_new2old = nullptr); /* Return ragged shape with only a subset of the elements on the last @@ -777,9 +795,9 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); Note: all dimensions and tot-sizes preceding the last two axes will remain the same, which might give rise to empty lists. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, - Renumbering &renumbering_before_last, - Renumbering &renumbering_last); +RaggedShape SubsetRaggedShape(RaggedShape &src, + Renumbering &renumbering_before_last, + Renumbering &renumbering_last); /* Removes empty lists on a particular axis (not last axis) of a RaggedShape, @@ -866,17 +884,82 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, /* - Return ragged array with only a subset of the bottom-level elements kept. - Require renumbering.NumOldElems() == src.NumElements(). Note: all - dimensions and tot-sizes preceding the final axis will remain the same, which - might give rise to empty lists. + Return ragged array with only a subset of the elements or sub-lists + on the specified axis kept. (This is not regular sampling, it is + irregular subsampling with specified elements kept). + + @param [in] src The ragged shape that we are subsampling + @param [in] renumbering The renumbering object that dictates + which elements of `src` we keep; we require + renumbering.NumOldElems() == src.TotSize(axis2) + where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). + @param [in] axis The axis to subsample; if negative, will be + interpreted as an offset from src.NumAxes(). + @param [out] elems_new2old If supplied, this function will + output to this location a new2old array that + dictates how the elements of a ragged tensor + with shape `src` would be renumbered. + @return Returns the subsampled shape. All dimensions and tot-sizes + preceding the axis `axis` will remain the same, which might give + rise to empty lists on those axes; these can be removed if + necessary with RemoveEmptyLists(). */ template -Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering) { - return Ragged(SubsampleRaggedShape(src.shape, renumbering), - src.values[renumbering.New2Old()]); +Ragged SubsetRagged(Ragged &src, Renumbering &renumbering, + int32_t axis = -1, + Array1 *elems_new2old = nullptr) { + Array1 tmp; + if (elems_new2old == nullptr) + elems_new2old = &tmp; + RaggedShape shape = SubsetRaggedShape(src.shape, renumbering, + axis, elems_new2old); + return Ragged(shape, src.values[*elems_new2old]); } +/* + This function creates a Renumbering object that can be used to obtain subsets + of ragged arrays via SubsetRaggedShape(). It implements beam pruning as + used in pruned Viterbi search and similar algorithms, where there is both a + beam and a max-active (`max_elems`) constraint. T will probably be float or + double, interpreted as a "positive-is-better" sense, i.e. as scores. + + @param [in] src The ragged object to be subsampled. + @param [in] axis The axis to be subsampled, must satisfy + 0 <= axis < src.NumAxes(). The axis before `axis`, if axis > 0, + will be interpreted as a "batch" axis. + @param [in] beam The main pruning beam. The sub-lists of elements on axis + `axis` will be removed if their maximum element (or the element + itself, if axis + 1 == src.NumAxes()) is less than + this_best_elem - beam, where this_best_elem + is the maximum element taken over axis `axis-1` (or over the + entire array, if axis == 0). Think of axis `axis-1`, if + axis > 0, as the "batch" axis, and axis `axis` as the axis that we + actually remove elements or sub-lists on. Empty sub-lists on axis + `axis` will always be pruned, as their score would be treated + as -infinity. + @param [in] max_elems If max_elems > 0, it is the maximum number of + sub-lists or elements that are allowed within any sub-list + on axis `axis-1` (or the maximum number of top-level sub-lists + after subsampling, if axis == 0). We keep the best ones. + If max_elems <= 0, there is no such constraint. + @return Returns the renumbering object to be used to actually + prune/subsample the specified axis. + + Example: + PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 1, 5.0, 3) + would create a Renumbering object that would prune the + ragged tensor to [ [0 -1 -2], [ -10 ], [ ] ] + + PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0) + would create a Renumbering object that would prune the + ragged tensor to [ [0 -1 -2 -3] ] + */ +template +Renumbering PruneRagged(Ragged &src, + int32_t axis, + T beam, + int32_t max_elems); + /* Stack a list of Ragged arrays to create a Ragged array with one more axis. Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All @@ -974,8 +1057,7 @@ void Unstack(Ragged src, int32_t axis, std::vector> *out, /* Concatenate a list of Ragged to form a single Ragged. - @param [in] axis Axis to append them on. Currently - we only support axis == 0 or axis == 1. + @param [in] axis Axis to append them on. Previous axes must have the same shape, i.e. if axis == 1 then `src[i]->Dim0()` must all have the @@ -1368,7 +1450,7 @@ Ragged Merge(int32_t num_srcs, Ragged **src, /* Returns a ragged tensor after removing all 'values' that were <= a provided cutoff. Leaves all layers of the shape except for the last one unaffected. - Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] <= + Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] <= cutoff). */ template @@ -1377,7 +1459,7 @@ Ragged RemoveValuesLeq(Ragged &src, T cutoff); /* Returns a ragged tensor after removing all 'values' that equal a provided target. Leaves all layers of the shape except for the last one unaffected. - Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] == + Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] == target). */ template diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index bee7efb21..47297fab4 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -288,7 +288,7 @@ Ragged RemoveValuesLeq(Ragged &src, T cutoff) { K2_EVAL( c, src.NumElements(), lambda_set_keep, (int32_t i)->void { keep[i] = (char)(values_data[i] > cutoff); }); - return SubsampleRagged(src, r); + return SubsetRagged(src, r); } template @@ -301,7 +301,7 @@ Ragged RemoveValuesEq(Ragged &src, T target) { K2_EVAL( c, src.NumElements(), lambda_set_keep, (int32_t i)->void { keep[i] = (char)(values_data[i] != target); }); - return SubsampleRagged(src, r); + return SubsetRagged(src, r); } // Recursive function that prints (part of) a ragged shape. @@ -373,9 +373,9 @@ static void SortSublistsCpu(Ragged *src, Array1 *order) { int32_t cur = row_splits[i]; int32_t next = row_splits[i + 1]; if (order != nullptr) - std::sort(order->Data() + cur, order->Data() + next, lambda_comp); + std::stable_sort(order->Data() + cur, order->Data() + next, lambda_comp); - std::sort(p + cur, p + next, comp); + std::stable_sort(p + cur, p + next, comp); } } @@ -815,6 +815,137 @@ Array2 PadRagged(Ragged &src, const std::string &mode, T padding_value) { return res; } +/* Prune a two axes ragged tensor on axis0. + * This is a special case of PruneRagged with axis == 0 and src.NumAxes() == 2, + * To get more details, please refer to the docs for PruneRagged in + * ragged_ops.h. + */ +template +Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { + K2_CHECK_EQ(src.NumAxes(), 2); + const ContextPtr &c = src.Context(); + int32_t total_elements = src.TotSize(0); + Renumbering renumbering(c, total_elements); + + T negative_infinity = -std::numeric_limits::infinity(); + Array1 sub_max(c, total_elements); + MaxPerSublist(src, negative_infinity, &sub_max); + + T max_value = MaxValue(src.values); + + bool prune_with_max_elems = + max_elems > 0 && max_elems < total_elements; + + Array1 order_map; + const int32_t *order_map_data; + if (prune_with_max_elems) { + order_map = Array1(c, total_elements); + Sort>(&sub_max, &order_map); + order_map_data = order_map.Data(); + } + + char *keep_data = renumbering.Keep().Data(); + const T *sub_max_data = sub_max.Data(); + + // prune_with_max_elems means we have sorted the source ragged tensor + if (prune_with_max_elems) { + K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t i) { + bool pruned_by_beam = sub_max_data[i] < max_value - beam; + bool pruned_by_max_elems = i >= max_elems; + keep_data[order_map_data[i]] = + !(pruned_by_max_elems || pruned_by_beam); + }); + } else { + K2_EVAL(c, total_elements, lambda_set_keep, (int32_t i) { + keep_data[i] = sub_max_data[i] >= max_value - beam; + }); + } + return renumbering; +} + +/* Prune a two axes ragged tensor on axis1 + * This is a special case of PruneRagged with axis == 1 and src.NumAxes() == 2, + * To get more details, please refer to the docs for PruneRagged in + * ragged_ops.h. + */ +template +Renumbering PruneRaggedAxis1(Ragged &src, T beam, + int32_t max_elems) { + K2_CHECK_EQ(src.NumAxes(), 2); + const ContextPtr &c = src.Context(); + int32_t total_elements = src.TotSize(1); + Renumbering renumbering(c, total_elements); + + T negative_infinity = -std::numeric_limits::infinity(); + Array1 sub_max(c, src.TotSize(0)); + MaxPerSublist(src, negative_infinity, &sub_max); + + bool prune_with_max_elems = + max_elems > 0 && max_elems < total_elements; + + Array1 order_map; + const int32_t *order_map_data; + if (prune_with_max_elems) { + Ragged sorted_src = src.Clone(); + order_map = Array1(c, total_elements); + SortSublists>(&sorted_src, &order_map); + order_map_data = order_map.Data(); + } + + char *keep_data = renumbering.Keep().Data(); + const T *sub_max_data = sub_max.Data(), + *src_data = src.values.Data(); + const int32_t *row_ids1_data = src.RowIds(1).Data(), + *row_splits1_data = src.RowSplits(1).Data(); + // prune_with_max_elems means we have sorted the source ragged tensor + if (prune_with_max_elems) { + K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t idx01) { + // idx01 is the index after sorting + int32_t original_idx01 = order_map_data[idx01], + // SortSublists wouldn't chaneg idx0 & idx0x + idx0 = row_ids1_data[original_idx01], + idx0x = row_splits1_data[idx0], + // idx1 is the index after sorting + idx1 = idx01 - idx0x; + bool pruned_by_max_elems = idx1 >= max_elems, + pruned_by_beam = + src_data[original_idx01] < sub_max_data[idx0] - beam; + keep_data[original_idx01] = + !(pruned_by_max_elems || pruned_by_beam); + }); + } else { + K2_EVAL(c, total_elements, lambda_set_keep, (int32_t idx01) { + int32_t idx0 = row_ids1_data[idx01]; + keep_data[idx01] = src_data[idx01] >= sub_max_data[idx0] - beam; + }); + } + return renumbering; +} + +template +Renumbering PruneRagged(Ragged &src, int32_t axis, T beam, + int32_t max_elems) { + NVTX_RANGE(K2_FUNC); + if (axis == 0) { + auto reduced_src = src; + while (reduced_src.NumAxes() > 2) { + reduced_src = RemoveAxis(reduced_src, reduced_src.NumAxes() - 2); + } + return PruneRaggedAxis0(reduced_src, beam, max_elems); + } else if (axis == src.NumAxes() - 1) { + auto reduced_src = src; + while (reduced_src.NumAxes() > 2) { + reduced_src = RemoveAxis(reduced_src, 0); + } + return PruneRaggedAxis1(reduced_src, beam, max_elems); + } else { + RaggedShape top, bottom; + DecomposeRaggedShape(src.shape, axis, &top, &bottom); + Ragged bottom_ragged(bottom, src.values); + return PruneRagged(bottom_ragged, 0, beam, max_elems); + } +} + } // namespace k2 #endif // K2_CSRC_RAGGED_OPS_INL_H_ diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index f09d7edcd..451a2cf60 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -360,11 +360,6 @@ TEST(RaggedShapeTest, RemoveEmptyLists) { } } - - - - - TEST(RaggedShapeTest, RaggedShapeIterator) { // note RaggedShapeIndexIterator works only for CPU ContextPtr context = GetCpuContext(); @@ -425,4 +420,58 @@ TEST(RaggedShapeTest, RandomRaggedShape) { } } +TEST(RaggedShapeTest, SubsetRaggedShape) { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + RaggedShape src(c, + "[ [ [ x x ] [ x ] ] [ [ x x x ] [ ] ] [ [ x ] [ x x ] ] ]"); + // axis = 2 + Array1 ref_keep(c, std::vector({1, 0, 1, 0, 0, 1, 0, 1, 0})); + Renumbering renumbering(c, src.NumElements()); + auto &keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + Array1 new2old; + auto dest = SubsetRaggedShape(src, renumbering, 2, &new2old); + Array1 ref_new2old(c, std::vector({0, 2, 5, 7})); + RaggedShape ref_dest(c, + "[ [ [ x ] [ x ] ] [ [ x ] [ ] ] [ [ ] [ x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -1 + dest = SubsetRaggedShape(src, renumbering, -1, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + + // axis = 1 + ref_keep = Array1(c, std::vector({1, 0, 0, 1, 1, 1})); + renumbering = Renumbering(c, src.TotSize(1)); + keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + dest = SubsetRaggedShape(src, renumbering, 1, &new2old); + ref_new2old = Array1(c, std::vector({0, 1, 6, 7, 8})); + ref_dest = RaggedShape(c, "[ [ [ x x ] ] [ [ ] ] [ [ x ] [ x x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -2 + dest = SubsetRaggedShape(src, renumbering, -2, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + + // axis = 0 + ref_keep = Array1(c, std::vector({1, 0, 1})); + renumbering = Renumbering(c, src.TotSize(0)); + keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + dest = SubsetRaggedShape(src, renumbering, 0, &new2old); + ref_new2old = Array1(c, + std::vector({0, 1, 2, 6, 7, 8})); + ref_dest = RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -3 + dest = SubsetRaggedShape(src, renumbering, -3, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + } +} + } // namespace k2 diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index 7fe9c5fae..801e51f5c 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -41,22 +41,25 @@ namespace k2 { - TEST(RaggedShapeOpsTest, CatMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape1 = RaggedShape(c, "[ [ [ [ x x ] ] [ [ x ] ] ]" - " [ [ [ x ] ] ] ]"), - shape2 = RaggedShape(c, "[ [ [ [ x ] ] [ [ x ] ] ]" - " [ [ [ x x ] ] ] ]"), - shape3 = RaggedShape(c, "[ [ [ [ ] ] [ [ x ] ] ]" - " [ [ [ ] ] ] ]"); + RaggedShape shape1 = RaggedShape(c, + "[ [ [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x ] ] ] ]"), + shape2 = RaggedShape(c, + "[ [ [ [ x ] ] [ [ x ] ] ]" + " [ [ [ x x ] ] ] ]"), + shape3 = RaggedShape(c, + "[ [ [ [ ] ] [ [ x ] ] ]" + " [ [ [ ] ] ] ]"); RaggedShape cat_axis2_ref = - RaggedShape(c, "[ [ [ [ x x ] [ x ] [ ] ] [ [ x ] [ x ] [ x ] ] ]" - " [ [ [ x ] [ x x ] [ ] ] ] ]"); - RaggedShape cat_axis3_ref = - RaggedShape(c, "[ [ [ [ x x x ] ] [ [ x x x ] ] ]" - " [ [ [ x x x ] ] ] ]"); + RaggedShape(c, + "[ [ [ [ x x ] [ x ] [ ] ] [ [ x ] [ x ] [ x ] ] ]" + " [ [ [ x ] [ x x ] [ ] ] ] ]"); + RaggedShape cat_axis3_ref = RaggedShape(c, + "[ [ [ [ x x x ] ] [ [ x x x ] ] ]" + " [ [ [ x x x ] ] ] ]"); RaggedShape *srcs[] = {&shape1, &shape2, &shape3}; Array1 merge_map2; Array1 merge_map3; @@ -74,21 +77,25 @@ TEST(RaggedShapeOpsTest, CatMoreAxes) { TEST(RaggedShapeOpsTest, StackMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape1 = RaggedShape(c, "[ [ [ [ x x ] ] [ [ x ] ] ]" - " [ [ [ x ] ] ] ]"), - shape2 = RaggedShape(c, "[ [ [ [ x ] ] [ [ x ] ] ]" - " [ [ [ x x ] ] ] ]"), - shape3 = RaggedShape(c, "[ [ [ [ ] ] [ [ x ] ] ]" - " [ [ [ ] ] ] ]"); + RaggedShape shape1 = RaggedShape(c, + "[ [ [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x ] ] ] ]"), + shape2 = RaggedShape(c, + "[ [ [ [ x ] ] [ [ x ] ] ]" + " [ [ [ x x ] ] ] ]"), + shape3 = RaggedShape(c, + "[ [ [ [ ] ] [ [ x ] ] ]" + " [ [ [ ] ] ] ]"); RaggedShape stacked2_ref = - RaggedShape(c, "[ [ [ [ [ x x ] ] [ [ x ] ] [ [ ] ] ]" - " [ [ [ x ] ] [ [ x ] ] [ [ x ] ] ] ]" - " [ [ [ [ x ] ] [ [ x x ] ] [ [ ] ] ] ] ]"); - RaggedShape stacked3_ref = - RaggedShape(c, "[ [ [ [ [ x x ] [ x ] [ ] ] ]" - " [ [ [ x ] [ x ] [ x ] ] ] ]" - " [ [ [ [ x ] [ x x ] [ ] ] ] ] ]"); + RaggedShape(c, + "[ [ [ [ [ x x ] ] [ [ x ] ] [ [ ] ] ]" + " [ [ [ x ] ] [ [ x ] ] [ [ x ] ] ] ]" + " [ [ [ [ x ] ] [ [ x x ] ] [ [ ] ] ] ] ]"); + RaggedShape stacked3_ref = RaggedShape(c, + "[ [ [ [ [ x x ] [ x ] [ ] ] ]" + " [ [ [ x ] [ x ] [ x ] ] ] ]" + " [ [ [ [ x ] [ x x ] [ ] ] ] ] ]"); RaggedShape *srcs[] = {&shape1, &shape2, &shape3}; Array1 merge_map2; Array1 merge_map3; @@ -113,18 +120,13 @@ TEST(RaggedShapeOpsTest, Unstack2Axes) { // axis = 0 Unstack(shape, 0, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ x x ] ]"))); - K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ x x x ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{2, 3, 4}))); - K2_CHECK(Equal(out[2], - RaggedShape(c, "[ [ x ] ]"))); - K2_CHECK(Equal(out_map[2], - Array1(c, std::vector{5}))); + K2_CHECK(Equal(out[0], RaggedShape(c, "[ [ x x ] ]"))); + K2_CHECK(Equal(out_map[0], Array1(c, std::vector{0, 1}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ x x x ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{2, 3, 4}))); + K2_CHECK(Equal(out[2], RaggedShape(c, "[ [ x ] ]"))); + K2_CHECK(Equal(out_map[2], Array1(c, std::vector{5}))); std::vector out_ptr; out_ptr.clear(); @@ -135,18 +137,13 @@ TEST(RaggedShapeOpsTest, Unstack2Axes) { // axis = 1 Unstack(shape, 1, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ x x x ] ]"))); - K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 2, 5}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ x x ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{1, 3}))); - K2_CHECK(Equal(out[2], - RaggedShape(c, "[ [ x ] ]"))); - K2_CHECK(Equal(out_map[2], - Array1(c, std::vector{4}))); + K2_CHECK(Equal(out[0], RaggedShape(c, "[ [ x x x ] ]"))); + K2_CHECK( + Equal(out_map[0], Array1(c, std::vector{0, 2, 5}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ x x ] ]"))); + K2_CHECK(Equal(out_map[1], Array1(c, std::vector{1, 3}))); + K2_CHECK(Equal(out[2], RaggedShape(c, "[ [ x ] ]"))); + K2_CHECK(Equal(out_map[2], Array1(c, std::vector{4}))); // can not test Stack here, because the element numbers of axis 1 is not // the same } @@ -154,21 +151,21 @@ TEST(RaggedShapeOpsTest, Unstack2Axes) { TEST(RaggedShapeOpsTest, Unstack) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape(c, "[ [ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]" - " [ [ [ x x x ] ] ] ]"); + RaggedShape shape(c, + "[ [ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]" + " [ [ [ x x x ] ] ] ]"); std::vector out; std::vector> out_map; Unstack(shape, 0, &out, &out_map); // axis = 0 K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"))); + RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1, 2, 3, 4, 5}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x x x ] ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{6, 7, 8}))); + Array1(c, std::vector{0, 1, 2, 3, 4, 5}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x x x ] ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{6, 7, 8}))); std::vector out_ptr; for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); @@ -177,14 +174,13 @@ TEST(RaggedShapeOpsTest, Unstack) { // axis = 1 Unstack(shape, 1, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); + K2_CHECK( + Equal(out[0], RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1, 2, 6, 7, 8}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x ] [ x x ] ] [ ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{3, 4, 5}))); + Array1(c, std::vector{0, 1, 2, 6, 7, 8}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x ] [ x x ] ] [ ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{3, 4, 5}))); out_ptr.clear(); for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); @@ -194,14 +190,13 @@ TEST(RaggedShapeOpsTest, Unstack) { // axis = 2 Unstack(shape, 2, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); + K2_CHECK( + Equal(out[0], RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1, 3, 6, 7, 8}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x ] [ x x ] ] [ [ ] ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{2, 4, 5}))); + Array1(c, std::vector{0, 1, 3, 6, 7, 8}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x ] [ x x ] ] [ [ ] ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{2, 4, 5}))); out_ptr.clear(); for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); @@ -211,18 +206,15 @@ TEST(RaggedShapeOpsTest, Unstack) { // axis = 3 Unstack(shape, 3, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x x ] ] [ [ x ] ] ]"))); + K2_CHECK( + Equal(out[0], RaggedShape(c, "[ [ [ x x ] [ x x ] ] [ [ x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 2, 3, 4, 6}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x ] [ x ] ] [ [ x ] ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{1, 5, 7}))); - K2_CHECK(Equal(out[2], - RaggedShape(c, "[ [ [ ] [ ] ] [ [ x ] ] ]"))); - K2_CHECK(Equal(out_map[2], - Array1(c, std::vector{8}))); + Array1(c, std::vector{0, 2, 3, 4, 6}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x ] [ x ] ] [ [ x ] ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{1, 5, 7}))); + K2_CHECK(Equal(out[2], RaggedShape(c, "[ [ [ ] [ ] ] [ [ x ] ] ]"))); + K2_CHECK(Equal(out_map[2], Array1(c, std::vector{8}))); // can not test Stack here, because the element numbers of axis 3 is not // the same } @@ -230,12 +222,13 @@ TEST(RaggedShapeOpsTest, Unstack) { TEST(RaggedShapeOpsTest, UnstackMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape(c, "[ [ [ [ [ x ] [ ] ] [ [ x x x ] ] ] ]" - " [ [ [ [ x x x ] ] [ [ x x ] ] [ [ x ] ] ]" - " [ [ [ x x ] [ x ] [ ] [ x ] ] ]" - " [ [ [ x ] ] [ [ x ] [ x x x x ] ] ] ]" - " [ [ [ [ x ] ] [ ] ]" - " [ [ [ x x ] ] ] ] ]"); + RaggedShape shape(c, + "[ [ [ [ [ x ] [ ] ] [ [ x x x ] ] ] ]" + " [ [ [ [ x x x ] ] [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x x ] [ x ] [ ] [ x ] ] ]" + " [ [ [ x ] ] [ [ x ] [ x x x x ] ] ] ]" + " [ [ [ [ x ] ] [ ] ]" + " [ [ [ x x ] ] ] ] ]"); std::vector out; std::vector> out_map; @@ -254,11 +247,11 @@ TEST(RaggedShapeOpsTest, UnstackMoreAxes) { } TEST(RaggedShapeOpsTest, UnstackRandom) { - RaggedShape random_shape_ = RandomRaggedShape(true, // set_row_ids - 5, // min_num_axes - 5, // max_num_axes - 1, // min_num_elements - 100); // max_num_elements + RaggedShape random_shape_ = RandomRaggedShape(true, // set_row_ids + 5, // min_num_axes + 5, // max_num_axes + 1, // min_num_elements + 100); // max_num_elements for (auto &c : {GetCpuContext(), GetCudaContext()}) { auto random_shape0 = random_shape_.To(c); std::vector out; @@ -1733,9 +1726,13 @@ TEST(RaggedShapeOpsTest, TestIndex) { TEST(RaggedShapeOpsTest, TestIndexAxis1) { for (auto &context : {GetCpuContext(), GetCudaContext()}) { { - Ragged input = Ragged(" [ [ 1 2 ] [ 3 4 5 ] [ 6 7 ] [ ] ]").To(context); // NOLINT + Ragged input = + Ragged(" [ [ 1 2 ] [ 3 4 5 ] [ 6 7 ] [ ] ]") + .To(context); // NOLINT Array1 indexes = Array1(" [ 1 0 4 2 6 5 ]").To(context); - Ragged output = Ragged(" [ [ 2 1 ] [ 5 3 ] [ 7 6 ] [ ] ]").To(context); // NOLINT + Ragged output = + Ragged(" [ [ 2 1 ] [ 5 3 ] [ 7 6 ] [ ] ]") + .To(context); // NOLINT Ragged indexed = Index(input, 1, indexes); EXPECT_EQ(Equal(output, indexed), true); @@ -1743,8 +1740,6 @@ TEST(RaggedShapeOpsTest, TestIndexAxis1) { } } - - TEST(GetTransposeReordering, NoDuplicates) { // col0 col1 col2 col3 col4 col5 // row0 a0 b1 @@ -2241,41 +2236,39 @@ void TestUnstackRagged() { K2_CHECK(Equal(out[2], Ragged(c, "[ [ 50 ] ]"))); // more axes - ragged = Ragged(c, "[ [ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ]" - " [ [ 41 ] [ 51 ] ] ]" - " [ [ [ 61 62 63 ] ] ] ]"); + ragged = Ragged(c, + "[ [ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ]" + " [ [ 41 ] [ 51 ] ] ]" + " [ [ [ 61 62 63 ] ] ] ]"); // axis = 0 Unstack(ragged, 0, &out); - K2_CHECK(Equal(out[0], Ragged(c, - "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 41 ] [ 51 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal( + out[0], + Ragged(c, + "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 41 ] [ 51 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 61 62 63 ] ] ]"))); // axis = 1 Unstack(ragged, 1, &out); - K2_CHECK(Equal(out[0], Ragged(c, - "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 61 62 63 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 41 ] [ 51 ] ] [ ] ]"))); + K2_CHECK(Equal( + out[0], + Ragged(c, "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 41 ] [ 51 ] ] [ ] ]"))); // axis = 2 Unstack(ragged, 2, &out); - K2_CHECK(Equal(out[0], - Ragged(c, "[ [ [ 1 11 21 ] [ 41 ] ] [ [ 61 62 63 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 21 22 ] [ 51 ] ] [ [ ] ] ]"))); - K2_CHECK(Equal(out[2], - Ragged(c, "[ [ [ 31 ] [ ] ] [ [ ] ] ]"))); + K2_CHECK(Equal( + out[0], Ragged(c, "[ [ [ 1 11 21 ] [ 41 ] ] [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 21 22 ] [ 51 ] ] [ [ ] ] ]"))); + K2_CHECK(Equal(out[2], Ragged(c, "[ [ [ 31 ] [ ] ] [ [ ] ] ]"))); // axis = 3 Unstack(ragged, 3, &out); K2_CHECK(Equal(out[0], - Ragged(c, "[ [ [ 1 21 31 ] [ 41 51 ] ] [ [ 61 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 11 22 ] [ ] ] [ [ 62 ] ] ]"))); - K2_CHECK(Equal(out[2], - Ragged(c, "[ [ [ 21 ] [ ] ] [ [ 63 ] ] ]"))); + Ragged(c, "[ [ [ 1 21 31 ] [ 41 51 ] ] [ [ 61 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 11 22 ] [ ] ] [ [ 62 ] ] ]"))); + K2_CHECK(Equal(out[2], Ragged(c, "[ [ [ 21 ] [ ] ] [ [ 63 ] ] ]"))); } } @@ -2916,8 +2909,6 @@ TEST(RaggedOpsTest, TestComputeHash) { } } - - TEST(RaggedOpsTest, TestUniqueSequences) { for (int32_t i = 0; i < 20; i++) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { @@ -2932,7 +2923,7 @@ TEST(RaggedOpsTest, TestUniqueSequences) { ContextPtr cpu = GetCpuContext(); Array1 hash_src = ComputeHash(src).To(cpu), - hash_unique = ComputeHash(unique).To(cpu); + hash_unique = ComputeHash(unique).To(cpu); RaggedShape src_hash_shape = RemoveAxis(src.shape, src.NumAxes() - 1).To(cpu); @@ -2946,9 +2937,10 @@ TEST(RaggedOpsTest, TestUniqueSequences) { K2_CHECK_EQ(src_hash_shape.Dim0(), unique_hash_shape.Dim0()); const int32_t *src_hash_row_splits = src_hash_shape.RowSplits(1).Data(), - *unique_hash_row_splits = unique_hash_shape.RowSplits(1).Data(); + *unique_hash_row_splits = + unique_hash_shape.RowSplits(1).Data(); const int32_t *src_hash_data = hash_src.Data(), - *unique_hash_data = hash_unique.Data(); + *unique_hash_data = hash_unique.Data(); for (int32_t r = 0; r < src_hash_shape.Dim0(); r++) { int32_t src_begin = src_hash_row_splits[r], @@ -2979,7 +2971,6 @@ TEST(RaggedIntTest, TestCreateRagged2Int) { K2_CHECK(Equal(r, r2)); } - TEST(RaggedFloatTest, TestCreateRagged2Float) { std::vector> vecs{{1.2, 2.3}, {}, {3.4, 5.6}}; std::vector expected_values{1.2, 2.3, 3.4, 5.6}; @@ -3004,10 +2995,8 @@ static void TestPadRagged() { T padding_value = 0; Array2 res = PadRagged(src, "constant", padding_value); Array1 dst = res.Flatten(); - std::vector expected = {1, 2, 0, 0, - 3, 4, 3, 0, - 0, 0, 0, 0, - 5, 6, 7, 8}; + std::vector expected = {1, 2, 0, 0, 3, 4, 3, 0, + 0, 0, 0, 0, 5, 6, 7, 8}; CheckArrayData(dst, expected); } { @@ -3015,10 +3004,8 @@ static void TestPadRagged() { T padding_value = -1; Array2 res = PadRagged(src, "constant", padding_value); Array1 dst = res.Flatten(); - std::vector expected = {1, 2, -1, -1, - 3, 4, 3, -1, - -1, -1, -1, -1, - 5, 6, 7, 8}; + std::vector expected = {1, 2, -1, -1, 3, 4, 3, -1, + -1, -1, -1, -1, 5, 6, 7, 8}; CheckArrayData(dst, expected); } { @@ -3026,10 +3013,8 @@ static void TestPadRagged() { T padding_value = 100; Array2 res = PadRagged(src, "replicate", padding_value); Array1 dst = res.Flatten(); - std::vector expected = {1, 2, 2, 2, - 3, 4, 3, 3, - 100, 100, 100, 100, - 5, 6, 7, 8}; + std::vector expected = {1, 2, 2, 2, 3, 4, 3, 3, + 100, 100, 100, 100, 5, 6, 7, 8}; CheckArrayData(dst, expected); } } @@ -3039,4 +3024,125 @@ TEST(RaggedTest, TestPadRagged) { TestPadRagged(); TestPadRagged(); } + +template +static void TestPruneRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + Ragged src(c, + "[ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] " + " [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] " + " [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ]"); + + T beam = 2.0; + auto renumbering = PruneRagged(src, 0, beam, 2); + // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // no sublist is pruned by beam, 5.0 is pruned by max-elems + // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] + // [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] + Array1 keep_ref(c, std::vector{1, 1, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 0.1; + renumbering = PruneRagged(src, 0, beam, 3); + // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // 6.1 & 5.0 are pruned by beam + // keep : [ [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] + keep_ref = Array1(c, std::vector{0, 1, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 2.0; + renumbering = PruneRagged(src, 1, beam, 5); + // best_score=6.3, best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.2 & -inf are pruned by beam, 4.4 is pruned by max-elems. + // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 2.2 6.3 ] ] + // [ [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, std::vector{1, 1, 1, 0, 1, 0, 0, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 1.0; + renumbering = PruneRagged(src, 1, beam, 5); + // best_score=6.3, best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // all sublists are pruned by beam, except 6.1 & 6.3 + // keep : [ [ [ 6.1 ] ] [ [ 2.2 6.3 ] ] ] + keep_ref = Array1(c, std::vector{0, 0, 1, 0, 1, 0, 0, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 4.0; + renumbering = PruneRagged(src, 2, beam, 3); + // best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.1, 1.0, 2.2 are pruned by beam. + // keep : [ [ [ 2.1 5.2 ] [ 5.1 ] [ 6.1 ] ] [ [ 1.2 ] [ 6.3 ] ] + // [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ] + keep_ref = Array1( + c, std::vector{0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 5.0; + renumbering = PruneRagged(src, 2, beam, 2); + // best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.1 is pruned by max-elems. + // keep : [ [ [ 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 1.2 ] [ 2.2 6.3 ] ] + // [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ] + keep_ref = Array1( + c, std::vector{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + } +} + +TEST(RaggedTest, TestPruneRagged) { + TestPruneRagged(); + TestPruneRagged(); +} + +template +static void TestPruneRaggedAndSubsetRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + Ragged src(c, + "[ [ [ 1.1 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] " + " [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " + " [ [ 1.3 4.4 ] [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + T beam = 1.0; + auto renumbering = PruneRagged(src, 0, beam, 3); + Array1 new2old; + auto dest = SubsetRagged(src, renumbering, 0, &new2old); + Ragged dest_ref(c, "[ [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] ]"); + Array1 new2old_ref(c, std::vector{6, 7, 8, 9, 10, 11}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + + beam = 2.0; + renumbering = PruneRagged(src, 1, beam, 5); + dest = SubsetRagged(src, renumbering, 1, &new2old); + dest_ref = + Ragged(c, + "[ [ [ 5.0 3.1 ] ] [ [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " + " [ [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + new2old_ref = Array1( + c, std::vector{4, 5, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + + beam = 3.0; + renumbering = PruneRagged(src, 2, beam, 3); + dest = SubsetRagged(src, renumbering, 2, &new2old); + dest_ref = Ragged( + c, + "[ [ [ 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] [ [ 1.2 ] [ 6.3 ] [ 6.1 ] [ 5.1 ] ]" + " [ [ 4.4 ] [ 2.3 5.2 3.6 ] ] ]"); + new2old_ref = Array1( + c, std::vector{1, 2, 3, 4, 5, 6, 8, 10, 11, 13, 16, 17, 18}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + } +} + +TEST(RaggedTest, TestPruneRaggedAndSubsetRagged) { + TestPruneRaggedAndSubsetRagged(); + TestPruneRaggedAndSubsetRagged(); +} + } // namespace k2 diff --git a/k2/csrc/rm_epsilon.cu b/k2/csrc/rm_epsilon.cu index b8774bf53..c20775c1a 100644 --- a/k2/csrc/rm_epsilon.cu +++ b/k2/csrc/rm_epsilon.cu @@ -795,7 +795,7 @@ void ComputeEpsilonClosure(FsaVec &epsilon_fsa, FsaVec *closure_fsa, const Arc &cur_arc = arcs_data[arc_idx012]; arc_keep_data[arc_idx012] = (cur_arc.src_state != cur_arc.dest_state); }); - *closure_fsa = SubsampleRagged(*closure_fsa, arc_renumbering); + *closure_fsa = SubsetRagged(*closure_fsa, arc_renumbering); *arc_map = Index(*arc_map, 0, arc_renumbering.New2Old()); } @@ -1081,7 +1081,7 @@ void RemoveEpsilonDevice(FsaOrVec &src_fsa, FsaOrVec *dest_fsa, non_epsilon_arc_map, foll_shape, &combined_foll, &combined_foll_arc_map); FsaVec epsilon_closure_prec = - SubsampleRagged(epsilon_closure_mapped, epsilon_prec_renumbering); + SubsetRagged(epsilon_closure_mapped, epsilon_prec_renumbering); Ragged epsilon_closure_prec_arc_map = Index( epsilon_closure_mapped_arc_map, 0, epsilon_prec_renumbering.New2Old()); // `combined_prec` will be set to an FSA, with the same state numbering as From 3cc74f19ccc045a7176a8a96a4b24c4af76e9653 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 22 Feb 2022 12:24:04 +0800 Subject: [PATCH 44/64] Add Hash64 (#895) * Add hash64 * Fix tests * Resize hash64 * Fix comments * fix typo --- k2/csrc/hash.h | 347 ++++++++++++++++++++++++++++++++++++++++++- k2/csrc/hash_test.cu | 98 +++++++++++- 2 files changed, 442 insertions(+), 3 deletions(-) diff --git a/k2/csrc/hash.h b/k2/csrc/hash.h index ae651c6de..5be08b574 100644 --- a/k2/csrc/hash.h +++ b/k2/csrc/hash.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey) + * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey, Wei kang) * * See LICENSE for clarification regarding multiple authors * @@ -1014,6 +1014,350 @@ class Hash { }; +/* + How class Hash64 works: + + - It can function as a map from key=uint64_t to value=uint64_t, you must + decide the number of buckets, when you create the hash, but you can resize + it (manually). + + Note: + Each bucket contains a pair of key/value, each 64bits, key is stored at + data[2 * bucket_index] and value is stored at data[2 * bucket_index + 1]. + + Some constraints: + - You can store any (key,value) pair, except the pair where all the bits of + both key and value are set [that is used to mean "nothing here"] + - The number of buckets must always be a power of 2. + - When deleting values from the hash you must delete them all at + once (necessary because there is no concept of a "tombstone". + + Some notes on usage: + + You use it by: constructing it, obtaining its Accessor with GetAccessor(); + and inside kernels (or host code), calling functions Insert(), Find() or + Delete() of the Accessor object. Resizing is not automatic; it is the + user's responsibility to make sure the hash does not get too full + (which could cause assertion failures in kernels, and will be very slow). + + Some implementation notes: + - When accessing hash[key], we use bucket_index == key % num_buckets, + bucket_inc = 1 | (((key * 2) / num_buckets) ^ key). + - If the bucket at `bucket_index` is occupied, we look in locations + `(bucket_index + n * bucket_inc)%num_buckets` for n = 1, 2, ...; + this choice ensures that if multiple keys hash to the same bucket, + they don't all access the same sequence of locations; and bucket_inc + being odd ensures we eventually try all locations (of course for + reasonable hash occupancy levels, we shouldn't ever have to try + more than two or three). + +*/ +class Hash64 { + public: + /* Constructor. Context can be for CPU or GPU. + + @param [in] num_buckets Number of buckets in the hash; must be + a power of 2 and >= 128 (this limit was arbitrarily chosen). + The number of items in the hash cannot exceed the number of + buckets, or the code will loop infinitely when you try to add + items; aim for less than 50% occupancy. + */ + Hash64(ContextPtr c, int64_t num_buckets) { + K2_CHECK_GE(num_buckets, 128); + data_ = Array1(c, num_buckets * 2, ~(uint64_t)0); + int64_t n = 2; + for (buckets_num_bitsm1_ = 0; n < num_buckets; + n *= 2, buckets_num_bitsm1_++) { + } + K2_CHECK_EQ(num_buckets, 2 << buckets_num_bitsm1_) + << " num_buckets must be a power of 2."; + } + + // Only to be used prior to assignment. + Hash64() = default; + + int64_t NumBuckets() const { return data_.Dim() / 2; } + + // Returns data pointer; for testing.. + uint64_t *Data() { return data_.Data(); } + + // Shallow copy + Hash64 &operator=(const Hash64 &src) = default; + // Copy constructor (shallow copy) + explicit Hash64(const Hash64 &src) = default; + + ContextPtr &Context() const { return data_.Context(); } + + class Accessor { + public: + Accessor(Hash64 &hash) + : data_(hash.data_.Data()), + num_buckets_mask_(uint64_t(hash.NumBuckets()) - 1), + buckets_num_bitsm1_(hash.buckets_num_bitsm1_) {} + + // Copy constructor + Accessor(const Accessor &src) = default; + + /* + Try to insert pair (key,value) into hash. + @param [in] key Key into hash, it is an error if ~key == 0, i.e. if all + the allowed bits of `key` are set. + @param [in] value Value to set, it is an error if ~value == 0, i.e. if + all the allowed bits `value` are set. + @param [out] old_value If not nullptr, this location will be set to + the existing value *if this key was already present* in the + hash (or set by another thread in this kernel), i.e. only + if this function returns false. + @param [out] key_value_location If not nullptr, its contents will be + set to the address of the (key,value) pair (either the + existing or newly-written one). + @return Returns true if this (key,value) pair was inserted, false + otherwise. + + Note: the const is with respect to the metadata only; it is required, to + avoid compilation errors. + */ + __forceinline__ __host__ __device__ bool Insert( + uint64_t key, uint64_t value, uint64_t *old_value = nullptr, + uint64_t **key_value_location = nullptr) const { + uint64_t cur_bucket = key & num_buckets_mask_, + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); + + while (1) { + uint64_t cur_key = data_[2 * cur_bucket]; + uint64_t cur_value = data_[2 * cur_bucket + 1]; + if (cur_key == key) { + if (old_value) *old_value = cur_value; + if (key_value_location) *key_value_location = data_ + 2 * cur_bucket; + return false; // key exists in hash + } else if (~cur_key == 0) { + // we have a version of AtomicCAS that also works on host. + uint64_t old_key = AtomicCAS( + (unsigned long long *)(data_ + 2 * cur_bucket), cur_key, key); + if (old_key == cur_key) { + // set value + data_[2 * cur_bucket + 1] = value; + if (key_value_location) + *key_value_location = data_ + 2 * cur_bucket; + return true; // Successfully inserted. + } + if (old_key == key) { + if (old_value) *old_value = cur_value; + if (key_value_location) + *key_value_location = data_ + 2 * cur_bucket; + return false; // Another thread inserted this key + } + } + // Rotate bucket index until we find a free location. This will + // eventually visit all bucket indexes before it returns to the same + // location, because bucket_inc is odd (so only satisfies + // (n * bucket_inc) % num_buckets == 0 for n == num_buckets). + // Note: n here is the number of times we went around the loop. + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; + } + } + + /* + Look up this key in this hash; output the value and optionally the + location of the (key,value) pair if found. + + @param [in] key Key to look up; + @param [out] value_out If found, value will be written to here. This may + seem redundant with key_value_location, but this should + compile to a local variable, and we want to avoid + redundant memory reads. + @param [out] key_value_location (optional) The memory address of the + (key,value) pair, in case the caller wants to overwrite + the value via SetValue(); must be used for no other + purpose. + @return Returns true if an item with this key was found in the + hash, otherwise false. + + Note: the const is with respect to the metadata only; it is required, to + avoid compilation errors. + */ + __forceinline__ __host__ __device__ bool Find( + uint64_t key, uint64_t *value_out, + uint64_t **key_value_location = nullptr) const { + uint64_t cur_bucket = key & num_buckets_mask_, + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); + while (1) { + uint64_t old_key = data_[2 * cur_bucket]; + uint64_t old_value = data_[2 * cur_bucket + 1]; + if (~old_key == 0) { + return false; + } else if (old_key == key) { + while (~old_value == 0) old_value = data_[2 * cur_bucket + 1]; + *value_out = old_value; + if (key_value_location) *key_value_location = data_ + 2 * cur_bucket; + return true; + } else { + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; + } + } + } + + /* + Overwrite a value in a (key,value) pair whose location was obtained using + Find(). + @param [in] key_value_location Location that was obtained from + a successful call to Find(). + @param [in] value Value to write; + + Note: the const is with respect to the metadata only; it is required, to + avoid compilation errors. + */ + __forceinline__ __host__ __device__ void SetValue( + uint64_t *key_value_location, uint64_t value) const { + *(key_value_location + 1) = value; + } + + /* Deletes a key from a hash. Caution: this cannot be combined with other + operations on a hash; after you delete a key you cannot do Insert() or + Find() until you have deleted all keys. This is an open-addressing hash + table with no tombstones, which is why this limitation exists). + + @param [in] key Key to be deleted. Each key present in the hash must + be deleted by exactly one thread, or it will loop + forever! + + Note: the const is with respect to the metadata only; required, to avoid + compilation errors. + */ + __forceinline__ __host__ __device__ void Delete(uint64_t key) const { + uint64_t cur_bucket = key & num_buckets_mask_, + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); + while (1) { + uint64_t old_key = data_[2 * cur_bucket]; + if (old_key == key) { + data_[2 * cur_bucket] = ~((uint64_t)0); + data_[2 * cur_bucket + 1] = ~((uint64_t)0); + return; + } else { + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; + } + } + } + + private: + // pointer to data + uint64_t *data_; + // num_buckets_mask is num_buckets (i.e. size of `data_` array) minus one; + // num_buckets is a power of 2 so this can be used as a mask to get a number + // modulo num_buckets. + uint64_t num_buckets_mask_; + // A number satisfying num_buckets == 1 << (1+buckets_num_bitsm1_) + // the number of bits in `num_buckets` minus one. + uint64_t buckets_num_bitsm1_; + }; + + /* + Return an Accessor object which can be used in kernel code (or on CPU if the + context is a CPU context). + */ + Accessor GetAccessor() { return Accessor(*this); } + + // You should call this before the destructor is called if the hash will still + // contain values when it is destroyed, to bypass a check. + void Destroy() { data_ = Array1(); } + + void CheckEmpty() const { + if (data_.Dim() == 0) return; + ContextPtr c = Context(); + Array1 error(c, 1, -1); + int64_t *error_data = error.Data(); + const uint64_t *hash_data = data_.Data(); + + K2_EVAL( + Context(), data_.Dim(), lambda_check_data, (int64_t i)->void { + if (~(hash_data[i]) != 0) error_data[0] = i; + }); + int64_t i = error[0]; + if (i >= 0) { // there was an error; i is the index into the hash where + // there was an element. + int64_t elem = data_[i]; + // We don't know the number of bits the user was using for the key vs. + // value, so print in hex, maybe they can figure it out. + K2_LOG(FATAL) << "Destroying hash: still contains values: position " << i + << ", content = " << std::hex << elem; + } + } + + /* Resize the hash to a new number of buckets. + + @param [in] new_num_buckets New number of buckets; must be a power of 2, + and must be large enough to accommodate all values in the hash + (we assume the caller is keeping track of the number of elements + in the hash somehow). + + CAUTION: Resizing will invalidate any accessor objects you have; you need + to re-get the accessors before accessing the hash again. + */ + void Resize(int64_t new_num_buckets, bool copy_data = true) { + NVTX_RANGE(K2_FUNC); + + K2_CHECK_GT(new_num_buckets, 0); + K2_CHECK_EQ(new_num_buckets & (new_num_buckets - 1), 0); // power of 2. + + ContextPtr c = data_.Context(); + Hash64 new_hash(c, new_num_buckets); + + if (copy_data) { + new_hash.CopyDataFromSimple(*this); + } + + *this = new_hash; + new_hash.Destroy(); // avoid failed check in destructor (it would otherwise + // expect the hash to be empty when destroyed). + } + + /* + Copies all data elements from `src` to `*this`. + */ + void CopyDataFromSimple(Hash64 &src) { + NVTX_RANGE(K2_FUNC); + int64_t num_buckets = data_.Dim() / 2, + src_num_buckets = src.data_.Dim() / 2; + const uint64_t *src_data = src.data_.Data(); + uint64_t *data = data_.Data(); + uint64_t new_num_buckets_mask = static_cast(num_buckets) - 1, + new_buckets_num_bitsm1 = buckets_num_bitsm1_; + ContextPtr c = data_.Context(); + K2_EVAL(c, src_num_buckets, lambda_copy_data, (uint64_t i) -> void { + uint64_t key = src_data[2 * i]; + uint64_t value = src_data[2 * i + 1]; + if (~key == 0) return; // equals -1.. nothing there. + uint64_t bucket_inc = 1 | ((key >> new_buckets_num_bitsm1) ^ key); + uint64_t cur_bucket = key & new_num_buckets_mask; + while (1) { + uint64_t assumed = ~((uint64_t)0), + old_elem = AtomicCAS((unsigned long long*)(data + 2 * cur_bucket), + assumed, key); + if (old_elem == assumed) { + *(data + 2 * cur_bucket + 1) = value; + return; + } + cur_bucket = (cur_bucket + bucket_inc) & new_num_buckets_mask; + // Keep iterating until we find a free spot in the new hash... + } + }); + } + + // The destructor checks that the hash is empty, if we are in debug mode. + // If you don't want this, call Destroy() before the destructor is called. + ~Hash64() { +#ifndef NDEBUG + if (data_.Dim() != 0) CheckEmpty(); +#endif + } + + private: + Array1 data_; + + // number satisfying data_.Dim() == 1 << (1+buckets_num_bitsm1_) + uint64_t buckets_num_bitsm1_; +}; + /* Returns the number of bits needed for an unsigned integer sufficient to store the nonnegative value `size`. @@ -1029,7 +1373,6 @@ inline int32_t NumBitsNeededFor(int64_t size) { return 1 + HighestBitSet(size); } - } // namespace k2 #endif // K2_CSRC_HASH_H_ diff --git a/k2/csrc/hash_test.cu b/k2/csrc/hash_test.cu index 5bcb8e6b5..307bd7e18 100644 --- a/k2/csrc/hash_test.cu +++ b/k2/csrc/hash_test.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey) + * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) * * See LICENSE for clarification regarding multiple authors * @@ -300,6 +300,98 @@ void TestHashConstructPacked(int32_t num_key_bits, } } +void TestHash64Construct() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + for (int32_t size : {128, 1024, 2048, 65536, 1048576}) { + Hash64 hash(c, size); + + // obviously we're not going to fill it completely... this hash is not + // resizable. + int32_t num_elems = size / 2; + + // Some keys may be identical. + int32_t key_bound = num_elems * 2; + Array1 keys = RandUniformArray1(c, num_elems, 0, + key_bound - 1), + values = + RandUniformArray1(c, num_elems, 0, 10000), + success(c, num_elems, 0); + + Array1 cpu_keys = keys.To(GetCpuContext()); + Array1 count_per_key(GetCpuContext(), key_bound, 0); + int32_t *count_per_key_data = count_per_key.Data(); + + for (int32_t i = 0; i < cpu_keys.Dim(); ++i) { + ++count_per_key_data[cpu_keys[i]]; + } + count_per_key = count_per_key.To(c); + + if (size <= 2048) { + K2_LOG(INFO) << "keys = " << keys << ", values = " << values + << ", counts = " << count_per_key; + } + uint64_t *keys_data = keys.Data(), *values_data = values.Data(), + *success_data = success.Data(); + int32_t *counts_data = count_per_key.Data(); + Hash64::Accessor acc = hash.GetAccessor(); + K2_EVAL( + c, num_elems, lambda_insert_pairs, (int32_t i)->void { + uint64_t key = keys_data[i], value = values_data[i], success; + + int32_t count = counts_data[key]; + + uint64_t *key_value_location; + if (acc.Insert(key, value, nullptr, &key_value_location)) { + success = 1; + } else { + success = 0; + K2_CHECK(count > 1) << ", key = " << key << ", i = " << i; + } + uint64_t keyval = *key_value_location; + if (success) { + acc.SetValue(key_value_location, value); + K2_DCHECK_EQ(keyval, *key_value_location); + } + success_data[i] = success; + }); + + hash.Resize(hash.NumBuckets() * 2); + acc = hash.GetAccessor(); + + K2_EVAL( + c, num_elems, lambda_check_find, (int32_t i)->void { + uint64_t key = keys_data[i], value = values_data[i], + success = success_data[i]; + + uint64_t val = 0; + uint64_t *key_val_addr = nullptr; + bool ans = acc.Find(key, &val, &key_val_addr), + ans2 = acc.Find(key + key_bound, &val, &key_val_addr); + K2_CHECK(ans); // key should be present. + K2_CHECK(!ans2); // key == key + key_bound should not be present. + + if (success) { + // if this was the key that won the data race, its value should be + // present. + K2_CHECK_EQ(val, value); + K2_CHECK_EQ(*key_val_addr, key); + K2_CHECK_EQ(*(key_val_addr + 1), value); + } + }); + + + + K2_EVAL( + c, num_elems, lambda_check_delete, (int32_t i)->void { + uint64_t key = (uint64_t)keys_data[i]; + uint64_t success = success_data[i]; + + if (success) acc.Delete(key); + }); + } + } +} + TEST(Hash, Construct) { // This indirection gets around a limitation of the CUDA compiler. @@ -320,4 +412,8 @@ TEST(Hash, Construct) { } } +TEST(Hash64, Construct) { + TestHash64Construct(); +} + } // namespace k2 From 0feefc731f65c104f0849e7d4b747aac83c3217a Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 25 Feb 2022 10:30:00 +0800 Subject: [PATCH 45/64] Modified rnnt (#902) * Add modified mutual_information_recursion * Add modified rnnt loss * Using more efficient way to fix boundaries * Fix modified pruned rnnt loss * Fix the s_begin constrains of pruned loss for modified version transducer --- k2/python/csrc/torch/mutual_information.h | 29 +- .../csrc/torch/mutual_information_cpu.cu | 96 ++++-- .../csrc/torch/mutual_information_cuda.cu | 300 ++++++++---------- k2/python/k2/mutual_information.py | 19 +- k2/python/k2/rnnt_loss.py | 278 +++++++++------- k2/python/tests/mutual_information_test.py | 217 +++++++------ k2/python/tests/rnnt_loss_test.py | 283 +++++++++-------- 7 files changed, 682 insertions(+), 540 deletions(-) diff --git a/k2/python/csrc/torch/mutual_information.h b/k2/python/csrc/torch/mutual_information.h index efcdccaa3..2efbaaa1f 100644 --- a/k2/python/csrc/torch/mutual_information.h +++ b/k2/python/csrc/torch/mutual_information.h @@ -33,13 +33,15 @@ namespace k2 { in mutual_information.py. This is the core recursion in the sequence-to-sequence mutual information computation. - @param px Tensor of shape [B][S][T + 1]; contains the log-odds ratio of - generating the next x in the sequence, i.e. - xy[b][s][t] is the log of - p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), - i.e. the log-prob of generating x_s given subsequences of - lengths (s, t), divided by the prior probability of generating - x_s. (See mutual_information.py for more info). + @param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if + modified. `modified` can be worked out from this. In not-modified case, + it can be thought of as the log-odds ratio of generating the next x in + the sequence, i.e. + xy[b][s][t] is the log of + p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), + i.e. the log-prob of generating x_s given subsequences of + lengths (s, t), divided by the prior probability of generating x_s. + (See mutual_information.py for more info). @param py The log-odds ratio of generating the next y in the sequence. Shape [B][S + 1][T] @param p This function writes to p[b][s][t] the mutual information between @@ -49,10 +51,13 @@ namespace k2 { in the case where s_begin == t_begin == 0: p[b,0,0] = 0.0 - p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + if not modified: + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) - if s > 0 or t > 0, - treating values with any -1 index as -infinity. + if modified: + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) + ... treating values with any -1 index as -infinity. .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. @param boundary If set, a tensor of shape [B][4] of type int64_t, which contains, where for each batch element b, boundary[b] @@ -79,8 +84,8 @@ torch::Tensor MutualInformationCpu( torch::Tensor p); // [B][S+1][T+1]; an output torch::Tensor MutualInformationCuda( - torch::Tensor px, // [B][S][T+1] - torch::Tensor py, // [B][S+1][T] + torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified. + torch::Tensor py, // [B][S+1][T] torch::optional boundary, // [B][4], int64_t. torch::Tensor p); // [B][S+1][T+1]; an output diff --git a/k2/python/csrc/torch/mutual_information_cpu.cu b/k2/python/csrc/torch/mutual_information_cpu.cu index bcbe6ed27..7b1b56c6d 100644 --- a/k2/python/csrc/torch/mutual_information_cpu.cu +++ b/k2/python/csrc/torch/mutual_information_cpu.cu @@ -23,9 +23,30 @@ namespace k2 { -// forward of mutual_information. See also comment of `mutual_information` +// forward of mutual_information. See """... """ comment of +// `mutual_information_recursion` in // in k2/python/k2/mutual_information.py for documentation of the // behavior of this function. + +// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out +// `modified` from this. +// py: of shape [B, S+1, T] +// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end) +// defaulting to (0, 0, S, T). +// p: of shape (S+1, T+1) +// Computes the recursion: +// if !modified: +// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], +// p[b,s,t-1] + py[b,s,t-1]) +// if modified: +// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], +// p[b,s,t-1] + py[b,s,t-1]) + +// .. treating out-of-range elements as -infinity and with special cases: +// p[b, s_begin, t_begin] = 0.0 +// +// and this function returns a tensor of shape (B,) consisting of elements +// p[b, s_end, t_end] torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py, torch::optional opt_boundary, torch::Tensor p) { @@ -36,10 +57,13 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py, px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(), "inputs must be CPU tensors"); + bool modified = (px.size(2) == py.size(2)); + auto scalar_t = px.scalar_type(); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); - const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + const int B = px.size(0), S = px.size(1), T = py.size(2); + TORCH_CHECK(px.size(2) == (modified ? T : T + 1)); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); @@ -61,15 +85,22 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py, auto boundary_a = boundary.accessor(); auto ans_a = ans.accessor(); + int t_offset = (modified ? -1 : 0); for (int b = 0; b < B; b++) { int s_begin = boundary_a[b][0]; int t_begin = boundary_a[b][1]; int s_end = boundary_a[b][2]; int t_end = boundary_a[b][3]; p_a[b][s_begin][t_begin] = 0.0; - for (int s = s_begin + 1; s <= s_end; ++s) - p_a[b][s][t_begin] = - p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; + if (modified) { + for (int s = s_begin + 1; s <= s_end; ++s) + p_a[b][s][t_begin] = -std::numeric_limits::infinity(); + } else { + // note: t_offset = 0 so don't need t_begin + t_offset below. + for (int s = s_begin + 1; s <= s_end; ++s) + p_a[b][s][t_begin] = + p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; + } for (int t = t_begin + 1; t <= t_end; ++t) p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1]; @@ -77,12 +108,13 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py, scalar_t p_s_t1 = p_a[b][s][t_begin]; for (int t = t_begin + 1; t <= t_end; ++t) { // The following statement is a small optimization of: - // p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t], - // p_a[b][s][t - 1] + py_a[b][s][t - 1]); + // p_a[b][s][t] = LogAdd( + // p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset], + // p_a[b][s][t - 1] + py_a[b][s][t - 1]); // .. which obtains p_a[b][s][t - 1] from a register. - p_a[b][s][t] = p_s_t1 = - LogAdd()(p_a[b][s - 1][t] + px_a[b][s - 1][t], - p_s_t1 + py_a[b][s][t - 1]); + p_a[b][s][t] = p_s_t1 = LogAdd()( + p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset], + p_s_t1 + py_a[b][s][t - 1]); } } ans_a[b] = p_a[b][s_end][t_end]; @@ -102,6 +134,8 @@ std::vector MutualInformationBackwardCpu( TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); + bool modified = (px.size(2) == py.size(2)); + TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu() && ans_grad.device().is_cpu(), "inputs must be CPU tensors"); @@ -109,8 +143,9 @@ std::vector MutualInformationBackwardCpu( auto scalar_t = px.scalar_type(); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); - const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; - TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); + const int B = px.size(0), S = px.size(1), T = py.size(2); + TORCH_CHECK(px.size(2) == (modified ? T : T + 1)); + TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); auto boundary = opt_boundary.value_or( @@ -123,9 +158,10 @@ std::vector MutualInformationBackwardCpu( TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64); bool has_boundary = opt_boundary.has_value(); + int T1 = T + (modified ? 0 : 1); torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts), - px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) - : torch::empty({B, S, T + 1}, opts)), + px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts) + : torch::empty({B, S, T1}, opts)), py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) : torch::empty({B, S + 1, T}, opts)); @@ -138,6 +174,7 @@ std::vector MutualInformationBackwardCpu( auto ans_grad_a = ans_grad.accessor(); auto boundary_a = boundary.accessor(); + int t_offset = (modified ? -1 : 0); for (int b = 0; b < B; b++) { int s_begin = boundary_a[b][0]; @@ -151,10 +188,12 @@ std::vector MutualInformationBackwardCpu( for (int t = t_end; t > t_begin; --t) { // The s,t indexes correspond to // The statement we are backpropagating here is: - // p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t], - // p_a[b][s][t - 1] + py_a[b][s][t - 1]); + // p_a[b][s][t] = LogAdd( + // p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset], + // p_a[b][s][t - 1] + py_a[b][s][t - 1]); // .. which obtains p_a[b][s][t - 1] from a register. - scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t], + scalar_t term1 = p_a[b][s - 1][t + t_offset] + + px_a[b][s - 1][t + t_offset], // term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not // actually needed.. total = p_a[b][s][t]; @@ -170,8 +209,8 @@ std::vector MutualInformationBackwardCpu( // could happen if total == -inf term1_grad = term2_grad = 0.0; } - px_grad_a[b][s - 1][t] = term1_grad; - p_grad_a[b][s - 1][t] = term1_grad; + px_grad_a[b][s - 1][t + t_offset] = term1_grad; + p_grad_a[b][s - 1][t + t_offset] = term1_grad; py_grad_a[b][s][t - 1] = term2_grad; p_grad_a[b][s][t - 1] += term2_grad; } @@ -184,14 +223,17 @@ std::vector MutualInformationBackwardCpu( p_grad_a[b][s_begin][t - 1] += this_p_grad; py_grad_a[b][s_begin][t - 1] = this_p_grad; } - for (int s = s_end; s > s_begin; --s) { - // Backprop for: - // p_a[b][s][t_begin] = - // p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; - scalar_t this_p_grad = p_grad_a[b][s][t_begin]; - p_grad_a[b][s - 1][t_begin] += this_p_grad; - px_grad_a[b][s - 1][t_begin] = this_p_grad; - } + if (!modified) { + for (int s = s_end; s > s_begin; --s) { + // Backprop for: + // p_a[b][s][t_begin] = + // p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; + scalar_t this_p_grad = p_grad_a[b][s][t_begin]; + p_grad_a[b][s - 1][t_begin] += this_p_grad; + px_grad_a[b][s - 1][t_begin] = this_p_grad; + } + } // else these were all -infinity's and there is nothing to + // backprop. // There is no backprop for: // p_a[b][s_begin][t_begin] = 0.0; // .. but we can use this for a check, that the grad at the beginning diff --git a/k2/python/csrc/torch/mutual_information_cuda.cu b/k2/python/csrc/torch/mutual_information_cuda.cu index c858d4d7b..84e60871e 100644 --- a/k2/python/csrc/torch/mutual_information_cuda.cu +++ b/k2/python/csrc/torch/mutual_information_cuda.cu @@ -42,13 +42,14 @@ namespace k2 { is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading code). Args: - px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of - generating the next x in the sequence, i.e. - xy[b][s][t] is the log of - p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), + px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified; + may be interpreted as the log-odds ratio of + generating the next x in the sequence, i.e. + xy[b][s][t] is the log of + p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), i.e. the log-prob of generating x_s given subsequences of lengths (s, t), divided by the prior probability of generating x_s. (See - mutual_information.py for more info). + mutual_information.py for more info). py: The log-odds ratio of generating the next y in the sequence. Shape [B][S + 1][T] p: This function writes to p[b][s][t] the mutual information between @@ -58,10 +59,14 @@ namespace k2 { in the case where s_begin == t_begin == 0: p[b,0,0] = 0.0 + if not `modified`: p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) (eq. 0) - if s > 0 or t > 0, - treating values with any -1 index as -infinity. + if `modified`: + p[b,s,t] = log_add(p[b,s-1,t-t] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) (eq. 0) + + treating values with any -1 index as -infinity. .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. boundary: If set, a tensor of shape [B][4] of type int64_t, which contains, where for each batch element b, boundary[b] equals @@ -98,6 +103,9 @@ __global__ void mutual_information_kernel( // num_t_blocks = T / BLOCK_SIZE + 1 // so that each group depends on the previous group... const int B = px.size(0), S = px.size(1), T = py.size(2); + const bool modified = (px.size(2) == T); + const int t_offset = (modified ? -1 : 0); // see CPU code to understand. + // num_s_blocks and num_t_blocks are the number of blocks we need to cover the // array of size (S, T) with blocks of this size, in the s and t directions // respectively. @@ -121,9 +129,9 @@ __global__ void mutual_information_kernel( int num_blocks_this_iter = min(iter + 1, num_s_blocks); // For the block with s_block_begin == 0 and t_block_begin == 0 (for - // easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0 - // for out-of-range indexes into px. - // Likewise, py_buf[s][t] will contain exp(py[s][t - 1]). + // easy illustration), px_buf[s][t] will contain px[s - 1][t + t_offset]; or + // -infinity. for out-of-range indexes into px. Likewise, py_buf[s][t] will + // contain (py[s][t - 1]). __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], py_buf[BLOCK_SIZE][BLOCK_SIZE]; @@ -183,19 +191,23 @@ __global__ void mutual_information_kernel( if (block_S <= 0 || block_T <= 0) continue; - // Load px_buf and py_buf. We exponentiate; the assumption is that they - // most likely won't overflow or underflow, but if they do overflow we'll - // detect it later; we'll also detect certain kinds of underflow. + // Load px_buf and py_buf. for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, - s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; + s = s_in_block + s_block_begin, t = t_in_block + t_block_begin, + t_off = t + t_offset; // comparing as unsigned int makes sure the index is nonnegative. // Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px // and py values that are outside the proper boundaries that we need, but // the corresponding p_buf values will end up being 0 so this won't // matter. scalar_t this_px = -INFINITY; - if (s > s_begin && s <= s_end && t <= t_end) this_px = px[b][s - 1][t]; + // Below, "&& t <= t_end" can be interpreted as: + // "&& (modified ? t_off < t_end : t_off <= t_end) + // [since px's last valid index is t_end - 1 if modified, else t_end. + if (s > s_begin && s <= s_end && t_off >= t_begin && t <= t_end) + this_px = px[b][s - 1][t_off]; + px_buf[s_in_block][t_in_block] = this_px; scalar_t this_py = -INFINITY; @@ -203,12 +215,12 @@ __global__ void mutual_information_kernel( py_buf[s_in_block][t_in_block] = this_py; } - // Load the 1st row and 1st column of p_buf (except element[0][0] is not - // needed). This is the context from previously computed blocks of the + // Load the 1st row and 1st column of p_buf. + // This is the context from previously computed blocks of the // image. Remember: p_buf[s][t] will correspond to p[s + s_block_begin - // 1][t + t_block_begin - 1] if (threadIdx.x <= BLOCK_SIZE) { - // s_in_p_buf are simply the indexes into p_buf + // s_in_p_buf and t_in_pbuf are simply the indexes into p_buf int s_in_p_buf = threadIdx.x, t_in_p_buf = 0, s = s_in_p_buf + s_block_begin - 1, t = t_in_p_buf + t_block_begin - 1; @@ -216,10 +228,6 @@ __global__ void mutual_information_kernel( scalar_t this_p = -INFINITY; if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) this_p = p[b][s][t]; - /*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, - t, (float)this_p, (int)threadIdx.x, - (float)px_buf[s_in_p_buf][t_in_p_buf], - (float)py_buf[s_in_p_buf][t_in_p_buf]); */ p_buf[s_in_p_buf][t_in_p_buf] = this_p; } else if (static_cast(static_cast(threadIdx.x) - 64) <= static_cast(BLOCK_SIZE)) { @@ -232,10 +240,6 @@ __global__ void mutual_information_kernel( scalar_t this_p = -INFINITY; if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) this_p = p[b][s][t]; - /*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, - t, (float)this_p, (int)threadIdx.x, - (float)px_buf[s_in_p_buf][t_in_p_buf], - (float)py_buf[s_in_p_buf][t_in_p_buf]);*/ p_buf[s_in_p_buf][t_in_p_buf] = this_p; } @@ -253,18 +257,10 @@ __global__ void mutual_information_kernel( // probability of the pair of sequences of length (0, 0). p_buf[1][1] = (is_origin_block ? 0.0 - : LogAdd()(p_buf[0][1] + px_buf[0][0], - p_buf[1][0] + py_buf[0][0])); - } - - scalar_t p_buf_s1_t; // This is for an optimization to avoid one - // shared-memory read/write in the loop below. it - // represents p_buf[s + 1][t]; the first time we - // access this, it will be for t == 0, except for - // thread 0 when we first need it for t == 1. - if (threadIdx.x < BLOCK_SIZE) { - int s = threadIdx.x; - p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0]; + : LogAdd()( + // px_buf has t_offset applied. + p_buf[0][1 + t_offset] + px_buf[0][0], + p_buf[1][0] + py_buf[0][0])); } int s = threadIdx.x; @@ -299,27 +295,11 @@ __global__ void mutual_information_kernel( // the same as the recursion defined for p in // mutual_information.py:mutual_information_recursion(); and (eq. 0) // above. -#if 0 - p_buf[s + 1][t + 1] = LogAdd()( - p_buf[s][t + 1] + px_buf[s][t], p_buf[s + 1][t] + py_buf[s][t]); - - /*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = - %f, p_buf[s][t+1] = %f, " "px_buf[s][t] = %f, p_buf[s + 1][t] = %f, - py_buf[s][t] = %f\n", (int)threadIdx.x, i, s, t, - (float)p_buf[s+1][t+1], (float)p_buf[s][t+1], (float)px_buf[s][t], - (float)p_buf[s+1][t], (float)py_buf[s][t]);*/ -#else - // This is an optimization of the statement above (the other half of - // this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid - // the need for a load from shared memory. - p_buf_s1_t = LogAdd()(p_buf[s][t + 1] + px_buf[s][t], - p_buf_s1_t + py_buf[s][t]); - // The next time this thread reads p_buf_s1_t, t will be one greater, - // so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this - // thread uses p_buf_s1_t is when t == 0, except for thread 0 where - // the 1st item accessed is for s == 0, t == 1. - p_buf[s + 1][t + 1] = p_buf_s1_t; -#endif + + // note: px_buf has t_offset applied.. + p_buf[s + 1][t + 1] = + LogAdd()(p_buf[s][t + 1 + t_offset] + px_buf[s][t], + p_buf[s + 1][t] + py_buf[s][t]); // We don't need to do __syncthreads() in this loop because all the // threads that are active are in the same warp. (However, in future, // if NVidia changes some things, we might need to sync here). @@ -327,8 +307,7 @@ __global__ void mutual_information_kernel( } __syncthreads(); - // Write out the data to p; check that nothing has gone out of numerical - // range, and write 'panic' flag if it has. + // Write out the data to p; for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; @@ -355,90 +334,64 @@ __global__ void mutual_information_kernel( } } +// like exp(), but returns 0 if arg is inf/nan, or if result would be +// infinity or nan (note: this can happen for out-of-range elements +// when setting px_buf and py_buf is block_S != BLOCK_SIZE or +// block_T != BLOCK_SIZE, and it's a problem because even though +// out-of-range gradients are zero, if we multiply them by infinity +// we get NaN. +template +__forceinline__ __device__ Real safe_exp(Real x) { + if (x - x != 0) + return 0; + else { + Real ans = exp(x); + if (ans - ans != 0.0) return 0; + return ans; + } +} + /* Backward of mutual_information. - If we were to write the forward pass in non-log space, it would be (ignoring - edge cases), as follows... we'll prefix all the variable names with e, e.g. -ep, to clarify that it's the exp of the actual argument p: - - ep[b][s][t] = ep[b][s - 1][t] * epx[b][s - 1][t] + - ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1) - -(A) - First we consider the part of the backprop that requires recursion or -iteration, i.e. the part involving only gradients of ep. - This is: ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t] - ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1]. - - .. and if we add 1 to the s index of the first equation above and 1 to the - t index of the second equation, we can see that: + The forward pass is: - ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] + - ep_grad[b][s][t + 1] * epy[b][s][t]. - - Now, if ep = exp(p), and y is the loss function we are backprop'ing, - then ep_grad == dy/dep == dy/dp - dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p) - == dy/dp / ep. == p_grad / ep. - I.e. ep_grad = p_grad / ep. - - So we can write the above as: - p_grad[b][s][t] / ep[b][s][t] - = p_grad[b][s + 1][t] / ep[b][s + 1][t] * epx[b][s][t] + - p_grad[b][s][t + 1] / ep[b][s][t + 1] * epy[b][s][t]. - - Or, rearranging: - p_grad[b][s][t] = - p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + - p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]). - (eq. 2) - - (B) The following is the backprop for epx and epy from (eq. 1): - - epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t] - epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1] - - .. adding 1 to the s indexes in the 1st equation and to the t indexes in the -2nd: + p[b,s,t] = log_add(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset], + p[b,s,t-1] + py[b,s,t-1]) (eq. 0) - epx_grad[b][s][t] = ep_grad[b][s + 1][t] * ep[b][s][t] - epy_grad[b][s][t] = ep_grad[b][s][t + 1] * ep[b][s][t] + where t_offset = (modified ? -1 : 0) - Using, similar to the above, ep_grad = p_grad / ep, and similarly, - epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) -for p and so on, the above becomes: + The backprop for the above, implemented in the obvious way, would be as + follows (note, we define term1 and term2 with offsets in the indexes, which + will be convenient later..): - px_grad[b][s][t] / exp(px[b][s][t]) = - p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t]) - py_grad[b][s][t] / exp(py[b][s][t]) = - p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t]) - Rearranging: - px_grad[b][s][t] = - p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) - (eq. 3a) - py_grad[b][s][t] = - p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) - (eq. 3b) + term1(b,s-1,t+t_offset) = + exp(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset] - p[b,s,t]) (0a) + term2(b,s,t-1) = exp(p[b,s,t-1] + py[b,s,t-1] - p[b,s,t]) (0b) + p_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1a) + px_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1b) + p_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1c) + py_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1d) - Defining terms that are common to (eq. 2) and (eqs. 3a,3b), write: + Adding 1 and -t_offset to the s and t indexes of (1a) an (1b), and + 1 to the t index of (1c) and (1d), the equations become: - xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 4) - yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5) + p_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2a) + px_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2b) + p_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2c) + py_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2d) - .. and note that these quantities are <= 1 so there is no problem doing - the exponentiation. So the recursion can be simplified as from eqs. (2, 3a, -3b), as: + .. and replacing "+=" with "=", we can write: - p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + - p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6) - px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] (eq. 7) - py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8) + p_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a) + p_grad[b,s,t+1] * term2(b,s,t) + px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) (3b) + py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t) (3c) - (It might seem like we could just reuse px_grad and py_grad for (eq. 6), but -it's not clear to me that this is the best strategy since that would require an -extra write to shared memory within the loop that's the limiting factor.) + Writing the definitions of term1 and term2 in a more convenient way: + term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a) + term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b) The backward pass will be slightly different from the forward pass in terms of how we store and index p (and p_grad), because for writing a particular block @@ -447,15 +400,15 @@ extra write to shared memory within the loop that's the limiting factor.) */ template __global__ void mutual_information_backward_kernel( - // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 - torch::PackedTensorAccessor32 px, + torch::PackedTensorAccessor32 + px, // B, S, T + 1 if !modified; B, S, T if modified. torch::PackedTensorAccessor32 py, // B, S + 1, T. // B, S + 1, T + 1. Produced in forward pass. torch::PackedTensorAccessor32 p, // [B]. This is an input. torch::PackedTensorAccessor32 ans_grad, - // B, S + 1, T + 1. This is a temporary. - torch::PackedTensorAccessor32 p_grad, + torch::PackedTensorAccessor32 + p_grad, // B, S + 1, T + 1 if !modified; B, S, T if modified. torch::PackedTensorAccessor32 px_grad, // B, S, T + 1. torch::PackedTensorAccessor32 py_grad, // B, S + 1, T. // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) @@ -471,6 +424,8 @@ __global__ void mutual_information_backward_kernel( // identical or very close to the value of // ans_grad that was passed in. const int B = px.size(0), S = px.size(1), T = py.size(2); + const bool modified = (px.size(2) == T); + const int neg_t_offset = (modified ? 1 : 0); // For statements that are the same as the forward pass, we are omitting some // comments. We'll focus, in the comments, on differences from the forward @@ -487,17 +442,17 @@ __global__ void mutual_information_backward_kernel( // px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin]; // py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin]. // Later (see eq. 4 and eq. 5): - // px_buf[s][t] contains - // exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]), - // py_buf[s][t] contains - // exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1] + // px_buf[s][t] contains term1(b,ss,tt) == + // exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt-t_offset]), + // py_buf[s][t] contains term2(b,ss,tt) == + // where ss == s + s_block_begin, tt = t + t_block_begin. // Unlike in the forward code, there is no offset of 1 in the indexes. __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], py_buf[BLOCK_SIZE][BLOCK_SIZE]; // p_buf is initially used to store p, and then (after we are done putting - // xderiv and yderiv into px_buf and py_buf) it is repurposed to store + // term1 and term2 into px_buf and py_buf) it is repurposed to store // p_grad. // // Unlike in the forward pass, p_buf has the same numbering as px_buf and @@ -588,7 +543,7 @@ __global__ void mutual_information_backward_kernel( } __syncthreads(); - // Set xderiv and yderiv; see (eq. 4) and (eq. 5). + // Set term1 and term2; see equations (4a) and (4b) above. for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { // We can apply this formula to the entire block even if we are processing // a partial block; we have ensured that x_buf and y_buf contain @@ -596,26 +551,28 @@ __global__ void mutual_information_backward_kernel( // x_buf and y_buf containing 0 after applying the followin formulas. int s = i / BLOCK_SIZE, t = i % BLOCK_SIZE; // Mathematically the following is doing: - // xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + // term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a) // (with an offset on the s and t indexes) - px_buf[s][t] = exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]); + // Use safe_exp() not exp(), as we could have (-inf) - (-inf) = nan, want + // any finite number in this case as derivs would be zero. + // Also want -inf->zero. + px_buf[s][t] = + safe_exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t + neg_t_offset]); // Mathematically the following is doing: - // yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) + // term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b) // (with an offset on the s and t indexes) - py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); + py_buf[s][t] = safe_exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); } __syncthreads(); // Load p_grad for the top and right elements in p_buf: i.e. for elements - // p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't - // need to load the top-right corner [block_S][block_T]; that location will - // never be accessed. + // p_buf[s][t] where s == block_S (exclusive-or) t == block_T. // These are the p_grad values computed by previous instances of this kernel // If this is one of the top or right blocks, some or all of the p_grad // values we'd be reading here will be out of range, and we use zeros // to ensure no gradient gets propagated from those positions. - if (threadIdx.x < block_S) { + if (threadIdx.x <= block_S) { int s_in_block = threadIdx.x, t_in_block = block_T, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; p_buf[s_in_block][t_in_block] = @@ -660,11 +617,12 @@ __global__ void mutual_information_backward_kernel( static_cast(t) < static_cast(block_T)) { // The following statement is really operating on the gradients; // it corresponds, with offsets of s_block_begin and t_block_begin - // on the indexes, to (eq. 6) defined above, i.e.: - // p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + - // p_grad[b][s][t + 1] * yderiv[b][s][t] - p_buf[s][t] = - (p_buf[s + 1][t] * px_buf[s][t] + p_buf[s][t + 1] * py_buf[s][t]); + // on the indexes, to equation (3a) above, i.e.: + // p_grad[b,s,t] = + // p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a) + // p_grad[b,s,t+1] * term2(b,s,t) + p_buf[s][t] = (p_buf[s + 1][t + neg_t_offset] * px_buf[s][t] + + p_buf[s][t + 1] * py_buf[s][t]); } } } @@ -680,15 +638,20 @@ __global__ void mutual_information_backward_kernel( if (t <= t_end && s <= s_end) { p_grad[b][s][t] = p_buf[s_in_block][t_in_block]; - if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1] - // From (eq. 7): - // px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] - px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] * + if (s < s_end && t <= t_end - neg_t_offset) { + // write px_grad, which is of shape [B][S][T + 1] if !modified, + // [B][S][T] if modified. the condition "t <= t_end - neg_t_offset" + // becomes "t <= t_end" if !modified, and "t <= t_end - 1" if + // modified, keeping us within the bounds of px_grad. + + // From (eq. 3b): + // px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) + px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block + neg_t_offset] * px_buf[s_in_block][t_in_block]); } if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T] - // from (eq. 8): - // py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] + // from (eq. 3c): + // py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t) py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] * py_buf[s_in_block][t_in_block]); } @@ -717,7 +680,8 @@ torch::Tensor MutualInformationCuda(torch::Tensor px, torch::Tensor py, auto scalar_t = px.scalar_type(); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); - const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + const int B = px.size(0), S = px.size(1), T = py.size(2); + TORCH_CHECK(px.size(2) == T || px.size(2) == T + 1); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); @@ -777,9 +741,12 @@ std::vector MutualInformationBackwardCuda( auto scalar_t = px.scalar_type(); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); - const int B = px.size(0), S = px.size(1), T = px.size(2) - 1; + const int B = px.size(0), S = px.size(1), T = py.size(2); - TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); + TORCH_CHECK(px.size(2) == T || + px.size(2) == T + 1); // modified case || not-modified case + const bool modified = (px.size(2) == T); + TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); auto boundary = opt_boundary.value_or( @@ -793,9 +760,10 @@ std::vector MutualInformationBackwardCuda( bool has_boundary = opt_boundary.has_value(); + int T1 = T + (modified ? 0 : 1); torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), - px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) - : torch::empty({B, S, T + 1}, opts)), + px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts) + : torch::empty({B, S, T1}, opts)), py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) : torch::empty({B, S + 1, T}, opts)); diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 6a61d8f1f..7e94105e2 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -31,7 +31,8 @@ def forward( return_grad: bool = False, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: (B, S, T1) = px.shape - T = T1 - 1 + T = py.shape[-1] + assert T1 in [T, T + 1] assert py.shape == (B, S + 1, T) if boundary is not None: assert boundary.shape == (B, 4) @@ -155,9 +156,12 @@ def mutual_information_recursion( of lengths ``s`` and ``t``:: p[b,0,0] = 0.0 + if !modified: p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t-1] + py[b,s,t-1]) - (if s > 0 or t > 0) + if modified: + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) where we handle edge cases by treating quantities with negative indexes as **-infinity**. The extension to cases where the boundaries are @@ -166,10 +170,10 @@ def mutual_information_recursion( """ assert px.ndim == 3 B, S, T1 = px.shape - T = T1 - 1 + T = py.shape[-1] + assert px.shape[-1] in [T, T + 1] # if T, then "modified". assert py.shape == (B, S + 1, T) assert px.dtype == py.dtype - (B, S, T) = px.shape if boundary is not None: assert boundary.dtype == torch.int64 assert boundary.shape == (B, 4) @@ -250,11 +254,14 @@ def joint_mutual_information_recursion( N = len(px) assert len(py) == N and N > 0 B, S, T1 = px[0].shape - T = T1 - 1 + T = py[0].shape[2] + assert T1 in [T, T + 1] # T if modified... assert py[0].shape == (B, S + 1, T) assert px[0].dtype == py[0].dtype - px_cat = torch.stack(px, dim=0) # (N, B, S, T+1) + px_cat = torch.stack( + px, dim=0 + ) # (N, B, S, T+1) if !modified,(N, B, S, T) if modified. py_cat = torch.stack(py, dim=0) # (N, B, S+1, T) px_tot = px_cat.sum(dim=0) # (B, S, T+1) py_tot = py_cat.sum(dim=0) # (B, S+1, T) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 6805ae2a4..1e68668df 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -26,14 +26,16 @@ def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: """ Insert -inf's into `px` in appropriate places if `boundary` is not - None. If boundary == None, px[:,:,-1] will be -infinity, - but if boundary is specified, we need px[b,:,boundary[b,3]] + None. If boundary == None and modified == False, px[:,:,-1] will + be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] to be -infinity. + Args: - px: a Tensor of of shape [B][S][T+1], px is modified in-place - and returned. - boundary: None, or a Tensor of shape [B][3] containing - [s_begin, t_begin, s_end, t_end]; we need only t_end. + px: a Tensor of of shape [B][S][T+1] (this function is only + called if modified == False, see other docs for `modified`) + px is modified in-place and returned. + boundary: None, or a Tensor of shape [B][3] containing + [s_begin, t_begin, s_end, t_end]; we need only t_end. """ if boundary is None: return px @@ -48,6 +50,7 @@ def get_rnnt_logprobs( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + modified: bool = False, ) -> Tuple[Tensor, Tensor]: """ Reduces RNN-T problem (the simple case, where joiner network is just @@ -89,34 +92,39 @@ def get_rnnt_logprobs( termination_symbol: The identity of the termination symbol, must be in {0..C-1} boundary: - a LongTensor of shape [B, 4] with elements interpreted as + a optional LongTensor of shape [B, 4] with elements interpreted as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. Returns: - (px, py) (the names are quite arbitrary):: - - px: logprobs, of shape [B][S][T+1] + (px, py) (the names are quite arbitrary). + px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified. py: logprobs, of shape [B][S+1][T] in the recursion:: p[b,0,0] = 0.0 - p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], - p[b,s,t-1] + py[b,s,t-1]) - - where p[b][s][t] is the "joint score" of the pair of subsequences - of length s and t respectively. px[b][s][t] represents the - probability of extending the subsequences of length (s,t) by one in - the s direction, given the particular symbol, and py[b][s][t] - represents the probability of extending the subsequences of length - (s,t) by one in the t direction, - i.e. of emitting the termination/next-frame symbol. - - px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame - we cannot emit any symbols. This is simply a way of incorporating - the probability of the termination symbol on the last frame. + if !modified: + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + if modified: + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences + of length s and t respectively. px[b][s][t] represents the + probability of extending the subsequences of length (s,t) by one in + the s direction, given the particular symbol, and py[b][s][t] + represents the probability of extending the subsequences of length + (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + if !modified, px[:,:,T] equals -infinity, meaning on the + "one-past-the-last" frame we cannot emit any symbols. + This is simply a way of incorporating + the probability of the termination symbol on the last frame. """ assert lm.ndim == 3 assert am.ndim == 3 @@ -152,15 +160,19 @@ def get_rnnt_logprobs( -1 ) # [B][S][T] - px_am = torch.cat( - ( - px_am, - torch.full( - (B, S, 1), float("-inf"), device=px_am.device, dtype=px_am.dtype + if not modified: + px_am = torch.cat( + ( + px_am, + torch.full( + (B, S, 1), + float("-inf"), + device=px_am.device, + dtype=px_am.dtype, + ), ), - ), - dim=2, - ) # now: [B][S][T+1], index [:,:,T] has -inf.. + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. px_lm = torch.gather( lm[:, :S], dim=2, index=symbols.unsqueeze(-1) @@ -175,7 +187,9 @@ def get_rnnt_logprobs( py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] py = py_am + py_lm - normalizers - px = fix_for_boundary(px, boundary) + if not modified: + px = fix_for_boundary(px, boundary) + return (px, py) @@ -185,6 +199,7 @@ def rnnt_loss_simple( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + modified: bool = False, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: @@ -204,11 +219,13 @@ def rnnt_loss_simple( termination_symbol: the termination symbol, with 0 <= termination_symbol < C boundary: - a LongTensor of shape [B, 4] with elements interpreted as + a optional LongTensor of shape [B, 4] with elements interpreted as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -236,6 +253,7 @@ def rnnt_loss_simple( symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, + modified=modified, ) scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad @@ -259,6 +277,7 @@ def get_rnnt_logprobs_joint( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + modified: bool = False, ) -> Tuple[Tensor, Tensor]: """Reduces RNN-T problem to a compact, standard form that can then be given (with boundaries) to mutual_information_recursion(). @@ -274,11 +293,13 @@ def get_rnnt_logprobs_joint( termination_symbol: The identity of the termination symbol, must be in {0..C-1} boundary: - a LongTensor of shape [B, 4] with elements interpreted as + a optional LongTensor of shape [B, 4] with elements interpreted as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. Returns: (px, py) (the names are quite arbitrary):: @@ -288,18 +309,22 @@ def get_rnnt_logprobs_joint( in the recursion:: p[b,0,0] = 0.0 - p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], - p[b,s,t-1] + py[b,s,t-1]) - - where p[b][s][t] is the "joint score" of the pair of subsequences of + if !modified: + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + if modified: + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences of length s and t respectively. px[b][s][t] represents the probability of extending the subsequences of length (s,t) by one in the s direction, given the particular symbol, and py[b][s][t] represents the probability of extending the subsequences of length (s,t) by one in the t direction, i.e. of emitting the termination/next-frame symbol. - px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame - we cannot emit any symbols. This is simply a way of incorporating + if !modified, px[:,:,T] equals -infinity, meaning on the + "one-past-the-last" frame we cannot emit any symbols. + This is simply a way of incorporating the probability of the termination symbol on the last frame. """ assert logits.ndim == 4 @@ -314,15 +339,17 @@ def get_rnnt_logprobs_joint( logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) ).squeeze(-1) px = px.permute((0, 2, 1)) - px = torch.cat( - ( - px, - torch.full( - (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + + if not modified: + px = torch.cat( + ( + px, + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), ), - ), - dim=2, - ) # now: [B][S][T+1], index [:,:,T] has -inf.. + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. px[:, :, :T] -= normalizers[:, :S, :] @@ -333,7 +360,9 @@ def get_rnnt_logprobs_joint( px = px.contiguous() py = py.contiguous() - px = fix_for_boundary(px, boundary) + if not modified: + px = fix_for_boundary(px, boundary) + return (px, py) @@ -342,6 +371,7 @@ def rnnt_loss( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + modified: bool = False, reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, @@ -357,10 +387,12 @@ def rnnt_loss( termination_symbol: the termination symbol, with 0 <= termination_symbol < C boundary: - a LongTensor of shape [B, 4] with elements interpreted as + a optional LongTensor of shape [B, 4] with elements interpreted as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -378,6 +410,7 @@ def rnnt_loss( symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, + modified=modified, ) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": @@ -494,27 +527,27 @@ def get_rnnt_prune_ranges( (B, T, s_range). """ (B, S, T1) = px_grad.shape - T = T1 - 1 + T = py_grad.shape[-1] + assert T1 in [T, T + 1] assert py_grad.shape == (B, S + 1, T) assert boundary.shape == (B, 4) assert s_range >= 1 if s_range > S: s_range = S - px_pad = torch.zeros( - (B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device - ) + px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device) py_pad = torch.zeros( (B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device ) - tot_grad = torch.cat((px_grad, px_pad), dim=1) + torch.cat( - (py_grad, py_pad), dim=2 - ) # (B, S + 1, T + 1) + py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2) + tot_grad = ( + torch.cat((px_grad, px_pad), dim=1) + py_grad_padded + ) # (B, S + 1, T1) tot_grad = torch.cat( ( torch.zeros( - (B, 1, T + 1), dtype=tot_grad.dtype, device=tot_grad.device + (B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device ), tot_grad, ), @@ -539,9 +572,12 @@ def get_rnnt_prune_ranges( s_begin = torch.where(mask, s_begin, s_begin_padding) - # adjusting lower bound to make it satisfied constrains, see docs in + # adjusting lower bound to make it satisfied some constrains, see docs in # `adjust_pruning_lower_bound` for more details of these constrains. - s_begin = _adjust_pruning_lower_bound(s_begin, s_range) + # T1 == T here means we are using the modified version of transducer, + # the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because + # it only emits one symbol per frame. + s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range) ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange( s_range, device=px_grad.device ) @@ -630,6 +666,7 @@ def get_rnnt_logprobs_pruned( ranges: Tensor, termination_symbol: int, boundary: Tensor, + modified: bool = False, ) -> Tuple[Tensor, Tensor]: """Construct px, py for mutual_information_recursion with pruned output. @@ -644,14 +681,16 @@ def get_rnnt_logprobs_pruned( termination_symbol: the termination symbol, with 0 <= termination_symbol < C boundary: - a LongTensor of shape [B, 4] with elements interpreted as + a optional LongTensor of shape [B, 4] with elements interpreted as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. Returns: - Return the px (B, S, T + 1) and py (B, S + 1, T) needed by - mutual_information_recursion. + Return the px (B, S, T) if modified else (B, S, T + 1) and + py (B, S + 1, T) needed by mutual_information_recursion. """ # logits (B, T, s_range, C) # symbols (B, S) @@ -676,7 +715,7 @@ def get_rnnt_logprobs_pruned( ) # (B, T, s_range) - pruning_symbols = torch.gather( + pruned_symbols = torch.gather( symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)), dim=2, index=ranges, @@ -684,7 +723,7 @@ def get_rnnt_logprobs_pruned( # (B, T, s_range) px = torch.gather( - logits, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) + logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1) ).squeeze(-1) px = px - normalizers @@ -706,15 +745,17 @@ def get_rnnt_logprobs_pruned( px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S] px = px.permute((0, 2, 1)) - px = torch.cat( - ( - px, - torch.full( - (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + + if not modified: + px = torch.cat( + ( + px, + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), ), - ), - dim=2, - ) # now: [B][S][T+1], index [:,:,T] has -inf.. + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range) py = py - normalizers @@ -741,7 +782,9 @@ def get_rnnt_logprobs_pruned( px = px.contiguous() py = py.contiguous() - px = fix_for_boundary(px, boundary) + if not modified: + px = fix_for_boundary(px, boundary) + return (px, py) @@ -751,6 +794,7 @@ def rnnt_loss_pruned( ranges: Tensor, termination_symbol: int, boundary: Tensor = None, + modified: bool = False, reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output @@ -773,6 +817,8 @@ def rnnt_loss_pruned( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -790,6 +836,7 @@ def rnnt_loss_pruned( ranges=ranges, termination_symbol=termination_symbol, boundary=boundary, + modified=modified, ) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": @@ -812,8 +859,10 @@ def get_rnnt_logprobs_smoothed( lm_only_scale: float = 0.1, am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, + modified: bool = False, ) -> Tuple[Tensor, Tensor]: - """Reduces RNN-T problem (the simple case, where joiner network is just + """ + Reduces RNN-T problem (the simple case, where joiner network is just addition), to a compact, standard form that can then be given (with boundaries) to mutual_information_recursion(). This version allows you to make the loss-function one of the form:: @@ -866,34 +915,38 @@ def get_rnnt_logprobs_smoothed( the scale on the "AM-only" part of the loss, for which we use an "averaged" LM (averaged over all histories, so effectively unigram). boundary: - a LongTensor of shape [B, 4] with elements interpreted as + a optional LongTensor of shape [B, 4] with elements interpreted as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. Returns: - (px, py) (the names are quite arbitrary):: - - px: logprobs, of shape [B][S][T+1] + (px, py) (the names are quite arbitrary). + px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified. py: logprobs, of shape [B][S+1][T] in the recursion:: p[b,0,0] = 0.0 - p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], - p[b,s,t-1] + py[b,s,t-1]) - - where p[b][s][t] is the "joint score" of the pair of subsequences - of length s and t respectively. px[b][s][t] represents the - probability of extending the subsequences of length (s,t) by one in - the s direction, given the particular symbol, and py[b][s][t] - represents the probability of extending the subsequences of length - (s,t) by one in the t direction, - i.e. of emitting the termination/next-frame symbol. - - px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame - we cannot emit any symbols. This is simply a way of incorporating - the probability of the termination symbol on the last frame. + if !modified: + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + if modified: + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences + of length s and t respectively. px[b][s][t] represents the + probability of extending the subsequences of length (s,t) by one in + the s direction, given the particular symbol, and py[b][s][t] + represents the probability of extending the subsequences of length + (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame + we cannot emit any symbols. This is simply a way of incorporating + the probability of the termination symbol on the last frame. """ assert lm.ndim == 3 assert am.ndim == 3 @@ -952,15 +1005,20 @@ def get_rnnt_logprobs_smoothed( ).squeeze( -1 ) # [B][S][T] - px_am = torch.cat( - ( - px_am, - torch.full( - (B, S, 1), float("-inf"), device=px_am.device, dtype=px_am.dtype + + if not modified: + px_am = torch.cat( + ( + px_am, + torch.full( + (B, S, 1), + float("-inf"), + device=px_am.device, + dtype=px_am.dtype, + ), ), - ), - dim=2, - ) # now: [B][S][T+1], index [:,:,T] has -inf.. + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. px_lm = torch.gather( lm[:, :S], dim=2, index=symbols.unsqueeze(-1) @@ -969,10 +1027,12 @@ def get_rnnt_logprobs_smoothed( unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] - px = px_am + px_lm # [B][S][T+1], last slice indexed [:,:,T] is -inf - px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] + px = px_am + px_lm # [B][S][T+1] if not modified, [B][S][T] if modified + px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T] - px_amonly = px_am + px_lm_unigram # [B][S][T+1] + px_amonly = ( + px_am + px_lm_unigram + ) # [B][S][T+1] if !modified; [B][S][T] if modified. px_amonly[:, :, :T] -= amonly_normalizers px_lmonly = px_lm - lmonly_normalizers[:, :S, :] @@ -1005,7 +1065,9 @@ def get_rnnt_logprobs_smoothed( + py_amonly * am_only_scale ) - px_interp = fix_for_boundary(px_interp, boundary) + if not modified: + px_interp = fix_for_boundary(px_interp, boundary) + return (px_interp, py_interp) @@ -1017,6 +1079,7 @@ def rnnt_loss_smoothed( lm_only_scale: float = 0.1, am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, + modified: bool = False, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Tensor: @@ -1048,6 +1111,8 @@ def rnnt_loss_smoothed( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -1078,6 +1143,7 @@ def rnnt_loss_smoothed( lm_only_scale=lm_only_scale, am_only_scale=am_only_scale, boundary=boundary, + modified=modified, ) scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad diff --git a/k2/python/tests/mutual_information_test.py b/k2/python/tests/mutual_information_test.py index fc48cd1da..11917f18f 100644 --- a/k2/python/tests/mutual_information_test.py +++ b/k2/python/tests/mutual_information_test.py @@ -31,69 +31,85 @@ # Caution: this will fail occasionally due to cutoffs not being quite large # enough. As long as it passes most of the time, it's OK. class TestMutualInformation(unittest.TestCase): - @classmethod def setUpClass(cls): - cls.devices = [torch.device('cpu')] + cls.devices = [torch.device("cpu")] if torch.cuda.is_available() and k2.with_cuda: - cls.devices.append(torch.device('cuda', 0)) + cls.devices.append(torch.device("cuda", 0)) if torch.cuda.device_count() > 1: torch.cuda.set_device(1) - cls.devices.append(torch.device('cuda', 1)) + cls.devices.append(torch.device("cuda", 1)) cls.dtypes = [torch.float32, torch.float64] def test_mutual_information_basic(self): for _iter in range(100): - (B, S, T) = (random.randint(1, 10), random.randint(1, 16), - random.randint(1, 500)) - random_px = (random.random() < 0.2) - random_py = (random.random() < 0.2) - random_boundary = (random.random() < 0.7) - big_px = (random.random() < 0.2) - big_py = (random.random() < 0.2) + (B, S, T) = ( + random.randint(1, 10), + random.randint(1, 16), + random.randint(1, 500), + ) + random_px = random.random() < 0.2 + random_py = random.random() < 0.2 + random_boundary = random.random() < 0.7 + big_px = random.random() < 0.2 + big_py = random.random() < 0.2 + + modified = random.random() < 0.5 + + if modified and T < S: + T = S + random.randint(0, 30) for dtype in self.dtypes: for device in self.devices: if random_boundary: def get_boundary_row(): - s_begin = random.randint(0, S - 1) - t_begin = random.randint(0, T - 1) - # allow empty sequence - s_end = random.randint(s_begin, S) - # allow empty sequence - t_end = random.randint(t_begin, T) + this_S = random.randint( + 0, S + ) # allow empty sequence + this_T = random.randint( + this_S if modified else 1, T + ) + s_begin = random.randint(0, S - this_S) + t_begin = random.randint(0, T - this_T) + s_end = s_begin + this_S + t_end = t_begin + this_T return [s_begin, t_begin, s_end, t_end] - if device == torch.device('cpu'): + if device == torch.device("cpu"): boundary = torch.tensor( [get_boundary_row() for _ in range(B)], dtype=torch.int64, - device=device) + device=device, + ) else: boundary = boundary.to(device) else: # Use default boundary, but either specified directly # or not. if random.random() < 0.5: - boundary = torch.tensor( - [0, 0, S, T], - dtype=torch.int64).unsqueeze(0).expand( - B, 4).to(device) + boundary = ( + torch.tensor([0, 0, S, T], dtype=torch.int64) + .unsqueeze(0) + .expand(B, 4) + .to(device) + ) else: boundary = None - if device == torch.device('cpu'): + if device == torch.device("cpu"): if random_px: # log of an odds ratio - px = torch.randn(B, S, T + 1, - dtype=dtype).to(device) - if S > 1 and not random_boundary: - px[:, :, -1:] = float('-inf') + px = torch.randn( + B, S, T + (0 if modified else 1), dtype=dtype + ).to(device) + if S > 1 and not random_boundary and not modified: + px[:, :, -1:] = float("-inf") else: # log of an odds ratio - px = torch.zeros(B, S, T + 1, - dtype=dtype).to(device) + px = torch.zeros( + B, S, T + (0 if modified else 1), dtype=dtype + ).to(device) # px and py get exponentiated, and then multiplied # together up to 32 times (BLOCK_SIZE in the CUDA code), # so 15 is actually a big number that could lead to @@ -102,12 +118,14 @@ def get_boundary_row(): px += 15.0 if random_py: # log of an odds ratio - py = torch.randn(B, S + 1, T, - dtype=dtype).to(device) + py = torch.randn(B, S + 1, T, dtype=dtype).to( + device + ) else: # log of an odds ratio - py = torch.zeros(B, S + 1, T, - dtype=dtype).to(device) + py = torch.zeros(B, S + 1, T, dtype=dtype).to( + device + ) if big_py: py += 15.0 @@ -119,11 +137,13 @@ def get_boundary_row(): m = k2.mutual_information_recursion(px, py, boundary) - m2 = k2.joint_mutual_information_recursion((px,), (py,), - boundary) + m2 = k2.joint_mutual_information_recursion( + (px,), (py,), boundary + ) m3 = k2.joint_mutual_information_recursion( - (px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary) + (px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary + ) # it is supposed to be identical only after # summing over dim 0, corresponding to the @@ -150,28 +170,39 @@ def get_boundary_row(): expected_px_grad = px.grad expected_py_grad = py.grad expected_m = m - assert torch.allclose(px.grad, - expected_px_grad.to(device), - atol=1.0e-02, - rtol=1.0e-02) - assert torch.allclose(py.grad, - expected_py_grad.to(device), - atol=1.0e-02, - rtol=1.0e-02) - assert torch.allclose(m, - expected_m.to(device), - atol=1.0e-02, - rtol=1.0e-02) + assert torch.allclose( + px.grad, + expected_px_grad.to(device), + atol=1.0e-02, + rtol=1.0e-02, + ) + assert torch.allclose( + py.grad, + expected_py_grad.to(device), + atol=1.0e-02, + rtol=1.0e-02, + ) + assert torch.allclose( + m, expected_m.to(device), atol=1.0e-02, rtol=1.0e-02 + ) def test_mutual_information_deriv(self): for _iter in range(100): - (B, S, T) = (random.randint(1, 10), random.randint(1, 200), - random.randint(1, 200)) - random_px = (random.random() < 0.2) - random_py = (random.random() < 0.2) - random_boundary = (random.random() < 0.7) - big_px = (random.random() < 0.2) - big_py = (random.random() < 0.2) + (B, S, T) = ( + random.randint(1, 100), + random.randint(1, 200), + random.randint(1, 200), + ) + random_px = random.random() < 0.2 + random_py = random.random() < 0.2 + random_boundary = random.random() < 0.7 + big_px = random.random() < 0.2 + big_py = random.random() < 0.2 + + modified = random.random() < 0.5 + + if modified and T < S: + T = S + random.randint(0, 30) for dtype in self.dtypes: for device in self.devices: @@ -179,39 +210,45 @@ def test_mutual_information_deriv(self): if random_boundary: def get_boundary_row(): - s_begin = random.randint(0, S - 1) - t_begin = random.randint(0, T - 1) - s_end = random.randint(s_begin + 1, S) - t_end = random.randint(t_begin + 1, T) + this_S = random.randint(1, S) + this_T = random.randint( + this_S if modified else 1, T + ) + s_begin = random.randint(0, S - this_S) + t_begin = random.randint(0, T - this_T) + s_end = s_begin + this_S + t_end = t_begin + this_T return [s_begin, t_begin, s_end, t_end] - if device == torch.device('cpu'): + if device == torch.device("cpu"): boundary = torch.tensor( [get_boundary_row() for _ in range(B)], dtype=torch.int64, - device=device) + device=device, + ) else: boundary = boundary.to(device) else: # Use default boundary, but either specified directly # or not. if random.random() < 0.5: - boundary = torch.tensor( - [0, 0, S, T], - dtype=torch.int64).unsqueeze(0).expand( - B, 4).to(device) + boundary = ( + torch.tensor([0, 0, S, T], dtype=torch.int64) + .unsqueeze(0) + .expand(B, 4) + .to(device) + ) else: boundary = None - if device == torch.device('cpu'): + T1 = T + (0 if modified else 1) + if device == torch.device("cpu"): if random_px: # log of an odds ratio - px = torch.randn(B, S, T + 1, - dtype=dtype).to(device) + px = torch.randn(B, S, T1, dtype=dtype).to(device) else: # log of an odds ratio - px = torch.zeros(B, S, T + 1, - dtype=dtype).to(device) + px = torch.zeros(B, S, T1, dtype=dtype).to(device) # px and py get exponentiated, and then multiplied # together up to 32 times (BLOCK_SIZE in the CUDA code), # so 15 is actually a big number that could lead to @@ -220,12 +257,14 @@ def get_boundary_row(): px += 15.0 if random_py: # log of an odds ratio - py = torch.randn(B, S + 1, T, - dtype=dtype).to(device) + py = torch.randn(B, S + 1, T, dtype=dtype).to( + device + ) else: # log of an odds ratio - py = torch.zeros(B, S + 1, T, - dtype=dtype).to(device) + py = torch.zeros(B, S + 1, T, dtype=dtype).to( + device + ) if big_py: py += 15.0 else: @@ -241,30 +280,30 @@ def get_boundary_row(): delta = 1.0e-04 delta_px = delta * torch.randn_like(px) m2 = k2.mutual_information_recursion( - px + delta_px, py, boundary) + px + delta_px, py, boundary + ) delta_m = m2 - m - observed_delta = (delta_m * m_grad).sum().to('cpu') - predicted_delta = (delta_px * px.grad).sum().to('cpu') + observed_delta = (delta_m * m_grad).sum().to("cpu") + predicted_delta = (delta_px * px.grad).sum().to("cpu") atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 - assert torch.allclose(observed_delta, - predicted_delta, - atol=atol, - rtol=rtol) + assert torch.allclose( + observed_delta, predicted_delta, atol=atol, rtol=rtol + ) delta_py = delta * torch.randn_like(py) m2 = k2.mutual_information_recursion( - px, py + delta_py, boundary) + px, py + delta_py, boundary + ) delta_m = m2 - m - observed_delta = (delta_m * m_grad).sum().to('cpu') - predicted_delta = (delta_py * py.grad).sum().to('cpu') + observed_delta = (delta_m * m_grad).sum().to("cpu") + predicted_delta = (delta_py * py.grad).sum().to("cpu") - assert torch.allclose(observed_delta, - predicted_delta, - atol=atol, - rtol=rtol) + assert torch.allclose( + observed_delta, predicted_delta, atol=atol, rtol=rtol + ) if __name__ == "__main__": diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py index d619591a8..e04526767 100644 --- a/k2/python/tests/rnnt_loss_test.py +++ b/k2/python/tests/rnnt_loss_test.py @@ -205,105 +205,114 @@ def test_rnnt_loss_random(self): boundary_[:, 2] = seq_length boundary_[:, 3] = frames - for device in self.devices: - - # lm: [B][S+1][C] - lm = lm_.to(device) - # am: [B][T][C] - am = am_.to(device) - symbols = symbols_.to(device) - boundary = boundary_.to(device) - - px, py = k2.get_rnnt_logprobs( - lm=lm, - am=am, - symbols=symbols, - termination_symbol=termination_symbol, - boundary=boundary, - ) - assert px.shape == (B, S, T + 1) - assert py.shape == (B, S + 1, T) - assert symbols.shape == (B, S) - m = k2.mutual_information_recursion(px=px, py=py, boundary=boundary) - - if device == torch.device("cpu"): - expected = -torch.mean(m) - assert torch.allclose(-torch.mean(m), expected.to(device)) + for modified in [True, False]: + for device in self.devices: + # lm: [B][S+1][C] + lm = lm_.to(device) + # am: [B][T][C] + am = am_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) - m = k2.rnnt_loss_simple( - lm=lm, - am=am, - symbols=symbols, - termination_symbol=termination_symbol, - boundary=boundary, - ) - assert torch.allclose(m, expected.to(device)) + px, py = k2.get_rnnt_logprobs( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + modified=modified, + ) + assert px.shape == (B, S, T) if modified else (B, S, T + 1) + assert py.shape == (B, S + 1, T) + assert symbols.shape == (B, S) + m = k2.mutual_information_recursion( + px=px, py=py, boundary=boundary + ) - m = k2.rnnt_loss_smoothed( - lm=lm, - am=am, - symbols=symbols, - termination_symbol=termination_symbol, - lm_only_scale=0.0, - am_only_scale=0.0, - boundary=boundary, - ) - assert torch.allclose(m, expected.to(device)) + if device == torch.device("cpu"): + expected = -torch.mean(m) + assert torch.allclose(-torch.mean(m), expected.to(device)) - probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss( - logits=probs, - symbols=symbols, - termination_symbol=termination_symbol, - boundary=boundary, - ) - assert torch.allclose(m, expected.to(device)) + m = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + modified=modified, + ) + assert torch.allclose(m, expected.to(device)) - # compare with torchaudio rnnt_loss - if self.has_torch_rnnt_loss: - import torchaudio.functional + m = k2.rnnt_loss_smoothed( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.0, + boundary=boundary, + modified=modified, + ) + assert torch.allclose(m, expected.to(device)) - m = torchaudio.functional.rnnt_loss( + probs = am.unsqueeze(2) + lm.unsqueeze(1) + m = k2.rnnt_loss( logits=probs, - targets=symbols.int(), - logit_lengths=boundary[:, 3].int(), - target_lengths=boundary[:, 2].int(), - blank=termination_symbol, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + modified=modified, ) assert torch.allclose(m, expected.to(device)) - # should be invariant to adding a constant for any frame. - lm += torch.randn(B, S + 1, 1, device=device) - am += torch.randn(B, T, 1, device=device) - - m = k2.rnnt_loss_simple( - lm=lm, - am=am, - symbols=symbols, - termination_symbol=termination_symbol, - boundary=boundary, - ) - assert torch.allclose(m, expected.to(device)) + # compare with torchaudio rnnt_loss + if self.has_torch_rnnt_loss and not modified: + import torchaudio.functional + + m = torchaudio.functional.rnnt_loss( + logits=probs, + targets=symbols.int(), + logit_lengths=boundary[:, 3].int(), + target_lengths=boundary[:, 2].int(), + blank=termination_symbol, + ) + assert torch.allclose(m, expected.to(device)) + + # should be invariant to adding a constant for any frame. + lm += torch.randn(B, S + 1, 1, device=device) + am += torch.randn(B, T, 1, device=device) + + m = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + modified=modified, + ) + assert torch.allclose(m, expected.to(device)) - probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss( - logits=probs, - symbols=symbols, - termination_symbol=termination_symbol, - boundary=boundary, - ) - assert torch.allclose(m, expected.to(device)) + probs = am.unsqueeze(2) + lm.unsqueeze(1) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + modified=modified, + ) + assert torch.allclose(m, expected.to(device)) - m = k2.rnnt_loss_smoothed( - lm=lm, - am=am, - symbols=symbols, - termination_symbol=termination_symbol, - lm_only_scale=0.0, - am_only_scale=0.0, - boundary=boundary, - ) - assert torch.allclose(m, expected.to(device)) + m = k2.rnnt_loss_smoothed( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + lm_only_scale=0.0, + am_only_scale=0.0, + boundary=boundary, + modified=modified, + ) + assert torch.allclose(m, expected.to(device)) def test_rnnt_loss_gradient(self): if self.has_torch_rnnt_loss: @@ -434,62 +443,68 @@ def test_rnnt_loss_pruned(self): boundary_[:, 2] = seq_length boundary_[:, 3] = frames - for device in self.devices: - # normal rnnt - am = am_.to(device) - lm = lm_.to(device) - symbols = symbols_.to(device) - boundary = boundary_.to(device) - t_am = am.unsqueeze(2).float() - t_lm = lm.unsqueeze(1).float() - t_prob = t_am + t_lm - # nonlinear transform - t_prob = torch.sigmoid(t_prob) - k2_loss = k2.rnnt_loss( - logits=t_prob, - symbols=symbols, - termination_symbol=terminal_symbol, - boundary=boundary, - reduction="none", - ) - - print("unpruned rnnt loss: ", k2_loss) - - # pruning - k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( - lm=lm, - am=am, - symbols=symbols, - termination_symbol=terminal_symbol, - boundary=boundary, - return_grad=True, - reduction="none", - ) + for modified in [True, False]: + for device in self.devices: + # normal rnnt + am = am_.to(device) + lm = lm_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + t_am = am.unsqueeze(2).float() + t_lm = lm.unsqueeze(1).float() + t_prob = t_am + t_lm - for r in range(2, 50, 5): - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, + # nonlinear transform + t_prob = torch.sigmoid(t_prob) + k2_loss = k2.rnnt_loss( + logits=t_prob, + symbols=symbols, + termination_symbol=terminal_symbol, boundary=boundary, - s_range=r, + modified=modified, ) - # (B, T, r, C) - am_p, lm_p = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) - - t_prob_p = am_p + lm_p - # nonlinear transform - t_prob_p = torch.sigmoid(t_prob_p) + print( + f"unpruned rnnt loss with modified {modified} : {k2_loss}" + ) - pruning_loss = k2.rnnt_loss_pruned( - logits=t_prob_p, + # pruning + k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( + lm=lm, + am=am, symbols=symbols, - ranges=ranges, termination_symbol=terminal_symbol, boundary=boundary, + modified=modified, + return_grad=True, reduction="none", ) - print(f"pruned loss with range {r} : ", pruning_loss) + + for r in range(2, 50, 5): + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=r, + ) + # (B, T, r, C) + am_p, lm_p = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) + + t_prob_p = am_p + lm_p + + # nonlinear transform + t_prob_p = torch.sigmoid(t_prob_p) + + pruned_loss = k2.rnnt_loss_pruned( + logits=t_prob_p, + symbols=symbols, + ranges=ranges, + termination_symbol=terminal_symbol, + boundary=boundary, + modified=modified, + reduction="none", + ) + print(f"pruning loss with range {r} : {pruned_loss}") if __name__ == "__main__": From 2239c39dd406003a593a1ab2c338b9d37c4cb13a Mon Sep 17 00:00:00 2001 From: "Wang, Guanbo" Date: Thu, 24 Feb 2022 21:31:38 -0500 Subject: [PATCH 46/64] Fix Stack (#925) * return the correct layer * unskip the test --- k2/csrc/ragged_test.cu | 5 +---- k2/csrc/ragged_utils.cu | 4 +++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index 801e51f5c..345def9a2 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -265,10 +265,7 @@ TEST(RaggedShapeOpsTest, UnstackRandom) { for (size_t i = 0; i < out.size(); ++i) { out_ptr.emplace_back(&(out[i])); } - // There is a bug in `Stack` for stacking a shape itself, - // not urgent, so skipping here. - // TODO: Remove this line when the bug fixed. - if (out.size() == 1) continue; + auto dest = Stack(axis, out.size(), out_ptr.data()); dest = RemoveEmptyLists(dest, axis); diff --git a/k2/csrc/ragged_utils.cu b/k2/csrc/ragged_utils.cu index e038357dc..f4c942534 100644 --- a/k2/csrc/ragged_utils.cu +++ b/k2/csrc/ragged_utils.cu @@ -92,7 +92,9 @@ RaggedShape IntersperseRaggedLayer(int32_t layer, if (merge_map) *(reinterpret_cast*>(merge_map)) = Range(src[0]->Context(), src[0]->TotSize(layer + 1), 0); - return *src[0]; + std::vector layers; + layers.emplace_back(src[0]->Layers()[layer]); + return RaggedShape(layers); } std::vector row_splits_ptrs_vec(num_srcs); From 5ee082ea55f50e8bd42203ba266945ea5a236ab8 Mon Sep 17 00:00:00 2001 From: drawfish Date: Sun, 27 Feb 2022 09:00:48 +0800 Subject: [PATCH 47/64] Fix 'TypeError' of rnnt_loss_pruned function. (#924) * Fix 'TypeError' of rnnt_loss_simple function. Fix 'TypeError' exception when calling rnnt_loss_simple(..., return_grad=False) at validation steps. * Fix 'MutualInformationRecursionFunction.forward()' return type check error for pytorch < 1.10.x * Modify return type. * Add documents about class MutualInformationRecursionFunction. * Formated code style. * Fix rnnt_loss_smoothed return type. Co-authored-by: gzchenduisheng --- k2/python/k2/mutual_information.py | 135 +++++++++++++++++++++++++---- k2/python/k2/rnnt_loss.py | 2 +- 2 files changed, 121 insertions(+), 16 deletions(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 7e94105e2..a88e6c32a 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -18,18 +18,120 @@ import torch import _k2 from torch import Tensor -from typing import Tuple, Optional, Sequence, Union +from typing import Tuple, Optional, Sequence, Union, List class MutualInformationRecursionFunction(torch.autograd.Function): + """A recursion that is useful in computing mutual information between two + sequences of real vectors, but may be useful more generally in + sequence-to-sequence tasks where monotonic alignment between pairs of + sequences is desired. + """ + @staticmethod def forward( ctx, px: torch.Tensor, py: torch.Tensor, + pxy_grads: List[Optional[torch.Tensor]], boundary: Optional[torch.Tensor] = None, return_grad: bool = False, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor: + """ + Computing mutual information between two sequences of real vectors. + Args: + px: + A torch.Tensor of some floating point type, with shape + ``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the + length of the ``x`` sequence (including representations of + ``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the + length of the ``y`` sequence (including representations of + ``EOS`` symbols but not ``BOS`` symbols). In the mutual + information application, ``px[b][s][t]`` would represent the + following log odds ratio; ignoring the b index on the right + to make the notation more + compact:: + + px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] + + This expression also implicitly includes the log-probability of + choosing to generate an ``x`` value as opposed to a ``y`` value. In + practice it might be computed as ``a + b``, where ``a`` is the log + probability of choosing to extend the sequence of length ``(s,t)`` + with an ``x`` as opposed to a ``y`` value; and ``b`` might in + practice be of the form:: + + log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) + + where ``N`` is the number of terms that the sum over ``t'`` + included, which might include some or all of the other sequences as + well as this one. + + Note: + we don't require ``px`` and py to be contiguous, but the + code assumes for optimization purposes that the ``T`` axis has + stride 1. + + py: + A torch.Tensor of the same dtype as ``px``, with shape + ``[B][S+1][T]``, representing:: + + py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] + + This function does not treat ``x`` and ``y`` differently; the only + difference is that for optimization purposes we assume the last axis + (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` + are contiguous. + + pxy_grads: + A List to store the return grads of ``px`` and ``py`` + if return_grad == True. + Remain unchanged if return_grad == False. + + See `this PR ` for more + information about why we add this parameter. + + Note: + the length of the list must be 2, where the first element + represents the grads of ``px`` and the second one represents + the grads of ``py``. + + boundary: + If supplied, a torch.LongTensor of shape ``[B][4]``, where each + row contains ``[s_begin, t_begin, s_end, t_end]``, + with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T`` + (this implies that empty sequences are allowed). + If not supplied, the values ``[0, 0, S, T]`` will be assumed. + These are the beginning and one-past-the-last positions in the + ``x`` and ``y`` sequences respectively, and can be used if not + all sequences are + of the same length. + + return_grad: + Whether to return grads of ``px`` and ``py``, this grad standing + for the occupation probability is the output of the backward with a + ``fake gradient`` the ``fake gradient`` is the same as the gradient + you'd get if you did + ``torch.autograd.grad((scores.sum()), [px, py])``. + This is useful to implement the pruned version of rnnt loss. + + Returns: + Returns a torch.Tensor of shape ``[B]``, containing the log of + the mutual information between the b'th pair of sequences. This is + defined by the following recursion on ``p[b,s,t]`` (where ``p`` + is of shape ``[B,S+1,T+1]``), representing a mutual information + between sub-sequences of lengths ``s`` and ``t``:: + + p[b,0,0] = 0.0 + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + (if s > 0 or t > 0) + + where we handle edge cases by treating quantities with negative + indexes as **-infinity**. The extension to cases where the + boundaries are specified should be obvious; it just works on + shorter sequences with offsets into ``px`` and ``py``. + """ (B, S, T1) = px.shape T = py.shape[-1] assert T1 in [T, T + 1] @@ -58,21 +160,24 @@ def forward( if return_grad or px.requires_grad or py.requires_grad: ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _k2.mutual_information_backward( - px, py, boundary, p, ans_grad - ) + px, py, boundary, p, ans_grad) ctx.save_for_backward(px_grad, py_grad) - return ans, px_grad, py_grad + assert len(pxy_grads) == 2 + pxy_grads[0] = px_grad + pxy_grads[1] = py_grad + + return ans @staticmethod def backward( - ctx, ans_grad: Tensor, dummy_px_grad: Tensor, dummy_py_grad: Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: + ctx, ans_grad: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: (px_grad, py_grad) = ctx.saved_tensors (B,) = ans_grad.shape ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1) px_grad *= ans_grad py_grad *= ans_grad - return (px_grad, py_grad, None, None) + return (px_grad, py_grad, None, None, None) def mutual_information_recursion( @@ -183,10 +288,10 @@ def mutual_information_recursion( # The following assertions are for efficiency assert px.is_contiguous() assert py.is_contiguous() - - scores, px_grad, py_grad = MutualInformationRecursionFunction.apply( - px, py, boundary, return_grad - ) + pxy_grads = [None, None] + scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, + boundary, return_grad) + px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores @@ -288,9 +393,9 @@ def joint_mutual_information_recursion( # actual derivative w.r.t. the total probs. ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype) - (px_grad, py_grad) = _k2.mutual_information_backward( - px_tot, py_tot, boundary, p, ans_grad - ) + (px_grad, + py_grad) = _k2.mutual_information_backward(px_tot, py_tot, boundary, p, + ans_grad) px_grad = px_grad.reshape(1, B, -1) py_grad = py_grad.reshape(1, B, -1) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 1e68668df..5918d7b9e 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -1082,7 +1082,7 @@ def rnnt_loss_smoothed( modified: bool = False, reduction: Optional[str] = "mean", return_grad: bool = False, -) -> Tensor: +) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: """A simple case of the RNN-T loss, where the 'joiner' network is just addition. From 36e2b8d528d761ba73f99c34a25292873dd5ec02 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Mar 2022 15:39:45 +0800 Subject: [PATCH 48/64] Support torch 1.11.0 and CUDA 11.5 (#931) * Support torch 1.11.0 and CUDA 11.5 --- .github/workflows/build-cpu.yml | 33 ++++++++- .github/workflows/build.yml | 78 +++++++++++++++++-- .github/workflows/build_conda.yml | 99 ++++++++++++++++++++++--- .github/workflows/build_conda_cpu.yml | 39 ++++++++-- .github/workflows/nightly-cpu.yml | 44 +++++++++-- scripts/github_actions/install_cuda.sh | 3 + scripts/github_actions/install_cudnn.sh | 8 +- scripts/github_actions/install_torch.sh | 17 +++++ 8 files changed, 284 insertions(+), 37 deletions(-) diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 4283d8dc4..173e59ac8 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -36,10 +36,35 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04, macos-10.15] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x - python-version: [3.6, 3.7, 3.8, 3.9] + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] + # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x, 1.11.x + # Python 3.10 is for PyTorch 1.11.x + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] exclude: + - python-version: "3.10" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + torch: "1.5.0" + - python-version: "3.10" + torch: "1.5.1" + - python-version: "3.10" + torch: "1.6.0" + - python-version: "3.10" + torch: "1.7.0" + - python-version: "3.10" + torch: "1.7.1" + - python-version: "3.10" + torch: "1.8.0" + - python-version: "3.10" + torch: "1.8.1" + - python-version: "3.10" + torch: "1.9.0" + - python-version: "3.10" + torch: "1.9.1" + - python-version: "3.10" + torch: "1.10.0" + - python-version: "3.10" + torch: "1.10.1" + - python-version: "3.10" + torch: "1.10.2" - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - python-version: 3.9 @@ -48,6 +73,8 @@ jobs: torch: "1.6.0" - python-version: 3.9 torch: "1.7.0" + - python-version: 3.6 # exclude Python 3.6 for [1.11.0] + torch: "1.11.0" steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 769bdf05a..374c8bccf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,6 +39,7 @@ jobs: # from https://download.pytorch.org/whl/torch_stable.html # Note: There are no torch versions for CUDA 11.2 # + # 1.11.x supports: cuda10.2 (default), 11.3, 11.5 # 1.10.x supports: cuda10.2 (default), 11.1, 11.3 # 1.9.x supports: cuda10.2 (default), 11.1 # PyTorch 1.8.x supports: cuda 10.1, 10.2 (default), 11.1 @@ -46,14 +47,43 @@ jobs: # PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default) # PyTorch 1.5.x supports: cuda 10.1, 10.2 (default) # Other PyTorch versions are not tested - # CUDA 11.3 is for torch 1.10 - cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"] + # CUDA 10.1 is for 1.5.x, 1.6.0, 1.7.x, 1.8.x + # CUDA 11.1 is for torch 1.8.x, 1.9.x, 1.10.x + # CUDA 11.3 is for torch 1.10, 1.11.x + # CUDA 11.5 is for torch 1.11.x + cuda: ["10.1", "10.2", "11.0", "11.1", "11.3", "11.5"] gcc: ["7"] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] # - # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10.x - python-version: [3.6, 3.7, 3.8, 3.9] + # torch 1.11.x does not support Python 3.6 + # From torch 1.11.x, it supports Python 3.10 + # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10.x, 11.x + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] exclude: + - cuda: "11.5" # exclude 11.5 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + torch: "1.5.0" + - cuda: "11.5" + torch: "1.5.1" + - cuda: "11.5" + torch: "1.6.0" + - cuda: "11.5" + torch: "1.7.0" + - cuda: "11.5" + torch: "1.7.1" + - cuda: "11.5" + torch: "1.8.0" + - cuda: "11.5" + torch: "1.8.1" + - cuda: "11.5" + torch: "1.9.0" + - cuda: "11.5" + torch: "1.9.1" + - cuda: "11.5" + torch: "1.10.0" + - cuda: "11.5" + torch: "1.10.1" + - cuda: "11.5" + torch: "1.10.2" - cuda: "11.3" # exclude 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1] torch: "1.5.0" - cuda: "11.3" @@ -72,7 +102,7 @@ jobs: torch: "1.9.0" - cuda: "11.3" torch: "1.9.1" - - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0] torch: "1.5.0" - cuda: "11.0" torch: "1.5.1" @@ -92,7 +122,9 @@ jobs: torch: "1.10.1" - cuda: "11.0" torch: "1.10.2" - - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1] + - cuda: "11.0" + torch: "1.11.0" + - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.11.0] torch: "1.5.0" - cuda: "11.1" torch: "1.5.1" @@ -102,7 +134,9 @@ jobs: torch: "1.7.0" - cuda: "11.1" torch: "1.7.1" - - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10.0, 10.1, 10.2] + - cuda: "11.1" + torch: "1.11.0" + - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10.0, 10.1, 10.2, 1.11.0] torch: "1.9.0" - cuda: "10.1" torch: "1.9.1" @@ -112,6 +146,8 @@ jobs: torch: "1.10.1" - cuda: "10.1" torch: "1.10.2" + - cuda: "10.1" + torch: "1.11.0" - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - python-version: 3.9 @@ -120,6 +156,32 @@ jobs: torch: "1.6.0" - python-version: 3.9 torch: "1.7.0" + - python-version: "3.10" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + torch: "1.5.0" + - python-version: "3.10" + torch: "1.5.1" + - python-version: "3.10" + torch: "1.6.0" + - python-version: "3.10" + torch: "1.7.0" + - python-version: "3.10" + torch: "1.7.1" + - python-version: "3.10" + torch: "1.8.0" + - python-version: "3.10" + torch: "1.8.1" + - python-version: "3.10" + torch: "1.9.0" + - python-version: "3.10" + torch: "1.9.1" + - python-version: "3.10" + torch: "1.10.0" + - python-version: "3.10" + torch: "1.10.1" + - python-version: "3.10" + torch: "1.10.2" + - python-version: "3.6" # exclude Python 3.6 for [1.11.0] + torch: "1.11.0" steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/build_conda.yml b/.github/workflows/build_conda.yml index 74ad2238c..b8107d0c6 100644 --- a/.github/workflows/build_conda.yml +++ b/.github/workflows/build_conda.yml @@ -33,10 +33,12 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04] - python-version: [3.6, 3.7, 3.8, 3.9] - cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + cuda: ["10.1", "10.2", "11.0", "11.1", "11.3", "11.5"] # from https://download.pytorch.org/whl/torch_stable.html + # Note: There are no torch versions for CUDA 11.2 # + # 1.11.x supports: cuda10.2 (default), 11.3, 11.5 # PyTorch 1.10.x supports: 10.2 (default), 11.1, 11.3 # PyTorch 1.9.x supports: 10.2 (default), 11.1 # PyTorch 1.8.1 supports: cuda 10.1, 10.2 (default), 11.1 @@ -45,7 +47,8 @@ jobs: # PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup) # PyTorch 1.5.x supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup) # - # PyTorch 1.7.1, 1.8.x, 1.9.x, and 1.10 support 3.6, 3.7, 3.8, 3.9 + # PyTorch 1.11.x supports Python 3.10 + # PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x, and 1.11.x support 3.6, 3.7, 3.8, 3.9 # PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8 # # Other PyTorch versions are not tested @@ -56,9 +59,51 @@ jobs: # https://github.com/csukuangfj/k2/runs/2533830771?check_suite_focus=true # and # https://github.com/NVIDIA/apex/issues/805 - torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] + torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] exclude: - # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + - cuda: "11.5" # exclude cuda 11.5 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + torch: "1.5.0" + - cuda: "11.5" + torch: "1.5.1" + - cuda: "11.5" + torch: "1.6.0" + - cuda: "11.5" + torch: "1.7.0" + - cuda: "11.5" + torch: "1.7.1" + - cuda: "11.5" + torch: "1.8.0" + - cuda: "11.5" + torch: "1.8.1" + - cuda: "11.5" + torch: "1.9.0" + - cuda: "11.5" + torch: "1.9.1" + - cuda: "11.5" + torch: "1.10.0" + - cuda: "11.5" + torch: "1.10.1" + - cuda: "11.5" + torch: "1.10.2" + - cuda: "11.3" # exclude cuda 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1] + torch: "1.5.0" + - cuda: "11.3" + torch: "1.5.1" + - cuda: "11.3" + torch: "1.6.0" + - cuda: "11.3" + torch: "1.7.0" + - cuda: "11.3" + torch: "1.7.1" + - cuda: "11.3" + torch: "1.8.0" + - cuda: "11.3" + torch: "1.8.1" + - cuda: "11.3" + torch: "1.9.0" + - cuda: "11.3" + torch: "1.9.1" + # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0] # torch: "1.5.0" # - cuda: "11.0" # torch: "1.5.1" @@ -78,7 +123,9 @@ jobs: torch: "1.10.1" - cuda: "11.0" torch: "1.10.2" - # - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1] + - cuda: "11.0" + torch: "1.11.0" + # - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.11.0] # torch: "1.5.0" # - cuda: "11.1" # torch: "1.5.1" @@ -88,7 +135,9 @@ jobs: torch: "1.7.0" - cuda: "11.1" torch: "1.7.1" - - cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + - cuda: "11.1" + torch: "1.11.0" + - cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0] torch: "1.9.0" - cuda: "10.1" torch: "1.9.1" @@ -98,14 +147,42 @@ jobs: torch: "1.10.1" - cuda: "10.1" torch: "1.10.2" - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] + - cuda: "10.1" + torch: "1.11.0" + - python-version: "3.9" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] + torch: "1.5.0" + - python-version: "3.9" + torch: "1.5.1" + - python-version: "3.9" + torch: "1.6.0" + - python-version: "3.9" + torch: "1.7.0" + - python-version: "3.10" # exclude Python 3.10 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] torch: "1.5.0" - - python-version: 3.9 + - python-version: "3.10" torch: "1.5.1" - - python-version: 3.9 + - python-version: "3.10" torch: "1.6.0" - - python-version: 3.9 + - python-version: "3.10" torch: "1.7.0" + - python-version: "3.10" + torch: "1.7.1" + - python-version: "3.10" + torch: "1.8.0" + - python-version: "3.10" + torch: "1.8.1" + - python-version: "3.10" + torch: "1.9.0" + - python-version: "3.10" + torch: "1.9.1" + - python-version: "3.10" + torch: "1.10.0" + - python-version: "3.10" + torch: "1.10.1" + - python-version: "3.10" + torch: "1.10.2" + - python-version: "3.6" # exclude Python 3.6 for [1.11.0] + torch: "1.11.0" steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/build_conda_cpu.yml b/.github/workflows/build_conda_cpu.yml index aec3e114d..fe3e552ab 100644 --- a/.github/workflows/build_conda_cpu.yml +++ b/.github/workflows/build_conda_cpu.yml @@ -43,24 +43,51 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04, macos-10.15] - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] # from https://download.pytorch.org/whl/torch_stable.html # + # PyTorch 1.11.x supports 3.7, 3.8, 3.9, 3.10 # PyTorch 1.10, 1.9.x, 1.8.x, and 1.7.1 support 3.6, 3.7, 3.8, 3.9 # PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8 # # Other PyTorch versions are not tested # - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] + torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] exclude: - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] + - python-version: "3.9" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.5.0" - - python-version: 3.9 + - python-version: "3.9" torch: "1.5.1" - - python-version: 3.9 + - python-version: "3.9" torch: "1.6.0" - - python-version: 3.9 + - python-version: "3.9" torch: "1.7.0" + - python-version: "3.10" # exclude Python 3.10 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + torch: "1.5.0" + - python-version: "3.10" + torch: "1.5.1" + - python-version: "3.10" + torch: "1.6.0" + - python-version: "3.10" + torch: "1.7.0" + - python-version: "3.10" + torch: "1.7.1" + - python-version: "3.10" + torch: "1.8.0" + - python-version: "3.10" + torch: "1.8.1" + - python-version: "3.10" + torch: "1.9.0" + - python-version: "3.10" + torch: "1.9.1" + - python-version: "3.10" + torch: "1.10.0" + - python-version: "3.10" + torch: "1.10.1" + - python-version: "3.10" + torch: "1.10.2" + - python-version: "3.6" # exclude Python 3.6 for [1.11.0] + torch: "1.11.0" steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/nightly-cpu.yml b/.github/workflows/nightly-cpu.yml index af052ecdf..8fdc6d0a6 100644 --- a/.github/workflows/nightly-cpu.yml +++ b/.github/workflows/nightly-cpu.yml @@ -39,20 +39,48 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04, macos-10.15] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10 - python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"] + # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x, 1.11.x + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] exclude: - - python-version: 3.9 # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] + - python-version: "3.9" # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] torch: "1.4.0" - - python-version: 3.9 + - python-version: "3.9" torch: "1.5.0" - - python-version: 3.9 + - python-version: "3.9" torch: "1.5.1" - - python-version: 3.9 + - python-version: "3.9" torch: "1.6.0" - - python-version: 3.9 + - python-version: "3.9" torch: "1.7.0" + - python-version: "3.10" # exclude Python 3.10 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] + torch: "1.4.0" + - python-version: "3.10" + torch: "1.5.0" + - python-version: "3.10" + torch: "1.5.1" + - python-version: "3.10" + torch: "1.6.0" + - python-version: "3.10" + torch: "1.7.0" + - python-version: "3.10" + torch: "1.7.1" + - python-version: "3.10" + torch: "1.8.0" + - python-version: "3.10" + torch: "1.8.1" + - python-version: "3.10" + torch: "1.9.0" + - python-version: "3.10" + torch: "1.9.1" + - python-version: "3.10" + torch: "1.10.0" + - python-version: "3.10" + torch: "1.10.1" + - python-version: "3.10" + torch: "1.10.2" + - python-version: "3.6" # exclude Python 3.6 for [1.11.0] + torch: "1.11.0" steps: - uses: actions/checkout@v2 diff --git a/scripts/github_actions/install_cuda.sh b/scripts/github_actions/install_cuda.sh index 358c05b45..b84de8924 100755 --- a/scripts/github_actions/install_cuda.sh +++ b/scripts/github_actions/install_cuda.sh @@ -40,6 +40,9 @@ case "$cuda" in # url=https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run url=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run ;; + 11.5) + url=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run + ;; *) echo "Unknown cuda version: $cuda" exit 1 diff --git a/scripts/github_actions/install_cudnn.sh b/scripts/github_actions/install_cudnn.sh index ceead2d36..8feafbea3 100755 --- a/scripts/github_actions/install_cudnn.sh +++ b/scripts/github_actions/install_cudnn.sh @@ -33,6 +33,12 @@ case $cuda in 11.3) filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz ;; + 11.5) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; + # 11.5) + # filename=cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive.tar.xz + # ;; *) echo "Unsupported cuda version: $cuda" exit 1 @@ -45,7 +51,7 @@ git clone https://huggingface.co/csukuangfj/cudnn cd cudnn git lfs pull --include="$filename" -sudo tar xf ./$filename -C /usr/local +sudo tar xf ./$filename --strip-components=1 -C /usr/local/cuda # save disk space git lfs prune && cd .. && rm -rf cudnn diff --git a/scripts/github_actions/install_torch.sh b/scripts/github_actions/install_torch.sh index 8024729f1..ed813c5a8 100755 --- a/scripts/github_actions/install_torch.sh +++ b/scripts/github_actions/install_torch.sh @@ -108,6 +108,23 @@ case ${torch} in ;; esac ;; + 1.11.*) + case ${cuda} in + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.3) + package="torch==${torch}+cu113" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.5) + package="torch==${torch}+cu115" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; *) echo "Unsupported PyTorch version: ${torch}" exit 1 From f4b42477774e3482ffa4fe6b699abf160e432b9b Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 16 Mar 2022 10:18:54 +0800 Subject: [PATCH 49/64] Implement Rnnt decoding (#926) * first working draft of rnnt decoding * FormatOutput works... * Different num frames for FormatOutput works * Update docs * Fix comments, break advance into several stages, add more docs * Add python wrapper * Add more docs * Minor fixes * Fix comments --- .flake8 | 1 + k2/csrc/CMakeLists.txt | 4 + k2/csrc/algorithms.h | 9 +- k2/csrc/array_of_ragged.cu | 54 ++ k2/csrc/array_of_ragged.h | 200 +++++++ k2/csrc/array_of_ragged_test.cu | 78 +++ k2/csrc/array_ops.h | 65 ++- k2/csrc/array_ops_inl.h | 41 +- k2/csrc/fsa_algo.cu | 35 ++ k2/csrc/fsa_algo.h | 20 + k2/csrc/ragged_ops.cu | 30 +- k2/csrc/ragged_ops.h | 64 ++- k2/csrc/ragged_ops_inl.h | 111 ++-- k2/csrc/ragged_test.cu | 78 +-- k2/csrc/rnnt_decode.cu | 817 ++++++++++++++++++++++++++++ k2/csrc/rnnt_decode.h | 390 +++++++++++++ k2/csrc/rnnt_decode_test.cu | 111 ++++ k2/python/csrc/torch.cu | 2 + k2/python/csrc/torch/CMakeLists.txt | 1 + k2/python/csrc/torch/fsa_algo.cu | 52 +- k2/python/csrc/torch/rnnt_decode.cu | 164 ++++++ k2/python/csrc/torch/rnnt_decode.h | 30 + k2/python/k2/__init__.py | 5 + k2/python/k2/fsa.py | 3 +- k2/python/k2/fsa_algo.py | 23 + k2/python/k2/rnnt_decode.py | 243 +++++++++ k2/python/tests/CMakeLists.txt | 1 + k2/python/tests/rnnt_decode_test.py | 78 +++ 28 files changed, 2550 insertions(+), 160 deletions(-) create mode 100644 k2/csrc/array_of_ragged.cu create mode 100644 k2/csrc/array_of_ragged.h create mode 100644 k2/csrc/array_of_ragged_test.cu create mode 100644 k2/csrc/rnnt_decode.cu create mode 100644 k2/csrc/rnnt_decode.h create mode 100644 k2/csrc/rnnt_decode_test.cu create mode 100644 k2/python/csrc/torch/rnnt_decode.cu create mode 100644 k2/python/csrc/torch/rnnt_decode.h create mode 100644 k2/python/k2/rnnt_decode.py create mode 100644 k2/python/tests/rnnt_decode_test.py diff --git a/.flake8 b/.flake8 index 6cd23ad39..c0ad0c420 100644 --- a/.flake8 +++ b/.flake8 @@ -7,6 +7,7 @@ per-file-ignores = # line break before operator W503 k2/python/k2/rnnt_loss.py: E501, W503 k2/python/tests/rnnt_loss_test.py: W503 + k2/python/tests/rnnt_decode_test.py: W503 exclude = .git, setup.py, diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 1759244ea..f48b895d5 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -45,6 +45,7 @@ add_subdirectory(host) # please keep it sorted set(context_srcs algorithms.cu + array_of_ragged.cu array_ops.cu connect.cu context.cu @@ -65,6 +66,7 @@ set(context_srcs ragged_utils.cu rand.cu rm_epsilon.cu + rnnt_decode.cu tensor.cu tensor_ops.cu thread_pool.cu @@ -142,6 +144,7 @@ target_link_libraries(test_utils PUBLIC context gtest) # please sort the source files alphabetically set(cuda_test_srcs algorithms_test.cu + array_of_ragged_test.cu array_ops_test.cu array_test.cu connect_test.cu @@ -163,6 +166,7 @@ set(cuda_test_srcs ragged_utils_test.cu rand_test.cu rm_epsilon_test.cu + rnnt_decode_test.cu tensor_ops_test.cu tensor_test.cu thread_pool_test.cu diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index 6e11a31cb..0877e70b6 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -119,12 +119,9 @@ class Renumbering { return new2old_; } - /* Return a mapping from new index to old index, with one extra element - containing the total number of kept elements if extra_element == true. - If Keep() can be interpreted as a tails vector, i.e. with 1 at the end - of sub-lists of elements, then New2Old(true) would corresponds to a - row-splits array and Old2New(false) would correspond to a row-ids - array. + /* + Return a mapping from new index to old index, with one extra element + containing the total number of kept elements if extra_element == true. */ Array1 New2Old(bool extra_element) { Array1 &new2old_part = New2Old(); diff --git a/k2/csrc/array_of_ragged.cu b/k2/csrc/array_of_ragged.cu new file mode 100644 index 000000000..cd93434d9 --- /dev/null +++ b/k2/csrc/array_of_ragged.cu @@ -0,0 +1,54 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/array_of_ragged.h" + +namespace k2 { + +Array1OfRaggedShape::Array1OfRaggedShape(RaggedShape *src, int32_t num_srcs) + : num_srcs_(num_srcs) { + K2_CHECK_GE(num_srcs, 1); + K2_CHECK(src); + num_axes_ = src[0].NumAxes(); + c_ = src[0].Context(); + + row_splits_ = + Array2(GetCpuContext(), num_axes_ - 1, num_srcs_); + row_ids_ = Array2(GetCpuContext(), num_axes_ - 1, num_srcs_); + tot_sizes_ = Array1(GetCpuContext(), num_axes_, 0); + + auto row_splits_acc = row_splits_.Accessor(), + row_ids_acc = row_ids_.Accessor(); + int32_t *tot_sizes_data = tot_sizes_.Data(); + + for (int32_t i = 0; i < num_srcs_; ++i) { + K2_CHECK_EQ(src[i].NumAxes(), num_axes_); + K2_CHECK(c_->IsCompatible(*(src[i].Context()))); + for (int32_t j = 1; j < num_axes_; ++j) { + row_splits_acc(j - 1, i) = src[i].RowSplits(j).Data(); + row_ids_acc(j - 1, i) = src[i].RowIds(j).Data(); + tot_sizes_data[j] += src[i].TotSize(j); + } + tot_sizes_data[0] += src[i].TotSize(0); + } + + row_splits_ = row_splits_.To(c_); + row_ids_ = row_ids_.To(c_); +} + +} // namespace k2 diff --git a/k2/csrc/array_of_ragged.h b/k2/csrc/array_of_ragged.h new file mode 100644 index 000000000..31349cf91 --- /dev/null +++ b/k2/csrc/array_of_ragged.h @@ -0,0 +1,200 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_ +#define K2_CSRC_ARRAY_OF_RAGGED_H_ + +#include +#include +#include + +#include "k2/csrc/array.h" +#include "k2/csrc/context.h" +#include "k2/csrc/log.h" +#include "k2/csrc/ragged_ops.h" + +namespace k2 { +/* + Array1OfRaggedShape is a convenience function that gives you easy access + to pointers-of-pointers for an array of ragged shapes. + */ +class Array1OfRaggedShape { + public: + /* + Constructor. + Args: + srcs: pointers to the source shapes, a CPU pointer + num_srcs: the number of source shapes. All shapes must have the + same NumAxes() and must be on the same device. + + TODO: we'll likely, later, add optional args which dictate which of + the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should + enable us to save kernels by combining certain operations across the + axes. + */ + Array1OfRaggedShape(RaggedShape *srcs, int32_t num_srcs); + Array1OfRaggedShape() = default; + + int32_t NumSrcs() const { return num_srcs_; } + int32_t NumAxes() const { return num_axes_; } + + ContextPtr &Context() { return c_; } + + // Returns device-accessible array of row-splits for the individual shapes, + // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this + // Array2 is [NumAxes() - 1][NumSrcs()]. + const Array2 *RowSplits() const { return &row_splits_; } + + // Returns device-accessible vector of row-splits for a particular + // axis, indexed by 0 <= src < num_srcs. + const int32_t **RowSplits(int32_t axis) { + return row_splits_.Row(axis - 1).Data(); + } + + // Returns device-accessible array of row-ids for the individual shapes + // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this + // Array2 is [NumAxes() - 1][NumSrcs()]. + const Array2 *RowIds() const { return &row_ids_; } + + // Returns device-accessible vector of row-splits for a particular + // axis, indexed by 0 <= src < num_srcs. + const int32_t **RowIds(int32_t axis) { return row_ids_.Row(axis - 1).Data(); } + + /* Return the total size on this axis, which is the sum of the TotSize() of + the individual shapes. Requires 0 <= axis < NumAxes() and + for axis=0 the returned value is the same as Dim0(). + */ + int32_t TotSize(int32_t axis) const { return tot_sizes_[axis]; } + + // equivalent to TotSize(0). + int32_t Dim0() const { return TotSize(0); } + + /* Return the device-accessible meta-row-splits, which is the cumulative sum, + along the src axis, of the tot-sizes of the individual arrays. + This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src]; + caution, the indexing is different from RowSplits(), there is no offset. + Also, the meta_row_splits0 is a thing, unlike with regular row-splits + which start from 1. + + Caution: the lengths of the arrays pointed to by the elements of this + Array2 (which contains pointers!) are of course all different, and + these lengths are currently only available + + Implementation note: we can probably just populate this on CPU and transfer + to GPU, this will be faster than invoking an extra kernel in normal cases + when the NumSrcs() is small. [Also: see GetRowInfoMulti()]. + */ + // TODO: implement it... + Array2 MetaRowSplits(); + + // could POSSIBLY add this so this code could be used in functions like + // Stack(). would be like MetaRowSplits but with an extra 1st row containing + // 0,1,2,... We could perhaps create it with 1 extra initial row so this is + // always convenient to output. + // TODO: implement it... + Array2 Offsets(); + + /* + Returns the meta-row-splits for a particular axis, with 0 <= axis < + NumAxes(); this is the cumulative sum of the TotSize(axis) for all of the + sources, with MetaRowSplits(axis).Dim() == NumSrcs() + 1. + + Note: in ragged_ops.cu we refer to this as composed_row_splits + */ + // TODO: implement it... + Array1 MetaRowSplits(int32_t axis); + + /* Return the device-accessible meta-row-ids, which are the row-ids + corresponding to MetaRowSplits(); this tells us, for indexes into the + appended/concatenated array, which source array they belong to, i.e. + elements are in [0,NumSrcs()-1]. + + This cannot be an Array2 because unlike the MetaRowSplits(), all the + row-ids arrays are of different lengths. + + Note: in ragged_ops.cu we refer to this as composed_row_ids. + */ + // TODO: implement it... + Array1 MetaRowIds(); + + /* + Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes(); + this is the row-ids corresponding to MetaRowSplits(axis), and its elements + gives, for indexes into the concatentated shape (concatenated on axis 0),m + which source they come from. E.g. element 100 of MetaRowIds(2) + would tell us which source an idx012 with value 100 into axis 2 of + concatenated array would come from. + */ + // TODO: implement it... + Array1 MetaRowIds(int32_t axis); + + private: + ContextPtr c_; + int32_t num_srcs_; + int32_t num_axes_; + Array2 row_splits_; // shape [num_axes_ - 1][num_srcs_] + Array2 row_ids_; // shape [num_axes_ - 1][num_srcs_] + Array1 tot_sizes_; // dim num_axes_, this is on CPU +}; + +/* + Array1OfRagged is a 1-dimensional array of Ragged. + It is intended for situations where you want to do some operations on + arrays of ragged arrays, without explicitly concatenating them (e.g. to + save time). This is a fairly low-level interface, intended to + be used mostly by CUDA/C++ implementation code. It is a convenience + wrapper that saves you the trouble of creating arrays of pointers. + */ +template +struct Array1OfRagged { + Array1OfRaggedShape shape; + + // Array of the individual values pointers of the source arrays, indexed by + // shape + Array1 values; + + int32_t NumSrcs() const { return values.Dim(); } + ContextPtr &Context() { return shape.Context(); } + + Array1OfRagged() = default; + + /* + Constructor. + Args: + srcs: pointers to the source ragged tensors, a CPU pointer + num_srcs: the number of source ragged tensors. All ragged tensors must + have the same NumAxes() and must be on the same device. + */ + Array1OfRagged(Ragged *srcs, int32_t num_srcs) { + K2_CHECK_GE(num_srcs, 1); + K2_CHECK(srcs); + values = Array1(GetCpuContext(), num_srcs); + T **values_data = values.Data(); + std::vector shapes(num_srcs); + for (int32_t i = 0; i < num_srcs; ++i) { + shapes[i] = srcs[i].shape; + values_data[i] = srcs[i].values.Data(); + } + shape = Array1OfRaggedShape(shapes.data(), num_srcs); + values = values.To(shape.Context()); + } +}; + +} // namespace k2 + +#endif // K2_CSRC_ARRAY_OF_RAGGED_H_ diff --git a/k2/csrc/array_of_ragged_test.cu b/k2/csrc/array_of_ragged_test.cu new file mode 100644 index 000000000..176be139b --- /dev/null +++ b/k2/csrc/array_of_ragged_test.cu @@ -0,0 +1,78 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "k2/csrc/array_of_ragged.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/csrc/ragged_utils.h" +#include "k2/csrc/test_utils.h" + +namespace k2 { + +template +void TestArray1OfRaggedConstruct() { + int32_t num_srcs = 5; + int32_t num_axes = 4; + + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + std::vector> raggeds; + for (int32_t i = 0; i < num_srcs; ++i) { + raggeds.emplace_back( + RandomRagged(0 /*min_value*/, 100 /*max_value*/, + num_axes /*min_num_axes*/, num_axes /*max_num_axes*/, + 0 /*min_num_elements*/, 100 /*max_num_elements*/) + .To(c, true /*copy_all*/)); + } + auto array_of_ragged = Array1OfRagged(raggeds.data(), num_srcs); + for (int32_t j = 1; j < num_axes; ++j) { + const int32_t **row_splits = array_of_ragged.shape.RowSplits(j); + const int32_t **row_ids = array_of_ragged.shape.RowIds(j); + Array1 excepted_row_splits(GetCpuContext(), num_srcs); + Array1 excepted_row_ids(GetCpuContext(), num_srcs); + int32_t **excepted_row_splits_data = excepted_row_splits.Data(); + int32_t **excepted_row_ids_data = excepted_row_ids.Data(); + for (int32_t i = 0; i < num_srcs; ++i) { + excepted_row_splits_data[i] = raggeds[i].RowSplits(j).Data(); + excepted_row_ids_data[i] = raggeds[i].RowIds(j).Data(); + } + excepted_row_splits = excepted_row_splits.To(c); + excepted_row_ids = excepted_row_ids.To(c); + excepted_row_splits_data = excepted_row_splits.Data(); + excepted_row_ids_data = excepted_row_ids.Data(); + Array1 flags(c, 2, 1); + int32_t *flags_data = flags.Data(); + K2_EVAL( + c, num_srcs, lambda_check_pointer, (int32_t i) { + if (row_splits[i] != excepted_row_splits_data[i]) flags_data[0] = 0; + if (row_ids[i] != excepted_row_ids_data[i]) flags_data[1] = 0; + }); + K2_CHECK(Equal(flags, Array1(c, std::vector{1, 1}))); + } + for (int32_t i = 0; i < num_srcs; ++i) { + K2_CHECK_EQ(array_of_ragged.values[i], raggeds[i].values.Data()); + } + } +} + +TEST(Array1OfRagged, Construct) { + TestArray1OfRaggedConstruct(); + TestArray1OfRaggedConstruct(); +} + +} // namespace k2 diff --git a/k2/csrc/array_ops.h b/k2/csrc/array_ops.h index d8c73e8d7..a661badaa 100644 --- a/k2/csrc/array_ops.h +++ b/k2/csrc/array_ops.h @@ -83,9 +83,8 @@ void ExclusiveSum(const Array1 &src, Array1 *dest) { ExclusiveSum(src.Context(), dest_dim, src.Data(), dest->Data()); } -/* wrapper for the ExclusiveSum above (returns array with same dim as src). Will satisfy - ans[i] = sum_{j=0}^{i-1} src[j] for i > 0. - ans[0] is always 0. +/* wrapper for the ExclusiveSum above (returns array with same dim as src). + Will satisfy ans[i] = sum_{j=0}^{i-1} src[j] for i > 0. ans[0] is always 0. */ template Array1 ExclusiveSum(const Array1 &src) { @@ -182,7 +181,6 @@ Array1 Cat(ContextPtr c, int32_t src_size, const Array1 **src); template Array1 Cat(ContextPtr c, int32_t src_size, const Array1 *src); - /* Concatenate arrays, adding offsets to each one. @param [in] offsets Array containing the offsets to add @@ -220,10 +218,6 @@ Array1 CatWithOffsets(const Array1 &offsets, */ Array1 SpliceRowSplits(int32_t src_size, const Array1 **src); - - - - /* Get the reduction value from the array `src` with a binary operator `Op`, initialized with `default_value`. Will be used to implement @@ -377,11 +371,58 @@ Array1 RandUniformArray1(ContextPtr c, int32_t dim, T min_value, T max_value, @param[in] min_value Minimum value allowed in the array @param[in] max_value Maximum value allowed in the array; require max_value >= min_value. + @param [in] seed The seed for the random generator. 0 to + use the default seed. Set it to a non-zero + value for reproducibility. @return Returns the randomly generated array */ template Array2 RandUniformArray2(ContextPtr c, int32_t dim0, int32_t dim1, - T min_value, T max_value); + T min_value, T max_value, int32_t seed = 0); + +/* + Returns a random Array1, normally distributed with `mean` and + `std`. CAUTION: for now, this will be randomly generated on CPU and + then transferred to other devices if c is not a CPU context, so it will be + slow if c is not a CPU context. + Note T should be floating-pointer type. + + @param[in] c Context for this array; note, this function will be slow + if this is not a CPU context + @param [in] dim Dimension, must be > 0 + @param [in] mean Mean value of the normal distribution. + @param [in] std Standard deviation of the normal distribution. + @param [in] seed The seed for the random generator. 0 to + use the default seed. Set it to a non-zero + value for reproducibility. + @return Returns the randomly generated array + */ +template +Array1 RandGaussianArray1(ContextPtr c, int32_t dim, T mean, T std, + int32_t seed = 0); + +/* + Returns a random Array2, each row normally distributed with `mean` and + `std`. CAUTION: for now, this will be randomly generated on CPU and + then transferred to other devices if c is not a CPU context, so it will be + slow if c is not a CPU context. + Note: T should be floating-pointer type. + + @param[in] c Context for this array; note, this function will be slow + if this is not a CPU context + @param [in] dim0 Dimension 0 of answer, must be >= 0. + @param [in] dim1 Dimension 1 of answer, must be >= 0. + @param [in] mean Mean value of the normal distribution. + @param [in] std Standard deviation of the normal distribution. + @param [in] seed The seed for the random generator. 0 to + use the default seed. Set it to a non-zero + value for reproducibility. + + @return Returns the randomly generated array + */ +template +Array2 RandGaussianArray2(ContextPtr c, int32_t dim0, int32_t dim1, + T mean, T std, int32_t seed = 0); /* Return a newly allocated Array1 whose values form a linear sequence, @@ -403,7 +444,6 @@ Array1 Arange(ContextPtr c, T begin, T end, T inc = 1); void RowSplitsToRowIds(const Array1 &row_splits, Array1 *row_ids); - /* This function is like RowSplitsToRowIds() but after subracting the first element of `row_splits_part` from `row_splits_part`. @@ -418,8 +458,6 @@ void RowSplitsToRowIds(const Array1 &row_splits, void RowSplitsToRowIdsOffset(const Array1 &row_splits_part, Array1 *row_ids_part); - - /* Given a vector of row_splits, return a vector of sizes. @@ -660,8 +698,6 @@ Array2 ToContiguous(const Array2 &src); template bool Equal(const Array2 &a, const Array2 &b); - - /* Return true if all elements of the two arrays are equal. Will crash if the sizes differ. @@ -669,7 +705,6 @@ bool Equal(const Array2 &a, const Array2 &b); template bool ApproxEqual(const Array2 &a, const Array2 &b, T tol = T(0.0001)); - /* Index `src` with `indexes`, as in src[indexes]. @param [in] src Array whose elements are to be read diff --git a/k2/csrc/array_ops_inl.h b/k2/csrc/array_ops_inl.h index d24e9a58c..a213bd15a 100644 --- a/k2/csrc/array_ops_inl.h +++ b/k2/csrc/array_ops_inl.h @@ -193,6 +193,19 @@ static void RandArray1Internal(int32_t dim, T min_value, T max_value, T *data, for (int32_t i = 0; i < dim; ++i) data[i] = dis(gen); } +// called in RandGaussianArray1 & RandGaussianArray2 +template ::value, + T>::type * = nullptr> +static void GaussianArray1Internal(int32_t dim, T mean, T std, T *data, + int32_t seed = 0) { + std::random_device rd; + std::mt19937 gen(rd()); + if (seed != 0) gen = std::mt19937(seed); + std::normal_distribution dis(mean, std); + for (int32_t i = 0; i < dim; ++i) data[i] = dis(gen); +} + + } // namespace internal } // namespace k2 @@ -442,17 +455,41 @@ Array1 RandUniformArray1(ContextPtr c, int32_t dim, T min_value, T max_value, template Array2 RandUniformArray2(ContextPtr c, int32_t dim0, int32_t dim1, - T min_value, T max_value) { + T min_value, T max_value, int32_t seed /*= 0*/) { int32_t dim1_extra = RandInt(0, 2), // make it randomly not contiguous. new_dim1 = dim1 + dim1_extra; Array1 array1temp = - RandUniformArray1(c, dim0 * new_dim1, min_value, max_value); + RandUniformArray1(c, dim0 * new_dim1, min_value, max_value, seed); Array2 array2temp(array1temp, dim0, new_dim1); int32_t offset = RandInt(0, dim1_extra); return array2temp.ColArange(offset, offset + dim1); } +template +Array1 RandGaussianArray1(ContextPtr c, int32_t dim, T mean, T std, + int32_t seed /*= 0*/) { + static_assert(std::is_floating_point::value, + "Only support floating-point type"); + Array1 temp(GetCpuContext(), dim); + T *data = temp.Data(); + internal::GaussianArray1Internal(dim, mean, std, data, seed); + return temp.To(c); +} + +template +Array2 RandGaussianArray2(ContextPtr c, int32_t dim0, int32_t dim1, + T mean, T std, int32_t seed /*= 0*/) { + static_assert(std::is_floating_point::value, + "Only support floating-point type"); + Array2 temp(GetCpuContext(), dim0, dim1); + for (int32_t i = 0; i < dim0; ++i) { + internal::GaussianArray1Internal( + dim1, mean, std, temp.Row(i).Data(), seed); + } + return temp.To(c); +} + template Array1 Range(ContextPtr c, int32_t dim, T first_value, T inc /*=1*/) { NVTX_RANGE(K2_FUNC); diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index c7106bc0c..45736d639 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -816,6 +816,41 @@ Fsa CtcTopo(const ContextPtr &c, int32_t max_token, bool modified, } } +Fsa TrivialGraph(const ContextPtr &c, int32_t max_token, + Array1 *aux_labels) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(aux_labels); + int32_t num_arcs = max_token + 2; + Array1 row_splits( + c, std::vector{0, max_token + 2, max_token + 2}); + Array1 row_ids(c, num_arcs); + Array1 values(c, num_arcs); + *aux_labels = Array1(c, num_arcs); + int32_t *row_ids_data = row_ids.Data(), + *aux_labels_data = aux_labels->Data(); + Arc *values_data = values.Data(); + + K2_EVAL( + c, num_arcs, lambda, (int32_t idx)->void { + Arc arc; + arc.score = 0; + arc.src_state = 0; + arc.dest_state = 0; + arc.label = idx; + int32_t aux_label = idx, row_id = 0; + if (idx == num_arcs - 1) { + row_id = 0; + arc.dest_state = 1; + arc.label = -1; + aux_label = -1; + } + row_ids_data[idx] = row_id; + values_data[idx] = arc; + aux_labels_data[idx] = aux_label; + }); + return Ragged(RaggedShape2(&row_splits, &row_ids, num_arcs), values); +} + void ArcSort(Fsa *fsa) { if (fsa->NumAxes() < 2) return; // it is empty SortSublists(fsa); diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 2e14e1c77..8f8779b0f 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -579,6 +579,26 @@ FsaVec LevenshteinGraphs(const Ragged &symbols, Fsa CtcTopo(const ContextPtr &c, int32_t max_token, bool modified, Array1 *aux_labels); +/* + Creat a trivial graph which has only two states, on state 0, there are + `max-token + 1` self loops(i.e. a loop for each symbol, including blank), and + state 1 is the final state. + + @param [in] c The context with which we'll allocate memory for + the trivial graph. + @param [in] max_token The maximum token ID (inclusive). We assume that + token IDs are contiguous (from 1 to `max_token`). + 0 represents blank. + @param [out] aux_labels The output labels of graph will write to this + array, will be reallocated. The label and aux_label + on each arc are equal + (i.e. aux_labels = Arange(0, max_token + 1); + + @return Returns the expected trivial graph on the given device. + */ +Fsa TrivialGraph(const ContextPtr &c, int32_t max_token, + Array1 *aux_labels); + /* Compute the forward shortest path in the tropical semiring. @param [in] fsas Input FsaVec (must have 3 axes). Must be diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 808ad41e2..78bfed8ab 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -1317,14 +1317,16 @@ static void SelectAxis0(RaggedShape &src, const Ragged &indexes, } } -void Unstack(RaggedShape &src, int32_t axis, std::vector *out, +void Unstack(RaggedShape &src, int32_t axis, bool pad_right, + std::vector *out, std::vector> *split_map) { + NVTX_RANGE(K2_FUNC); ContextPtr &c = src.Context(); if (axis == 0) { if (src.NumAxes() == 2) { auto new_src = ComposeRaggedShapes( TrivialShape(c, src.TotSize(0)), src); - return Unstack(new_src, 1, out, split_map); + return Unstack(new_src, 1, pad_right, out, split_map); } auto indexes = Ragged(RegularRaggedShape(c, src.Dim0(), 1), Arange(c, 0, src.Dim0())); @@ -1339,18 +1341,11 @@ void Unstack(RaggedShape &src, int32_t axis, std::vector *out, const int32_t *row_splits_axis = src.RowSplits(axis).Data(), *row_ids_axis = src.RowIds(axis).Data(); - // Get the number of elements of current axis on each sublist - Array1 sublists_size(c, tot_size_axis_minus1); - int32_t *sublists_size_data = sublists_size.Data(); - K2_EVAL(c, tot_size_axis_minus1, lambda_get_sublists_size, (int32_t i) { - sublists_size_data[i] = row_splits_axis[i + 1] - row_splits_axis[i]; - }); - // Each sublist contains the elements of axis `axis`, unstack operation will // split all these elements in a sublist to different RaggedShapes, so the // number of output RaggedShapes is the size of the sublist with max // elements. - int32_t num_out = MaxValue(sublists_size); + int32_t num_out = src.MaxSize(axis); out->resize(num_out); if (split_map != nullptr) split_map->resize(num_out); @@ -1366,8 +1361,14 @@ void Unstack(RaggedShape &src, int32_t axis, std::vector *out, K2_EVAL(c, tot_size_axis, lambda_set_indexes, (int32_t idx01) { int32_t idx0 = row_ids_axis[idx01], idx0x = row_splits_axis[idx0], - idx1 = idx01 - idx0x; - indexes_data[idx1 * tot_size_axis_minus1 + idx0] = idx01; + idx1 = idx01 - idx0x, + idx_row = idx1; + if (!pad_right) { + int32_t idx0x_next = row_splits_axis[idx0 + 1], + num_elems = idx0x_next - idx0x; + idx_row = num_out - num_elems + idx1; + } + indexes_data[idx_row * tot_size_axis_minus1 + idx0] = idx01; }); // To make `DecomposeRaggedShape` work, we add a RegularRaggedShape @@ -1412,6 +1413,11 @@ void Unstack(RaggedShape &src, int32_t axis, std::vector *out, } } +void Unstack(RaggedShape &src, int32_t axis, std::vector *out, + std::vector> *split_map /*= nullptr*/) { + Unstack(src, axis, true/*pad_right*/, out, split_map); +} + RaggedShape Merge(int32_t num_srcs, RaggedShape **src, const Array1 &merge_map, Array1 *merge_map_out) { diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 83c2ca238..85268d287 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -200,6 +200,12 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, be rearranged into output RaggedShapes. @param [out] out The container where the output RaggedShapes would write to. MUST NOT be a nullptr, will be reallocated. + @param [in] pad_right Before unstack, we will (conceptually) pad the + sublists along axis `axis` to the same size with empty lists + `pad_right` tells where to put the padding empty lists, see + the example for more details. + Note, `pad_right` makes no difference when `axis == 0` or + `axis == src.NumAxes() - 1`. @param [out] split_map If not nullptr will store the element-index within `src` telling where the elements of each split RaggedShapes come from. It has the same size of `out`, see notes below @@ -221,7 +227,7 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, only one sublist along `axis == 0`(i.e. the src itself), so the number of output RaggedShape will be equal to `src.Dim0()`. - A small example of unstacking a 3 axes RaggedShape: + A small example of unstacking a 3 axes RaggedShape (with pad_right=true): src: [ [ [ x x ] [ x ] ] [ [ x ] ] ] unstack on axis 0: @@ -233,6 +239,9 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, unstack on axis 1: two sublists along axis 1, the sizes are [2, 1], will produce 2 RaggedShape + think about that we first pad src to [ [ [ x x ] [ x ] ] [ [ x ] [ ] ] ] + then select elements along axis 1 into separate ragged shapes + out[0] : [ [ x x ] [ x ] ] split_map[0] : [0, 1, 3] out[1] : [ [ x ] [ ] ] split_map[1] : [2] @@ -242,6 +251,25 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, out[0] : [ [ x x ] [ x ] ] split_map[0] : [0, 2, 3] out[1] : [ [ x ] [ ] ] split_map[1] : [1] + + for pad_right equals to false: + + src: [ [ [ x x ] [ x ] ] [ [ x ] ] ] + + unstack on axis 1: + + think about that we first pad src to [ [ [ x x ] [ x ] ] [ [ ] [ x ] ] ] + then select elements along axis 1 into separate ragged shapes + + out[0] : [ [ x x ] [ ] ] split_map[0] : [0, 1] + out[1] : [ [ x ] [ x ] ] split_map[1] : [2, 3] + */ +void Unstack(RaggedShape &src, int32_t axis, bool pad_right, + std::vector *out, + std::vector> *split_map = nullptr); + +/* + * The same as above, except that it uses `pad_right=true`. */ void Unstack(RaggedShape &src, int32_t axis, std::vector *out, std::vector> *split_map = nullptr); @@ -1003,6 +1031,12 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, @param [in] src The ragged tensor to be unstacked. @param [in] axis The axis to be removed, all the elements of this axis will be rearranged into output Raggeds. + @param [in] pad_right Before unstack, we will (conceptually) pad the + sublists along axis `axis` to the same size with empty lists + `pad_right` tells where to put the padding empty lists, see + the example for more details. + Note, `pad_right` makes no difference when `axis == 0` or + `axis == src.NumAxes() - 1`. @param [out] out The container where the output ragged tensors would write to. MUST NOT be a nullptr, will be reallocated. @param [out] split_map If not nullptr will store the element-index within @@ -1027,7 +1061,7 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, only one sublist along `axis == 0`(i.e. the src itself), so the number of output ragged will be equal to `src.Dim0()`. - A small example of unstacking a 3 axes Ragged: + A small example of unstacking a 3 axes Ragged (with pad_right = true): src: [ [ [ 1 2 ] [ 3 ] ] [ [ 4 ] ] ] unstack on axis 0: @@ -1039,6 +1073,9 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, unstack on axis 1: two sublists along axis 1, the sizes are [2, 1], will produce 2 ragged tensors + think about that we first pad src to [ [ [ 1 2 ] [ 3 ] ] [ [ 4 ] [ ] ] ] + then select elements along axis 1 into separate raggeds + out[0] : [ [ 1 2 ] [ 4 ] ] split_map[0] : [0, 1, 3] out[1] : [ [ 3 ] [ ] ] split_map[1] : [2] @@ -1048,11 +1085,32 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, out[0] : [ [ 1 3 ] [ 4 ] ] split_map[0] : [0, 2, 3] out[1] : [ [ 2 ] [ ] ] split_map[1] : [1] + + for pad_right equals to false: + + src: [ [ [ 1 2 ] [ 3 ] ] [ [ 4 ] ] ] + + unstack on axis 1: + + think about that we first pad src to [ [ [ 1 2 ] [ 3 ] ] [ [ ] [ 4 ] ] ] + then select elements along axis 1 into separate raggeds + + out[0] : [ [ 1 2 ] [ ] ] split_map[0] : [0, 1] + out[1] : [ [ 3 ] [ 4 ] ] split_map[1] : [2, 3] */ +template +void Unstack(Ragged src, int32_t axis, bool pad_right, + std::vector> *out, + std::vector> *split_map = nullptr); +/* + * The same as above, except that it uses `pad_right=true`. + */ template void Unstack(Ragged src, int32_t axis, std::vector> *out, - std::vector> *split_map = nullptr); + std::vector> *split_map = nullptr) { + Unstack(src, axis, true /*pad_right*/, out, split_map); +} /* Concatenate a list of Ragged to form a single Ragged. diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index 47297fab4..03b91eb12 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -170,7 +170,8 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, } template -void Unstack(Ragged src, int32_t axis, std::vector> *out, +void Unstack(Ragged src, int32_t axis, bool pad_right, + std::vector> *out, std::vector> *split_map /* = nullptr */) { NVTX_RANGE(K2_FUNC); K2_CHECK(out != nullptr); @@ -180,7 +181,7 @@ void Unstack(Ragged src, int32_t axis, std::vector> *out, (split_map != nullptr ? split_map : &split_map_tmp); std::vector shape_out; - Unstack(src.shape, axis, &shape_out, split_map_ptr); + Unstack(src.shape, axis, pad_right, &shape_out, split_map_ptr); out->resize(shape_out.size()); // +1 here because we need to do ExclusiveSum on this Array1 later @@ -815,89 +816,46 @@ Array2 PadRagged(Ragged &src, const std::string &mode, T padding_value) { return res; } -/* Prune a two axes ragged tensor on axis0. - * This is a special case of PruneRagged with axis == 0 and src.NumAxes() == 2, - * To get more details, please refer to the docs for PruneRagged in - * ragged_ops.h. - */ -template -Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { - K2_CHECK_EQ(src.NumAxes(), 2); - const ContextPtr &c = src.Context(); - int32_t total_elements = src.TotSize(0); - Renumbering renumbering(c, total_elements); - - T negative_infinity = -std::numeric_limits::infinity(); - Array1 sub_max(c, total_elements); - MaxPerSublist(src, negative_infinity, &sub_max); - - T max_value = MaxValue(src.values); - - bool prune_with_max_elems = - max_elems > 0 && max_elems < total_elements; - - Array1 order_map; - const int32_t *order_map_data; - if (prune_with_max_elems) { - order_map = Array1(c, total_elements); - Sort>(&sub_max, &order_map); - order_map_data = order_map.Data(); - } - - char *keep_data = renumbering.Keep().Data(); - const T *sub_max_data = sub_max.Data(); - - // prune_with_max_elems means we have sorted the source ragged tensor - if (prune_with_max_elems) { - K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t i) { - bool pruned_by_beam = sub_max_data[i] < max_value - beam; - bool pruned_by_max_elems = i >= max_elems; - keep_data[order_map_data[i]] = - !(pruned_by_max_elems || pruned_by_beam); - }); - } else { - K2_EVAL(c, total_elements, lambda_set_keep, (int32_t i) { - keep_data[i] = sub_max_data[i] >= max_value - beam; - }); - } - return renumbering; -} - -/* Prune a two axes ragged tensor on axis1 - * This is a special case of PruneRagged with axis == 1 and src.NumAxes() == 2, +/* Prune a three axes ragged tensor on axis1 + * This is a special case of PruneRagged with axis == 1 and src.NumAxes() == 3, * To get more details, please refer to the docs for PruneRagged in * ragged_ops.h. */ template Renumbering PruneRaggedAxis1(Ragged &src, T beam, int32_t max_elems) { - K2_CHECK_EQ(src.NumAxes(), 2); - const ContextPtr &c = src.Context(); + K2_CHECK_EQ(src.NumAxes(), 3); + ContextPtr &c = src.Context(); int32_t total_elements = src.TotSize(1); Renumbering renumbering(c, total_elements); T negative_infinity = -std::numeric_limits::infinity(); - Array1 sub_max(c, src.TotSize(0)); + Array1 sub_max(c, src.TotSize(1)); MaxPerSublist(src, negative_infinity, &sub_max); + Array1 best_scores(c, src.TotSize(0)); + Ragged ragged_sub_max(RemoveAxis(src.shape, 2), sub_max); + MaxPerSublist(ragged_sub_max, negative_infinity, &best_scores); + bool prune_with_max_elems = max_elems > 0 && max_elems < total_elements; Array1 order_map; const int32_t *order_map_data; if (prune_with_max_elems) { - Ragged sorted_src = src.Clone(); + Ragged sorted_sub_max = ragged_sub_max.Clone(); order_map = Array1(c, total_elements); - SortSublists>(&sorted_src, &order_map); + SortSublists>(&sorted_sub_max, &order_map); order_map_data = order_map.Data(); } char *keep_data = renumbering.Keep().Data(); const T *sub_max_data = sub_max.Data(), + *best_scores_data = best_scores.Data(), *src_data = src.values.Data(); const int32_t *row_ids1_data = src.RowIds(1).Data(), *row_splits1_data = src.RowSplits(1).Data(); - // prune_with_max_elems means we have sorted the source ragged tensor + // prune_with_max_elems means we have sorted the sub-max ragged tensor if (prune_with_max_elems) { K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t idx01) { // idx01 is the index after sorting @@ -909,14 +867,14 @@ Renumbering PruneRaggedAxis1(Ragged &src, T beam, idx1 = idx01 - idx0x; bool pruned_by_max_elems = idx1 >= max_elems, pruned_by_beam = - src_data[original_idx01] < sub_max_data[idx0] - beam; + sub_max_data[original_idx01] < best_scores_data[idx0] - beam; keep_data[original_idx01] = !(pruned_by_max_elems || pruned_by_beam); }); } else { K2_EVAL(c, total_elements, lambda_set_keep, (int32_t idx01) { int32_t idx0 = row_ids1_data[idx01]; - keep_data[idx01] = src_data[idx01] >= sub_max_data[idx0] - beam; + keep_data[idx01] = sub_max_data[idx01] >= best_scores_data[idx0] - beam; }); } return renumbering; @@ -926,23 +884,32 @@ template Renumbering PruneRagged(Ragged &src, int32_t axis, T beam, int32_t max_elems) { NVTX_RANGE(K2_FUNC); + ContextPtr &c = src.Context(); if (axis == 0) { - auto reduced_src = src; - while (reduced_src.NumAxes() > 2) { - reduced_src = RemoveAxis(reduced_src, reduced_src.NumAxes() - 2); + auto new_shape = ComposeRaggedShapes( + TrivialShape(c, src.TotSize(0)), src.shape); + auto new_src = Ragged(new_shape, src.values); + while (new_src.NumAxes() > 3) { + new_src = RemoveAxis(new_src, new_src.NumAxes() - 2); } - return PruneRaggedAxis0(reduced_src, beam, max_elems); + return PruneRaggedAxis1(new_src, beam, max_elems); } else if (axis == src.NumAxes() - 1) { - auto reduced_src = src; - while (reduced_src.NumAxes() > 2) { - reduced_src = RemoveAxis(reduced_src, 0); + auto new_shape = ComposeRaggedShapes(src.shape, + RegularRaggedShape(c, src.NumElements(), 1)); + auto new_src = Ragged(new_shape, src.values); + while (new_src.NumAxes() > 3) { + new_src = RemoveAxis(new_src, 0); } - return PruneRaggedAxis1(reduced_src, beam, max_elems); + return PruneRaggedAxis1(new_src, beam, max_elems); } else { - RaggedShape top, bottom; - DecomposeRaggedShape(src.shape, axis, &top, &bottom); - Ragged bottom_ragged(bottom, src.values); - return PruneRagged(bottom_ragged, 0, beam, max_elems); + auto new_src = src; + while (--axis) { + new_src = RemoveAxis(new_src, 0); + } + while (new_src.NumAxes() > 3) { + new_src = RemoveAxis(new_src, new_src.NumAxes() - 2); + } + return PruneRaggedAxis1(new_src, beam, max_elems); } } diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index 345def9a2..e2cae8907 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -235,13 +235,15 @@ TEST(RaggedShapeOpsTest, UnstackMoreAxes) { std::vector out_ptr; for (int32_t axis = 0; axis < 4; axis++) { - Unstack(shape, axis, &out, &out_map); + for (bool pad_right : {true, false}) { + Unstack(shape, axis, pad_right, &out, &out_map); - out_ptr.clear(); - for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); - auto dest = Stack(axis, out.size(), out_ptr.data()); - dest = RemoveEmptyLists(dest, axis); - K2_CHECK(Equal(dest, RemoveEmptyLists(shape, axis))); + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); + auto dest = Stack(axis, out.size(), out_ptr.data()); + dest = RemoveEmptyLists(dest, axis); + K2_CHECK(Equal(dest, RemoveEmptyLists(shape, axis))); + } } } } @@ -258,18 +260,18 @@ TEST(RaggedShapeOpsTest, UnstackRandom) { std::vector out_ptr; for (int32_t axis = 0; axis < 4; axis++) { auto random_shape = RemoveEmptyLists(random_shape0, axis); + for (bool pad_right : {true, false}) { + Unstack(random_shape, axis, pad_right, &out, nullptr); - Unstack(random_shape, axis, &out, nullptr); + out_ptr.clear(); + for (size_t i = 0; i < out.size(); ++i) { + out_ptr.emplace_back(&(out[i])); + } + auto dest = Stack(axis, out.size(), out_ptr.data()); + dest = RemoveEmptyLists(dest, axis); - out_ptr.clear(); - for (size_t i = 0; i < out.size(); ++i) { - out_ptr.emplace_back(&(out[i])); + K2_CHECK(Equal(dest, random_shape)); } - - auto dest = Stack(axis, out.size(), out_ptr.data()); - dest = RemoveEmptyLists(dest, axis); - - K2_CHECK(Equal(dest, random_shape)); } } } @@ -3032,7 +3034,7 @@ static void TestPruneRagged() { T beam = 2.0; auto renumbering = PruneRagged(src, 0, beam, 2); - // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // best_score=6.3, max scores for sublists are [6.1, 6.3, 5.0] // no sublist is pruned by beam, 5.0 is pruned by max-elems // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] // [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] @@ -3041,29 +3043,29 @@ static void TestPruneRagged() { beam = 0.1; renumbering = PruneRagged(src, 0, beam, 3); - // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // best_score=6.3, max scores for sublists are [6.1, 6.3, 5.0] // 6.1 & 5.0 are pruned by beam // keep : [ [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] keep_ref = Array1(c, std::vector{0, 1, 0}); K2_CHECK(Equal(renumbering.Keep(), keep_ref)); beam = 2.0; - renumbering = PruneRagged(src, 1, beam, 5); - // best_score=6.3, best scores for sublists are - // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] - // 1.2 & -inf are pruned by beam, 4.4 is pruned by max-elems. - // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 2.2 6.3 ] ] - // [ [ 2.3 5.0 ] ] ] - keep_ref = Array1(c, std::vector{1, 1, 1, 0, 1, 0, 0, 1}); + renumbering = PruneRagged(src, 1, beam, 2); + // best_score=[6.1, 6.3, 5.0], max scores for sublists are + // [[5.2, 5.1, 6.1], [1.2, 6.3, -inf], [4.4, 5.0]] + // 1.2 & -inf are pruned by beam, 5.1 is pruned by max-elems. + // keep : [ [ [ 1.1 2.1 5.2 ] [ 6.1 ] ] [ [ 2.2 6.3 ] ] + // [ [1.3 4.4] [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, std::vector{1, 0, 1, 0, 1, 0, 1, 1}); K2_CHECK(Equal(renumbering.Keep(), keep_ref)); - beam = 1.0; - renumbering = PruneRagged(src, 1, beam, 5); - // best_score=6.3, best scores for sublists are - // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] - // all sublists are pruned by beam, except 6.1 & 6.3 - // keep : [ [ [ 6.1 ] ] [ [ 2.2 6.3 ] ] ] - keep_ref = Array1(c, std::vector{0, 0, 1, 0, 1, 0, 0, 0}); + beam = 0.5; + renumbering = PruneRagged(src, 1, beam, 2); + // best_score=[6.1, 6.3, 5.0], max scores for sublists are + // [[5.2, 5.1, 6.1], [1.2, 6.3, -inf], [4.4, 5.0]] + // all sublists are pruned by beam, except 6.1 & 6.3 & 5.0 + // keep : [ [ [ 6.1 ] ] [ [ 2.2 6.3 ] ] [ [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, std::vector{0, 0, 1, 0, 1, 0, 0, 1}); K2_CHECK(Equal(renumbering.Keep(), keep_ref)); beam = 4.0; @@ -3111,15 +3113,19 @@ static void TestPruneRaggedAndSubsetRagged() { K2_CHECK(Equal(dest, dest_ref)); K2_CHECK(Equal(new2old, new2old_ref)); - beam = 2.0; - renumbering = PruneRagged(src, 1, beam, 5); + beam = 1.0; + renumbering = PruneRagged(src, 1, beam, 2); dest = SubsetRagged(src, renumbering, 1, &new2old); + // [5.0, 6.3, 5.2] + // [ [4.2 5.0] [1.2 6.3 6.1 5.1] [4.4 5.2]] dest_ref = Ragged(c, - "[ [ [ 5.0 3.1 ] ] [ [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " - " [ [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + "[ [ [ 1.1 4.2 2.1 1.8] [ 5.0 3.1 ] ]" + " [ [ 2.2 6.3 ] [ 2.4 6.1 ] ] " + " [ [ 1.3 4.4 ] [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); new2old_ref = Array1( - c, std::vector{4, 5, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18}); + c, std::vector{ + 0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18}); K2_CHECK(Equal(dest, dest_ref)); K2_CHECK(Equal(new2old, new2old_ref)); diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu new file mode 100644 index 000000000..5fc1780f3 --- /dev/null +++ b/k2/csrc/rnnt_decode.cu @@ -0,0 +1,817 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "k2/csrc/fsa.h" +#include "k2/csrc/macros.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/csrc/rnnt_decode.h" + +namespace k2 { +namespace rnnt_decoding { + +std::shared_ptr CreateStream( + const std::shared_ptr &graph) { + K2_CHECK_EQ(graph->shape.NumAxes(), 2); + ContextPtr &c = graph->shape.Context(); + RnntDecodingStream stream; + stream.graph = graph; + stream.num_graph_states = graph->shape.Dim0(); + // initialize to start state + stream.states = Ragged(RegularRaggedShape(c, 1, 1), + Array1(c, std::vector{0})); + stream.scores = Ragged(stream.states.shape, + Array1(c, std::vector{0.0})); + return std::make_shared(stream); +} + +RnntDecodingStreams::RnntDecodingStreams( + std::vector> &srcs, + const RnntDecodingConfig &config) + : attached_(true), num_streams_(srcs.size()), srcs_(srcs), config_(config) { + K2_CHECK_GE(num_streams_, 1); + c_ = srcs_[0]->graph->shape.Context(); + + Array1 num_graph_states(GetCpuContext(), num_streams_); + + int32_t *num_graph_states_data = num_graph_states.Data(); + + std::vector *> states_ptr(num_streams_); + std::vector *> scores_ptr(num_streams_); + std::vector graphs(num_streams_); + + for (int32_t i = 0; i < num_streams_; ++i) { + K2_CHECK(c_->IsCompatible(*(srcs_[i]->graph->shape.Context()))); + num_graph_states_data[i] = srcs_[i]->num_graph_states; + states_ptr[i] = &(srcs_[i]->states); + scores_ptr[i] = &(srcs_[i]->scores); + graphs[i] = *(srcs_[i]->graph); + } + + num_graph_states_ = num_graph_states.To(c_); + states_ = Stack(0, num_streams_, states_ptr.data()); + scores_ = Stack(0, num_streams_, scores_ptr.data()); + graphs_ = Array1OfRagged(graphs.data(), num_streams_); + + // We don't combine prev_frames_ here, will do that when needed, for example + // when we need all prev_frames_ to format output fsas. +} + +void RnntDecodingStreams::TerminateAndFlushToStreams() { + NVTX_RANGE(K2_FUNC); + // return directlly if already detached or no frames decoded. + if (!attached_ || prev_frames_.empty()) return; + std::vector> states; + std::vector> scores; + Unstack(states_, 0, &states); + Unstack(scores_, 0, &scores); + K2_CHECK_EQ(static_cast(states.size()), num_streams_); + K2_CHECK_EQ(static_cast(scores.size()), num_streams_); + + // detatch prev_frames_ + std::vector *> frames_ptr; + for (size_t i = 0; i < prev_frames_.size(); ++i) { + frames_ptr.emplace_back(prev_frames_[i].get()); + } + // statck_frames has a shape of [t][stream][state][arc] + auto stack_frames = Stack(0, prev_frames_.size(), frames_ptr.data()); + // stack_frames now has a shape of [strams][state][arc] + // its Dim0(0) equals to `num_streams_ * prev_frames_.size()` + stack_frames = stack_frames.RemoveAxis(0); + + std::vector> frames; + Unstack(stack_frames, 0, &frames); + + K2_CHECK_EQ(num_streams_ * prev_frames_.size(), frames.size()); + + for (int32_t i = 0; i < num_streams_; ++i) { + for (size_t j = 0; j < prev_frames_.size(); ++j) { + srcs_[i]->prev_frames.emplace_back( + std::make_shared>(frames[j * num_streams_ + i])); + } + srcs_[i]->states = states[i]; + srcs_[i]->scores = scores[i]; + } + + attached_ = false; + prev_frames_.clear(); +} + +void RnntDecodingStreams::GetContexts(RaggedShape *shape, + Array2 *contexts) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(shape); + K2_CHECK(contexts); + K2_CHECK_EQ(states_.NumAxes(), 3); + + // shape has a shape of [stream][context] + *shape = RemoveAxis(states_.shape, 2); + int32_t num_contexts = shape->TotSize(1), + decoder_history_len = config_.decoder_history_len, + vocab_size = config_.vocab_size; + + // contexts has a shape of [num_contexts][decoder_history_len] + *contexts = Array2(c_, num_contexts, decoder_history_len); + const int32_t *shape_row_ids1_data = shape->RowIds(1).Data(), + *states_row_splits2_data = states_.RowSplits(2).Data(); + auto contexts_acc = contexts->Accessor(); + + const int32_t *num_graph_states_data = num_graph_states_.Data(); + const int64_t *states_values_data = states_.values.Data(); + + K2_EVAL2( + c_, num_contexts, decoder_history_len, lambda_set_contexts, + (int32_t row, int32_t col) { + int32_t idx0 = shape_row_ids1_data[row], + num_graph_states = num_graph_states_data[idx0], + state_idx01x = states_row_splits2_data[row]; + // state_value = context_state * num_graph_states + graph_state + // We want to extract token ids from context_state below. + // Think about that the vocab_size=10 & decoder_history_len=3, and we + // have a context_state of "284" and want to extract it into [2, 8, 4]. + // For col=0(to get 2), it performs like `(284 % 10^3) / 10^2`, + // for col=1(to get 8), it problems like `(284 % 10^2) / 10^1`, + // for col=2(to get 4), it performs like `(284 % 10^1) / 10^0`. + int64_t state_value = states_values_data[state_idx01x], + context_state = state_value / num_graph_states, + exp = decoder_history_len - col, + state = context_state % (int64_t)pow(vocab_size, exp); + state = state / (int64_t)pow(vocab_size, exp - 1); + contexts_acc(row, col) = state; + }); +} + +Ragged RnntDecodingStreams::PruneTwice(Ragged &incoming_scores, + Array1 *arcs_new2old) { + NVTX_RANGE(K2_FUNC); + K2_CHECK_EQ(incoming_scores.NumAxes(), 4); + K2_CHECK_EQ(incoming_scores.Dim0(), num_streams_); + // TODO: could introduce a max-arcs per stream to prune on... this would + // be done at this point, before pruning on states... thus avoiding any + // problems created by empty lists, although it's perhaps not an optimal way + // to prune. + + // incoming_scores has a shape of [stream][context][state][arc] + // states_prune is a renumbering on the states axis. + Renumbering states_prune = PruneRagged(incoming_scores, 2 /*axis*/, + config_.beam, config_.max_states); + + Array1 arcs_new2old1; + Ragged temp_scores = + SubsetRagged(incoming_scores, states_prune, 2 /*axis*/, &arcs_new2old1); + + // incoming_scores has a shape of [stream][context][state][arc] + // context_prune is a renumbering on the states context. + Renumbering context_prune = + PruneRagged(temp_scores, 1 /*axis*/, config_.beam, config_.max_contexts); + Array1 arcs_new2old2; + Ragged ans_scores = + SubsetRagged(temp_scores, context_prune, 1 /*axis*/, &arcs_new2old2); + + if (arcs_new2old) *arcs_new2old = arcs_new2old1[arcs_new2old2]; + + return ans_scores; +} + +RaggedShape RnntDecodingStreams::ExpandArcs() { + NVTX_RANGE(K2_FUNC); + int32_t num_states = states_.NumElements(); + + Array1 num_arcs(c_, num_states + 1); + // populate array of num-arcs, indexed by idx012 into `states`. + // These num-arcs are the num-arcs leaving the state in the corresponding + // graph, plus one for the implicit epsilon self-loop. + const int32_t *states_row_ids2_data = states_.RowIds(2).Data(), + *states_row_ids1_data = states_.RowIds(1).Data(), + *num_graph_states_data = num_graph_states_.Data(); + + const int64_t *states_values_data = states_.values.Data(); + const int32_t **graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + int32_t *num_arcs_data = num_arcs.Data(); + + K2_EVAL( + c_, num_states, lambda_set_num_arcs, (int32_t idx012) { + int64_t state_value = states_values_data[idx012]; + int32_t idx0 = states_row_ids1_data[states_row_ids2_data[idx012]], + num_graph_states = num_graph_states_data[idx0], + graph_state = state_value % num_graph_states; + + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + // plus one for the implicit epsilon self-loop + num_arcs_data[idx012] = graph_row_split1_data[graph_state + 1] - + graph_row_split1_data[graph_state] + 1; + }); + + // Compute exclusive sum of num-arcs above. + ExclusiveSum(num_arcs, &num_arcs); + RaggedShape states2arcs_shape = RaggedShape2(&num_arcs, nullptr, -1); + + // unpruned_arcs_shape has 4 axes: [stream][context][state][arc] + RaggedShape unpruned_arcs_shape = + ComposeRaggedShapes(states_.shape, states2arcs_shape); + return unpruned_arcs_shape; +} + +Renumbering RnntDecodingStreams::DoFisrtPassPruning( + RaggedShape &unpruned_arcs_shape, Array2 &logprobs) { + NVTX_RANGE(K2_FUNC); + K2_CHECK_EQ(unpruned_arcs_shape.NumAxes(), 4); + + // Do initial pruning pass on the arcs (because it will be quite a large + // array), populating the `keep` array of a Renumbering object.. The pruning + // rule is: + // (1) keep all epsilon transitions to the next frame, to ensure there is + // no way we can have no states surviving. + // (2) for all other arcs, keep the it if the forward scores after the + // arc would be >= the max_scores_per_stream entry for this stream, + // minus the beam from the config. + Array1 max_scores_per_stream(c_, num_streams_); + double minus_inf = -std::numeric_limits::infinity(); + { + // scores_ has 3 axes: [stream][context][score] + Ragged scores_per_stream = scores_.RemoveAxis(1); + MaxPerSublist(scores_per_stream, minus_inf, &max_scores_per_stream); + } + Renumbering pass1_renumbering(c_, unpruned_arcs_shape.NumElements()); + char *pass1_keep_data = pass1_renumbering.Keep().Data(); + const auto logprobs_acc = logprobs.Accessor(); + const double *scores_data = scores_.values.Data(), + *max_scores_per_stream_data = max_scores_per_stream.Data(); + double beam = config_.beam; + // "uas" is short for unpruned_arcs_shape + const int32_t *uas_row_ids3_data = unpruned_arcs_shape.RowIds(3).Data(), + *uas_row_splits3_data = unpruned_arcs_shape.RowSplits(3).Data(), + *uas_row_ids2_data = unpruned_arcs_shape.RowIds(2).Data(), + *uas_row_ids1_data = unpruned_arcs_shape.RowIds(1).Data(), + *num_graph_states_data = num_graph_states_.Data(); + const int32_t **graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const int64_t *states_values_data = states_.values.Data(); + + Arc **graphs_arcs_data = graphs_.values.Data(); + + K2_EVAL( + c_, unpruned_arcs_shape.NumElements(), lambda_pass1_pruning, + (int32_t idx0123) { + int32_t idx012 = uas_row_ids3_data[idx0123], + idx012x = uas_row_splits3_data[idx012], + idx3 = idx0123 - idx012x; + // keep the implicit epsilon self-loop + if (idx3 == 0) { + pass1_keep_data[idx0123] = 1; + return; + } + + int32_t idx01 = uas_row_ids2_data[idx012], + idx0 = uas_row_ids1_data[idx01], + num_graph_states = num_graph_states_data[idx0]; + + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + int64_t state = states_values_data[idx012]; + int32_t graph_state = state % num_graph_states, + graph_idx0x = graph_row_split1_data[graph_state], + graph_idx01 = + graph_idx0x + idx3 - 1; // minus 1 as the implicit epsilon + // self-loop takes the position 0. + Arc arc = graph_arcs_data[graph_idx01]; + + // keep the epsilon transitions + if (arc.label == 0) { + pass1_keep_data[idx0123] = 1; + return; + } + + // prune the arcs pointting final state. + if (arc.label == -1) { + pass1_keep_data[idx0123] = 0; + return; + } + + double this_score = scores_data[idx012], arc_score = arc.score, + log_prob = logprobs_acc(idx01, arc.label), + score = this_score + arc_score + log_prob, + max_score = max_scores_per_stream_data[idx0]; + // prune with beam + if (score >= max_score - beam) { + pass1_keep_data[idx0123] = 1; + } else { + pass1_keep_data[idx0123] = 0; + } + }); + return pass1_renumbering; +} + +RaggedShape RnntDecodingStreams::GroupStatesByContexts( + Ragged &states) { + NVTX_RANGE(K2_FUNC); + // states has a shape of [stream][state] + K2_CHECK_EQ(states.NumAxes(), 2); + // state_boundaries and context_boundaries are Renumbering objects + // that we use in a slightly different way from normal. + // We populate their Keep() arrays with: + // for context_boundaries: a 1 if next_stream != this_stream, + // or next_context != this_context. + // for state_boundaries: a 1 if next_stream != this_stream or + // next_state (i.e. the 64-bit state index) != this_state. + int32_t cur_num_arcs = states.NumElements(); + Renumbering context_boundaries(c_, cur_num_arcs); + Renumbering state_boundaries(c_, cur_num_arcs); + const int32_t *states_row_ids1_data = states.RowIds(1).Data(), + *num_graph_states_data = num_graph_states_.Data(); + char *context_boundaries_keep_data = context_boundaries.Keep().Data(), + *state_boundaries_keep_data = state_boundaries.Keep().Data(); + const int64_t *states_data = states.values.Data(); + + K2_EVAL( + c_, cur_num_arcs, lambda_set_boundaries, (int32_t idx01) { + int32_t context_keep = 0, state_keep = 0; + + if (idx01 != cur_num_arcs - 1) { + int32_t idx0 = states_row_ids1_data[idx01], + idx0_next = states_row_ids1_data[idx01 + 1], + num_graph_states = num_graph_states_data[idx0], + next_num_graph_states = num_graph_states_data[idx0_next]; + int64_t state_value = states_data[idx01], + next_state_value = states_data[idx01 + 1], + context_state = state_value / num_graph_states, + next_context_state = next_state_value / next_num_graph_states; + if (idx0 != idx0_next || context_state != next_context_state) + context_keep = 1; + if (idx0 != idx0_next || state_value != next_state_value) + state_keep = 1; + } else { + context_keep = 1; + state_keep = 1; + } + + context_boundaries_keep_data[idx01] = context_keep; + state_boundaries_keep_data[idx01] = state_keep; + }); + + Array1 arc2state_row_ids_extra = state_boundaries.Old2New(true), + arc2state_row_ids = state_boundaries.Old2New(), + state_boundaries_new2old = state_boundaries.New2Old(); + + RaggedShape state_arc_shape = + RaggedShape2(nullptr, &arc2state_row_ids, arc2state_row_ids.Dim()); + + Array1 arc2ctx_row_ids_extra = context_boundaries.Old2New(true), + arc2ctx_row_ids = context_boundaries.Old2New(), + context_boundaries_new2old = context_boundaries.New2Old(); + + Array1 state2ctx_row_ids = arc2ctx_row_ids[state_boundaries_new2old]; + + RaggedShape ctx_state_shape = + RaggedShape2(nullptr, &state2ctx_row_ids, state2ctx_row_ids.Dim()); + + RaggedShape &stream_arc_shape = states.shape; + Array1 &arc2stream_row_ids = stream_arc_shape.RowIds(1), + &stream2arc_row_splits = stream_arc_shape.RowSplits(1); + + Array1 ctx2stream_row_ids = + arc2stream_row_ids[context_boundaries_new2old], + stream2ctx_row_splits = + arc2ctx_row_ids_extra[stream2arc_row_splits]; + + RaggedShape stream_ctx_shape = RaggedShape2( + &stream2ctx_row_splits, &ctx2stream_row_ids, ctx2stream_row_ids.Dim()); + + // grouped_arcs_shape has indexes [stream][context][state][arc]. + // It represents the incoming arcs sorted by destination state. + RaggedShape grouped_arcs_shape = + ComposeRaggedShapes3(stream_ctx_shape, ctx_state_shape, state_arc_shape); + return grouped_arcs_shape; +} + +/* + There are several steps to finish this `Advance()` procedure. + (1) Expand arcs based on source states(i.e. the states_ member). + (2) Do initial pruning(beam pruning with some special rules) to reduce the + the number of arcs. + (3) Figure out the dest-states and corresponding scores. + (4) Re-arange dest-states by contexts and states. + (5) Second pass pruning (prune on context axis and state axis). + (6) Update states_, scores_ and prev_frames_. + */ +void RnntDecodingStreams::Advance(Array2 &logprobs) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(attached_) << "Streams terminated."; + K2_CHECK_EQ(logprobs.Dim0(), states_.TotSize(1)); + K2_CHECK_EQ(logprobs.Dim1(), config_.vocab_size); + + ContextPtr c = logprobs.Context(); + K2_CHECK(c_->IsCompatible(*c)); + + // (1) Expand arcs. + // unpruned_arcs_shape has a shape of [stream][context][state][arc] + auto unpruned_arcs_shape = ExpandArcs(); + + // (2) Do initial pruning. + auto pass1_renumbering = DoFisrtPassPruning(unpruned_arcs_shape, logprobs); + + // pass1_arcs_shape has a shape of [stream][context][state][arc] + auto pass1_arcs_shape = + SubsetRaggedShape(unpruned_arcs_shape, pass1_renumbering); + + // (3) Figure out the dest-states and corresponding scores. + // stream_arc_shape is pass1_arcs indexed [stream][arc]. + // We need to rearrange so it's by destination context and state, not source. + RaggedShape stream_arc_shape = RemoveAxis(pass1_arcs_shape, 2); + stream_arc_shape = RemoveAxis(stream_arc_shape, 1); + + // arcs, indexed [stream][context][state][arc]. + Ragged arcs(pass1_arcs_shape); + // dest-states of arcs, incexed [stream][arc] + Ragged states(stream_arc_shape); + // final-scores after arcs, indexed [stream][arc] + Ragged scores(stream_arc_shape); + + // We will populate arcs, states and scores below, it computes + // the destination state for each arc and puts its in 'states', + // and the after-the-arc scores for each arc and puts them in + // 'scores'. + int32_t cur_num_arcs = arcs.NumElements(); + // This renumbering object will be used for renumbering the arcs after we + // fiishing the pruning. + Renumbering renumber_arcs(c_, cur_num_arcs); + char *renumber_arcs_keep_data = renumber_arcs.Keep().Data(); + + const int64_t *this_states_values_data = states_.values.Data(); + int64_t *states_data = states.values.Data(); + const double *this_scores_data = scores_.values.Data(); + double *scores_data = scores.values.Data(); + ArcInfo *arcs_data = arcs.values.Data(); + int32_t vocab_size = config_.vocab_size, + decoder_history_len = config_.decoder_history_len; + // "uas" is short for unpruned_arcs_shape, see above, it is the output of + // `ExpandArcs()`. + const int32_t *num_graph_states_data = num_graph_states_.Data(), + *uas_row_ids3_data = unpruned_arcs_shape.RowIds(3).Data(), + *uas_row_splits3_data = unpruned_arcs_shape.RowSplits(3).Data(), + *uas_row_ids2_data = unpruned_arcs_shape.RowIds(2).Data(), + *uas_row_ids1_data = unpruned_arcs_shape.RowIds(1).Data(), + *pass1_new2old_data = pass1_renumbering.New2Old().Data(); + const int32_t **graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const auto logprobs_acc = logprobs.Accessor(); + Arc **graphs_arcs_data = graphs_.values.Data(); + + K2_EVAL( + c_, cur_num_arcs, lambda_populate_arcs_states_scores, (int32_t arc_idx) { + // Init renumber_arcs to 0, place here to save one kernel. + renumber_arcs_keep_data[arc_idx] = 0; + // The idx below is the index into unpruned_arcs_shape, which has a + // shape of [stream][context][state][arc] + // Note: states_.shape == unpruned_arcs_shape.RemoveAxis(-1). + int32_t idx0123 = pass1_new2old_data[arc_idx], + idx012 = uas_row_ids3_data[idx0123], + idx012x = uas_row_splits3_data[idx012], + idx3 = idx0123 - idx012x, // `idx3 - 1` can be interpreted as + // idx1 into the corresponding + // decoding graph, minus 1 here + // because we add a implicit + // self-loop for each state, see + // `ExpandArcs()`. + idx01 = uas_row_ids2_data[idx012], idx0 = uas_row_ids1_data[idx01], + num_graph_states = num_graph_states_data[idx0]; + int64_t this_state = this_states_values_data[idx012]; + double this_score = this_scores_data[idx012]; + + // handle the implicit epsilon self-loop + if (idx3 == 0) { + states_data[arc_idx] = this_state; + // we assume termination symbol to be 0 here. + scores_data[arc_idx] = this_score + logprobs_acc(idx01, 0); + ArcInfo ai; + ai.graph_arc_idx01 = -1; + ai.score = logprobs_acc(idx01, 0); + arcs_data[arc_idx] = ai; + return; + } + + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + + int64_t this_context_state = this_state / num_graph_states; + int32_t this_graph_state = this_state % num_graph_states, + graph_idx0x = graph_row_split1_data[this_graph_state], + graph_idx01 = graph_idx0x + idx3 - 1; // minus 1 here as + // epsilon self-loop + // takes the position 0. + Arc arc = graph_arcs_data[graph_idx01]; + int64_t context_state = this_context_state; + + // non epsilon transitions, update context_state + if (arc.label != 0) { + // Think about that vocab_size=10, decoder_history_len=3, + // this_context_state=358, arc.label=6, we need to update + // context_state to 586. First, we need to extract 58 from 358, that + // can be done with `358 % 10^2`, then we append 6 to 58, that can be + // done with `58 * 10 + 6`. + context_state = this_context_state % + (int64_t)pow(vocab_size, decoder_history_len - 1); + context_state = context_state * vocab_size + arc.label; + } + + // next state is the state current arc pointting to. + int64_t state = context_state * num_graph_states + arc.dest_state; + states_data[arc_idx] = state; + + double arc_score = arc.score, log_prob = logprobs_acc(idx01, arc.label); + + scores_data[arc_idx] = this_score + arc_score + log_prob; + + ArcInfo ai; + ai.graph_arc_idx01 = graph_idx01; + ai.score = arc_score + log_prob; + arcs_data[arc_idx] = ai; + }); + + // (4) Re-arange dest-states by contexts and states. + // sort states so that we can group states by context-state + Array1 dest_state_sort_new2old(c, states.NumElements()); + SortSublists(&states, &dest_state_sort_new2old); + + auto incoming_arcs_shape = GroupStatesByContexts(states); + + scores.values = scores.values[dest_state_sort_new2old]; + Ragged incoming_scores(incoming_arcs_shape, scores.values); + + // (5) Second pass pruning (prune on context axis and state axis). + // The scores has been re-arange by destination context and state. + Array1 arcs_prune2_new2old; + Ragged pruned_incoming_scores = + PruneTwice(incoming_scores, &arcs_prune2_new2old); + + Ragged pruned_dest_states(pruned_incoming_scores.shape, + states.values[arcs_prune2_new2old]); + + // (6) Update states_, scores_ and prev_frames_. + // Here, use MaxPerSublist to reduce `pruned_incoming_scores` to be per + // state not per arc. (Need to remove last axis from the shape) + int32_t num_dest_states = pruned_incoming_scores.TotSize(2); + Array1 dest_state_scores_values(c_, num_dest_states); + double minus_inf = -std::numeric_limits::infinity(); + MaxPerSublist(pruned_incoming_scores, minus_inf, &dest_state_scores_values); + + // dest_state_scores will be the 'scores' held by this object on the next + // frame + Ragged dest_state_scores(RemoveAxis(pruned_incoming_scores.shape, 3), + dest_state_scores_values); + scores_ = dest_state_scores; + + // dest_states will be the `states` held by this object on the next frame. + // sub-lists along last axis has same values, so we just pick the first one, + // see `GroupStatesByContexts()` for more details. + auto pruned_row_split3 = pruned_dest_states.RowSplits(3); + Ragged dest_states( + dest_state_scores.shape, + pruned_dest_states + .values[pruned_row_split3.Arange(0, pruned_row_split3.Dim() - 1)]); + states_ = dest_states; + + // Update prev_frames_. + // arcs_new2old is new2old map from indexes in `incoming_scores` or + // `pruned_dest_states`, to indexes into `arcs` (remember, we did not renumber + // arcs, it is in the original order after pass1 pruning). + Array1 arcs_new2old = dest_state_sort_new2old[arcs_prune2_new2old]; + + // Renumber the original arcs, we create and initialize the renumbering object + // when we create the arcs, see above. + // arcs has a shape of [stream][context][state][arc] + int32_t *arcs_new2old_data = arcs_new2old.Data(); + K2_EVAL( + c_, arcs_new2old.Dim(), lambda_renumber_arcs, (int32_t idx) { + int32_t arc_idx0123 = arcs_new2old_data[idx]; + renumber_arcs_keep_data[arc_idx0123] = 1; + }); + + // pruned_arcs is indexed [stream][context][src_state][arc]. + Ragged pruned_arcs = SubsetRagged(arcs, renumber_arcs); + + // arcs_dest2src maps from an arc-index in `pruned_dest_states` to an + // arc-index in `pruned_arcs`. This is a permutation of integers + // 0..num_pruned_arcs-1. + Array1 arcs_dest2src = renumber_arcs.Old2New()[arcs_new2old]; + + // reduce_pruned_dest_states has a shape of [stream][state][arc] + // we don't need context axis in prev_frames_. + auto reduce_pruned_dest_states = RemoveAxis(pruned_dest_states, 1); + // "rpds" is short for reduce_pruned_dest_states + const int32_t *rpds_row_ids2_data = + reduce_pruned_dest_states.RowIds(2).Data(), + *rpds_row_ids1_data = + reduce_pruned_dest_states.RowIds(1).Data(), + *rpds_row_splits1_data = + reduce_pruned_dest_states.RowSplits(1).Data(), + *arcs_dest2src_data = arcs_dest2src.Data(); + ArcInfo *pruned_arcs_data = pruned_arcs.values.Data(); + + // Set the dest_state of the arcs in pruned_arcs. + // It works as follows: + // For each arc_idx012 in `reduce_pruned_dest_states`: + // work out the state_idx1, which will be the `dest_state` + // for the corresponding ArcInfo. + // Work out the arc-index (arc_idx0123) in `pruned_arcs`, which + // is just arcs_dest2src[arc_idx012], and then set the dest_state + // in `pruned_arcs`. + K2_EVAL( + c_, reduce_pruned_dest_states.NumElements(), lambda_set_dest_states, + (int32_t idx012) { + int32_t idx01 = rpds_row_ids2_data[idx012], + idx0 = rpds_row_ids1_data[idx01], + idx0x = rpds_row_splits1_data[idx0], idx1 = idx01 - idx0x, + pruned_arc_idx0123 = arcs_dest2src_data[idx012]; + + ArcInfo info = pruned_arcs_data[pruned_arc_idx0123]; + info.dest_state = idx1; + pruned_arcs_data[pruned_arc_idx0123] = info; + }); + + prev_frames_.emplace_back( + std::make_shared>(pruned_arcs.RemoveAxis(1))); +} + +void RnntDecodingStreams::GatherPrevFrames(std::vector &num_frames) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(!attached_) << "Please call TerminateAndFlushToStreams() first."; + K2_CHECK_EQ(num_streams_, static_cast(num_frames.size())); + std::vector *> frames_ptr; + Array1 stream2t_row_splits(GetCpuContext(), num_frames.size() + 1); + + for (size_t i = 0; i < num_frames.size(); ++i) { + stream2t_row_splits.Data()[i] = num_frames[i]; + K2_CHECK_LE(num_frames[i], + static_cast(srcs_[i]->prev_frames.size())); + for (int32_t j = 0; j < num_frames[i]; ++j) { + frames_ptr.push_back(srcs_[i]->prev_frames[j].get()); + } + } + + // frames has a shape of [t][state][arc], + // its Dim0() equals std::sum(num_frames) + auto frames = Stack(0, frames_ptr.size(), frames_ptr.data()); + + stream2t_row_splits = stream2t_row_splits.To(c_); + ExclusiveSum(stream2t_row_splits, &stream2t_row_splits); + auto stream2t_shape = RaggedShape2(&stream2t_row_splits, nullptr, -1); + + // now frames has a shape of [stream][t][state][arc] + frames = Ragged(ComposeRaggedShapes(stream2t_shape, frames.shape), + frames.values); + + std::vector> prev_frames; + Unstack(frames, 1, false /*pad_right*/, &prev_frames); + + prev_frames_.resize(prev_frames.size()); + for (size_t i = 0; i < prev_frames.size(); ++i) { + prev_frames_[i] = std::make_shared>(prev_frames[i]); + } +} + +void RnntDecodingStreams::FormatOutput(std::vector &num_frames, + FsaVec *ofsa, Array1 *out_map) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(!attached_) + << "You can only get outputs after calling TerminateAndFlushToStreams()"; + K2_CHECK(ofsa); + K2_CHECK(out_map); + K2_CHECK_EQ(static_cast(num_frames.size()), num_streams_); + + GatherPrevFrames(num_frames); + + int32_t frames = prev_frames_.size(); + + auto last_frame_shape = prev_frames_[frames - 1]->shape; + + auto pre_final_arcs_shape = ComposeRaggedShapes( + RemoveAxis(last_frame_shape, 1), + RegularRaggedShape(c_, last_frame_shape.NumElements(), 1)); + + auto stream_state_shape = RegularRaggedShape(c_, num_streams_, 1); + auto state_arc_shape = + RegularRaggedShape(c_, stream_state_shape.NumElements(), 0); + auto final_arcs_shape = + ComposeRaggedShapes(stream_state_shape, state_arc_shape); + + RaggedShape oshape; + // see documentation of Stack() in ragged_ops.h for explanation. + Array1 oshape_merge_map; + + Array1 arcs_data_ptrs(GetCpuContext(), frames); + ArcInfo **arcs_data_ptrs_data = arcs_data_ptrs.Data(); + + { + // each of these have 3 axes. + std::vector arcs_shapes(frames + 2); + for (int32_t t = 0; t < frames; t++) { + arcs_shapes[t] = &(prev_frames_[t]->shape); + arcs_data_ptrs_data[t] = prev_frames_[t]->values.Data(); + } + + arcs_shapes[frames] = &pre_final_arcs_shape; + arcs_shapes[frames + 1] = &final_arcs_shape; + + // oshape is a 4-axis ragged tensor which is indexed: + // oshape[stream][t][state_idx][arc_idx] + int32_t axis = 1; + oshape = Stack(axis, frames + 2, arcs_shapes.data(), &oshape_merge_map); + } + + int32_t num_arcs = oshape.NumElements(); + + // transfer to GPU if we're using a GPU + arcs_data_ptrs = arcs_data_ptrs.To(c_); + arcs_data_ptrs_data = arcs_data_ptrs.Data(); + uint32_t *oshape_merge_map_data = oshape_merge_map.Data(); + + *out_map = Array1(c_, num_arcs); + int32_t *out_map_data = out_map->Data(); + + int32_t *oshape_row_ids3 = oshape.RowIds(3).Data(), + *oshape_row_ids2 = oshape.RowIds(2).Data(), + *oshape_row_ids1 = oshape.RowIds(1).Data(), + *oshape_row_splits2 = oshape.RowSplits(2).Data(), + *oshape_row_splits1 = oshape.RowSplits(1).Data(); + + Array1 arcs_out(c_, num_arcs); + Arc *arcs_out_data = arcs_out.Data(); + Arc **graphs_arcs_data = graphs_.values.Data(); + + K2_EVAL( + c_, num_arcs, lambda_set_arcs, (int32_t oarc_idx0123) { + int32_t oarc_idx012 = oshape_row_ids3[oarc_idx0123], + oarc_idx01 = oshape_row_ids2[oarc_idx012], + oarc_idx0 = oshape_row_ids1[oarc_idx01], + oarc_idx0x = oshape_row_splits1[oarc_idx0], + oarc_idx0xx = oshape_row_splits2[oarc_idx0x], + oarc_idx1 = oarc_idx01 - oarc_idx0x, + oarc_idx01x_next = oshape_row_splits2[oarc_idx01 + 1]; + + int32_t m = oshape_merge_map_data[oarc_idx0123], + // actually we won't get t == frames + 1 + // here since those frames have no arcs. + t = m % (frames + 2), + // arc_idx012 into prev_frames_ arcs on time t, index of the arc + // on that frame. + arcs_idx012 = m / (frames + 2); + + K2_CHECK_EQ(t, oarc_idx1); + + ArcInfo arc_info; + Arc arc; + + // all arcs in t == frames point to final state + if (t == frames) { + arc.src_state = oarc_idx012 - oarc_idx0xx; + arc.dest_state = oarc_idx01x_next - oarc_idx0xx; + arc.label = -1; + arc.score = 0; + arc_info.graph_arc_idx01 = -1; + } else { + const Arc *graph_arcs_data = graphs_arcs_data[oarc_idx0]; + const ArcInfo *arcs_data = arcs_data_ptrs_data[t]; + + arc_info = arcs_data[arcs_idx012]; + arc.src_state = oarc_idx012 - oarc_idx0xx; + // Note: the idx1 w.r.t. the frame's `arcs` is an idx2 w.r.t. + // `oshape`. + int32_t dest_state_idx012 = oarc_idx01x_next + arc_info.dest_state; + arc.dest_state = dest_state_idx012 - oarc_idx0xx; + + // graph_arc_idx01 == -1 means this is a implicit epsilon self-loop + if (arc_info.graph_arc_idx01 == -1) { + arc.label = 0; + } else { + arc.label = graph_arcs_data[arc_info.graph_arc_idx01].label; + } + arc.score = arc_info.score; + } + out_map_data[oarc_idx0123] = arc_info.graph_arc_idx01; + arcs_out_data[oarc_idx0123] = arc; + }); + + // Remove axis 1, which corresponds to time. + *ofsa = FsaVec(RemoveAxis(oshape, 1), arcs_out); +} + +} // namespace rnnt_decoding +} // namespace k2 diff --git a/k2/csrc/rnnt_decode.h b/k2/csrc/rnnt_decode.h new file mode 100644 index 000000000..10f2f1bd6 --- /dev/null +++ b/k2/csrc/rnnt_decode.h @@ -0,0 +1,390 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors:Daniel Povey, Wei kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_RNNT_DECODE_H_ +#define K2_CSRC_RNNT_DECODE_H_ + +#include +#include +#include + +#include "k2/csrc/array.h" +#include "k2/csrc/array_of_ragged.h" +#include "k2/csrc/array_ops.h" +#include "k2/csrc/log.h" +#include "k2/csrc/macros.h" + +namespace k2 { +namespace rnnt_decoding { + +/* + The RNN-T decoding implemented here is for what we call "modified" RNN-T, or + equivalently, regular RNN-T but with max_sym_per_frame set to 1. (We can + train with the "modified" option set with some probability, in order to ensure + that the trained model is compatible with decoding with this setting). + + - contexts are finite symbol left-contexts, of length + RnntDecodingConfig::decoder_history_len. conceptuatly they represent a list of + `decoder_history_len` symbols; they are represented numerically as, for + example in the length-2-history case: symbol_{t-1} + symbol_{t-2} * + vocab_size. + + - frames come from the transcription network, they are derived from + sub-sampling of acoustic frames or samples. + */ + +struct RnntDecodingConfig { + RnntDecodingConfig(int32_t vocab_size, int32_t decoder_history_len, + double beam, int32_t max_states, int32_t max_contexts) + : vocab_size(vocab_size), + decoder_history_len(decoder_history_len), + beam(beam), + max_states(max_states), + max_contexts(max_contexts) { + num_context_states = pow(vocab_size, decoder_history_len); + } + + // vocab_size is the largest-symbol plus one. + int32_t vocab_size; + + // decoder_history_len is the number of symbols + // of history the decoder takes; will normally + // be one or two ("stateless decoder"), this + // RNN-T decoding setup does not support + // unlimited decoder context such as with LSTMs + int32_t decoder_history_len; + + // num_context_states == pow(vocab_size, decoder_history_len). + // We need an unique id for each context state, think about that if vocab_size + // equals to 10, we need 0 ~ 9 to distinguish each context state when + // decoder_history_len is 1, and 0 ~ 99 (10 ^ 2 ids) for decoder_history_len + // equals to 2, 0 ~ 999 (10 ^ 3 ids) for decoder_history_len equals to 3. + int32_t num_context_states; + + // `beam` imposes a limit on the score of a state, relative to the + // best-scoring state on the same frame. E.g. 10. + double beam; + + // `max_states` is a limit on the number of distinct states that we allow per + // frame, per stream; the number of states will not be allowed to exceed + // this limit. + int32_t max_states; + + // `max_contexts` is a limit on the number of distinct contexts that we allow + // per frame, per stream; the number of contexts will not be allowed to + // exceed this limit. + int32_t max_contexts; +}; + +struct ArcInfo { + // The arc-index within the RnntDecodingStream::graph that corresponds to this + // arc, or -1 if this arc is a "termination symbol" (these do not appear in + // the graph). + int32_t graph_arc_idx01; + + // The score on the arc; contains both the graph score (if any) and the score + // from the RNN-T joiner. + float score; + + // dest_state is the state index within the array of states on the next frame; + // it would be an (idx1 or idx2) depending whether this is part of an + // RnntDecodingStream or RnntDecodingStreams object. + int32_t dest_state; +}; + +struct RnntDecodingStream { + // `graph` is a pointer to the FSA (decoding graph) that we are decoding this + // stream with. Different streams might have different graphs. This must + // be an Fsa, not FsaVec (i.e. 2 axes). + std::shared_ptr graph; + + // The states number of the graph, equals to graph->shape.Dim0(). + int32_t num_graph_states; + + // `states` contains int64_t which represents the decoder state; this is: + // state_idx = context_state * num_graph_states + graph_state. + // `states` would be indexed + // [context_state][state], i.e. the states are grouped first + // by context_state (they are sorted, to make this possible). + Ragged states; + + // `scores` contains the forward scores of the states in `states`; + // it has the same shape as `states`. + Ragged scores; + + // frames contains the arc information, for previously decoded + // frames, that we can later use to create a lattice. + // It contains Ragged with 2 axes (state, arc). + std::vector>> prev_frames; +}; + +class RnntDecodingStreams { + public: + /* Constructor. Combines multiple RnntDecodingStream objects to create a + RnntDecodingStreams object */ + RnntDecodingStreams(std::vector> &srcs, + const RnntDecodingConfig &config); + + /* This function must be called prior to evaluating the joiner network + for a particular frame. It tells the calling code which contexts + it must evaluate the joiner network for. + + @param [out] shape A RaggedShape with 2 axes, representing + [stream][context], will be written to here. + @param [out] contexts An array of shape + [tot_contexts][decoder_history_len], will be output to + here, where tot_contexts == shape->TotSize(1) and + decoder_history_len comes from the config, it represents + the number of symbols in the context of the decode + network (assumed to be finite). It contains the token ids + into the vocabulary(i.e. `0 <= value < vocab_size`). + */ + void GetContexts(RaggedShape *shape, Array2 *contexts); + + /* + Advance decoding streams by one frame. + + @param [in] logprobs Array of shape [tot_contexts][num_symbols], + containing log-probs of symbols given the contexts output + by `GetContexts()`. Will satisfy + logprobs.Dim0() == states.TotSize(1). + */ + void Advance(Array2 &logprobs); + + /* + Generate the lattice. + + Note: The prev_frames_ only contains decoded by current object, in order to + generate the lattice we will fisrt gather all the previous frames from + individual streams. + + @param [in] num_frames A vector containing the number of frames we want + to gather for each stream (note: the frames we have + ever received). + It MUST satisfy `num_frames.size() == num_streams_`, and + `num_frames[i] < srcs_[i].prev_frames.size()`. + @param [out] ofsa The output lattice will write to here, its num_axes + equals to 3, will be re-allocated. + @param [out] out_map It it a Array1 with Dim() equals to + ofsa.NumElements() containing the idx01 into the graph of + each individual streams, mapping current arc in ofsa to + original decoding graphs. It may contains -1 which means + this arc is a "termination symbol". + */ + void FormatOutput(std::vector &num_frames, FsaVec *ofsa, + Array1 *out_map); + + /* + Terminate the decoding process of current RnntDecodingStreams object, it + will update the states & scores of each individual streams and split & + append the prev_frames_ in current object to the prev_frames of the + individual streams. + + Note: We can not decode with this object anymore after calling + TerminateAndFlushToStreams(). + */ + void TerminateAndFlushToStreams(); + + const ContextPtr &Context() const { return c_; } + const Ragged &States() const { return states_; } + const Ragged &Scores() const { return scores_; } + const Array1 &NumGraphStates() const { return num_graph_states_; } + int32_t NumStreams() const { return num_streams_; } + + // Note: The following three functions should be private members, they are not + // expected to be called outsize this class. We make it public because of the + // extended lambda restrictions, see + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#extended-lambda-restrictions + // for more details. + + /* Expand arcs according to states_. + + `states_` has a shape of [stream][context][state], each of its values is a + combinatioin of context_state and graph_state, this is: + `state = context_state * num_graph_states + graph_state`. The graph_state + is the idx0 of corresponding individual graph(has shape [state][arc]). + This function will expand each of these states into several + arcs(i.e. the out-going arcs of state idx0), so that we can get a new + shape of [stream][context][state][arc]. + + Caution: This function intends to be used in `Advance()` only. + + @return Return the expected 4 axes shape + (i.e.[stream][context][state][arc]). + */ + RaggedShape ExpandArcs(); + + /* + Do initial pruning pass on the arcs (because it will be quite a large + array), populating the `keep` array of a Renumbering object. The pruning + rule is: + (1) keep all epsilon transitions to the next frame, to ensure there is + no way we can have no states surviving. + (2) for all other arcs, keep the it if the forward scores after the + arc would be >= the max_scores_per_stream entry for this stream + minus the beam from the config. + + Caution: This function intends to be used in `Advance()` only. + + @param [in] unprund_arcs_shape The RaggedShape return by `ExpandArcs()`. + @param [in] logprobs Array of shape [tot_contexts][num_symbols], + containing log-probs of symbols given the contexts output + by `GetContexts()`. Will satisfy + logprobs.Dim0() == states_.TotSize(1). + (Note: states_.ToSize(1) == unprund_arcs_shape.Tosize(1)). + + @return Return the renumbering object indicating which arc will be kept. + */ + Renumbering DoFisrtPassPruning(RaggedShape &unprund_arcs_shape, + Array2 &logprobs); + /* + Group states by contexts. + + `states` has a shape of [stream][arc], it contains the sorted values + (per stream) which is: + `state = context_state * num_graph_states + graph_state`, this guarantees + the context_states of the states are sorted too. So that we can easily + separate these states by finding the boundaries of context_states. + + Note: Actually we will group the states by contexts and states, because + we need a shape of [stream][context][state][arc], obviously the + sub-lists along axis -1 contains same values. + + Here is a example: suppose vocab_size=10, num_graph_states=10, + decoder_history_len=2, we have a states like: + + [ [ 112 120 123 125 345 345 ] [ 123 124 567 568 670 ] ] + + the context_states are (context_state = state / num_graph_states): + + [ [ 11 12 12 12 34 34 ] [ 12 12 56 56 67 ] ] + + It will finally be grouped into ([stream][context][state][arc]): + + [ [ [ [ 112 ] ] [ [ 120 ] [ 123 ] [ 125 ] ] [ [ 345 345 ] ] ] + [ [ [ 123 ] [ 124 ] ] [ [ 567 ] [ 568 ] ] [ [ 670 ] ] ] ] + + Caution: This function intends to be used in `Advance()` only. + + @param [in] states A two axes ragged tensor with each sub-list **sorted**. + + @return Return RaggedShape with 4 axes (i.e.[stream][context][state][arc]) + it satisfies `ans.NumElements() == states.NumElements()` and + `ans.Dim0() == states.Dim0()`. + */ + RaggedShape GroupStatesByContexts(Ragged &states); + + private: + /* + Prune the incoming scores based on beam, max-states and max-contexts. + Actually the beam part is not realy necessary, as we already pruned + with the beam, but it doesn't cost anything extra. + Args: + incoming_scores [in] The ragged array of scores to be pruned, indexed + [stream][context][state][arc]. The scores are per arc, but + it's at the state and context level + that we prune, based on settings in this->config, so entire + sub-lists of arcs will be deleted. + arcs_new2old [out] The new2old map of the pruned arcs will be + written to here. + Returns: pruned array of incoming scores, indexed + [stream][context][state][arc]. + */ + Ragged PruneTwice(Ragged &incoming_scores, + Array1 *arcs_new2old); + + /* + Gather all previously decoded frames util now, we need all the previous + frames to generate lattice. + + Note: The prev_frames_ in current object only contains the frames from the + point we created this object to the frame we called + `TerminateAndFlushToStreams()` (i.e. prev_frames_.size() equals to the + times we called `Advance()`. + + @param [in] num_frames A vector containing the number of frames we want + to gather for each stream. + It MUST satisfy `num_frames.size() == num_streams_`, and + `num_frames[i] < srcs_[i].prev_frames.size()`. + */ + void GatherPrevFrames(std::vector &num_frames); + + ContextPtr c_; + + bool attached_; // A flag indicating whether this streams is still attached, + // initialized with true, only if the + // TerminateAndFlushToStreams() being called `attached_` will + // set to false, that means we can not do decoding any more. + + int32_t num_streams_; // The number of RnntDecodingStream + + // RnntDecodingStream pointers. + std::vector> srcs_; + + // The configuration object. + const RnntDecodingConfig config_; + + // array of the individual graphs of the streams, with graphs.NumSrcs() == + // number of streams. All the graphs might actually be the same. + Array1OfRagged graphs_; + + // Number of graph states, per graph; this is used in constructing: + // state_idx = context_state * num_graph_states + graph_state. + // for elements of `states`. + Array1 num_graph_states_; + + // `states` contains int64_t which represents the decoder state; this is: + // state = context_state * num_graph_states + graph_state. + // the num_graph_states is specific to the decoding stream, + // and would be an element of the array `num_graph_states`. + // + // `states` is indexed [stream][context_state][state], i.e. + // i.e. the states are grouped first + // by context_state (they are sorted, to make this possible). + Ragged states_; + + // `scores` contains the forward scores of the states in `states`; + // it has the same shape as `states`. + Ragged scores_; + + // frames contains the arc information for previously decoded + // frames, to be split and appended to the prev_frames of the + // individual streams when we are done with this RnnDecodingStreams + // object. These arrays are indexed [stream][state][arc]. + std::vector>> prev_frames_; +}; + +/* Create a new decoding stream. + + Every sequence(wave data) need a decoding stream, this function is expected + to be called when a new sequence comes. We support different decoding graphs + for different streams. + + @param [in] graph The decoding graph used in this stream. + + @return The pointer to this decoding stream, which will be combined into + `RnntDecodingStreams` to do decoding together with other + sequences in parallel. + */ +std::shared_ptr CreateStream( + const std::shared_ptr &graph); + +} // namespace rnnt_decoding +} // namespace k2 + +#endif // K2_CSRC_RNNT_DECODE_H_ diff --git a/k2/csrc/rnnt_decode_test.cu b/k2/csrc/rnnt_decode_test.cu new file mode 100644 index 000000000..7b4020ff6 --- /dev/null +++ b/k2/csrc/rnnt_decode_test.cu @@ -0,0 +1,111 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Wei kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "k2/csrc/fsa.h" +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/rnnt_decode.h" + +namespace k2 { +namespace rnnt_decoding { +TEST(RnntDecodeStream, CreateRnntDecodeStream) { + for (const auto &c : {GetCpuContext(), GetCudaContext()}) { + Array1 aux_labels; + auto graph = std::make_shared(TrivialGraph(c, 5, &aux_labels)); + auto stream = CreateStream(graph); + K2_CHECK(Equal(*graph, *(stream->graph))); + K2_CHECK(Equal(stream->states, Ragged(c, "[[0]]"))); + K2_CHECK(Equal(stream->scores, Ragged(c, "[[0]]"))); + K2_CHECK_EQ(stream->num_graph_states, graph->Dim0()); + } +} + +// This test does not do any checking, just to make sure it runs normally. +TEST(RnntDecodingStreams, Basic) { + for (const auto &c : {GetCpuContext(), GetCudaContext()}) { + int32_t vocab_size = 6; + auto config = + RnntDecodingConfig(vocab_size, 2 /*decoder_history_len*/, 8.0f /*beam*/, + 2 /*max_states*/, 3 /*max_contexts*/); + + Array1 aux_labels; + auto trivial_graph = std::make_shared(TrivialGraph(c, 5, &aux_labels)); + auto ctc_topo = std::make_shared(CtcTopo(c, 5, false, &aux_labels)); + auto ctc_topo_modified = + std::make_shared(CtcTopo(c, 5, true, &aux_labels)); + std::vector> graphs( + {trivial_graph, ctc_topo, ctc_topo_modified}); + + int32_t num_streams = 3; + std::vector> streams_vec(num_streams); + for (int32_t i = 0; i < num_streams; ++i) { + streams_vec[i] = CreateStream(graphs[RandInt(0, 2)]); + } + auto streams = RnntDecodingStreams(streams_vec, config); + + K2_LOG(INFO) << "states : " << streams.States(); + K2_LOG(INFO) << "scores : " << streams.Scores(); + K2_LOG(INFO) << "num_graph_states : " << streams.NumGraphStates(); + + float mean = 5, std = 3; + RaggedShape context_shape; + Array2 context; + + int32_t steps = 5; + + for (int32_t i = 0; i < steps; ++i) { + streams.GetContexts(&context_shape, &context); + K2_LOG(INFO) << "context_shape : " << context_shape; + K2_LOG(INFO) << "context : " << context; + auto logprobs = RandGaussianArray2(c, context_shape.NumElements(), + vocab_size, mean, std); + K2_LOG(INFO) << "logprobs : " << logprobs; + streams.Advance(logprobs); + K2_LOG(INFO) << "states : " << streams.States(); + K2_LOG(INFO) << "scores : " << streams.Scores(); + } + streams.TerminateAndFlushToStreams(); + + std::vector num_frames(num_streams, steps); + Array1 out_map; + FsaVec ofsa; + streams.FormatOutput(num_frames, &ofsa, &out_map); + K2_LOG(INFO) << "ofsa : " << ofsa; + K2_LOG(INFO) << "out map : " << out_map; + std::vector fsas; + Unstack(ofsa, 0, &fsas); + for (size_t i = 0; i < fsas.size(); ++i) { + K2_LOG(INFO) << FsaToString(fsas[i]); + } + + // different num frames + num_frames = std::vector({2, 5, 4}); + streams.FormatOutput(num_frames, &ofsa, &out_map); + K2_LOG(INFO) << "ofsa : " << ofsa; + K2_LOG(INFO) << "out map : " << out_map; + Unstack(ofsa, 0, &fsas); + for (size_t i = 0; i < fsas.size(); ++i) { + K2_LOG(INFO) << FsaToString(fsas[i]); + } + } +} + +} // namespace rnnt_decoding +} // namespace k2 diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index 132065699..38c954b99 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -35,6 +35,7 @@ #include "k2/python/csrc/torch/nbest.h" #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" +#include "k2/python/csrc/torch/rnnt_decode.h" #include "k2/python/csrc/torch/v2/k2.h" void PybindTorch(py::module &m) { @@ -49,6 +50,7 @@ void PybindTorch(py::module &m) { PybindNbest(m); PybindRagged(m); PybindRaggedOps(m); + PybindRnntDecode(m); k2::PybindV2(m); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 94f005ba8..a05678053 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -12,6 +12,7 @@ set(torch_srcs nbest.cu ragged.cu ragged_ops.cu + rnnt_decode.cu torch_util.cu v2/any.cu diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 2af279240..f4016695d 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -244,8 +244,8 @@ static void PybindIntersectDense(py::module &m) { "intersect_dense", [](FsaVec &a_fsas, DenseFsaVec &b_fsas, torch::optional a_to_b_map, float output_beam, - int32_t max_states, int32_t max_arcs) - -> std::tuple { + int32_t max_states, + int32_t max_arcs) -> std::tuple { DeviceGuard guard(a_fsas.Context()); Array1 arc_map_a; Array1 arc_map_b; @@ -703,12 +703,12 @@ static void PybindReplaceFsa(py::module &m) { static void PybindCtcGraph(py::module &m) { m.def( "ctc_graph", - [](RaggedAny &symbols, bool modified = false) - -> std::pair { + [](RaggedAny &symbols, + bool modified = false) -> std::pair { DeviceGuard guard(symbols.any.Context()); Array1 aux_labels; - FsaVec graph = CtcGraphs(symbols.any.Specialize(), modified, - &aux_labels); + FsaVec graph = + CtcGraphs(symbols.any.Specialize(), modified, &aux_labels); torch::Tensor tensor = ToTorch(aux_labels); return std::make_pair(graph, tensor); }, @@ -745,19 +745,46 @@ static void PybindCtcTopo(py::module &m) { py::arg("modified") = false); } +static void PybindTrivialGraph(py::module &m) { + m.def( + "trivial_graph", + [](int32_t max_token, torch::optional device = {}) + -> std::pair { + ContextPtr context = GetContext(device.value_or(torch::Device("cpu"))); + DeviceGuard guard(context); + Array1 aux_labels; + Fsa fsa = TrivialGraph(context, max_token, &aux_labels); + torch::Tensor tensor = ToTorch(aux_labels); + return std::make_pair(fsa, tensor); + }, + py::arg("max_token"), py::arg("device") = py::none()); + + m.def( + "trivial_graph", + [](int32_t max_token, torch::optional device = {}) + -> std::pair { + ContextPtr context = GetContext(torch::Device(device.value_or("cpu"))); + DeviceGuard guard(context); + Array1 aux_labels; + Fsa fsa = TrivialGraph(context, max_token, &aux_labels); + torch::Tensor tensor = ToTorch(aux_labels); + return std::make_pair(fsa, tensor); + }, + py::arg("max_token"), py::arg("device") = py::none()); +} + static void PybindLevenshteinGraph(py::module &m) { m.def( "levenshtein_graph", [](RaggedAny &symbols, float ins_del_score = -0.501, - bool need_score_offset = - true) -> std::tuple> { + bool need_score_offset = true) + -> std::tuple> { DeviceGuard guard(symbols.any.Context()); Array1 aux_labels; Array1 score_offsets; - FsaVec graph = LevenshteinGraphs(symbols.any.Specialize(), - ins_del_score, &aux_labels, - need_score_offset ? &score_offsets : nullptr); + FsaVec graph = LevenshteinGraphs( + symbols.any.Specialize(), ins_del_score, &aux_labels, + need_score_offset ? &score_offsets : nullptr); torch::Tensor aux_labels_tensor = ToTorch(aux_labels); torch::optional score_offsets_tensor; if (need_score_offset) score_offsets_tensor = ToTorch(score_offsets); @@ -791,5 +818,6 @@ void PybindFsaAlgo(py::module &m) { k2::PybindReplaceFsa(m); k2::PybindShortestPath(m); k2::PybindTopSort(m); + k2::PybindTrivialGraph(m); k2::PybindUnion(m); } diff --git a/k2/python/csrc/torch/rnnt_decode.cu b/k2/python/csrc/torch/rnnt_decode.cu new file mode 100644 index 000000000..6678bc7b8 --- /dev/null +++ b/k2/python/csrc/torch/rnnt_decode.cu @@ -0,0 +1,164 @@ +/** + * @brief python wrappers for rnnt_decode.h + * + * @copyright + * Copyright 2022 Xiaomi Corp. (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "k2/csrc/device_guard.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/rnnt_decode.h" +#include "k2/python/csrc/torch/rnnt_decode.h" +#include "k2/python/csrc/torch/torch_util.h" + +namespace k2 { +static void PybindRnntDecodingConfig(py::module &m) { + using PyClass = rnnt_decoding::RnntDecodingConfig; + py::class_ config(m, "RnntDecodingConfig"); + config.def(py::init(), + py::arg("vocab_size"), py::arg("decoder_history_len"), + py::arg("beam"), py::arg("max_states"), py::arg("max_contexts"), + R"( + Construct a RnntDecodingConfig object, it contains the parameters + needed by rnnt decoding. + + Args: + vocab_size: + It indicates how many symbols we are using, euqals the + largest-symbol plus one. + decoder_history_len: + `decoder_history_len` is the number of symbols of history the + decoder takes; will normally be one or two + ("stateless decoder"), our RNN-T decoding setup does not + support unlimited decoder context such as with LSTMs. + beam: + `beam` imposes a limit on the score of a state, relative to the + best-scoring state on the same frame. E.g. 10. + max_states: + `max_states` is a limit on the number of distinct states that + we allow per frame, per stream; the number of states will not + be allowed to exceed this limit. + max_contexts: + `max_contexts` is a limit on the number of distinct contexts + that we allow per frame, per stream; the number of contexts + will not be allowed to exceed this limit. + )"); + + config.def_readwrite("vocab_size", &PyClass::vocab_size) + .def_readwrite("decoder_history_len", &PyClass::decoder_history_len) + .def_readwrite("beam", &PyClass::beam) + .def_readwrite("max_states", &PyClass::max_states) + .def_readwrite("max_contexts", &PyClass::max_contexts); + + config.def("__str__", [](const PyClass &self) -> std::string { + std::ostringstream os; + os << "RnntDecodingConfig : {\n" + << " vocab_size : " << self.vocab_size << "\n" + << " decoder_history_len : " << self.decoder_history_len << "\n" + << " beam : " << self.beam << "\n" + << " max_states : " << self.max_states << "\n" + << " max_contexts : " << self.max_contexts << "\n" + << "}"; + return os.str(); + }); +} + +static void PybindRnntDecodingStream(py::module &m) { + using PyClass = rnnt_decoding::RnntDecodingStream; + py::class_> stream(m, "RnntDecodingStream"); + + stream.def("__str__", [](const PyClass &self) -> std::string { + std::ostringstream os; + os << "RnntDecodingStream : {\n" + << " num graph states : " << self.graph->Dim0() << "\n" + << " num graph arcs : " << self.graph->NumElements() << "\n" + << " num contexts : " << self.states.Dim0() << "\n" + << " num states : " << self.states.NumElements() << "\n" + << " num prev frames : " << self.prev_frames.size() << "\n" + << "}"; + return os.str(); + }); + + m.def("create_rnnt_decoding_stream", + [](Fsa &graph) -> std::shared_ptr { + DeviceGuard guard(graph.Context()); + return rnnt_decoding::CreateStream(std::make_shared(graph)); + }); +} + +static void PybindRnntDecodingStreams(py::module &m) { + using PyClass = rnnt_decoding::RnntDecodingStreams; + py::class_ streams(m, "RnntDecodingStreams"); + + streams.def(py::init( + [](std::vector> &srcs, + const rnnt_decoding::RnntDecodingConfig &config) + -> std::unique_ptr { + K2_CHECK_GE(srcs.size(), 1); + DeviceGuard guard(srcs[0]->graph->Context()); + return std::make_unique(srcs, config); + })); + + streams.def("advance", [](PyClass &self, torch::Tensor logprobs) -> void { + DeviceGuard guard(self.Context()); + logprobs = logprobs.to(torch::kFloat); + Array2 logprobs_array = FromTorch(logprobs, Array2Tag{}); + self.Advance(logprobs_array); + }); + + streams.def("get_contexts", + [](PyClass &self) -> std::pair { + DeviceGuard guard(self.Context()); + RaggedShape shape; + Array2 contexts; + self.GetContexts(&shape, &contexts); + torch::Tensor contexts_tensor = ToTorch(contexts); + return std::make_pair(shape, contexts_tensor); + }); + + streams.def("terminate_and_flush_to_streams", [](PyClass &self) -> void { + DeviceGuard guard(self.Context()); + self.TerminateAndFlushToStreams(); + }); + + streams.def( + "format_output", + [](PyClass &self, + std::vector &num_frames) -> std::pair { + DeviceGuard guard(self.Context()); + FsaVec ofsa; + Array1 out_map; + self.FormatOutput(num_frames, &ofsa, &out_map); + torch::Tensor out_map_tensor = ToTorch(out_map); + return std::make_pair(ofsa, out_map_tensor); + }); +} + +} // namespace k2 + +void PybindRnntDecode(py::module &m) { + k2::PybindRnntDecodingConfig(m); + k2::PybindRnntDecodingStream(m); + k2::PybindRnntDecodingStreams(m); +} diff --git a/k2/python/csrc/torch/rnnt_decode.h b/k2/python/csrc/torch/rnnt_decode.h new file mode 100644 index 000000000..e61a7e97c --- /dev/null +++ b/k2/python/csrc/torch/rnnt_decode.h @@ -0,0 +1,30 @@ +/** + * @brief python wrappers for rnnt_decode.h + * + * @copyright + * Copyright 2022 Xiaomi Corp. (author: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_RNNT_DECODE_H_ +#define K2_PYTHON_CSRC_TORCH_RNNT_DECODE_H_ + +#include "k2/python/csrc/torch.h" + +void PybindRnntDecode(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_RNNT_DECODE_H_ diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 0f9d4e51b..1091f4348 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -50,6 +50,7 @@ from .fsa_algo import replace_fsa from .fsa_algo import shortest_path from .fsa_algo import top_sort +from .fsa_algo import trivial_graph from .fsa_algo import union from .fsa_properties import to_str as properties_to_str from .mutual_information import joint_mutual_information_recursion @@ -61,6 +62,10 @@ from .ops import index_fsa from .ops import index_select +from .rnnt_decode import RnntDecodingConfig +from .rnnt_decode import RnntDecodingStream +from .rnnt_decode import RnntDecodingStreams + from .rnnt_loss import do_rnnt_pruning from .rnnt_loss import get_rnnt_logprobs from .rnnt_loss import get_rnnt_logprobs_joint diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index aabbc8dd0..d924c9b93 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -283,8 +283,7 @@ def to_str(self, openfst: bool = False) -> str: ans += 'FsaVec[' + str(i) + ']: ' + _k2.fsa_to_str( ragged_arc, openfst=openfst, extra_labels=[x[start:end] for x in extra_labels], - ragged_labels=[_k2.ragged_int_arange(x, 0, start, end) - for x in ragged_labels]) + ragged_labels=[x[start:end] for x in ragged_labels]) ans += 'properties_str = ' + _k2.fsa_properties_as_str( self._properties) + '.' for name, value in self.named_tensor_attr(include_scores=False): diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 5c608cb77..7b893782e 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1061,6 +1061,29 @@ def ctc_topo(max_token: int, return fsa +def trivial_graph(max_token: int, + device: Optional[Union[torch.device, str]] = None) -> k2.Fsa: + ''' + Creat a trivial graph which has only two states, on state 0, there are + `max-token + 1` self loops(i.e. a loop for each symbol, including blank), + and state 1 is the final state. + + Args: + max_token: + The maximum token ID (inclusive). We assume that token IDs + are contiguous (from 1 to `max_token`). 0 represents blank. + device: + Optional. It can be either a string (e.g., 'cpu', + 'cuda:0') or a torch.device. + If it is None, then the returned FSA is on CPU. + + Returns: Returns the expected trivial graph on the given device. + ''' + ragged_arc, aux_labels = _k2.trivial_graph(max_token, device) + fsa = Fsa(ragged_arc, aux_labels=aux_labels) + return fsa + + def levenshtein_graph( symbols: Union[k2.RaggedTensor, List[List[int]]], ins_del_score: float = -0.501, diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py new file mode 100644 index 000000000..bd6a4bc66 --- /dev/null +++ b/k2/python/k2/rnnt_decode.py @@ -0,0 +1,243 @@ +# Copyright 2022 Xiaomi Corp. (author: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from typing import Tuple + +import k2 +import torch +import _k2 + +from k2 import Fsa +from k2 import RaggedShape +from k2 import RaggedTensor +from torch import Tensor +from .ops import index_select +from _k2 import RnntDecodingConfig + + +class RnntDecodingStream(object): + """Create a new rnnt decoding stream. + + Every sequence(wave data) need a decoding stream, this function is expected + to be called when a new sequence comes. We support different decoding graphs + for different streams. + + Args: + graph: + The decoding graph used in this stream. + + Returns: + A rnnt decoding stream object, which will be combined into + `RnntDecodingStreams` to do decoding together with other + sequences in parallel. + """ + def __init__(self, fsa: Fsa) -> None: + self.fsa = fsa + self.stream = _k2.create_rnnt_decoding_stream(fsa.arcs) + self.device = fsa.device + + """Return a string representation of this object + + For visualization and debug only. + """ + def __str__(self) -> str: + return f"{self.stream}, device : {self.device}\n" + + +class RnntDecodingStreams(object): + """ + Combines multiple RnntDecodingStream objects to create a RnntDecodingStreams + object, then all these RnntDecodingStreams can do decoding in parallel. + + Args: + src_streams: + A list of RnntDecodingStream object to be combined. + config: + A configuration object which contains decoding parameters like + `vocab-size`, `decoder_history_len`, `beam`, `max_states`, + `max_contexts` etc. + + Returns: + Return a RnntDecodingStreams object. + """ + def __init__( + self, src_streams: List[RnntDecodingStream], config: RnntDecodingConfig + ) -> None: + assert len(src_streams) > 0 + self.num_streams = len(src_streams) + self.src_streams = src_streams + self.device = self.src_streams[0].device + streams = [x.stream for x in self.src_streams] + self.streams = _k2.RnntDecodingStreams(streams, config) + + '''Return a string representation of this object + + For visualization and debug only. + ''' + def __str__(self) -> str: + s = f"num_streams : {self.num_streams}\n" + for i in range(self.num_streams): + s += f"stream[{i}] : {self.src_streams[i]}" + return s + + """ + This function must be called prior to evaluating the joiner network + for a particular frame. It tells the calling code which contexts + it must evaluate the joiner network for. + + Returns: + Return a two elements tuple containing a RaggedShape and a tensor. + shape: + A RaggedShape with 2 axes, representing [stream][context]. + contexts: + A tensor of shape [tot_contexts][decoder_history_len], where + tot_contexts == shape->TotSize(1) and decoder_history_len comes from + the config, it represents the number of symbols in the context of the + decode network (assumed to be finite). It contains the token ids + into the vocabulary(i.e. `0 <= value < vocab_size`). + """ + def get_contexts(self) -> Tuple[RaggedShape, Tensor]: + return self.streams.get_contexts() + + """ + Advance decoding streams by one frame. + + Args: + logprobs: + A tensor of shape [tot_contexts][num_symbols], containing log-probs of + symbols given the contexts output by `get_contexts()`. Will satisfy + logprobs.Dim0() == shape.TotSize(1). + """ + def advance(self, logprobs: Tensor) -> None: + self.streams.advance(logprobs) + + """ + Terminate the decoding process of current RnntDecodingStreams objects. + It will update the decoding states and store the decoding results currently + got to each of the individual streams. + + Note: We can not decode with this object anymore after calling + terminate_and_flush_to_streams(). + """ + def terminate_and_flush_to_streams(self) -> None: + self.streams.terminate_and_flush_to_streams() + + """ + Generate the lattice Fsa currently got. + + Note: The attributes of the generated lattice is a union of the attributes + of all the decoding graphs. For example, a streams contains three + individual stream, each stream has its own decoding graphs, graph[0] + has attributes attr1, attr2; graph[1] has attributes attr1, attr3; + graph[2] has attributes attr3, attr4; then the generated lattice has + attributes attr1, attr2, attr3, attr4. + + Args: + num_frames: + A List containing the number of frames we want to gather for each stream + (note: the frames we have ever received for the corresponding stream). + It MUST satisfy `len(num_frames) == self.num_streams`. + Returns: + Return the lattice Fsa with all the attributes propagated. The returned + Fsa has 3 axes with `fsa.dim0==self.num_streams`. + """ + def format_output(self, num_frames: List[int]) -> Fsa: + assert len(num_frames) == self.num_streams + + ragged_arcs, out_map = self.streams.format_output(num_frames) + fsa = Fsa(ragged_arcs) + + # propagate attributes + tensor_attr_info = dict() + # gather the attributes info of all the decoding graphs, + for i in range(self.num_streams): + src = self.src_streams[i].fsa + for name, value in src.named_tensor_attr(include_scores=False): + if name not in tensor_attr_info: + filler = 0.0 + if isinstance(value, Tensor): + filler = float(src.get_filler(name)) + dtype = value.dtype + tensor_type = "Tensor" + else: + assert isinstance(value, k2.RaggedTensor) + # Only integer types ragged attributes are supported now + assert value.dtype == torch.int32 + assert value.num_axes == 2 + dtype = torch.int32 + tensor_type = "RaggedTensor" + tensor_attr_info[name] = { + "filler": filler, + "dtype": dtype, + "tensor_type": tensor_type, + } + # combine the attributes propagating from different decoding graphs + for name, info in tensor_attr_info.items(): + values = list() + start = 0 + for i in range(self.num_streams): + src = self.src_streams[i].fsa + device = self.device + num_arcs = fsa[i].num_arcs + arc_map = out_map[start:start + num_arcs] + start = start + num_arcs + if hasattr(src, name): + value = getattr(src, name) + if info["tensor_type"] == "Tensor": + assert isinstance(value, Tensor) + new_value = index_select( + value, arc_map, default_value=filler + ) + else: + assert isinstance(value, RaggedTensor) + # Only integer types ragged attributes are supported now + assert value.num_axes == 2 + assert value.dtype == torch.int32 + new_value, _ = value.index( + arc_map, axis=0, need_value_indexes=False + ) + else: + if info["tensor_type"] == "Tensor": + # fill with filler value + new_value = torch.tensor( + [filler] * num_arcs, + dtype=info["dtype"], + device=device, + ) + else: + # fill with empty RaggedTensor + new_value = RaggedTensor( + torch.empty( + (num_arcs, 0), + dtype=info["dtype"], + device=device, + ) + ) + values.append(new_value) + if info["tensor_type"] == "Tensor": + new_value = torch.cat(values) + else: + new_value = k2.ragged.cat(values, axis=0) + setattr(fsa, name, new_value) + + # set non_tensor_attrs + for i in range(self.num_streams): + src = self.src_streams[i].fsa + for name, value in src.named_non_tensor_attr(): + setattr(fsa, name, value) + + return fsa diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 0cb2c45fe..57525979e 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -64,6 +64,7 @@ set(py_test_files random_paths_test.py remove_epsilon_self_loops_test.py remove_epsilon_test.py + rnnt_decode_test.py rnnt_loss_test.py shortest_path_test.py sparse_abs_test.py diff --git a/k2/python/tests/rnnt_decode_test.py b/k2/python/tests/rnnt_decode_test.py new file mode 100644 index 000000000..8766adfaa --- /dev/null +++ b/k2/python/tests/rnnt_decode_test.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (authors: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R rnnt_decode_test_py + +import unittest + +import k2 +import torch + + +class TestRnntDecode(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + + def test(self): + for device in self.devices: + fsa1 = k2.ctc_topo(5, device=device) + fsa1.attr1 = torch.tensor([1] * fsa1.num_arcs, device=device) + + stream1 = k2.RnntDecodingStream(fsa1) + + fsa2 = k2.trivial_graph(3, device=device) + fsa2.attr1 = torch.tensor([2] * fsa2.num_arcs, device=device) + fsa2.attr2 = torch.tensor([22] * fsa2.num_arcs, device=device) + + stream2 = k2.RnntDecodingStream(fsa2) + + fsa3 = k2.ctc_topo(3, modified=True, device=device) + fsa3.attr3 = k2.RaggedTensor( + torch.ones((fsa3.num_arcs, 2), dtype=torch.int32, device=device) + * 3 + ) + + stream3 = k2.RnntDecodingStream(fsa3) + + config = k2.RnntDecodingConfig(10, 2, 3.0, 3, 3) + streams = k2.RnntDecodingStreams( + [stream1, stream2, stream3], config + ) + + for i in range(5): + shape, context = streams.get_contexts() + logprobs = torch.randn( + (context.shape[0], 10), dtype=torch.float32, device=device + ) + streams.advance(logprobs) + + streams.terminate_and_flush_to_streams() + ofsa = streams.format_output([3, 4, 5]) + print(ofsa) + + +if __name__ == "__main__": + unittest.main() From 9a0d72cabe8102c435214bf23621ddabeb2ab4cc Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 16 Mar 2022 11:06:16 +0800 Subject: [PATCH 50/64] fix building docs (#933) --- k2/python/k2/rnnt_decode.py | 170 +++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 82 deletions(-) diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py index bd6a4bc66..8d4814f59 100644 --- a/k2/python/k2/rnnt_decode.py +++ b/k2/python/k2/rnnt_decode.py @@ -30,53 +30,54 @@ class RnntDecodingStream(object): - """Create a new rnnt decoding stream. + def __init__(self, fsa: Fsa) -> None: + """Create a new rnnt decoding stream. - Every sequence(wave data) need a decoding stream, this function is expected - to be called when a new sequence comes. We support different decoding graphs - for different streams. + Every sequence(wave data) need a decoding stream, this function is + expected to be called when a new sequence comes. We support different + decoding graphs for different streams. - Args: - graph: - The decoding graph used in this stream. + Args: + graph: + The decoding graph used in this stream. - Returns: - A rnnt decoding stream object, which will be combined into - `RnntDecodingStreams` to do decoding together with other - sequences in parallel. - """ - def __init__(self, fsa: Fsa) -> None: + Returns: + A rnnt decoding stream object, which will be combined into + `RnntDecodingStreams` to do decoding together with other + sequences in parallel. + """ self.fsa = fsa self.stream = _k2.create_rnnt_decoding_stream(fsa.arcs) self.device = fsa.device - """Return a string representation of this object - - For visualization and debug only. - """ def __str__(self) -> str: + """Return a string representation of this object + + For visualization and debug only. + """ return f"{self.stream}, device : {self.device}\n" class RnntDecodingStreams(object): - """ - Combines multiple RnntDecodingStream objects to create a RnntDecodingStreams - object, then all these RnntDecodingStreams can do decoding in parallel. - - Args: - src_streams: - A list of RnntDecodingStream object to be combined. - config: - A configuration object which contains decoding parameters like - `vocab-size`, `decoder_history_len`, `beam`, `max_states`, - `max_contexts` etc. - - Returns: - Return a RnntDecodingStreams object. - """ def __init__( self, src_streams: List[RnntDecodingStream], config: RnntDecodingConfig ) -> None: + """ + Combines multiple RnntDecodingStream objects to create a + RnntDecodingStreams object, then all these RnntDecodingStreams can do + decoding in parallel. + + Args: + src_streams: + A list of RnntDecodingStream object to be combined. + config: + A configuration object which contains decoding parameters like + `vocab-size`, `decoder_history_len`, `beam`, `max_states`, + `max_contexts` etc. + + Returns: + Return a RnntDecodingStreams object. + """ assert len(src_streams) > 0 self.num_streams = len(src_streams) self.src_streams = src_streams @@ -84,78 +85,83 @@ def __init__( streams = [x.stream for x in self.src_streams] self.streams = _k2.RnntDecodingStreams(streams, config) - '''Return a string representation of this object - - For visualization and debug only. - ''' def __str__(self) -> str: + """Return a string representation of this object + + For visualization and debug only. + """ s = f"num_streams : {self.num_streams}\n" for i in range(self.num_streams): s += f"stream[{i}] : {self.src_streams[i]}" return s - """ - This function must be called prior to evaluating the joiner network - for a particular frame. It tells the calling code which contexts - it must evaluate the joiner network for. - - Returns: - Return a two elements tuple containing a RaggedShape and a tensor. - shape: - A RaggedShape with 2 axes, representing [stream][context]. - contexts: - A tensor of shape [tot_contexts][decoder_history_len], where - tot_contexts == shape->TotSize(1) and decoder_history_len comes from - the config, it represents the number of symbols in the context of the - decode network (assumed to be finite). It contains the token ids - into the vocabulary(i.e. `0 <= value < vocab_size`). - """ def get_contexts(self) -> Tuple[RaggedShape, Tensor]: + """ + This function must be called prior to evaluating the joiner network + for a particular frame. It tells the calling code which contexts + it must evaluate the joiner network for. + + Returns: + Return a two elements tuple containing a RaggedShape and a tensor. + + shape: + A RaggedShape with 2 axes, representing [stream][context]. + + contexts: + A tensor of shape [tot_contexts][decoder_history_len], where + tot_contexts == shape->TotSize(1) and decoder_history_len comes from + the config, it represents the number of symbols in the context of + the decode network (assumed to be finite). It contains the token ids + into the vocabulary(i.e. `0 <= value < vocab_size`). + """ return self.streams.get_contexts() - """ - Advance decoding streams by one frame. - - Args: - logprobs: - A tensor of shape [tot_contexts][num_symbols], containing log-probs of - symbols given the contexts output by `get_contexts()`. Will satisfy - logprobs.Dim0() == shape.TotSize(1). - """ def advance(self, logprobs: Tensor) -> None: + """ + Advance decoding streams by one frame. + + Args: + logprobs: + A tensor of shape [tot_contexts][num_symbols], containing log-probs + of symbols given the contexts output by `get_contexts()`. It + satisfies `logprobs.Dim0() == shape.TotSize(1)`, shape is returned + by `get_contexts()`. + """ self.streams.advance(logprobs) - """ - Terminate the decoding process of current RnntDecodingStreams objects. - It will update the decoding states and store the decoding results currently - got to each of the individual streams. - - Note: We can not decode with this object anymore after calling - terminate_and_flush_to_streams(). - """ def terminate_and_flush_to_streams(self) -> None: + """ + Terminate the decoding process of current RnntDecodingStreams objects. + It will update the decoding states and store the decoding results + currently got to each of the individual streams. + + Note: + We can not decode with this object anymore after calling + terminate_and_flush_to_streams(). + """ self.streams.terminate_and_flush_to_streams() - """ - Generate the lattice Fsa currently got. + def format_output(self, num_frames: List[int]) -> Fsa: + """ + Generate the lattice Fsa currently got. - Note: The attributes of the generated lattice is a union of the attributes + Note: + The attributes of the generated lattice is a union of the attributes of all the decoding graphs. For example, a streams contains three individual stream, each stream has its own decoding graphs, graph[0] has attributes attr1, attr2; graph[1] has attributes attr1, attr3; graph[2] has attributes attr3, attr4; then the generated lattice has attributes attr1, attr2, attr3, attr4. - Args: - num_frames: - A List containing the number of frames we want to gather for each stream - (note: the frames we have ever received for the corresponding stream). - It MUST satisfy `len(num_frames) == self.num_streams`. - Returns: - Return the lattice Fsa with all the attributes propagated. The returned - Fsa has 3 axes with `fsa.dim0==self.num_streams`. - """ - def format_output(self, num_frames: List[int]) -> Fsa: + Args: + num_frames: + A List containing the number of frames we want to gather for each + stream (note: the frames we have ever received for the corresponding + stream). It MUST satisfy `len(num_frames) == self.num_streams`. + Returns: + Return the lattice Fsa with all the attributes propagated. + The returned Fsa has 3 axes with `fsa.dim0==self.num_streams`. + """ assert len(num_frames) == self.num_streams ragged_arcs, out_map = self.streams.format_output(num_frames) From 6833270cb228aba7bf9681fccd41e2b52f7d984c Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 16 Mar 2022 11:16:05 +0800 Subject: [PATCH 51/64] Release v1.14 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 386e412d5..649d57136 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.13") +set(K2_VERSION "1.14") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) From 613e03d173aa0c0dbfb53e2907074bdd1d32e12e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 17 Mar 2022 16:42:33 +0800 Subject: [PATCH 52/64] Remove unused DiscountedCumSum. (#936) --- k2/csrc/tensor_ops.cu | 224 +-------------------- k2/csrc/tensor_ops.h | 21 -- k2/csrc/tensor_ops_test.cu | 62 ------ k2/python/csrc/torch.cu | 2 - k2/python/csrc/torch/CMakeLists.txt | 1 - k2/python/csrc/torch/discounted_cum_sum.cu | 72 ------- k2/python/csrc/torch/discounted_cum_sum.h | 30 --- 7 files changed, 3 insertions(+), 409 deletions(-) delete mode 100644 k2/python/csrc/torch/discounted_cum_sum.cu delete mode 100644 k2/python/csrc/torch/discounted_cum_sum.h diff --git a/k2/csrc/tensor_ops.cu b/k2/csrc/tensor_ops.cu index 7d3891692..94ab6c1c5 100644 --- a/k2/csrc/tensor_ops.cu +++ b/k2/csrc/tensor_ops.cu @@ -517,239 +517,21 @@ Tensor SimpleRaggedIndexSelect1D(Tensor &src, Ragged &indexes) { return ans; } -template -struct DiscountedCumSumElement { - Real y; // y is the partial sums of x values. Initially it is just a - // single x value. In general each x is multiplied by all - // previous gammas. - Real gamma; // gamma is the product of gammas along a range of elements -}; -template -struct CombineCumSumOp { - __device__ DiscountedCumSumElement operator() ( - DiscountedCumSumElement &a, - DiscountedCumSumElement &b) const { - return DiscountedCumSumElement{b.y + b.gamma * a.y, - a.gamma * b.gamma}; - } -}; - -// A stateful callback functor that maintains a running prefix to be applied -// during consecutive scan operations. -template -struct BlockPrefixCallbackOp { - using Elem = DiscountedCumSumElement; - Elem running_total; - // Constructor - __device__ BlockPrefixCallbackOp(): running_total{0.0, 0.0} { } - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide - // scan. - __device__ Elem operator()(Elem block_aggregate) { - Elem old_prefix = running_total; - running_total = CombineCumSumOp()(running_total, block_aggregate); - return old_prefix; - } -}; - -/* - Notes for DiscountedCumSum. - - It implements a discounted sum along a sequence. Suppose we have x_i, gamma_i and - y_i, for 0 <= i < T. Then we do: - y_0 = x_0 - y_i = x_i + y_{i-1} gamma_i - for 0 < i < T. (This is done as a generic inclusive-scan/inclusive-sum with a special - reduction op). - - See DiscountedCumSumElement and CombineCumSumOp for how we use a special operator to - do this as an inclusive-sum. - - The tensors involved must be 2-dimensional with dimensions (N, T) where N is - the batch size and T the time duration. - - Each thread-block is of (x,y,z) size (ThreadsPerBlock,1,1), and it processes N - items. It processes ThreadsPerBlock items at a time; and if T > - ThreadsPerBlock it simply loops to cover the remaining items. - - The grid size (x,y,z) is (X,Y,1) where the X and Y together cover the "N" - (batch) dimension. (We can't cover it just in the X dimension because of - limits on the size of each time). - - @param [in] N The batch size, i.e. number of separate sequences. We expect - that N <= gridDim.x * gridDim.y. - @param [in] T The sequence length. There is no constraint on the sequence - length; the kernel deals with ThreadsPerBlock items at a time, - and takes care of T > ThreadsPerBlock by looping. - @param [in] x Pointer to the x input data, which is an array of shape (N,T) - @param [in] x_stride0 Stride along axis 0 of of the `x` data - @param [in] gamma Pointer to the gamma input data, which is an array of shape (N,T) - @param [in] gamma_stride0 Stride along axis 0 of of the `gamma` data - @param [in] y Pointer to the y output data, which is an array of shape (N,T) - @param [in] y_stride0 Stride along axis 0 of the `y` data - @param [in] stride1 Stride along axis 1 of the three arrays (this is expected - to be identical, nonzero, and preferably -1 or 1. -*/ -template -static __global__ void DiscountedCumSumKernel(int N, int T, const Real *x, - int x_stride0, const Real *gamma, - int gamma_stride0, Real *y, - int y_stride0, int stride1) { - int n_idx = blockIdx.y * gridDim.x + blockIdx.x; - if (n_idx >= N) - return; - x += x_stride0 * n_idx; - gamma += gamma_stride0 * n_idx; - y += y_stride0 * n_idx; - - int thread_idx = threadIdx.x; - using Elem = DiscountedCumSumElement; - - BlockPrefixCallbackOp prefix_callback; - - typedef cub::BlockScan BlockScan; - // shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - for (int base_t = 0; base_t < T; base_t += ThreadsPerBlock) { - Elem elem; - - // Load x and gamma from memory. These reads will be coalesced (which is - // the advantage of having each thread process one element at this stage; - // although we spend more time with raking reduction than we really - // need to). - if (base_t + thread_idx < T) { - elem.y = x[(base_t + thread_idx) * stride1]; - elem.gamma = gamma[(base_t + thread_idx) * stride1]; - } - CombineCumSumOp op; - - // the last arg is a callback functor that provides us the aggregate of this - // block and which is expected to return the element that we want to add to - BlockScan(temp_storage).InclusiveScan(elem, elem, op, prefix_callback); - __syncthreads(); - - if (base_t + thread_idx < T) y[(base_t + thread_idx) * stride1] = elem.y; - } -} - -template -void DiscountedCumSumCudaImpl(cudaStream_t stream, - int N, int T, - const Real *x, int x_stride0, - const Real *gamma, int gamma_stride0, - Real *y, int y_stride0, int stride1) { - int32_t tot_grid_size = N; - int32_t x_grid_size = (tot_grid_size < (1 << 20) - ? std::min(tot_grid_size, (1 << 10)) - : 32768), - y_grid_size = NumBlocks(tot_grid_size, x_grid_size); - - dim3 grid_dim(x_grid_size, y_grid_size, 1), - block_dim(ThreadsPerBlock, 1, 1); - K2_CUDA_SAFE_CALL( - DiscountedCumSumKernel - <<>>( - N, T, x, x_stride0, gamma, gamma_stride0, y, y_stride0, stride1)); -} - - -template -static void DiscountedCumSumCpuImpl(int N, int T, - const Real *x, int x_stride0, - const Real *gamma, int gamma_stride0, - Real *y, int y_stride0, - int stride1) { - for (int32_t n = 0; n < N; n++, - x += x_stride0, gamma += gamma_stride0, y += y_stride0) { - Real cur_sum = 0.0; - for (int32_t t = 0; t < T; t++) { - cur_sum = x[t * stride1] + cur_sum * gamma[t * stride1]; - y[t * stride1] = cur_sum; - } - } -} - - -void DiscountedCumSum(const Tensor &src, const Tensor &gamma, Tensor *dest) { - // check contexts compatible: - if (!(IsCompatible(src, gamma) && IsCompatible(src, *dest))) { - K2_LOG(FATAL) << "Tensors are on different devices"; - } - if (!(src.NumAxes() == 2 && gamma.NumAxes() == 2 && dest->NumAxes() == 2)) { - K2_LOG(FATAL) << "Expected all num-axes to equal 2."; - } - if (!(src.SameDims(gamma) && src.SameDims(*dest))) { - K2_LOG(FATAL) << "Expected all args to have the same dim."; - } - if (!(src.Stride(1) == gamma.Stride(1) && src.Stride(1) == dest->Stride(1))) { - K2_LOG(FATAL) << "Expected all strides on dim 1 to be the same."; - } - if (!(src.GetDtype() == gamma.GetDtype() && - src.GetDtype() == dest->GetDtype())) { - K2_LOG(FATAL) << "Expected all args to have the same dtype."; - } - int32_t N = src.Dim(0), - T = src.Dim(1), - src_stride0 = src.Stride(0), - gamma_stride0 = gamma.Stride(0), - dest_stride0 = dest->Stride(0), - stride1 = src.Stride(1); // these are all the same. - ContextPtr c = src.Context(); - if (src.GetDtype() == kFloatDtype) { - if (c->GetDeviceType() == kCuda) { - DiscountedCumSumCudaImpl(c->GetCudaStream(), N, T, - src.Data(), src_stride0, - gamma.Data(), gamma_stride0, - dest->Data(), dest_stride0, - stride1); - } else { - DiscountedCumSumCpuImpl(N, T, - src.Data(), src_stride0, - gamma.Data(), gamma_stride0, - dest->Data(), dest_stride0, - stride1); - } - } else if (src.GetDtype() == kDoubleDtype) { - if (c->GetDeviceType() == kCuda) { - DiscountedCumSumCudaImpl(c->GetCudaStream(), N, T, - src.Data(), src_stride0, - gamma.Data(), gamma_stride0, - dest->Data(), dest_stride0, - stride1); - } else { - DiscountedCumSumCpuImpl(N, T, - src.Data(), src_stride0, - gamma.Data(), gamma_stride0, - dest->Data(), dest_stride0, - stride1); - } - } else { - K2_LOG(FATAL) - << "This algorithm only instantiated for float and double; type is " - << TraitsOf(src.GetDtype()).Name(); - } -} - - Tensor Flip(Tensor &src, int32_t axis) { int32_t num_axes = src.NumAxes(); K2_CHECK_GE(axis, -num_axes); K2_CHECK_LT(axis, num_axes); - if (axis < 0) - axis += num_axes; + if (axis < 0) axis += num_axes; int32_t old_dim = src.Dim(axis); - if (old_dim <= 1) - return src; // No point copying it, it's a no-op. + if (old_dim <= 1) return src; // No point copying it, it's a no-op. TensorImplPtr src_impl = src.Impl(), ans_impl = std::make_shared(*src_impl); int32_t old_stride = ans_impl->shape.Stride(axis); ans_impl->shape.SetStride(axis, -old_stride); int64_t byte_offset = old_stride * static_cast(old_dim - 1) * - TraitsOf(ans_impl->dtype).NumBytes(); + TraitsOf(ans_impl->dtype).NumBytes(); ans_impl->byte_offset += byte_offset; return Tensor(ans_impl); } - } // namespace k2 diff --git a/k2/csrc/tensor_ops.h b/k2/csrc/tensor_ops.h index bf86e13f9..850e090d4 100644 --- a/k2/csrc/tensor_ops.h +++ b/k2/csrc/tensor_ops.h @@ -132,27 +132,6 @@ void IndexAdd(Tensor &src, Array1 &indexes, bool allow_minus_one, */ Tensor SimpleRaggedIndexSelect1D(Tensor &src, Ragged &indexes); -/* - This is a rather specialized op which is being included here for the convenience - of Snowfall developers, it actually has little to do with k2 in general and - is more closely related to this https://github.com/toshas/torch-discounted-cumsum - (our version has a different discounting factor per element). - - It implements a discounted sum along a sequence. Suppose we have x_i, gamma_i and - y_i, for 0 <= i < T. Then we do: - y_0 = x_0 - y_i = x_i + y_{i-1} gamma_i - for 0 < i < T. (This is done as a generic inclusive-sum with a special - reduction op). - - It supports only 2-d tensors, with the 2nd dimension interpreted as the - time dimension and the 1st dimension interpreted as the batch dimension. - - The strides on axis 1 (the 2nd axis) are expected to be identical, and - for efficiency of memory access it's best if the stride is 1 or -1. - */ -void DiscountedCumSum(const Tensor &src, const Tensor &gamma, Tensor *dest); - /* Flips a Tensor on axis `axis`, i.e. reversing the order of elements on that axis. Does this shallowly by modifying the metadata (caution: Torch diff --git a/k2/csrc/tensor_ops_test.cu b/k2/csrc/tensor_ops_test.cu index 3736f3fee..f57636ba8 100644 --- a/k2/csrc/tensor_ops_test.cu +++ b/k2/csrc/tensor_ops_test.cu @@ -347,66 +347,4 @@ TEST(Index, SimpleRaggedIndexSelect1D) { TestSimpleRaggedIndexSelect1D(); } - -template -void TestDiscountedCumSum() { - for (int32_t i = 0; i < 4; i++) { - int32_t M = RandInt(0, 1000), - T = RandInt(1, 2000); // TODO: increase. - while (M * T > 10000) { // don't want test to take too long. - M /= 2; - T /= 2; - } - - ContextPtr cuda_context = GetCudaContext(), - cpu_context = GetCpuContext(); - - Array2 x = RandUniformArray2(cuda_context, M, T, -2.0, 2.0); - Array2 gamma = RandUniformArray2(cuda_context, M, T, 0.0, 1.0); - Array2 y(cuda_context, M, T); - y = -10.0; - - bool flip = (i % 2 == 1); - - Array2 x_cpu = x.To(cpu_context), - gamma_cpu = gamma.To(cpu_context), - y_cpu(cpu_context, M, T); - - Tensor x_ten = x.ToTensor(), - gamma_ten = gamma.ToTensor(), - y_ten = y.ToTensor(); - - Tensor x_ten_cpu = x_cpu.ToTensor(), - gamma_ten_cpu = gamma_cpu.ToTensor(), - y_ten_cpu = y_cpu.ToTensor(); - - if (flip) { - x_ten = Flip(x_ten, 1); - gamma_ten = Flip(gamma_ten, 1); - y_ten = Flip(y_ten, 1); - x_ten_cpu = Flip(x_ten_cpu, 1); - gamma_ten_cpu = Flip(gamma_ten_cpu, 1); - y_ten_cpu = Flip(y_ten_cpu, 1); - } - - DiscountedCumSum(x_ten, gamma_ten, &y_ten); - DiscountedCumSum(x_ten_cpu, gamma_ten_cpu, &y_ten_cpu); - - Array2 y_cpu_copy = y.To(cpu_context); - - /*K2_LOG(INFO) << "x_cpu = " << x_cpu - << ", gamma_cpu = " << gamma_cpu - << ", y_cpu = " << y_cpu - << ", y = " << y_cpu_copy; */ - - // We are using the CPU and GPU versions to check each other. - EXPECT_EQ(true, ApproxEqual(y_cpu, y_cpu_copy, (Real)0.01)); - } -} - -TEST(Tensor, DiscountedCumSum) { - TestDiscountedCumSum(); - TestDiscountedCumSum(); -} - } // namespace k2 diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index 38c954b99..65768dfd1 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -26,7 +26,6 @@ #include "k2/python/csrc/torch/arc.h" #include "k2/python/csrc/torch/array_ops.h" -#include "k2/python/csrc/torch/discounted_cum_sum.h" #include "k2/python/csrc/torch/fsa.h" #include "k2/python/csrc/torch/fsa_algo.h" #include "k2/python/csrc/torch/index_add.h" @@ -41,7 +40,6 @@ void PybindTorch(py::module &m) { PybindArc(m); PybindArrayOps(m); - PybindDiscountedCumSum(m); PybindFsa(m); PybindFsaAlgo(m); PybindIndexAdd(m); diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index a05678053..75cf8a4da 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -2,7 +2,6 @@ set(torch_srcs arc.cu array_ops.cu - discounted_cum_sum.cu fsa.cu fsa_algo.cu index_add.cu diff --git a/k2/python/csrc/torch/discounted_cum_sum.cu b/k2/python/csrc/torch/discounted_cum_sum.cu deleted file mode 100644 index 5606f5b96..000000000 --- a/k2/python/csrc/torch/discounted_cum_sum.cu +++ /dev/null @@ -1,72 +0,0 @@ -/** - * @brief wraps discounted_cum_sum code. - * - * @copyright - * Copyright 2021 Xiaomi Corp. (authors: Daniel Povey) - * - * @copyright - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "k2/csrc/context.h" -#include "k2/csrc/device_guard.h" -#include "k2/csrc/macros.h" -#include "k2/csrc/nvtx.h" -#include "k2/csrc/tensor_ops.h" -#include "k2/python/csrc/torch/discounted_cum_sum.h" -#include "k2/python/csrc/torch/torch_util.h" - -namespace k2 { - -static void DiscountedCumSumWrapper(torch::Tensor x, torch::Tensor gamma, - torch::Tensor y, bool flip = false) { - NVTX_RANGE(K2_FUNC); - DeviceGuard guard(GetContext(x)); - Tensor x_k2 = FromTorch(x, TensorTag{}); - Tensor gamma_k2 = FromTorch(gamma, TensorTag{}); - Tensor y_k2 = FromTorch(y, TensorTag{}); - if (flip) { - // We have to do this in C++ because Torch tensors don't support negative - // strides. - x_k2 = Flip(x_k2, 1); - gamma_k2 = Flip(gamma_k2, 1); - y_k2 = Flip(y_k2, 1); - } - DiscountedCumSum(x_k2, gamma_k2, &y_k2); -} - -} // namespace k2 - -void PybindDiscountedCumSum(py::module &m) { - // note it supports only 1-D and 2-D tensors. - m.def("discounted_cum_sum", &k2::DiscountedCumSumWrapper, py::arg("x"), - py::arg("gamma"), py::arg("y"), py::arg("flip") = false, - R"( - Args: - x: - A 2-D tensor with dtype `torch.float` or `torch.double` and x.stride(1) == 1. - gamma: - A tensor with the same shape and dtype as x, and gamma.stride(1) == 1 - y: - A tensor with the same shape and dtype as x, and y.stride(1) == 1. - This function outputs to here. It is allowed to be the same tensor - as x and/or gamma. - The shapes are interpreted as (N, T) with N as the batch size and T - a sequence or time dimensions. It implements: - y(n, 0) = x(n, 0) - y(n, t) = x(n, t) + y(n, t-1) * gamma(n, t) (for 0 Date: Thu, 17 Mar 2022 18:00:35 +0800 Subject: [PATCH 53/64] Fix compiler warnings. (#937) * Fix compiler warnings. --- k2/csrc/nbest.cu | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 5ca250ccb..68d6f64bf 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -18,6 +18,8 @@ */ #include +#include + #include "k2/csrc/nbest.h" // This is not really a CUDA file but for build-system reasons I'm currently @@ -46,14 +48,12 @@ inline bool Leq(T a1, T a2, T a3, T b1, T b2, T b3) { */ template static void RadixPass(const T* a, T* b, const T* r, T n, T K) { - T* c = new T[K + 1]; // counter array - for (T i = 0; i <= K; i++) c[i] = 0; // reset counters + std::vector c(K + 1, 0); // counter array for (T i = 0; i < n; i++) c[r[a[i]]]++; // count occurrences for (T i = 0, sum = 0; i <= K; i++) { // exclusive prefix sums T t = c[i]; c[i] = sum; sum += t; } for (T i = 0; i < n; i++) b[c[r[a[i]]]++] = a[i]; // sort - delete [] c; } // See documentation in nbest.h, where we use different names @@ -69,19 +69,20 @@ void CreateSuffixArray(const T* text, T n, T K, T* SA) { return; } T n0 = (n + 2) / 3, n1 = (n+1) / 3, n2 = n / 3, n02 = n0 + n2; - T *R = new T[n02 + 3]; R[n02] = R[n02 + 1] = R[n02 + 2] = 0; - T *SA12 = new T[n02 + 3]; SA12[n02] = SA12[n02 + 1] = SA12[n02 + 2] = 0; - T *R0 = new T[n0]; - T *SA0 = new T[n0]; + std::vector R(n02 + 3); // entries are set to zero by default + std::vector SA12(n02 + 3); + std::vector R0(n0); + std::vector SA0(n0); + //******* Step 0: Construct sample ******** // generate positions of mod 1 and mod 2 suffixes // the "+(n0-n1)" adds a dummy mod 1 suffix if n%3 == 1 for (T i = 0, j = 0; i < n + (n0 - n1); i++) if (i % 3 != 0) R[j++] = i; //******* Step 1: Sort sample suffixes ******** // lsb radix sort the mod 1 and mod 2 triples - RadixPass(R, SA12, text + 2, n02, K); - RadixPass(SA12, R , text + 1, n02, K); - RadixPass(R, SA12, text, n02, K); + RadixPass(R.data(), SA12.data(), text + 2, n02, K); + RadixPass(SA12.data(), R.data() , text + 1, n02, K); + RadixPass(R.data(), SA12.data(), text, n02, K); // find lexicographic names of triples and // write them to correct places in R @@ -99,7 +100,7 @@ void CreateSuffixArray(const T* text, T n, T K, T* SA) { } // recurse if names are not yet unique if (name < n02) { - CreateSuffixArray(R, n02, name, SA12); + CreateSuffixArray(R.data(), n02, name, SA12.data()); // store unique names in R using the suffix array for (T i = 0; i < n02; i++) R[SA12[i]] = i + 1; } else // generate the suffix array of R directly @@ -108,7 +109,7 @@ void CreateSuffixArray(const T* text, T n, T K, T* SA) { // stably sort the mod 0 suffixes from SA12 by their first character for (T i = 0, j = 0; i < n02; i++) if (SA12[i] < n0) R0[j++] = 3 * SA12[i]; - RadixPass(R0, SA0, text, n0, K); + RadixPass(R0.data(), SA0.data(), text, n0, K); //******* Step 3: Merge ******** // merge sorted SA0 suffixes and sorted SA12 suffixes for (T p = 0, t = n0 - n1, k = 0; k < n; k++) { @@ -129,7 +130,6 @@ void CreateSuffixArray(const T* text, T n, T K, T* SA) { SA[k] = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2); } } - delete [] R; delete [] SA12; delete [] SA0; delete [] R0; } // Instantiate template for int32_t and int16_t From 10b94236ca1b16db1bf3fefb4cda745f160202ab Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 19 Mar 2022 09:30:11 +0800 Subject: [PATCH 54/64] Minor fixes for RNN-T decoding. (#938) * Minor fixes for RNN-T decoding. --- k2/csrc/array_of_ragged_test.cu | 24 ++++---- k2/csrc/fsa_algo.h | 4 +- k2/csrc/nbest.cu | 10 ++-- k2/csrc/ragged_ops.h | 8 +-- k2/csrc/ragged_ops_inl.h | 2 +- k2/csrc/rnnt_decode.cu | 90 ++++++++++++++++------------- k2/csrc/rnnt_decode.h | 57 +++++++++--------- k2/python/csrc/torch/rnnt_decode.cu | 4 +- k2/python/k2/fsa_algo.py | 4 +- k2/python/k2/rnnt_decode.py | 44 +++++++------- 10 files changed, 131 insertions(+), 116 deletions(-) diff --git a/k2/csrc/array_of_ragged_test.cu b/k2/csrc/array_of_ragged_test.cu index 176be139b..69b482315 100644 --- a/k2/csrc/array_of_ragged_test.cu +++ b/k2/csrc/array_of_ragged_test.cu @@ -43,24 +43,24 @@ void TestArray1OfRaggedConstruct() { for (int32_t j = 1; j < num_axes; ++j) { const int32_t **row_splits = array_of_ragged.shape.RowSplits(j); const int32_t **row_ids = array_of_ragged.shape.RowIds(j); - Array1 excepted_row_splits(GetCpuContext(), num_srcs); - Array1 excepted_row_ids(GetCpuContext(), num_srcs); - int32_t **excepted_row_splits_data = excepted_row_splits.Data(); - int32_t **excepted_row_ids_data = excepted_row_ids.Data(); + Array1 expected_row_splits(GetCpuContext(), num_srcs); + Array1 expected_row_ids(GetCpuContext(), num_srcs); + int32_t **expected_row_splits_data = expected_row_splits.Data(); + int32_t **expected_row_ids_data = expected_row_ids.Data(); for (int32_t i = 0; i < num_srcs; ++i) { - excepted_row_splits_data[i] = raggeds[i].RowSplits(j).Data(); - excepted_row_ids_data[i] = raggeds[i].RowIds(j).Data(); + expected_row_splits_data[i] = raggeds[i].RowSplits(j).Data(); + expected_row_ids_data[i] = raggeds[i].RowIds(j).Data(); } - excepted_row_splits = excepted_row_splits.To(c); - excepted_row_ids = excepted_row_ids.To(c); - excepted_row_splits_data = excepted_row_splits.Data(); - excepted_row_ids_data = excepted_row_ids.Data(); + expected_row_splits = expected_row_splits.To(c); + expected_row_ids = expected_row_ids.To(c); + expected_row_splits_data = expected_row_splits.Data(); + expected_row_ids_data = expected_row_ids.Data(); Array1 flags(c, 2, 1); int32_t *flags_data = flags.Data(); K2_EVAL( c, num_srcs, lambda_check_pointer, (int32_t i) { - if (row_splits[i] != excepted_row_splits_data[i]) flags_data[0] = 0; - if (row_ids[i] != excepted_row_ids_data[i]) flags_data[1] = 0; + if (row_splits[i] != expected_row_splits_data[i]) flags_data[0] = 0; + if (row_ids[i] != expected_row_ids_data[i]) flags_data[1] = 0; }); K2_CHECK(Equal(flags, Array1(c, std::vector{1, 1}))); } diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 8f8779b0f..bd59f76c7 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -580,8 +580,8 @@ Fsa CtcTopo(const ContextPtr &c, int32_t max_token, bool modified, Array1 *aux_labels); /* - Creat a trivial graph which has only two states, on state 0, there are - `max-token + 1` self loops(i.e. a loop for each symbol, including blank), and + Create a trivial graph which has only two states. On state 0, there are + `max_token + 1` self loops(i.e. a loop for each symbol, including blank), and state 1 is the final state. @param [in] c The context with which we'll allocate memory for diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 68d6f64bf..30244f24c 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -48,7 +48,7 @@ inline bool Leq(T a1, T a2, T a3, T b1, T b2, T b3) { */ template static void RadixPass(const T* a, T* b, const T* r, T n, T K) { - std::vector c(K + 1, 0); // counter array + std::vector c(K + 1, 0); // counter array for (T i = 0; i < n; i++) c[r[a[i]]]++; // count occurrences for (T i = 0, sum = 0; i <= K; i++) { // exclusive prefix sums T t = c[i]; c[i] = sum; sum += t; @@ -69,10 +69,10 @@ void CreateSuffixArray(const T* text, T n, T K, T* SA) { return; } T n0 = (n + 2) / 3, n1 = (n+1) / 3, n2 = n / 3, n02 = n0 + n2; - std::vector R(n02 + 3); // entries are set to zero by default - std::vector SA12(n02 + 3); - std::vector R0(n0); - std::vector SA0(n0); + std::vector R(n02 + 3, 0); + std::vector SA12(n02 + 3, 0); + std::vector R0(n0, 0); + std::vector SA0(n0, 0); //******* Step 0: Construct sample ******** // generate positions of mod 1 and mod 2 suffixes diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 85268d287..53ba84ec3 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -189,11 +189,11 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, /* Unstack a RaggedShape to a list of RaggedShapes, all the output RaggedShapes - have one less axis. + have one axis less. This function tries to do the opposite of Stack(), i.e. to generate an array out such that `Equal(src, Stack(axis, out->size(), out->data()))`. But notes that `Stack` needs a pointer of RaggedShape pointer, Unstack produces only a - pointer of RaggedShape, you should do some convertion before using Stack. + pointer of RaggedShape, you should do some conversion before using Stack. @param [in] src The shape to unstack. @param [in] axis The axis to be removed, all the elements of this axis will @@ -210,7 +210,7 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, `src` telling where the elements of each split RaggedShapes come from. It has the same size of `out`, see notes below for the dimension of it. For Array1 in each of the - `split_map`, It satifies + `split_map`, It satisfies `split_map[i].Dim() == out[i].NumElements()`, and `0 <= split_map[i][j] < src.NumElements()`. `split_map` will be reallocated by this function. @@ -222,7 +222,7 @@ RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src, Note: The output RaggedShape may contain empty lists on axis `axis`, you can remove them by RemoveEmptyLists if needed. - Note: The number of output RaggedShape is decided by the size of sublist + Note: The number of output RaggedShape is determined by the size of sublist with max number of elements along axis `axis`, for `axis == 0`, it has only one sublist along `axis == 0`(i.e. the src itself), so the number of output RaggedShape will be equal to `src.Dim0()`. diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index 03b91eb12..9c894e6df 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -860,7 +860,7 @@ Renumbering PruneRaggedAxis1(Ragged &src, T beam, K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t idx01) { // idx01 is the index after sorting int32_t original_idx01 = order_map_data[idx01], - // SortSublists wouldn't chaneg idx0 & idx0x + // SortSublists wouldn't change idx0 and idx0x idx0 = row_ids1_data[original_idx01], idx0x = row_splits1_data[idx0], // idx1 is the index after sorting diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index 5fc1780f3..232c0454c 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "k2/csrc/fsa.h" @@ -78,7 +79,7 @@ RnntDecodingStreams::RnntDecodingStreams( void RnntDecodingStreams::TerminateAndFlushToStreams() { NVTX_RANGE(K2_FUNC); - // return directlly if already detached or no frames decoded. + // return directly if already detached or no frames decoded. if (!attached_ || prev_frames_.empty()) return; std::vector> states; std::vector> scores; @@ -87,15 +88,15 @@ void RnntDecodingStreams::TerminateAndFlushToStreams() { K2_CHECK_EQ(static_cast(states.size()), num_streams_); K2_CHECK_EQ(static_cast(scores.size()), num_streams_); - // detatch prev_frames_ + // detach prev_frames_ std::vector *> frames_ptr; for (size_t i = 0; i < prev_frames_.size(); ++i) { frames_ptr.emplace_back(prev_frames_[i].get()); } - // statck_frames has a shape of [t][stream][state][arc] + // stack_frames has a shape of [t][stream][state][arc] auto stack_frames = Stack(0, prev_frames_.size(), frames_ptr.data()); - // stack_frames now has a shape of [strams][state][arc] - // its Dim0(0) equals to `num_streams_ * prev_frames_.size()` + // stack_frames now has a shape of [stream][state][arc] + // its Dim0() equals to `num_streams_ * prev_frames_.size()` stack_frames = stack_frames.RemoveAxis(0); std::vector> frames; @@ -144,6 +145,10 @@ void RnntDecodingStreams::GetContexts(RaggedShape *shape, int32_t idx0 = shape_row_ids1_data[row], num_graph_states = num_graph_states_data[idx0], state_idx01x = states_row_splits2_data[row]; + // Note: Entries in the sublist [state] grouped by [context] share + // the same context, so we use the first entry to compute the + // context_state here. + // // state_value = context_state * num_graph_states + graph_state // We want to extract token ids from context_state below. // Think about that the vocab_size=10 & decoder_history_len=3, and we @@ -179,7 +184,7 @@ Ragged RnntDecodingStreams::PruneTwice(Ragged &incoming_scores, Ragged temp_scores = SubsetRagged(incoming_scores, states_prune, 2 /*axis*/, &arcs_new2old1); - // incoming_scores has a shape of [stream][context][state][arc] + // temp_scores has a shape of [stream][context][state][arc] // context_prune is a renumbering on the states context. Renumbering context_prune = PruneRagged(temp_scores, 1 /*axis*/, config_.beam, config_.max_contexts); @@ -205,7 +210,7 @@ RaggedShape RnntDecodingStreams::ExpandArcs() { *num_graph_states_data = num_graph_states_.Data(); const int64_t *states_values_data = states_.values.Data(); - const int32_t **graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); int32_t *num_arcs_data = num_arcs.Data(); K2_EVAL( @@ -232,7 +237,7 @@ RaggedShape RnntDecodingStreams::ExpandArcs() { } Renumbering RnntDecodingStreams::DoFisrtPassPruning( - RaggedShape &unpruned_arcs_shape, Array2 &logprobs) { + RaggedShape &unpruned_arcs_shape, const Array2 &logprobs) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(unpruned_arcs_shape.NumAxes(), 4); @@ -247,7 +252,7 @@ Renumbering RnntDecodingStreams::DoFisrtPassPruning( Array1 max_scores_per_stream(c_, num_streams_); double minus_inf = -std::numeric_limits::infinity(); { - // scores_ has 3 axes: [stream][context][score] + // scores_ has 3 axes: [stream][context][state] Ragged scores_per_stream = scores_.RemoveAxis(1); MaxPerSublist(scores_per_stream, minus_inf, &max_scores_per_stream); } @@ -263,10 +268,10 @@ Renumbering RnntDecodingStreams::DoFisrtPassPruning( *uas_row_ids2_data = unpruned_arcs_shape.RowIds(2).Data(), *uas_row_ids1_data = unpruned_arcs_shape.RowIds(1).Data(), *num_graph_states_data = num_graph_states_.Data(); - const int32_t **graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); const int64_t *states_values_data = states_.values.Data(); - Arc **graphs_arcs_data = graphs_.values.Data(); + const Arc *const *graphs_arcs_data = graphs_.values.Data(); K2_EVAL( c_, unpruned_arcs_shape.NumElements(), lambda_pass1_pruning, @@ -300,7 +305,7 @@ Renumbering RnntDecodingStreams::DoFisrtPassPruning( return; } - // prune the arcs pointting final state. + // prune the arcs pointing to the final state. if (arc.label == -1) { pass1_keep_data[idx0123] = 0; return; @@ -323,7 +328,7 @@ Renumbering RnntDecodingStreams::DoFisrtPassPruning( RaggedShape RnntDecodingStreams::GroupStatesByContexts( Ragged &states) { NVTX_RANGE(K2_FUNC); - // states has a shape of [stream][state] + // states has a shape of [stream][arc] K2_CHECK_EQ(states.NumAxes(), 2); // state_boundaries and context_boundaries are Renumbering objects // that we use in a slightly different way from normal. @@ -408,11 +413,11 @@ RaggedShape RnntDecodingStreams::GroupStatesByContexts( (2) Do initial pruning(beam pruning with some special rules) to reduce the the number of arcs. (3) Figure out the dest-states and corresponding scores. - (4) Re-arange dest-states by contexts and states. - (5) Second pass pruning (prune on context axis and state axis). + (4) Rearrange dest-states by contexts and states. + (5) Second pass pruning (prune on state axis and context axis). (6) Update states_, scores_ and prev_frames_. */ -void RnntDecodingStreams::Advance(Array2 &logprobs) { +void RnntDecodingStreams::Advance(const Array2 &logprobs) { NVTX_RANGE(K2_FUNC); K2_CHECK(attached_) << "Streams terminated."; K2_CHECK_EQ(logprobs.Dim0(), states_.TotSize(1)); @@ -433,25 +438,27 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { SubsetRaggedShape(unpruned_arcs_shape, pass1_renumbering); // (3) Figure out the dest-states and corresponding scores. - // stream_arc_shape is pass1_arcs indexed [stream][arc]. - // We need to rearrange so it's by destination context and state, not source. + // stream_arc_shape is pass1_arcs indexed by [stream][arc]. + // We need to rearrange it so it's ordered by destination context and state, + // not source. RaggedShape stream_arc_shape = RemoveAxis(pass1_arcs_shape, 2); stream_arc_shape = RemoveAxis(stream_arc_shape, 1); - // arcs, indexed [stream][context][state][arc]. + // arcs, indexed by [stream][context][state][arc]. Ragged arcs(pass1_arcs_shape); - // dest-states of arcs, incexed [stream][arc] + // dest-states of arcs, indexed by [stream][arc] Ragged states(stream_arc_shape); - // final-scores after arcs, indexed [stream][arc] + // final-scores after arcs, indexed by [stream][arc] + // It contains the forward scores of dest-states. Ragged scores(stream_arc_shape); - // We will populate arcs, states and scores below, it computes - // the destination state for each arc and puts its in 'states', + // We will populate arcs, states and scores below; it computes + // the destination state for each arc and puts it in 'states', // and the after-the-arc scores for each arc and puts them in // 'scores'. int32_t cur_num_arcs = arcs.NumElements(); // This renumbering object will be used for renumbering the arcs after we - // fiishing the pruning. + // finishing the pruning. Renumbering renumber_arcs(c_, cur_num_arcs); char *renumber_arcs_keep_data = renumber_arcs.Keep().Data(); @@ -470,9 +477,9 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { *uas_row_ids2_data = unpruned_arcs_shape.RowIds(2).Data(), *uas_row_ids1_data = unpruned_arcs_shape.RowIds(1).Data(), *pass1_new2old_data = pass1_renumbering.New2Old().Data(); - const int32_t **graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); const auto logprobs_acc = logprobs.Accessor(); - Arc **graphs_arcs_data = graphs_.values.Data(); + const Arc *const *graphs_arcs_data = graphs_.values.Data(); K2_EVAL( c_, cur_num_arcs, lambda_populate_arcs_states_scores, (int32_t arc_idx) { @@ -487,7 +494,7 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { idx3 = idx0123 - idx012x, // `idx3 - 1` can be interpreted as // idx1 into the corresponding // decoding graph, minus 1 here - // because we add a implicit + // because we added an implicit // self-loop for each state, see // `ExpandArcs()`. idx01 = uas_row_ids2_data[idx012], idx0 = uas_row_ids1_data[idx01], @@ -531,7 +538,7 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { context_state = context_state * vocab_size + arc.label; } - // next state is the state current arc pointting to. + // next state is the state the current arc pointing to. int64_t state = context_state * num_graph_states + arc.dest_state; states_data[arc_idx] = state; @@ -545,7 +552,7 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { arcs_data[arc_idx] = ai; }); - // (4) Re-arange dest-states by contexts and states. + // (4) Rearrange dest-states by contexts and states. // sort states so that we can group states by context-state Array1 dest_state_sort_new2old(c, states.NumElements()); SortSublists(&states, &dest_state_sort_new2old); @@ -554,9 +561,11 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { scores.values = scores.values[dest_state_sort_new2old]; Ragged incoming_scores(incoming_arcs_shape, scores.values); + // Note: `arcs` is not sorted. `renumber_arcs` will be used later + // to map `pruned arcs` to `arcs`. // (5) Second pass pruning (prune on context axis and state axis). - // The scores has been re-arange by destination context and state. + // The scores has been rearranged by context and destination state. Array1 arcs_prune2_new2old; Ragged pruned_incoming_scores = PruneTwice(incoming_scores, &arcs_prune2_new2old); @@ -576,17 +585,17 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { // frame Ragged dest_state_scores(RemoveAxis(pruned_incoming_scores.shape, 3), dest_state_scores_values); - scores_ = dest_state_scores; + scores_ = std::move(dest_state_scores); // dest_states will be the `states` held by this object on the next frame. // sub-lists along last axis has same values, so we just pick the first one, // see `GroupStatesByContexts()` for more details. auto pruned_row_split3 = pruned_dest_states.RowSplits(3); Ragged dest_states( - dest_state_scores.shape, + scores_.shape, pruned_dest_states .values[pruned_row_split3.Arange(0, pruned_row_split3.Dim() - 1)]); - states_ = dest_states; + states_ = std::move(dest_states); // Update prev_frames_. // arcs_new2old is new2old map from indexes in `incoming_scores` or @@ -597,7 +606,7 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { // Renumber the original arcs, we create and initialize the renumbering object // when we create the arcs, see above. // arcs has a shape of [stream][context][state][arc] - int32_t *arcs_new2old_data = arcs_new2old.Data(); + const int32_t *arcs_new2old_data = arcs_new2old.Data(); K2_EVAL( c_, arcs_new2old.Dim(), lambda_renumber_arcs, (int32_t idx) { int32_t arc_idx0123 = arcs_new2old_data[idx]; @@ -650,7 +659,8 @@ void RnntDecodingStreams::Advance(Array2 &logprobs) { std::make_shared>(pruned_arcs.RemoveAxis(1))); } -void RnntDecodingStreams::GatherPrevFrames(std::vector &num_frames) { +void RnntDecodingStreams::GatherPrevFrames( + const std::vector &num_frames) { NVTX_RANGE(K2_FUNC); K2_CHECK(!attached_) << "Please call TerminateAndFlushToStreams() first."; K2_CHECK_EQ(num_streams_, static_cast(num_frames.size())); @@ -687,7 +697,7 @@ void RnntDecodingStreams::GatherPrevFrames(std::vector &num_frames) { } } -void RnntDecodingStreams::FormatOutput(std::vector &num_frames, +void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, FsaVec *ofsa, Array1 *out_map) { NVTX_RANGE(K2_FUNC); K2_CHECK(!attached_) @@ -722,7 +732,7 @@ void RnntDecodingStreams::FormatOutput(std::vector &num_frames, { // each of these have 3 axes. std::vector arcs_shapes(frames + 2); - for (int32_t t = 0; t < frames; t++) { + for (int32_t t = 0; t < frames; ++t) { arcs_shapes[t] = &(prev_frames_[t]->shape); arcs_data_ptrs_data[t] = prev_frames_[t]->values.Data(); } @@ -758,9 +768,9 @@ void RnntDecodingStreams::FormatOutput(std::vector &num_frames, K2_EVAL( c_, num_arcs, lambda_set_arcs, (int32_t oarc_idx0123) { - int32_t oarc_idx012 = oshape_row_ids3[oarc_idx0123], - oarc_idx01 = oshape_row_ids2[oarc_idx012], - oarc_idx0 = oshape_row_ids1[oarc_idx01], + int32_t oarc_idx012 = oshape_row_ids3[oarc_idx0123], // state + oarc_idx01 = oshape_row_ids2[oarc_idx012], // frame + oarc_idx0 = oshape_row_ids1[oarc_idx01], // stream oarc_idx0x = oshape_row_splits1[oarc_idx0], oarc_idx0xx = oshape_row_splits2[oarc_idx0x], oarc_idx1 = oarc_idx01 - oarc_idx0x, diff --git a/k2/csrc/rnnt_decode.h b/k2/csrc/rnnt_decode.h index 10f2f1bd6..3074907b6 100644 --- a/k2/csrc/rnnt_decode.h +++ b/k2/csrc/rnnt_decode.h @@ -39,7 +39,7 @@ namespace rnnt_decoding { that the trained model is compatible with decoding with this setting). - contexts are finite symbol left-contexts, of length - RnntDecodingConfig::decoder_history_len. conceptuatly they represent a list of + RnntDecodingConfig::decoder_history_len. Conceptually they represent a list of `decoder_history_len` symbols; they are represented numerically as, for example in the length-2-history case: symbol_{t-1} + symbol_{t-2} * vocab_size. @@ -56,7 +56,7 @@ struct RnntDecodingConfig { beam(beam), max_states(max_states), max_contexts(max_contexts) { - num_context_states = pow(vocab_size, decoder_history_len); + // num_context_states = pow(vocab_size, decoder_history_len); } // vocab_size is the largest-symbol plus one. @@ -74,7 +74,7 @@ struct RnntDecodingConfig { // equals to 10, we need 0 ~ 9 to distinguish each context state when // decoder_history_len is 1, and 0 ~ 99 (10 ^ 2 ids) for decoder_history_len // equals to 2, 0 ~ 999 (10 ^ 3 ids) for decoder_history_len equals to 3. - int32_t num_context_states; + // int32_t num_context_states; // `beam` imposes a limit on the score of a state, relative to the // best-scoring state on the same frame. E.g. 10. @@ -162,38 +162,38 @@ class RnntDecodingStreams { @param [in] logprobs Array of shape [tot_contexts][num_symbols], containing log-probs of symbols given the contexts output by `GetContexts()`. Will satisfy - logprobs.Dim0() == states.TotSize(1). + logprobs.Dim0() == states_.TotSize(1). */ - void Advance(Array2 &logprobs); + void Advance(const Array2 &logprobs); /* Generate the lattice. Note: The prev_frames_ only contains decoded by current object, in order to - generate the lattice we will fisrt gather all the previous frames from + generate the lattice we will first gather all the previous frames from individual streams. @param [in] num_frames A vector containing the number of frames we want to gather for each stream (note: the frames we have ever received). It MUST satisfy `num_frames.size() == num_streams_`, and - `num_frames[i] < srcs_[i].prev_frames.size()`. + `num_frames[i] <= srcs_[i].prev_frames.size()`. @param [out] ofsa The output lattice will write to here, its num_axes equals to 3, will be re-allocated. - @param [out] out_map It it a Array1 with Dim() equals to + @param [out] out_map It is an Array1 with Dim() equals to ofsa.NumElements() containing the idx01 into the graph of each individual streams, mapping current arc in ofsa to - original decoding graphs. It may contains -1 which means + original decoding graphs. It may contain -1 which means this arc is a "termination symbol". */ - void FormatOutput(std::vector &num_frames, FsaVec *ofsa, + void FormatOutput(const std::vector &num_frames, FsaVec *ofsa, Array1 *out_map); /* Terminate the decoding process of current RnntDecodingStreams object, it - will update the states & scores of each individual streams and split & - append the prev_frames_ in current object to the prev_frames of the - individual streams. + will update the states & scores of each individual stream and split & + append the prev_frames_ in current object to the `prev_frames` of the + individual stream. Note: We can not decode with this object anymore after calling TerminateAndFlushToStreams(). @@ -207,22 +207,22 @@ class RnntDecodingStreams { int32_t NumStreams() const { return num_streams_; } // Note: The following three functions should be private members, they are not - // expected to be called outsize this class. We make it public because of the + // expected to be called outside this class. We make it public because of the // extended lambda restrictions, see // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#extended-lambda-restrictions // for more details. /* Expand arcs according to states_. - `states_` has a shape of [stream][context][state], each of its values is a - combinatioin of context_state and graph_state, this is: + `states_` has a shape of [stream][context][state], each of its value is a + combination of context_state and graph_state, that is: `state = context_state * num_graph_states + graph_state`. The graph_state - is the idx0 of corresponding individual graph(has shape [state][arc]). + is the idx0 of corresponding individual graph(with shape [state][arc]). This function will expand each of these states into several arcs(i.e. the out-going arcs of state idx0), so that we can get a new shape of [stream][context][state][arc]. - Caution: This function intends to be used in `Advance()` only. + Caution: This function is intended to be used in `Advance()` only. @return Return the expected 4 axes shape (i.e.[stream][context][state][arc]). @@ -235,13 +235,14 @@ class RnntDecodingStreams { rule is: (1) keep all epsilon transitions to the next frame, to ensure there is no way we can have no states surviving. - (2) for all other arcs, keep the it if the forward scores after the + (2) for all other arcs, keep it if the forward scores after the arc would be >= the max_scores_per_stream entry for this stream minus the beam from the config. - Caution: This function intends to be used in `Advance()` only. + Caution: This function is intended to be used in `Advance()` only. - @param [in] unprund_arcs_shape The RaggedShape return by `ExpandArcs()`. + @param [in] unprund_arcs_shape The RaggedShape returned by + `ExpandArcs()`. @param [in] logprobs Array of shape [tot_contexts][num_symbols], containing log-probs of symbols given the contexts output by `GetContexts()`. Will satisfy @@ -251,7 +252,7 @@ class RnntDecodingStreams { @return Return the renumbering object indicating which arc will be kept. */ Renumbering DoFisrtPassPruning(RaggedShape &unprund_arcs_shape, - Array2 &logprobs); + const Array2 &logprobs); /* Group states by contexts. @@ -265,7 +266,7 @@ class RnntDecodingStreams { we need a shape of [stream][context][state][arc], obviously the sub-lists along axis -1 contains same values. - Here is a example: suppose vocab_size=10, num_graph_states=10, + Here is an example: suppose vocab_size=10, num_graph_states=10, decoder_history_len=2, we have a states like: [ [ 112 120 123 125 345 345 ] [ 123 124 567 568 670 ] ] @@ -279,7 +280,7 @@ class RnntDecodingStreams { [ [ [ [ 112 ] ] [ [ 120 ] [ 123 ] [ 125 ] ] [ [ 345 345 ] ] ] [ [ [ 123 ] [ 124 ] ] [ [ 567 ] [ 568 ] ] [ [ 670 ] ] ] ] - Caution: This function intends to be used in `Advance()` only. + Caution: This function is intended to be used in `Advance()` only. @param [in] states A two axes ragged tensor with each sub-list **sorted**. @@ -292,7 +293,7 @@ class RnntDecodingStreams { private: /* Prune the incoming scores based on beam, max-states and max-contexts. - Actually the beam part is not realy necessary, as we already pruned + Actually the beam part is not really necessary, as we already pruned with the beam, but it doesn't cost anything extra. Args: incoming_scores [in] The ragged array of scores to be pruned, indexed @@ -309,7 +310,7 @@ class RnntDecodingStreams { Array1 *arcs_new2old); /* - Gather all previously decoded frames util now, we need all the previous + Gather all previously decoded frames until now, we need all the previous frames to generate lattice. Note: The prev_frames_ in current object only contains the frames from the @@ -320,9 +321,9 @@ class RnntDecodingStreams { @param [in] num_frames A vector containing the number of frames we want to gather for each stream. It MUST satisfy `num_frames.size() == num_streams_`, and - `num_frames[i] < srcs_[i].prev_frames.size()`. + `num_frames[i] <= srcs_[i].prev_frames.size()`. */ - void GatherPrevFrames(std::vector &num_frames); + void GatherPrevFrames(const std::vector &num_frames); ContextPtr c_; diff --git a/k2/python/csrc/torch/rnnt_decode.cu b/k2/python/csrc/torch/rnnt_decode.cu index 6678bc7b8..c84f21632 100644 --- a/k2/python/csrc/torch/rnnt_decode.cu +++ b/k2/python/csrc/torch/rnnt_decode.cu @@ -45,10 +45,10 @@ static void PybindRnntDecodingConfig(py::module &m) { Args: vocab_size: - It indicates how many symbols we are using, euqals the + It indicates how many symbols we are using, equals the largest-symbol plus one. decoder_history_len: - `decoder_history_len` is the number of symbols of history the + The number of symbols of history the decoder takes; will normally be one or two ("stateless decoder"), our RNN-T decoding setup does not support unlimited decoder context such as with LSTMs. diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 7b893782e..8e44f2e93 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1064,8 +1064,8 @@ def ctc_topo(max_token: int, def trivial_graph(max_token: int, device: Optional[Union[torch.device, str]] = None) -> k2.Fsa: ''' - Creat a trivial graph which has only two states, on state 0, there are - `max-token + 1` self loops(i.e. a loop for each symbol, including blank), + Create a trivial graph with only two states. On state 0, there are + `max_token + 1` self loops(i.e. a loop for each symbol, including blank), and state 1 is the final state. Args: diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py index 8d4814f59..85d56cd5d 100644 --- a/k2/python/k2/rnnt_decode.py +++ b/k2/python/k2/rnnt_decode.py @@ -30,10 +30,11 @@ class RnntDecodingStream(object): + def __init__(self, fsa: Fsa) -> None: """Create a new rnnt decoding stream. - Every sequence(wave data) need a decoding stream, this function is + Every sequence(wave data) needs a decoding stream, this function is expected to be called when a new sequence comes. We support different decoding graphs for different streams. @@ -43,7 +44,7 @@ def __init__(self, fsa: Fsa) -> None: Returns: A rnnt decoding stream object, which will be combined into - `RnntDecodingStreams` to do decoding together with other + :class:`RnntDecodingStreams` to do decoding together with other sequences in parallel. """ self.fsa = fsa @@ -59,9 +60,12 @@ def __str__(self) -> str: class RnntDecodingStreams(object): - def __init__( - self, src_streams: List[RnntDecodingStream], config: RnntDecodingConfig - ) -> None: + '''See https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py # noqa + for how this class is used in RNN-T decoding. + ''' + + def __init__(self, src_streams: List[RnntDecodingStream], + config: RnntDecodingConfig) -> None: """ Combines multiple RnntDecodingStream objects to create a RnntDecodingStreams object, then all these RnntDecodingStreams can do @@ -98,11 +102,11 @@ def __str__(self) -> str: def get_contexts(self) -> Tuple[RaggedShape, Tensor]: """ This function must be called prior to evaluating the joiner network - for a particular frame. It tells the calling code which contexts - it must evaluate the joiner network for. + for a particular frame. It tells the calling code for which contexts + it must evaluate the joiner network. Returns: - Return a two elements tuple containing a RaggedShape and a tensor. + Return a two-element tuple containing a RaggedShape and a tensor. shape: A RaggedShape with 2 axes, representing [stream][context]. @@ -111,8 +115,9 @@ def get_contexts(self) -> Tuple[RaggedShape, Tensor]: A tensor of shape [tot_contexts][decoder_history_len], where tot_contexts == shape->TotSize(1) and decoder_history_len comes from the config, it represents the number of symbols in the context of - the decode network (assumed to be finite). It contains the token ids - into the vocabulary(i.e. `0 <= value < vocab_size`). + the decoder network (assumed to be finite). It contains the token + ids into the vocabulary(i.e. `0 <= value < vocab_size`). + Its dtype is torch.int32. """ return self.streams.get_contexts() @@ -131,7 +136,7 @@ def advance(self, logprobs: Tensor) -> None: def terminate_and_flush_to_streams(self) -> None: """ - Terminate the decoding process of current RnntDecodingStreams objects. + Terminate the decoding process of current RnntDecodingStreams object. It will update the decoding states and store the decoding results currently got to each of the individual streams. @@ -147,7 +152,7 @@ def format_output(self, num_frames: List[int]) -> Fsa: Note: The attributes of the generated lattice is a union of the attributes - of all the decoding graphs. For example, a streams contains three + of all the decoding graphs. For example, if `self` contains three individual stream, each stream has its own decoding graphs, graph[0] has attributes attr1, attr2; graph[1] has attributes attr1, attr3; graph[2] has attributes attr3, attr4; then the generated lattice has @@ -205,17 +210,17 @@ def format_output(self, num_frames: List[int]) -> Fsa: value = getattr(src, name) if info["tensor_type"] == "Tensor": assert isinstance(value, Tensor) - new_value = index_select( - value, arc_map, default_value=filler - ) + new_value = index_select(value, + arc_map, + default_value=filler) else: assert isinstance(value, RaggedTensor) # Only integer types ragged attributes are supported now assert value.num_axes == 2 assert value.dtype == torch.int32 - new_value, _ = value.index( - arc_map, axis=0, need_value_indexes=False - ) + new_value, _ = value.index(arc_map, + axis=0, + need_value_indexes=False) else: if info["tensor_type"] == "Tensor": # fill with filler value @@ -231,8 +236,7 @@ def format_output(self, num_frames: List[int]) -> Fsa: (num_arcs, 0), dtype=info["dtype"], device=device, - ) - ) + )) values.append(new_value) if info["tensor_type"] == "Tensor": new_value = torch.cat(values) From 846c39c7953cbb71839d8f4fbafa4fcbe4d75c7a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Mar 2022 11:14:49 +0800 Subject: [PATCH 55/64] Removes arcs with label 0 from the TrivialGraph. (#939) --- k2/csrc/fsa_algo.cu | 10 ++++------ k2/csrc/fsa_algo.h | 6 +++--- k2/python/k2/fsa_algo.py | 13 +++++++------ 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 45736d639..4065713ff 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -820,9 +820,8 @@ Fsa TrivialGraph(const ContextPtr &c, int32_t max_token, Array1 *aux_labels) { NVTX_RANGE(K2_FUNC); K2_CHECK(aux_labels); - int32_t num_arcs = max_token + 2; - Array1 row_splits( - c, std::vector{0, max_token + 2, max_token + 2}); + int32_t num_arcs = max_token + 1; + Array1 row_splits(c, std::vector{0, num_arcs, num_arcs}); Array1 row_ids(c, num_arcs); Array1 values(c, num_arcs); *aux_labels = Array1(c, num_arcs); @@ -836,10 +835,9 @@ Fsa TrivialGraph(const ContextPtr &c, int32_t max_token, arc.score = 0; arc.src_state = 0; arc.dest_state = 0; - arc.label = idx; - int32_t aux_label = idx, row_id = 0; + arc.label = idx + 1; + int32_t aux_label = idx + 1, row_id = 0; if (idx == num_arcs - 1) { - row_id = 0; arc.dest_state = 1; arc.label = -1; aux_label = -1; diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index bd59f76c7..c32c894d4 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -581,20 +581,20 @@ Fsa CtcTopo(const ContextPtr &c, int32_t max_token, bool modified, /* Create a trivial graph which has only two states. On state 0, there are - `max_token + 1` self loops(i.e. a loop for each symbol, including blank), and + `max_token` self loops(i.e. a loop for each symbol from 1 to max_token), and state 1 is the final state. @param [in] c The context with which we'll allocate memory for the trivial graph. @param [in] max_token The maximum token ID (inclusive). We assume that token IDs are contiguous (from 1 to `max_token`). - 0 represents blank. @param [out] aux_labels The output labels of graph will write to this array, will be reallocated. The label and aux_label on each arc are equal - (i.e. aux_labels = Arange(0, max_token + 1); + (i.e. aux_labels = Arange(1, max_token + 1); @return Returns the expected trivial graph on the given device. + Note the returned graph does not contain arcs with label being 0. */ Fsa TrivialGraph(const ContextPtr &c, int32_t max_token, Array1 *aux_labels); diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 8e44f2e93..29a23318a 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1063,21 +1063,22 @@ def ctc_topo(max_token: int, def trivial_graph(max_token: int, device: Optional[Union[torch.device, str]] = None) -> k2.Fsa: - ''' - Create a trivial graph with only two states. On state 0, there are - `max_token + 1` self loops(i.e. a loop for each symbol, including blank), - and state 1 is the final state. + '''Create a trivial graph which has only two states. On state 0, there are + `max_token` self loops(i.e. a loop for each symbol from 1 to max_token), and + state 1 is the final state. Args: max_token: The maximum token ID (inclusive). We assume that token IDs - are contiguous (from 1 to `max_token`). 0 represents blank. + are contiguous (from 1 to `max_token`). device: Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a torch.device. If it is None, then the returned FSA is on CPU. - Returns: Returns the expected trivial graph on the given device. + Returns: + Returns the expected trivial graph on the given device. + Note: The returned graph does not contain arcs with label being 0. ''' ragged_arc, aux_labels = _k2.trivial_graph(max_token, device) fsa = Fsa(ragged_arc, aux_labels=aux_labels) From 0f65420ffbfcc35bd84d67a97c1ee26c8dbd4bed Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Mar 2022 16:37:03 +0800 Subject: [PATCH 56/64] Implement linear_fsa_with_self_loops. (#940) * Implement linear_fsa_with_self_loops. --- k2/python/k2/__init__.py | 1 + k2/python/k2/fsa_algo.py | 44 +++++++++++-- k2/python/tests/CMakeLists.txt | 1 + .../tests/linear_fsa_with_self_loops_test.py | 63 +++++++++++++++++++ 4 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 k2/python/tests/linear_fsa_with_self_loops_test.py diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 1091f4348..930affb18 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -41,6 +41,7 @@ from .fsa_algo import levenshtein_alignment from .fsa_algo import levenshtein_graph from .fsa_algo import linear_fsa +from .fsa_algo import linear_fsa_with_self_loops from .fsa_algo import linear_fst from .fsa_algo import prune_on_arc_post from .fsa_algo import random_paths diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 29a23318a..c3825aa54 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -68,6 +68,37 @@ def linear_fsa(labels: Union[List[int], List[List[int]], k2.RaggedTensor], return fsa +def linear_fsa_with_self_loops(fsas: k2.Fsa): + '''Create a linear FSA with epsilon self-loops by first removing epsilon + transitions from the input linear FSA. + + Args: + fsas: + An FSA or an FsaVec. It MUST be a linear FSA or a vector of linear FSAs. + Returns: + Return an FSA or FsaVec, where each FSA contains epsilon self-loops but + contains no epsilon transitions for arcs that are not self-loops. + ''' + if len(fsas.shape) == 2: + # A single FSA + device = fsas.device + shape0 = _k2.RaggedShape.regular_ragged_shape(dim0=1, + dim1=fsas.shape[0]) + shape = shape0.to(device).compose(fsas.arcs.shape()) + else: + shape = fsas.arcs.shape() + + shape = shape.remove_axis(1) # remove the state axis + + labels = k2.RaggedTensor(shape, fsas.labels.contiguous()) + labels = labels.remove_values_leq(0) + ans = add_epsilon_self_loops(linear_fsa(labels)) + + if len(fsas.shape) == 2: + ans = ans[0] + return ans + + def linear_fst(labels: Union[List[int], List[List[int]]], aux_labels: Union[List[int], List[List[int]]]) -> Fsa: '''Construct a linear FST from labels and its corresponding @@ -1192,16 +1223,18 @@ def levenshtein_alignment( hyps.rename_tensor_attribute_("aux_labels", "hyp_labels") - lattice = k2.intersect_device( - refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref) + lattice = k2.intersect_device(refs, + hyps, + b_to_a_map=hyp_to_ref_map, + sorted_match_a=sorted_match_ref) lattice = k2.remove_epsilon_self_loops(lattice) alignment = k2.shortest_path(lattice, use_double_scores=True).invert_() alignment.rename_tensor_attribute_("labels", "ref_labels") alignment.rename_tensor_attribute_("aux_labels", "labels") - alignment.scores -= getattr( - alignment, "__ins_del_score_offset_internal_attr_") + alignment.scores -= getattr(alignment, + "__ins_del_score_offset_internal_attr_") return alignment @@ -1223,5 +1256,6 @@ def union(fsas: Fsa) -> Fsa: need_arc_map = True ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) - out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map) + out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, + arc_map) return out_fsa diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 57525979e..cde9a5382 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -52,6 +52,7 @@ set(py_test_files levenshtein_alignment_test.py levenshtein_graph_test.py linear_fsa_test.py + linear_fsa_with_self_loops_test.py linear_fst_test.py multi_gpu_test.py mutual_information_test.py diff --git a/k2/python/tests/linear_fsa_with_self_loops_test.py b/k2/python/tests/linear_fsa_with_self_loops_test.py new file mode 100644 index 000000000..1e331bbbc --- /dev/null +++ b/k2/python/tests/linear_fsa_with_self_loops_test.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R linear_fsa_self_loops_test_py + +import torch +import k2 +import unittest + + +class TestLinearFsa(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + + def test_single_fsa(self): + for device in self.devices: + labels = [2, 0, 0, 0, 5, 8] + src = k2.linear_fsa(labels, device) + dst = k2.linear_fsa_with_self_loops(src) + assert src.device == dst.device + expected_labels = [0, 2, 0, 5, 0, 8, 0, -1] + assert dst.labels.tolist() == expected_labels + + def test_multiple_fsa(self): + for device in self.devices: + labels = [[2, 0, 0, 0, 5, 0, 0, 0, 8, 0, 0], [1, 2], + [0, 0, 0, 3, 0, 2]] + src = k2.linear_fsa(labels, device) + dst = k2.linear_fsa_with_self_loops(src) + assert src.device == dst.device + expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1] + expected_labels1 = [0, 1, 0, 2, 0, -1] + expected_labels2 = [0, 3, 0, 2, 0, -1] + expected_labels = expected_labels0 + expected_labels1 + expected_labels2 + assert dst.labels.tolist() == expected_labels + + +if __name__ == '__main__': + unittest.main() From a830c607c4f40455087c3a410c56c927c83b2719 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 31 Mar 2022 07:01:09 +0800 Subject: [PATCH 57/64] Fix the pruning with max-states (#941) --- k2/csrc/rnnt_decode.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index 232c0454c..db5e732dd 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -175,11 +175,17 @@ Ragged RnntDecodingStreams::PruneTwice(Ragged &incoming_scores, // problems created by empty lists, although it's perhaps not an optimal way // to prune. - // incoming_scores has a shape of [stream][context][state][arc] + // incoming_scores has a shape of [stream][context][state][arc], we are + // pruning with max-states per stream, so that contexts axis should be + // removed. reduced_incoming_scores has a shape of [stream][state][arc]. + auto reduced_incoming_scores = incoming_scores.RemoveAxis(1); // states_prune is a renumbering on the states axis. - Renumbering states_prune = PruneRagged(incoming_scores, 2 /*axis*/, + Renumbering states_prune = PruneRagged(reduced_incoming_scores, 1 /*axis*/, config_.beam, config_.max_states); + // The new2old indexes in states_prune are global indexes along axis state, + // so we can extract the surviving elements from `incoming_scores` along + // state axis. Array1 arcs_new2old1; Ragged temp_scores = SubsetRagged(incoming_scores, states_prune, 2 /*axis*/, &arcs_new2old1); From 8c28c864f2e0c616f0651e2f7ab353375cb013cb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 3 Apr 2022 15:51:35 +0800 Subject: [PATCH 58/64] Rnnt allow different encoder/decoder dims (#945) * Allow different encoder and decoder dim in rnnt_pruning * Bug fixes --- k2/python/k2/rnnt_loss.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 5918d7b9e..fa030a0a1 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -592,9 +592,9 @@ def do_rnnt_pruning( Args: am: - The encoder output, with shape (B, T, C) + The encoder output, with shape (B, T, encoder_dim) lm: - The prediction network output, with shape (B, S + 1, C) + The prediction network output, with shape (B, S + 1, decoder_dim) ranges: A tensor containing the symbol indexes for each frame that we want to keep. Its shape is (B, T, s_range), see the docs in @@ -603,26 +603,28 @@ def do_rnnt_pruning( Returns: Return the pruned am and lm with shape (B, T, s_range, C) """ - # am (B, T, C) - # lm (B, S + 1, C) + # am (B, T, encoder_dm) + # lm (B, S + 1, decoder_dim) # ranges (B, T, s_range) assert ranges.shape[0] == am.shape[0] assert ranges.shape[0] == lm.shape[0] assert am.shape[1] == ranges.shape[1] (B, T, s_range) = ranges.shape - (B, S1, C) = lm.shape + (B, S1, decoder_dim) = lm.shape + encoder_dim = am.shape[-1] + assert am.shape == (B, T, encoder_dim) S = S1 - 1 - # (B, T, s_range, C) - am_pruning = am.unsqueeze(2).expand((B, T, s_range, C)) + # (B, T, s_range, encoder_dim) + am_pruned = am.unsqueeze(2).expand((B, T, s_range, encoder_dim)) - # (B, T, s_range, C) - lm_pruning = torch.gather( - lm.unsqueeze(1).expand((B, T, S + 1, C)), + # (B, T, s_range, decoder_dim) + lm_pruned = torch.gather( + lm.unsqueeze(1).expand((B, T, S + 1, decoder_dim)), dim=2, - index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)), + index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, decoder_dim)), ) - return am_pruning, lm_pruning + return am_pruned, lm_pruned def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): From d9778653fab7a3a121463174c1816a270a3d14f2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 6 Apr 2022 19:12:13 +0800 Subject: [PATCH 59/64] Supporting building k2 on Windows (#946) --- .github/workflows/build-conda-cpu-macos.yml | 117 ++++++++ ...nda_cpu.yml => build-conda-cpu-ubuntu.yml} | 89 ++---- ...-conda.yml => build-conda-cpu-windows.yml} | 42 +-- .github/workflows/build-conda-cuda-ubuntu.yml | 127 ++++++++ .github/workflows/build-cpu-macos.yml | 123 ++++++++ .../{build-cpu.yml => build-cpu-ubuntu.yml} | 99 +++---- .../{windows.yml => build-cpu-windows.yml} | 69 +++-- .github/workflows/build-cuda-ubuntu.yml | 147 ++++++++++ .github/workflows/build.yml | 274 ------------------ .github/workflows/build_conda.yml | 259 ----------------- .github/workflows/nightly-cpu-macos.yml | 111 +++++++ ...nightly-cpu.yml => nightly-cpu-ubuntu.yml} | 88 ++---- ...ly-windows.yml => nightly-cpu-windows.yml} | 48 +-- .../{nightly.yml => nightly-cuda-ubuntu.yml} | 8 +- .github/workflows/run-tests.yml | 8 +- .../{wheel-cpu.yml => wheel-cpu-macos.yml} | 8 +- ...l-cpu-stable.yml => wheel-cpu-windows.yml} | 29 +- ...wheel-stable.yml => wheel-cuda-ubuntu.yml} | 8 +- .github/workflows/wheel.yml | 99 ------- CMakeLists.txt | 90 ++++-- cmake/moderngpu.cmake | 6 +- docs/source/installation/conda.rst | 2 +- docs/source/installation/for_developers.rst | 4 + docs/source/installation/from_source.rst | 6 +- docs/source/installation/images/README.md | 2 +- .../images/torch_ge_1.6.0-green.svg | 1 + docs/source/installation/index.rst | 6 +- docs/source/installation/pip.rst | 2 +- k2/csrc/CMakeLists.txt | 24 +- k2/csrc/benchmark/CMakeLists.txt | 1 + k2/csrc/fsa.h | 2 +- k2/csrc/host/CMakeLists.txt | 19 +- k2/csrc/log.h | 6 + k2/csrc/log_test.cu | 4 + k2/csrc/macros_test.cu | 4 +- k2/csrc/ragged_ops.cu | 58 +++- k2/csrc/rand_test.cu | 2 +- k2/csrc/rm_epsilon.cu | 8 +- k2/csrc/rnnt_decode.cu | 6 +- k2/csrc/tensor_ops.cu | 57 ++-- k2/csrc/tensor_ops_test.cu | 6 +- k2/csrc/test_utils.h | 7 +- k2/csrc/version.h.in | 6 +- k2/python/csrc/CMakeLists.txt | 13 +- k2/python/csrc/torch.h | 32 -- k2/python/csrc/torch/fsa.cu | 4 +- k2/python/csrc/torch/fsa_algo.cu | 85 ++---- k2/python/csrc/torch/ragged_ops.cu | 7 +- k2/python/csrc/torch/v2/any.cu | 64 ++-- k2/python/csrc/torch/v2/ragged_shape.cu | 10 +- k2/python/host/k2host/fsa.py | 4 +- k2/python/k2/rnnt_decode.py | 2 +- k2/python/k2/rnnt_loss.py | 4 +- .../tests/linear_fsa_with_self_loops_test.py | 2 +- k2/python/tests/mutual_information_test.py | 6 +- .../github_actions/generate_build_matrix.py | 111 +++++++ 56 files changed, 1261 insertions(+), 1165 deletions(-) create mode 100644 .github/workflows/build-conda-cpu-macos.yml rename .github/workflows/{build_conda_cpu.yml => build-conda-cpu-ubuntu.yml} (52%) rename .github/workflows/{windows-conda.yml => build-conda-cpu-windows.yml} (77%) create mode 100644 .github/workflows/build-conda-cuda-ubuntu.yml create mode 100644 .github/workflows/build-cpu-macos.yml rename .github/workflows/{build-cpu.yml => build-cpu-ubuntu.yml} (55%) rename .github/workflows/{windows.yml => build-cpu-windows.yml} (72%) create mode 100644 .github/workflows/build-cuda-ubuntu.yml delete mode 100644 .github/workflows/build.yml delete mode 100644 .github/workflows/build_conda.yml create mode 100644 .github/workflows/nightly-cpu-macos.yml rename .github/workflows/{nightly-cpu.yml => nightly-cpu-ubuntu.yml} (55%) rename .github/workflows/{nightly-windows.yml => nightly-cpu-windows.yml} (76%) rename .github/workflows/{nightly.yml => nightly-cuda-ubuntu.yml} (96%) rename .github/workflows/{wheel-cpu.yml => wheel-cpu-macos.yml} (94%) rename .github/workflows/{wheel-cpu-stable.yml => wheel-cpu-windows.yml} (66%) rename .github/workflows/{wheel-stable.yml => wheel-cuda-ubuntu.yml} (97%) delete mode 100644 .github/workflows/wheel.yml create mode 100644 docs/source/installation/images/torch_ge_1.6.0-green.svg create mode 100755 scripts/github_actions/generate_build_matrix.py diff --git a/.github/workflows/build-conda-cpu-macos.yml b/.github/workflows/build-conda-cpu-macos.yml new file mode 100644 index 000000000..623d3a472 --- /dev/null +++ b/.github/workflows/build-conda-cpu-macos.yml @@ -0,0 +1,117 @@ +# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to https://github.com/actions/starter-workflows/pull/47/files + + +# Note, we have to set +# +# export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.8/site-packages:$DYLD_LIBRARY_PATH +# +# before running `python3 -m k2.version` +# +# See https://github.com/openPMD/openPMD-api/issues/593#issuecomment-552690470 + + +name: build_conda_cpu_macos + +on: + push: + tags: + - '*' + +env: + K2_BUILD_TYPE: Release + +jobs: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + build_conda_cpu_macos: + needs: generate_build_matrix + runs-on: macos-10.15 + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + python-version: ${{ matrix.python-version }} + activate-environment: k2 + + - name: Display Python version + shell: bash -l {0} + run: | + python3 -c "import sys; print(sys.version)" + which python3 + + - name: Install conda dependencies + shell: bash -l {0} + run: | + conda install -y -q anaconda-client + conda install -y -q conda-build + conda install -y -q -c pytorch pytorch=${{ matrix.torch }} cpuonly + + - name: Display conda info + shell: bash -l {0} + run: | + which conda + conda env list + conda info + + - name: Build k2 + shell: bash -l {0} + env: + K2_PYTHON_VERSION: ${{ matrix.python-version}} + K2_TORCH_VERSION: ${{ matrix.torch }} + K2_CONDA_TOKEN: ${{ secrets.K2_CONDA_TOKEN}} + K2_IS_GITHUB_ACTIONS: 1 + K2_IS_FOR_CONDA: 1 + run: | + export K2_BUILD_TYPE=$K2_BUILD_TYPE + ./scripts/build_conda_cpu.sh + + - name: Display generated files + run: | + ls -lh /usr/local/miniconda/envs/k2/conda-bld/osx-64 + + - name: Upload generated files + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }} + path: /usr/local/miniconda/envs/k2/conda-bld/osx-64/*.tar.bz2 diff --git a/.github/workflows/build_conda_cpu.yml b/.github/workflows/build-conda-cpu-ubuntu.yml similarity index 52% rename from .github/workflows/build_conda_cpu.yml rename to .github/workflows/build-conda-cpu-ubuntu.yml index fe3e552ab..72cf5b412 100644 --- a/.github/workflows/build_conda_cpu.yml +++ b/.github/workflows/build-conda-cpu-ubuntu.yml @@ -26,68 +26,41 @@ # See https://github.com/openPMD/openPMD-api/issues/593#issuecomment-552690470 -name: build_conda_cpu +name: build_conda_cpu_ubuntu on: push: - branches: - - conda-cpu + tags: + - '*' env: K2_BUILD_TYPE: Release jobs: - build_conda_cpu: - runs-on: ${{ matrix.os }} + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + build_conda_cpu_ubuntu: + needs: generate_build_matrix + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: - os: [ubuntu-18.04, macos-10.15] - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] - # from https://download.pytorch.org/whl/torch_stable.html - # - # PyTorch 1.11.x supports 3.7, 3.8, 3.9, 3.10 - # PyTorch 1.10, 1.9.x, 1.8.x, and 1.7.1 support 3.6, 3.7, 3.8, 3.9 - # PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8 - # - # Other PyTorch versions are not tested - # - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] - exclude: - - python-version: "3.9" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: "3.9" - torch: "1.5.1" - - python-version: "3.9" - torch: "1.6.0" - - python-version: "3.9" - torch: "1.7.0" - - python-version: "3.10" # exclude Python 3.10 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.5.0" - - python-version: "3.10" - torch: "1.5.1" - - python-version: "3.10" - torch: "1.6.0" - - python-version: "3.10" - torch: "1.7.0" - - python-version: "3.10" - torch: "1.7.1" - - python-version: "3.10" - torch: "1.8.0" - - python-version: "3.10" - torch: "1.8.1" - - python-version: "3.10" - torch: "1.9.0" - - python-version: "3.10" - torch: "1.9.1" - - python-version: "3.10" - torch: "1.10.0" - - python-version: "3.10" - torch: "1.10.1" - - python-version: "3.10" - torch: "1.10.2" - - python-version: "3.6" # exclude Python 3.6 for [1.11.0] - torch: "1.11.0" + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: # refer to https://github.com/actions/checkout @@ -134,25 +107,11 @@ jobs: ./scripts/build_conda_cpu.sh - name: Display generated files - if: startsWith(matrix.os, 'ubuntu') run: | ls -lh /usr/share/miniconda/envs/k2/conda-bld/linux-64 - - name: Display generated files - if: startsWith(matrix.os, 'macos') - run: | - ls -lh /usr/local/miniconda/envs/k2/conda-bld/osx-64 - - name: Upload generated files - if: startsWith(matrix.os, 'ubuntu') uses: actions/upload-artifact@v2 with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }} path: /usr/share/miniconda/envs/k2/conda-bld/linux-64/*.tar.bz2 - - - name: Upload generated files - if: startsWith(matrix.os, 'macos') - uses: actions/upload-artifact@v2 - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }} - path: /usr/local/miniconda/envs/k2/conda-bld/osx-64/*.tar.bz2 diff --git a/.github/workflows/windows-conda.yml b/.github/workflows/build-conda-cpu-windows.yml similarity index 77% rename from .github/workflows/windows-conda.yml rename to .github/workflows/build-conda-cpu-windows.yml index 00bab9adc..551c13ce3 100644 --- a/.github/workflows/windows-conda.yml +++ b/.github/workflows/build-conda-cpu-windows.yml @@ -15,36 +15,42 @@ # limitations under the License. -name: build-windows-conda +name: build_conda_cpu_windows on: push: - branches: - - conda-win + tags: + - '*' env: BUILD_TYPE: Release jobs: - build-windows-conda: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + build_conda_cpu_windows: # see https://github.com/actions/virtual-environments/blob/win19/20210525.0/images/win/Windows2019-Readme.md - runs-on: ${{ matrix.os }} + needs: generate_build_matrix + runs-on: windows-2019 strategy: fail-fast: false matrix: - os: [windows-2019] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0, - python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] - exclude: - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: 3.9 - torch: "1.5.1" - - python-version: 3.9 - torch: "1.6.0" - - python-version: 3.9 - torch: "1.7.0" + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/build-conda-cuda-ubuntu.yml b/.github/workflows/build-conda-cuda-ubuntu.yml new file mode 100644 index 000000000..fa6dca28e --- /dev/null +++ b/.github/workflows/build-conda-cuda-ubuntu.yml @@ -0,0 +1,127 @@ +# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to https://github.com/actions/starter-workflows/pull/47/files + +name: build_conda_cuda_ubuntu + +on: + push: + tags: + - '*' + +env: + K2_BUILD_TYPE: Release + +jobs: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py --enable-cuda) + echo "::set-output name=matrix::${MATRIX}" + + build_conda_cuda_ubuntu: + needs: generate_build_matrix + runs-on: ubuntu-18.04 + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Install CUDA Toolkit ${{ matrix.cuda }} + shell: bash -l {0} + env: + cuda: ${{ matrix.cuda }} + run: | + source ./scripts/github_actions/install_cuda.sh + echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV + echo "${CUDA_HOME}/bin" >> $GITHUB_PATH + echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV + + - name: Display NVCC version + shell: bash -l {0} + run: | + which nvcc + nvcc --version + + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + python-version: ${{ matrix.python-version }} + activate-environment: k2 + + - name: Display Python version + shell: bash -l {0} + run: | + python3 -c "import sys; print(sys.version)" + which python3 + + - name: Install conda dependencies + shell: bash -l {0} + run: | + conda install -y -q anaconda-client + conda install -y -q conda-build + conda install -y -q bs4 requests tqdm + conda install -y -q -c pytorch -c conda-forge pytorch=${{ matrix.torch }} cudatoolkit=${{ matrix.cuda }} + + - name: Display conda info + shell: bash -l {0} + run: | + which conda + conda env list + conda info + nproc + + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + + - name: Download cudnn 8.0 + shell: bash -l {0} + env: + cuda: ${{ matrix.cuda }} + run: | + ./scripts/github_actions/install_cudnn.sh + + - name: Build k2 + shell: bash -l {0} + env: + K2_CUDA_VERSION: ${{ matrix.cuda }} + K2_PYTHON_VERSION: ${{ matrix.python-version}} + K2_TORCH_VERSION: ${{ matrix.torch }} + K2_CONDA_TOKEN: ${{ secrets.K2_CONDA_TOKEN}} + K2_IS_GITHUB_ACTIONS: 1 + K2_IS_FOR_CONDA: 1 + run: | + export K2_BUILD_TYPE=$K2_BUILD_TYPE + ./scripts/build_conda.sh diff --git a/.github/workflows/build-cpu-macos.yml b/.github/workflows/build-cpu-macos.yml new file mode 100644 index 000000000..392683ba8 --- /dev/null +++ b/.github/workflows/build-cpu-macos.yml @@ -0,0 +1,123 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to https://github.com/actions/starter-workflows/pull/47/files + +name: build-cpu-macos + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +env: + BUILD_TYPE: Release + +jobs: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: macos-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + build-cpu-macos: + if: github.event.label.name == 'ready' || github.event_name == 'push' + needs: generate_build_matrix + runs-on: macos-10.15 + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - uses: szenius/set-timezone@v1.0 + with: + timezoneLinux: "Asia/Shanghai" + + - name: Display date and time + run: date + + - name: Display clang version + run: | + clang --version + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Install PyTorch ${{ matrix.torch }} + shell: bash + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine dataclasses + python3 -m pip install -qq torch==${{ matrix.torch }} + + python3 -c "import torch; print('torch version:', torch.__version__)" + + - name: Build k2 + shell: bash + run: | + pwd + mkdir build + cd build + cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. + cat k2/csrc/version.h + cat CMakeCache.txt + + make VERBOSE=1 -j2 + + - name: Run tests + shell: bash + run: | + cd build + ctest --output-on-failure + + - name: Build wheel + shell: bash + run: | + export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" + export K2_MAKE_ARGS="-j2" + python3 setup.py bdist_wheel + ls -lh dist/ + ls -lh build/* + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-macos-10.15-cpu + path: dist/*.whl diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu-ubuntu.yml similarity index 55% rename from .github/workflows/build-cpu.yml rename to .github/workflows/build-cpu-ubuntu.yml index 173e59ac8..3cd7ec443 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu-ubuntu.yml @@ -16,7 +16,7 @@ # refer to https://github.com/actions/starter-workflows/pull/47/files -name: build-cpu +name: build-cpu-ubuntu on: push: @@ -29,52 +29,31 @@ env: BUILD_TYPE: Release jobs: - build-cpu: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + build-cpu-ubuntu: if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} + needs: generate_build_matrix + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: - os: [ubuntu-18.04, macos-10.15] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x, 1.11.x - # Python 3.10 is for PyTorch 1.11.x - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] - exclude: - - python-version: "3.10" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.5.0" - - python-version: "3.10" - torch: "1.5.1" - - python-version: "3.10" - torch: "1.6.0" - - python-version: "3.10" - torch: "1.7.0" - - python-version: "3.10" - torch: "1.7.1" - - python-version: "3.10" - torch: "1.8.0" - - python-version: "3.10" - torch: "1.8.1" - - python-version: "3.10" - torch: "1.9.0" - - python-version: "3.10" - torch: "1.9.1" - - python-version: "3.10" - torch: "1.10.0" - - python-version: "3.10" - torch: "1.10.1" - - python-version: "3.10" - torch: "1.10.2" - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: 3.9 - torch: "1.5.1" - - python-version: 3.9 - torch: "1.6.0" - - python-version: 3.9 - torch: "1.7.0" - - python-version: 3.6 # exclude Python 3.6 for [1.11.0] - torch: "1.11.0" + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: # refer to https://github.com/actions/checkout @@ -90,17 +69,11 @@ jobs: run: date - name: Install GCC 7 - if: startsWith(matrix.os, 'ubuntu') run: | sudo apt-get install -y gcc-7 g++-7 echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV - - name: Display clang version - if: startsWith(matrix.os, 'macos') - run: | - clang --version - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -110,26 +83,15 @@ jobs: run: python -c "import sys; print(sys.version)" - name: Install PyTorch ${{ matrix.torch }} - if: startsWith(matrix.os, 'ubuntu') shell: bash run: | python3 -m pip install -qq --upgrade pip - python3 -m pip install -qq wheel twine typing_extensions + python3 -m pip install -qq wheel twine typing_extensions dataclasses python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -c "import torch; print('torch version:', torch.__version__)" - - name: Install PyTorch ${{ matrix.torch }} - if: startsWith(matrix.os, 'macos') - shell: bash - run: | - python3 -m pip install -qq --upgrade pip - python3 -m pip install -qq wheel twine - python3 -m pip install -qq torch==${{ matrix.torch }} - - python3 -c "import torch; print('torch version:', torch.__version__)" - - - name: Configure CMake + - name: Build k2 shell: bash run: | pwd @@ -137,8 +99,17 @@ jobs: cd build cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. cat k2/csrc/version.h + cat CMakeCache.txt - - name: Build k2 + make VERBOSE=1 -j2 + + - name: Run tests + shell: bash + run: | + cd build + ctest --output-on-failure + + - name: Build wheel shell: bash run: | export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" @@ -150,5 +121,5 @@ jobs: - name: Upload Wheel uses: actions/upload-artifact@v2 with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }}-cpu + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu path: dist/*.whl diff --git a/.github/workflows/windows.yml b/.github/workflows/build-cpu-windows.yml similarity index 72% rename from .github/workflows/windows.yml rename to .github/workflows/build-cpu-windows.yml index 7890fb805..e622e9a37 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/build-cpu-windows.yml @@ -15,7 +15,7 @@ # limitations under the License. -name: build-windows +name: build-cpu-windows on: push: @@ -28,26 +28,32 @@ env: BUILD_TYPE: Release jobs: - build-windows: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + build-cpu-windows: # see https://github.com/actions/virtual-environments/blob/win19/20210525.0/images/win/Windows2019-Readme.md if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} + needs: generate_build_matrix + runs-on: windows-2019 strategy: fail-fast: false matrix: - os: [windows-2019] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0 - python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] - exclude: - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: 3.9 - torch: "1.5.1" - - python-version: 3.9 - torch: "1.6.0" - - python-version: 3.9 - torch: "1.7.0" + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - uses: actions/checkout@v2 @@ -68,8 +74,8 @@ jobs: - name: Install PyTorch ${{ matrix.torch }} run: | - pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html - pip3 install -qq wheel twine dataclasses numpy typing_extensions + pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html numpy + pip3 install -qq wheel twine dataclasses typing_extensions python3 -m torch.utils.collect_env @@ -85,18 +91,17 @@ jobs: cd build_release cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. ls -lh + cat k2/csrc/version.h + cat CMakeCache.txt - name: Build k2 - run: | - cd build_release - cmake --build . --target _k2 --config Release - - - name: Display generated files shell: bash run: | cd build_release - ls -lh bin/*/* + cmake --build . --target _k2 --config Release -- -m + cmake --build . --target ALL_BUILD --config Release ls -lh lib/*/* + ls -lh bin/*/* - name: Build wheel shell: bash @@ -106,15 +111,15 @@ jobs: ls -lh dist/ pip install ./dist/*.whl - - name: Upload Wheel - uses: actions/upload-artifact@v2 - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }}-cpu - path: dist/*.whl - - - name: Run C++ tests + - name: Run tests + shell: bash run: | cd build_release - cmake --build . --target ALL_BUILD --config Release # disable python tests for k2host ctest -C Release --output-on-failure -E host + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-2019-cpu + path: dist/*.whl diff --git a/.github/workflows/build-cuda-ubuntu.yml b/.github/workflows/build-cuda-ubuntu.yml new file mode 100644 index 000000000..adddf92e1 --- /dev/null +++ b/.github/workflows/build-cuda-ubuntu.yml @@ -0,0 +1,147 @@ +# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to https://github.com/actions/starter-workflows/pull/47/files + +name: build-cuda-ubuntu + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +env: + BUILD_TYPE: Release + +jobs: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py --enable-cuda + MATRIX=$(python scripts/github_actions/generate_build_matrix.py --enable-cuda --test-only-latest-torch) + echo "::set-output name=matrix::${MATRIX}" + + build-cuda-ubuntu: + if: github.event.label.name == 'ready' || github.event_name == 'push' + needs: generate_build_matrix + runs-on: ubuntu-18.04 + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - uses: szenius/set-timezone@v1.0 + with: + timezoneLinux: "Asia/Shanghai" + + - name: Display date and time + run: date + + - name: Install CUDA Toolkit ${{ matrix.cuda }} + env: + cuda: ${{ matrix.cuda }} + run: | + source ./scripts/github_actions/install_cuda.sh + echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV + echo "${CUDA_HOME}/bin" >> $GITHUB_PATH + echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV + shell: bash + + - name: Display NVCC version + run: | + which nvcc + nvcc --version + + - name: Install GCC 7 + run: | + sudo apt-get install -y gcc-7 g++-7 + echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV + echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV + echo "CUDAHOSTCXX=/usr/bin/g++-7" >> $GITHUB_ENV + + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Install PyTorch ${{ matrix.torch }} + env: + cuda: ${{ matrix.cuda }} + torch: ${{ matrix.torch }} + shell: bash + run: | + python3 -m pip install -q --upgrade pip + python3 -m pip install -q wheel twine typing_extensions + python3 -m pip install -q bs4 requests tqdm + + ./scripts/github_actions/install_torch.sh + python3 -c "import torch; print('torch version:', torch.__version__)" + + - name: Download cudnn 8.0 + env: + cuda: ${{ matrix.cuda }} + run: | + ./scripts/github_actions/install_cudnn.sh + + - name: Configure CMake + shell: bash + run: | + pwd + mkdir build + cd build + cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE .. + cat k2/csrc/version.h + cat CMakeCache.txt + + - name: Build k2 + shell: bash + run: | + export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE" + export K2_MAKE_ARGS="-j2" + python3 setup.py bdist_wheel + ls -lh dist/ + ls -lh build/* + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: gcc-7-cuda-${{ matrix.cuda }}-torch-${{ matrix.torch }}-python-${{ matrix.python-version }} + path: dist/*.whl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index 374c8bccf..000000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# refer to https://github.com/actions/starter-workflows/pull/47/files - -name: build - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -env: - BUILD_TYPE: Release - -jobs: - build: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-18.04] - # from https://download.pytorch.org/whl/torch_stable.html - # Note: There are no torch versions for CUDA 11.2 - # - # 1.11.x supports: cuda10.2 (default), 11.3, 11.5 - # 1.10.x supports: cuda10.2 (default), 11.1, 11.3 - # 1.9.x supports: cuda10.2 (default), 11.1 - # PyTorch 1.8.x supports: cuda 10.1, 10.2 (default), 11.1 - # PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0 - # PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default) - # PyTorch 1.5.x supports: cuda 10.1, 10.2 (default) - # Other PyTorch versions are not tested - # CUDA 10.1 is for 1.5.x, 1.6.0, 1.7.x, 1.8.x - # CUDA 11.1 is for torch 1.8.x, 1.9.x, 1.10.x - # CUDA 11.3 is for torch 1.10, 1.11.x - # CUDA 11.5 is for torch 1.11.x - cuda: ["10.1", "10.2", "11.0", "11.1", "11.3", "11.5"] - gcc: ["7"] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] - # - # torch 1.11.x does not support Python 3.6 - # From torch 1.11.x, it supports Python 3.10 - # Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10.x, 11.x - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] - exclude: - - cuda: "11.5" # exclude 11.5 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.5.0" - - cuda: "11.5" - torch: "1.5.1" - - cuda: "11.5" - torch: "1.6.0" - - cuda: "11.5" - torch: "1.7.0" - - cuda: "11.5" - torch: "1.7.1" - - cuda: "11.5" - torch: "1.8.0" - - cuda: "11.5" - torch: "1.8.1" - - cuda: "11.5" - torch: "1.9.0" - - cuda: "11.5" - torch: "1.9.1" - - cuda: "11.5" - torch: "1.10.0" - - cuda: "11.5" - torch: "1.10.1" - - cuda: "11.5" - torch: "1.10.2" - - cuda: "11.3" # exclude 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1] - torch: "1.5.0" - - cuda: "11.3" - torch: "1.5.1" - - cuda: "11.3" - torch: "1.6.0" - - cuda: "11.3" - torch: "1.7.0" - - cuda: "11.3" - torch: "1.7.1" - - cuda: "11.3" - torch: "1.8.0" - - cuda: "11.3" - torch: "1.8.1" - - cuda: "11.3" - torch: "1.9.0" - - cuda: "11.3" - torch: "1.9.1" - - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0] - torch: "1.5.0" - - cuda: "11.0" - torch: "1.5.1" - - cuda: "11.0" - torch: "1.6.0" - - cuda: "11.0" - torch: "1.8.0" - - cuda: "11.0" - torch: "1.8.1" - - cuda: "11.0" - torch: "1.9.0" - - cuda: "11.0" - torch: "1.9.1" - - cuda: "11.0" - torch: "1.10.0" - - cuda: "11.0" - torch: "1.10.1" - - cuda: "11.0" - torch: "1.10.2" - - cuda: "11.0" - torch: "1.11.0" - - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.11.0] - torch: "1.5.0" - - cuda: "11.1" - torch: "1.5.1" - - cuda: "11.1" - torch: "1.6.0" - - cuda: "11.1" - torch: "1.7.0" - - cuda: "11.1" - torch: "1.7.1" - - cuda: "11.1" - torch: "1.11.0" - - cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10.0, 10.1, 10.2, 1.11.0] - torch: "1.9.0" - - cuda: "10.1" - torch: "1.9.1" - - cuda: "10.1" - torch: "1.10.0" - - cuda: "10.1" - torch: "1.10.1" - - cuda: "10.1" - torch: "1.10.2" - - cuda: "10.1" - torch: "1.11.0" - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: 3.9 - torch: "1.5.1" - - python-version: 3.9 - torch: "1.6.0" - - python-version: 3.9 - torch: "1.7.0" - - python-version: "3.10" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.5.0" - - python-version: "3.10" - torch: "1.5.1" - - python-version: "3.10" - torch: "1.6.0" - - python-version: "3.10" - torch: "1.7.0" - - python-version: "3.10" - torch: "1.7.1" - - python-version: "3.10" - torch: "1.8.0" - - python-version: "3.10" - torch: "1.8.1" - - python-version: "3.10" - torch: "1.9.0" - - python-version: "3.10" - torch: "1.9.1" - - python-version: "3.10" - torch: "1.10.0" - - python-version: "3.10" - torch: "1.10.1" - - python-version: "3.10" - torch: "1.10.2" - - python-version: "3.6" # exclude Python 3.6 for [1.11.0] - torch: "1.11.0" - - steps: - # refer to https://github.com/actions/checkout - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - uses: szenius/set-timezone@v1.0 - with: - timezoneLinux: "Asia/Shanghai" - - - name: Display date and time - run: date - - - name: Install CUDA Toolkit ${{ matrix.cuda }} - env: - cuda: ${{ matrix.cuda }} - run: | - source ./scripts/github_actions/install_cuda.sh - echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV - echo "${CUDA_HOME}/bin" >> $GITHUB_PATH - echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV - shell: bash - - - name: Display NVCC version - run: | - which nvcc - nvcc --version - - - name: Install GCC ${{ matrix.gcc }} - run: | - sudo apt-get install -y gcc-${{ matrix.gcc }} g++-${{ matrix.gcc }} - echo "CC=/usr/bin/gcc-${{ matrix.gcc }}" >> $GITHUB_ENV - echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV - echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV - - - name: Install git lfs - run: | - sudo apt-get install -y git-lfs - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - - name: Install PyTorch ${{ matrix.torch }} - env: - cuda: ${{ matrix.cuda }} - torch: ${{ matrix.torch }} - shell: bash - run: | - python3 -m pip install -q --upgrade pip - python3 -m pip install -q wheel twine typing_extensions - python3 -m pip install -q bs4 requests tqdm - - ./scripts/github_actions/install_torch.sh - python3 -c "import torch; print('torch version:', torch.__version__)" - - - name: Download cudnn 8.0 - env: - cuda: ${{ matrix.cuda }} - run: | - ./scripts/github_actions/install_cudnn.sh - - - name: Configure CMake - shell: bash - run: | - pwd - mkdir build - cd build - cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE .. - cat k2/csrc/version.h - - - name: Build k2 - shell: bash - run: | - export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE" - export K2_MAKE_ARGS="-j2" - python3 setup.py bdist_wheel - ls -lh dist/ - ls -lh build/* - - - name: Upload Wheel - uses: actions/upload-artifact@v2 - with: - name: gcc-${{ matrix.gcc }}-cuda-${{ matrix.cuda }}-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }} - path: dist/*.whl diff --git a/.github/workflows/build_conda.yml b/.github/workflows/build_conda.yml deleted file mode 100644 index b8107d0c6..000000000 --- a/.github/workflows/build_conda.yml +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# refer to https://github.com/actions/starter-workflows/pull/47/files - -name: build_conda_cuda - -on: - push: - branches: - - conda-cuda - -env: - K2_BUILD_TYPE: Release - -jobs: - build_conda_cuda: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-18.04] - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] - cuda: ["10.1", "10.2", "11.0", "11.1", "11.3", "11.5"] - # from https://download.pytorch.org/whl/torch_stable.html - # Note: There are no torch versions for CUDA 11.2 - # - # 1.11.x supports: cuda10.2 (default), 11.3, 11.5 - # PyTorch 1.10.x supports: 10.2 (default), 11.1, 11.3 - # PyTorch 1.9.x supports: 10.2 (default), 11.1 - # PyTorch 1.8.1 supports: cuda 10.1, 10.2 (default), 11.1 - # PyTorch 1.8.0 supports: cuda 10.1, 10.2 (default), 11.1 - # PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0, 9.2 (not included in this setup) - # PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup) - # PyTorch 1.5.x supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup) - # - # PyTorch 1.11.x supports Python 3.10 - # PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x, and 1.11.x support 3.6, 3.7, 3.8, 3.9 - # PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8 - # - # Other PyTorch versions are not tested - # - # torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1"] - # 1.5.x is removed because there are compilation errors. - # See - # https://github.com/csukuangfj/k2/runs/2533830771?check_suite_focus=true - # and - # https://github.com/NVIDIA/apex/issues/805 - torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] - exclude: - - cuda: "11.5" # exclude cuda 11.5 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.5.0" - - cuda: "11.5" - torch: "1.5.1" - - cuda: "11.5" - torch: "1.6.0" - - cuda: "11.5" - torch: "1.7.0" - - cuda: "11.5" - torch: "1.7.1" - - cuda: "11.5" - torch: "1.8.0" - - cuda: "11.5" - torch: "1.8.1" - - cuda: "11.5" - torch: "1.9.0" - - cuda: "11.5" - torch: "1.9.1" - - cuda: "11.5" - torch: "1.10.0" - - cuda: "11.5" - torch: "1.10.1" - - cuda: "11.5" - torch: "1.10.2" - - cuda: "11.3" # exclude cuda 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1] - torch: "1.5.0" - - cuda: "11.3" - torch: "1.5.1" - - cuda: "11.3" - torch: "1.6.0" - - cuda: "11.3" - torch: "1.7.0" - - cuda: "11.3" - torch: "1.7.1" - - cuda: "11.3" - torch: "1.8.0" - - cuda: "11.3" - torch: "1.8.1" - - cuda: "11.3" - torch: "1.9.0" - - cuda: "11.3" - torch: "1.9.1" - # - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0] - # torch: "1.5.0" - # - cuda: "11.0" - # torch: "1.5.1" - - cuda: "11.0" - torch: "1.6.0" - - cuda: "11.0" - torch: "1.8.0" - - cuda: "11.0" - torch: "1.8.1" - - cuda: "11.0" - torch: "1.9.0" - - cuda: "11.0" - torch: "1.9.1" - - cuda: "11.0" - torch: "1.10.0" - - cuda: "11.0" - torch: "1.10.1" - - cuda: "11.0" - torch: "1.10.2" - - cuda: "11.0" - torch: "1.11.0" - # - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.11.0] - # torch: "1.5.0" - # - cuda: "11.1" - # torch: "1.5.1" - - cuda: "11.1" - torch: "1.6.0" - - cuda: "11.1" - torch: "1.7.0" - - cuda: "11.1" - torch: "1.7.1" - - cuda: "11.1" - torch: "1.11.0" - - cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0] - torch: "1.9.0" - - cuda: "10.1" - torch: "1.9.1" - - cuda: "10.1" - torch: "1.10.0" - - cuda: "10.1" - torch: "1.10.1" - - cuda: "10.1" - torch: "1.10.2" - - cuda: "10.1" - torch: "1.11.0" - - python-version: "3.9" # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: "3.9" - torch: "1.5.1" - - python-version: "3.9" - torch: "1.6.0" - - python-version: "3.9" - torch: "1.7.0" - - python-version: "3.10" # exclude Python 3.10 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.5.0" - - python-version: "3.10" - torch: "1.5.1" - - python-version: "3.10" - torch: "1.6.0" - - python-version: "3.10" - torch: "1.7.0" - - python-version: "3.10" - torch: "1.7.1" - - python-version: "3.10" - torch: "1.8.0" - - python-version: "3.10" - torch: "1.8.1" - - python-version: "3.10" - torch: "1.9.0" - - python-version: "3.10" - torch: "1.9.1" - - python-version: "3.10" - torch: "1.10.0" - - python-version: "3.10" - torch: "1.10.1" - - python-version: "3.10" - torch: "1.10.2" - - python-version: "3.6" # exclude Python 3.6 for [1.11.0] - torch: "1.11.0" - - steps: - # refer to https://github.com/actions/checkout - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Install CUDA Toolkit ${{ matrix.cuda }} - shell: bash -l {0} - env: - cuda: ${{ matrix.cuda }} - run: | - source ./scripts/github_actions/install_cuda.sh - echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV - echo "${CUDA_HOME}/bin" >> $GITHUB_PATH - echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV - - - name: Display NVCC version - shell: bash -l {0} - run: | - which nvcc - nvcc --version - - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: ${{ matrix.python-version }} - activate-environment: k2 - - - name: Display Python version - shell: bash -l {0} - run: | - python3 -c "import sys; print(sys.version)" - which python3 - - - name: Install conda dependencies - shell: bash -l {0} - run: | - conda install -y -q anaconda-client - conda install -y -q conda-build - conda install -y -q bs4 requests tqdm - conda install -y -q -c pytorch -c conda-forge pytorch=${{ matrix.torch }} cudatoolkit=${{ matrix.cuda }} - - - name: Display conda info - shell: bash -l {0} - run: | - which conda - conda env list - conda info - nproc - - - name: Install git lfs - run: | - sudo apt-get install -y git-lfs - - - name: Download cudnn 8.0 - shell: bash -l {0} - env: - cuda: ${{ matrix.cuda }} - run: | - ./scripts/github_actions/install_cudnn.sh - - - name: Build k2 - shell: bash -l {0} - env: - K2_CUDA_VERSION: ${{ matrix.cuda }} - K2_PYTHON_VERSION: ${{ matrix.python-version}} - K2_TORCH_VERSION: ${{ matrix.torch }} - K2_CONDA_TOKEN: ${{ secrets.K2_CONDA_TOKEN}} - K2_IS_GITHUB_ACTIONS: 1 - K2_IS_FOR_CONDA: 1 - run: | - export K2_BUILD_TYPE=$K2_BUILD_TYPE - ./scripts/build_conda.sh diff --git a/.github/workflows/nightly-cpu-macos.yml b/.github/workflows/nightly-cpu-macos.yml new file mode 100644 index 000000000..0b6d773cc --- /dev/null +++ b/.github/workflows/nightly-cpu-macos.yml @@ -0,0 +1,111 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: nightly_cpu_macos + +on: + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 14:00 UTC time every day + - cron: "0 14 * * *" + +env: + BUILD_TYPE: Release + +jobs: + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + nightly_cpu_macos: + needs: generate_build_matrix + runs-on: macos-10.15 + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Display date and time + run: date + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Display GCC version + run: | + gcc --version + + - name: Display clang version + run: | + clang --version + + - name: Install PyTorch ${{ matrix.torch }} + shell: bash + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine + python3 -m pip install -qq torch==${{ matrix.torch }} + python3 -m pip install --upgrade numpy + + - name: Build pip packages + shell: bash + run: | + export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" + export K2_MAKE_ARGS="-j2" + python3 setup.py bdist_wheel + ls -lh dist/ + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-macos-10.15 + path: dist/*.whl + + - name: Copy wheels to k2-fsa.org + run: | + user=${{ secrets.K2_USERNAME }} + server=${{ secrets.K2_HOST }} + port=${{ secrets.K2_PORT }} + echo "${{ secrets.K2_KEY }}" > id_rsa && chmod 600 id_rsa + scp -P $port -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i id_rsa dist/*.whl $user@$server:~/nightly/whl + rm id_rsa diff --git a/.github/workflows/nightly-cpu.yml b/.github/workflows/nightly-cpu-ubuntu.yml similarity index 55% rename from .github/workflows/nightly-cpu.yml rename to .github/workflows/nightly-cpu-ubuntu.yml index 8fdc6d0a6..b371af8d5 100644 --- a/.github/workflows/nightly-cpu.yml +++ b/.github/workflows/nightly-cpu-ubuntu.yml @@ -14,12 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: nightly-cpu +name: nightly_cpu_ubuntu on: - push: - branches: - - nightly-cpu schedule: # minute (0-59) # hour (0-23) @@ -33,54 +30,30 @@ env: BUILD_TYPE: Release jobs: - nightly-cpu: - runs-on: ${{ matrix.os }} + generate_build_matrix: + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + nightly_cpu_ubuntu: + needs: generate_build_matrix + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: - os: [ubuntu-18.04, macos-10.15] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x, 1.11.x - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] - torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2", "1.11.0"] - exclude: - - python-version: "3.9" # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.4.0" - - python-version: "3.9" - torch: "1.5.0" - - python-version: "3.9" - torch: "1.5.1" - - python-version: "3.9" - torch: "1.6.0" - - python-version: "3.9" - torch: "1.7.0" - - python-version: "3.10" # exclude Python 3.10 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2] - torch: "1.4.0" - - python-version: "3.10" - torch: "1.5.0" - - python-version: "3.10" - torch: "1.5.1" - - python-version: "3.10" - torch: "1.6.0" - - python-version: "3.10" - torch: "1.7.0" - - python-version: "3.10" - torch: "1.7.1" - - python-version: "3.10" - torch: "1.8.0" - - python-version: "3.10" - torch: "1.8.1" - - python-version: "3.10" - torch: "1.9.0" - - python-version: "3.10" - torch: "1.9.1" - - python-version: "3.10" - torch: "1.10.0" - - python-version: "3.10" - torch: "1.10.1" - - python-version: "3.10" - torch: "1.10.2" - - python-version: "3.6" # exclude Python 3.6 for [1.11.0] - torch: "1.11.0" + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - uses: actions/checkout@v2 @@ -109,13 +82,7 @@ jobs: run: | gcc --version - - name: Display clang version - if: startsWith(matrix.os, 'macos') - run: | - clang --version - - name: Install PyTorch ${{ matrix.torch }} - if: startsWith(matrix.os, 'ubuntu') shell: bash run: | python3 -m pip install -qq --upgrade pip @@ -125,15 +92,6 @@ jobs: python3 -c "import torch; print('torch version:', torch.__version__)" - - name: Install PyTorch ${{ matrix.torch }} - if: startsWith(matrix.os, 'macos') - shell: bash - run: | - python3 -m pip install -qq --upgrade pip - python3 -m pip install -qq wheel twine - python3 -m pip install -qq torch==${{ matrix.torch }} - python3 -m pip install --upgrade numpy - - name: Build pip packages shell: bash run: | @@ -145,7 +103,7 @@ jobs: - name: Upload Wheel uses: actions/upload-artifact@v2 with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }} + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04 path: dist/*.whl - name: Copy wheels to k2-fsa.org diff --git a/.github/workflows/nightly-windows.yml b/.github/workflows/nightly-cpu-windows.yml similarity index 76% rename from .github/workflows/nightly-windows.yml rename to .github/workflows/nightly-cpu-windows.yml index 42fa8b7bf..f23e6c801 100644 --- a/.github/workflows/nightly-windows.yml +++ b/.github/workflows/nightly-cpu-windows.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: nightly-windows +name: nightly_cpu_windows on: schedule: @@ -30,24 +30,29 @@ env: BUILD_TYPE: Release jobs: - nightly-windows: - runs-on: ${{ matrix.os }} + generate_build_matrix: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python scripts/github_actions/generate_build_matrix.py + MATRIX=$(python scripts/github_actions/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + nightly_cpu_windows: + needs: generate_build_matrix + runs-on: windows-2019 strategy: fail-fast: false matrix: - os: [windows-2019] - # Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0 - python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"] - exclude: - - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0] - torch: "1.5.0" - - python-version: 3.9 - torch: "1.5.1" - - python-version: 3.9 - torch: "1.6.0" - - python-version: 3.9 - torch: "1.7.0" + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - uses: actions/checkout@v2 @@ -68,7 +73,7 @@ jobs: - name: Install PyTorch ${{ matrix.torch }} run: | - pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html numpy pip3 install -qq wheel twine dataclasses typing_extensions python3 -m torch.utils.collect_env @@ -85,17 +90,20 @@ jobs: cd build_release cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. ls -lh + cat k2/csrc/version.h + cat CMakeCache.txt - name: Build k2 run: | cd build_release - cmake --build . --target _k2 --config Release + cmake --build . --target _k2 --config Release -- -m + ls -lh lib/*/* + ls -lh bin/*/* - name: Display generated files shell: bash run: | cd build_release - ls -lh bin/*/* ls -lh lib/*/* - name: Build wheel @@ -108,7 +116,7 @@ jobs: - name: Upload Wheel uses: actions/upload-artifact@v2 with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }}-cpu + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-cpu path: dist/*.whl - name: Copy wheels to k2-fsa.org diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly-cuda-ubuntu.yml similarity index 96% rename from .github/workflows/nightly.yml rename to .github/workflows/nightly-cuda-ubuntu.yml index 0cc6cb1be..7af8d657d 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly-cuda-ubuntu.yml @@ -1,9 +1,6 @@ -name: nightly +name: nightly-cuda-ubuntu on: - push: - branches: - - nightly schedule: # minute (0-59) # hour (0-23) @@ -18,11 +15,10 @@ env: jobs: nightly: - runs-on: ${{ matrix.os }} + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: - os: [ubuntu-18.04] cuda: ["10.1", "10.2", "11.0"] gcc: ["7"] torch: ["1.7.1"] diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index dbe73a0c8..2a56f6a43 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -33,10 +33,10 @@ jobs: fail-fast: false matrix: os: [ubuntu-18.04] - cuda: ["11.1"] - gcc: ["5"] - torch: ["1.9.0"] - python-version: [3.9] + cuda: ["10.2"] + gcc: ["7"] + torch: ["1.11.0"] + python-version: ["3.10"] build_type: ["Release", "Debug"] steps: diff --git a/.github/workflows/wheel-cpu.yml b/.github/workflows/wheel-cpu-macos.yml similarity index 94% rename from .github/workflows/wheel-cpu.yml rename to .github/workflows/wheel-cpu-macos.yml index 007688c8a..74bdfb496 100644 --- a/.github/workflows/wheel-cpu.yml +++ b/.github/workflows/wheel-cpu-macos.yml @@ -1,11 +1,11 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -name: Publish to PyPI macOS +name: Publish to PyPI - macOS CPU on: push: - branches: - - wheel + tags: + - '*' env: BUILD_TYPE: Release @@ -34,7 +34,6 @@ jobs: run: python -c "import sys; print(sys.version)" - name: Install PyTorch ${{ matrix.torch }} - if: startsWith(matrix.os, 'macos') shell: bash run: | python3 -m pip install -qq --upgrade pip @@ -46,6 +45,7 @@ jobs: shell: bash env: K2_IS_FOR_PYPI: 1 + K2_IS_STABLE: 1 run: | tag=$(python3 -c "import sys; print(''.join(sys.version[:3].split('.')))") export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE" diff --git a/.github/workflows/wheel-cpu-stable.yml b/.github/workflows/wheel-cpu-windows.yml similarity index 66% rename from .github/workflows/wheel-cpu-stable.yml rename to .github/workflows/wheel-cpu-windows.yml index a87ee808b..40ce800d6 100644 --- a/.github/workflows/wheel-cpu-stable.yml +++ b/.github/workflows/wheel-cpu-windows.yml @@ -1,22 +1,23 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -name: Publish to PyPI macOS - stable +name: Publish to PyPI - Windows CPU on: push: - branches: - - wheel-stable + tags: + - '*' env: BUILD_TYPE: Release jobs: - PyPI-macos-cpu: + PyPI-windows-cpu: + if: ${{ false }} # Disable it at present. Users can install it from https://k2-fsa.org/nightly/index.html runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: [macos-10.15] + os: [windows-2019] torch: ["1.7.1"] python-version: [3.6, 3.7, 3.8] @@ -25,6 +26,10 @@ jobs: with: fetch-depth: 0 + # see https://github.com/microsoft/setup-msbuild + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v1.0.2 + - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -34,12 +39,11 @@ jobs: run: python -c "import sys; print(sys.version)" - name: Install PyTorch ${{ matrix.torch }} - if: startsWith(matrix.os, 'macos') - shell: bash run: | - python3 -m pip install -qq --upgrade pip - python3 -m pip install -q wheel twine typing_extensions - python3 -m pip install -qq torch==${{ matrix.torch }} + pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html numpy + pip3 install -qq wheel twine dataclasses typing_extensions + + python3 -m torch.utils.collect_env - name: Build pip packages @@ -49,15 +53,14 @@ jobs: K2_IS_STABLE: 1 run: | tag=$(python3 -c "import sys; print(''.join(sys.version[:3].split('.')))") - export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE" - export K2_MAKE_ARGS="-j2" + export K2_CMAKE_ARGS="-DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=$BUILD_TYPE" python3 setup.py bdist_wheel --python-tag=py${tag} ls -lh dist/ - name: Upload Wheel uses: actions/upload-artifact@v2 with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }}-cpu + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-cpu path: dist/*.whl - name: Publish wheels to PyPI diff --git a/.github/workflows/wheel-stable.yml b/.github/workflows/wheel-cuda-ubuntu.yml similarity index 97% rename from .github/workflows/wheel-stable.yml rename to .github/workflows/wheel-cuda-ubuntu.yml index f142c2910..7888d9fa0 100644 --- a/.github/workflows/wheel-stable.yml +++ b/.github/workflows/wheel-cuda-ubuntu.yml @@ -1,17 +1,17 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -name: Publish to PyPI - stable +name: Publish to PyPI - Ubuntu CUDA on: push: - branches: - - wheel-stable + tags: + - '*' env: BUILD_TYPE: Release jobs: - PyPI: + PyPI_CUDA_Ubuntu: runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml deleted file mode 100644 index 74c46595a..000000000 --- a/.github/workflows/wheel.yml +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -name: Publish to PyPI - -on: - push: - branches: - - wheel - -env: - BUILD_TYPE: Release - -jobs: - PyPI: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-18.04] - cuda: ["10.1"] - gcc: ["5"] - torch: ["1.7.1"] - python-version: [3.6, 3.7, 3.8] - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - - name: Install CUDA Toolkit ${{ matrix.cuda }} - env: - cuda: ${{ matrix.cuda }} - run: | - source ./scripts/github_actions/install_cuda.sh - echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV - echo "${CUDA_HOME}/bin" >> $GITHUB_PATH - echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV - shell: bash - - - name: Display NVCC version - run: | - which nvcc - nvcc --version - - - name: Install GCC ${{ matrix.gcc }} - run: | - sudo apt-get install -y gcc-${{ matrix.gcc }} g++-${{ matrix.gcc }} - echo "CC=/usr/bin/gcc-${{ matrix.gcc }}" >> $GITHUB_ENV - echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV - echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV - - - name: Install PyTorch ${{ matrix.torch }} - env: - cuda: ${{ matrix.cuda }} - torch: ${{ matrix.torch }} - shell: bash - run: | - python3 -m pip install --upgrade pip - python3 -m pip install wheel twine typing_extensions - python3 -m pip install bs4 requests tqdm - - ./scripts/github_actions/install_torch.sh - python3 -c "import torch; print('torch version:', torch.__version__)" - - - name: Install git lfs - run: | - sudo apt-get install -y git-lfs - - - name: Download cudnn 8.0 - env: - cuda: ${{ matrix.cuda }} - run: | - ./scripts/github_actions/install_cudnn.sh - - - name: Build pip packages - shell: bash - env: - K2_IS_FOR_PYPI: 1 - run: | - tag=$(python3 -c "import sys; print(''.join(sys.version[:3].split('.')))") - export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE" - export K2_MAKE_ARGS="-j2" - python3 setup.py bdist_wheel --python-tag=py${tag} - ls -lh dist/ - - - name: Publish wheels to PyPI - env: - TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} - TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} - run: | - twine upload dist/k2-*.whl diff --git a/CMakeLists.txt b/CMakeLists.txt index 649d57136..7adbede8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,11 +73,12 @@ option(BUILD_SHARED_LIBS "Whether to build shared or static lib" ON) option(K2_USE_PYTORCH "Whether to build with PyTorch" ON) option(K2_ENABLE_BENCHMARK "Whether to enable benchmark" ON) option(K2_WITH_CUDA "Whether to build k2 with CUDA" ${_K2_WITH_CUDA}) +option(K2_ENABLE_NVTX "Whether to build k2 with the NVTX library" ON) -# If K2_WITH_CUDA is ON, then K2_ENABLE_NVTX has a default value ON -# If K2_WITH_CUDA is OFF, then K2_ENABLE_NVTX is set to OFF -include(CMakeDependentOption) -cmake_dependent_option(K2_ENABLE_NVTX "Whether to build with the NVTX library" ON K2_WITH_CUDA OFF) +if(NOT K2_WITH_CUDA) + message(STATUS "Set K2_ENABLE_NVTX to OFF since K2_WITH_CUDA is OFF") + set(K2_ENABLE_NVTX OFF CACHE BOOL "" FORCE) +endif() if(NOT K2_USE_PYTORCH) message(FATAL_ERROR "\ @@ -210,7 +211,16 @@ if(K2_WITH_CUDA) # https://www.myzhar.com/blog/tutorials/tutorial-nvidia-gpu-cuda-compute-capability/ set(K2_COMPUTE_ARCH_CANDIDATES 35 50 60 61 70 75) if(CUDA_VERSION VERSION_GREATER "11.0") - list(APPEND K2_COMPUTE_ARCH_CANDIDATES 80 86) + list(APPEND K2_COMPUTE_ARCH_CANDIDATES 80 86) + if(WIN32) + # To fix the following warning from PyTorch: + # c10/util/TypeCast.h(39): warning : calling a constexpr __host__ function from a + # __host__ __device__ function is not allowed. The experimental flag '--expt-relaxed-constexpr' + # can be used to allow this + string(APPEND CMAKE_CUDA_FLAGS " --expt-relaxed-constexpr ") + endif() + + string(APPEND CMAKE_CUDA_FLAGS " -Wno-deprecated-gpu-targets ") endif() message(STATUS "K2_COMPUTE_ARCH_CANDIDATES ${K2_COMPUTE_ARCH_CANDIDATES}") @@ -260,6 +270,10 @@ if(K2_WITH_CUDA) add_definitions(-DK2_WITH_CUDA) endif() +if(WIN32) + add_definitions(-DNOMINMAX) # Otherwise, std::max() and std::min() won't work +endif() + if(K2_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0) # CUB is included in CUDA toolkit 11.0 and above include(cub) @@ -271,39 +285,75 @@ endif() include(googletest) -if(K2_WITH_CUDA) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --compiler-options -Wall --compiler-options -Wno-unknown-pragmas --compiler-options -Wno-strict-overflow") +if(K2_WITH_CUDA AND NOT WIN32) + string(APPEND CMAKE_CUDA_FLAGS " --compiler-options -Wall ") + string(APPEND CMAKE_CUDA_FLAGS " --compiler-options -Wno-strict-overflow ") + string(APPEND CMAKE_CUDA_FLAGS " --compiler-options -Wno-unknown-pragmas ") message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") endif() -if(NOT K2_WITH_CUDA AND NOT WIN32) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable") -endif() if(NOT WIN32) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-strict-overflow") + string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-variable ") + string(APPEND CMAKE_CXX_FLAGS " -Wno-strict-overflow ") endif() if(WIN32) # disable various warnings for MSVC # NOTE: Most of the warnings are from PyTorch C++ APIs + # 4005: macro redefinition + # 4018: signed/unsigned mismatch + # 4067: unexpected tokens following preprocessor directive # 4068: unknown pragma "unroll" - # 4996: "getenv": This function is unsafe - # 4224: conversion from 'int64_t' to 'int32_t', possible loss of data # 4099: type name first seen using 'class' now seen using 'struct' + # 4101: 'identifier' : unreferenced local variable + # 4190: 'identifier1' has C-linkage specified, but returns UDT 'identifier2' which is incompatible with C + # 4224: conversion from 'int64_t' to 'int32_t', possible loss of data + # 4244: conversion from 'const M' to 'const FloatType' + # 4251: 'type' : class 'type1' needs to have dll-interface to be used by clients of class 'type2' # 4267: conversion from 'size_t' to 'I', possible loss of data + # 4275: non - DLL-interface class 'class_1' used as base for DLL-interface class 'class_2' # 4305: truncation from 'int' to 'bool' - # 4244: conversion from 'const M' to 'const FloatType' - # 4624: destructor was implicitly defined as deleted + # 4522: 'class' : multiple assignment operators specified # 4551: function call missing argument list - # 4067: unexpected tokens following preprocessor directive - # 4819: The file contains a character that cannot be presented in the current code page. - # 4005: macro redefinition + # 4624: destructor was implicitly defined as deleted + # 4700: uninitialized local variable 'device' used # 4722: destructor never returns - # 4018: signed/unsigned mismatch - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4996 /wd4224 /wd4099 /wd4267 /wd4305 /wd4244 /wd4624 /wd4551 /wd4067 /wd4819 /wd4005 /wd4722 /wd4018") + # 4819: The file contains a character that cannot be presented in the current code page. + # 4838: conversion from 'type_1' to 'type_2' requires a narrowing conversion + # 4996: "getenv": This function is unsafe + set(disabled_warnings + /wd4005 + /wd4018 + /wd4067 + /wd4068 + /wd4099 + /wd4101 + /wd4190 + /wd4224 + /wd4251 + /wd4244 + /wd4267 + /wd4275 + /wd4305 + /wd4522 + /wd4551 + /wd4624 + /wd4700 + /wd4722 + /wd4819 + /wd4838 + /wd4996 + ) + message(STATUS "Disabled warnings: ${disabled_warnings}") + foreach(w IN LISTS disabled_warnings) + string(APPEND CMAKE_CXX_FLAGS " ${w} ") + string(APPEND CMAKE_CUDA_FLAGS " --compiler-options ${w} ") + endforeach() + string(APPEND CMAKE_CXX_FLAGS " /bigobj ") endif() message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") add_subdirectory(k2) diff --git a/cmake/moderngpu.cmake b/cmake/moderngpu.cmake index efae0a211..a7ef9b291 100644 --- a/cmake/moderngpu.cmake +++ b/cmake/moderngpu.cmake @@ -20,9 +20,9 @@ function(download_moderngpu) include(FetchContent) - # this is the latest commit of modern gpu as of 2020-09-26 - set(moderngpu_URL "https://github.com/moderngpu/moderngpu/archive/2b3985541c8e88a133769598c406c33ddde9d0a5.zip") - set(moderngpu_HASH "SHA256=191546af18cd5fb858ecb561316f3af67537ab16f610fc8f1a5febbffc27755a") + # this is the latest commit of modern gpu as of 2022-04-03 + set(moderngpu_URL "https://github.com/moderngpu/moderngpu/archive/8ec9ac0de8672de7217d014917eedec5317f75f3.zip") + set(moderngpu_HASH "SHA256=1c20ffbb81d6f7bbe6107aaa5ee6d37392677c8a5fc7894935149c3ef0a3c2fb") FetchContent_Declare(moderngpu URL ${moderngpu_URL} diff --git a/docs/source/installation/conda.rst b/docs/source/installation/conda.rst index d4ef221e2..bf351263a 100644 --- a/docs/source/installation/conda.rst +++ b/docs/source/installation/conda.rst @@ -63,7 +63,7 @@ Supported versions .. |conda_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg +.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.6.0-green.svg :alt: Supported pytorch versions - |conda_python_versions| diff --git a/docs/source/installation/for_developers.rst b/docs/source/installation/for_developers.rst index 5529191f8..ece6d9167 100644 --- a/docs/source/installation/for_developers.rst +++ b/docs/source/installation/for_developers.rst @@ -1,6 +1,10 @@ For developers ============== +.. hint:: + + It supports Linux (CPU + CUDA), macOS (CPU), and Windows (CPU + CUDA). + This page is for developers and advanced users. It describes how to build k2 and run tests. diff --git a/docs/source/installation/from_source.rst b/docs/source/installation/from_source.rst index 5aac6d406..97a74b793 100644 --- a/docs/source/installation/from_source.rst +++ b/docs/source/installation/from_source.rst @@ -3,6 +3,10 @@ Install from source =================== +.. hint:: + + It supports Linux (CPU + CUDA), macOS (CPU), and Windows (CPU + CUDA). + The following versions of Python, CUDA, and PyTorch are known to work. - |source_python_versions| @@ -15,7 +19,7 @@ The following versions of Python, CUDA, and PyTorch are known to work. .. |source_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg +.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.6.0-green.svg :alt: Supported pytorch versions Before compiling k2, some preparation work has to be done: diff --git a/docs/source/installation/images/README.md b/docs/source/installation/images/README.md index a63890a06..aab295a20 100644 --- a/docs/source/installation/images/README.md +++ b/docs/source/installation/images/README.md @@ -5,7 +5,7 @@ - python_ge_3.6-blue.svg - cuda_ge_10.1-orange.svg -- pytorch_ge_1.5.0-green.svg +- pytorch_ge_1.6.0-green.svg - pypi_python-3.6_3.7_3.8-blue.svg - pypi_cuda-10.1-orange.svg diff --git a/docs/source/installation/images/torch_ge_1.6.0-green.svg b/docs/source/installation/images/torch_ge_1.6.0-green.svg new file mode 100644 index 000000000..d3ece9a17 --- /dev/null +++ b/docs/source/installation/images/torch_ge_1.6.0-green.svg @@ -0,0 +1 @@ +torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 5b6537ba2..4049aecb0 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -50,7 +50,7 @@ below: .. |conda_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg +.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.6.0-green.svg :alt: Supported pytorch versions .. |pip_python_versions| image:: ./images/python_ge_3.6-blue.svg @@ -59,7 +59,7 @@ below: .. |pip_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |pip_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg +.. |pip_pytorch_versions| image:: ./images/pytorch_ge_1.6.0-green.svg :alt: Supported pytorch versions .. |pypi_python_versions| image:: ./images/pypi_python-3.6_3.7_3.8-blue.svg @@ -77,7 +77,7 @@ below: .. |source_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg +.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.6.0-green.svg :alt: Supported pytorch versions Reporting issues diff --git a/docs/source/installation/pip.rst b/docs/source/installation/pip.rst index b756263b7..f145e16b6 100644 --- a/docs/source/installation/pip.rst +++ b/docs/source/installation/pip.rst @@ -7,7 +7,7 @@ Install using pip (k2-fsa.org) .. |pip_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg :alt: Supported cuda versions -.. |pip_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg +.. |pip_pytorch_versions| image:: ./images/pytorch_ge_1.6.0-green.svg :alt: Supported pytorch versions You can find a list of nightly pre-built diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index f48b895d5..8248be6a9 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -38,6 +38,12 @@ add_library(k2_nvtx INTERFACE) target_include_directories(k2_nvtx INTERFACE ${CMAKE_SOURCE_DIR}) if(K2_ENABLE_NVTX) target_compile_definitions(k2_nvtx INTERFACE K2_ENABLE_NVTX=1) + if(WIN32) + target_include_directories(k2_nvtx INTERFACE + ${CUDA_TOOLKIT_ROOT_DIR}/include/nvtx3 + "C:/Program Files/NVIDIA Corporation/NvToolsExt/include" + ) + endif() endif() add_subdirectory(host) @@ -115,7 +121,23 @@ target_link_libraries(context PUBLIC fsa) target_link_libraries(context PUBLIC k2_log) target_link_libraries(context PUBLIC k2_nvtx) if(K2_USE_PYTORCH) - target_link_libraries(context PUBLIC ${TORCH_LIBRARIES}) + if(NOT WIN32) + target_link_libraries(context PUBLIC ${TORCH_LIBRARIES}) + else() + # see https://discuss.pytorch.org/t/nvcc-fatal-a-single-input-file-is-required-for-a-non-link-phase-when-an-outputfile-is-specified/142843/6 + # Depending on ${TORCH_LIBRARIES} will introduce a compile time option "/bigobj", + # which causes the error in the above link. + # + # It would be ideal to remove /bigobj so that we can use ${TORCH_LIBRARIES}. + # To make life simpler, we use the following approach. + # + message(STATUS "TORCH_DIR: ${TORCH_DIR}") # TORCH_DIR is defined in cmake/torch.cmake + # target_link_libraries(context PUBLIC D:/software/anaconda3/envs/py38/Lib/site-packages/torch/lib/*.lib) + target_link_libraries(context PUBLIC ${TORCH_DIR}/lib/*.lib) + target_include_directories(context PUBLIC ${TORCH_DIR}/include) + target_include_directories(context PUBLIC ${TORCH_DIR}/include/torch/csrc/api/include) + endif() + if(UNIX AND NOT APPLE) # It causes errors on macOS target_link_libraries(context PUBLIC ${TORCH_DIR}/lib/libtorch_python.so) diff --git a/k2/csrc/benchmark/CMakeLists.txt b/k2/csrc/benchmark/CMakeLists.txt index cabb612cb..57570c926 100644 --- a/k2/csrc/benchmark/CMakeLists.txt +++ b/k2/csrc/benchmark/CMakeLists.txt @@ -1,6 +1,7 @@ function(k2_add_benchmark source) get_filename_component(name ${source} NAME_WE) add_executable(${name} ${source}) + set_target_properties(${name} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) target_link_libraries(${name} PRIVATE benchmark) endfunction() diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index c72a31d13..66ae4b626 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -34,7 +34,7 @@ struct Arc { int32_t label; float score; - __host__ __device__ __forceinline__ Arc() = default; + Arc() = default; __host__ __device__ __forceinline__ Arc(int32_t src_state, int32_t dest_state, int32_t label, float score) : src_state(src_state), diff --git a/k2/csrc/host/CMakeLists.txt b/k2/csrc/host/CMakeLists.txt index 208f1651c..4d183de8e 100644 --- a/k2/csrc/host/CMakeLists.txt +++ b/k2/csrc/host/CMakeLists.txt @@ -26,11 +26,20 @@ target_link_libraries(fsa PUBLIC k2_log) target_link_libraries(fsa PUBLIC k2_nvtx) target_include_directories(fsa PUBLIC ${CUDA_TOOLKIT_INCLUDE}) if(K2_ENABLE_NVTX) - target_link_libraries(fsa - PUBLIC - -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 # for /usr/local/cuda - -L${CUDA_TOOLKIT_ROOT_DIR}/lib # for conda - nvToolsExt) + if(NOT WIN32) + target_link_libraries(fsa + PUBLIC + -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 # for /usr/local/cuda + -L${CUDA_TOOLKIT_ROOT_DIR}/lib # for conda + nvToolsExt) + else() + target_link_directories(fsa PUBLIC + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 # for /usr/local/cuda + ${CUDA_TOOLKIT_ROOT_DIR}/lib # for conda + "C:/Program Files/NVIDIA Corporation/NvToolsExt/lib/x64/" + ) + target_link_libraries(fsa PUBLIC NvToolsExt64_1) + endif() endif() #---------------------------- Test K2 host sources ---------------------------- diff --git a/k2/csrc/log.h b/k2/csrc/log.h index ef4dfd095..5018c5275 100644 --- a/k2/csrc/log.h +++ b/k2/csrc/log.h @@ -166,7 +166,13 @@ class Logger { // this is usually caused by one of the K2_CHECK macros and the detailed // error messages should have already been printed by the macro, so we // use an arbitrary string here. +#ifndef _MSC_VER __assert_fail(kErrMsg, filename_, line_num_, func_name_); +#else + (void)kErrMsg; + assert(0); +#endif // _MSC_VER + #else std::string stack_trace = GetStackTrace(); if (!stack_trace.empty()) { diff --git a/k2/csrc/log_test.cu b/k2/csrc/log_test.cu index bd168bc97..8b86dafc3 100644 --- a/k2/csrc/log_test.cu +++ b/k2/csrc/log_test.cu @@ -28,7 +28,11 @@ TEST(Log, Cpu) { K2_LOG(DEBUG) << "Debug message"; K2_LOG(INFO) << "Info message"; K2_LOG(WARNING) << "Warning message"; +#ifndef _MSC_VER + // It fails on Windows with the following error: + // k2/csrc/log_test.cu(31): error : expected a ")" K2_LOG(ERROR) << "Error message"; +#endif K2_DLOG(INFO) << "This is printed only in debug mode"; diff --git a/k2/csrc/macros_test.cu b/k2/csrc/macros_test.cu index 6963cc45e..ef68e574e 100644 --- a/k2/csrc/macros_test.cu +++ b/k2/csrc/macros_test.cu @@ -27,7 +27,7 @@ namespace k2 { -static void TestEval() { +/*static*/ void TestEval() { for (auto &c : {GetCpuContext(), GetCudaContext()}) { Array1 array = Range(c, 3, 0); int32_t *array_data = array.Data(); @@ -46,7 +46,7 @@ static void TestEval() { } } -static void TestEval2() { +/*static*/ void TestEval2() { for (auto &c : {GetCpuContext(), GetCudaContext()}) { Array1 array1 = Range(c, 6, 0); Array2 array(array1, 2, 3); diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 78bfed8ab..1a919a02a 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -421,8 +421,12 @@ inline void GetOldAndNewOffsets(RaggedShape &src, ExclusiveSum(*new_offsets, new_offsets); } -static RaggedShape IndexAxis0(RaggedShape &src, const Array1 &new2old, - Array1 *elem_indexes /*=nullptr*/) { +// Don't make it static to fix the following error on Windows. +// Error : On Windows, the enclosing parent function ("IndexAxis0") for an +// extended __host__ __device__ lambda cannot have internal or no linkage +/*static*/ RaggedShape IndexAxis0(RaggedShape &src, + const Array1 &new2old, + Array1 *elem_indexes /*=nullptr*/) { NVTX_RANGE(K2_FUNC); ContextPtr &c = src.Context(); K2_CHECK(IsCompatible(src, new2old)); @@ -679,8 +683,8 @@ void GetRowInfoMulti(int32_t num_srcs, RaggedShape **src, *row_ids = row_ids_ptrs.To(ctx); } -static RaggedShape StackAxis0(int32_t num_srcs, RaggedShape **src, - Array1 *merge_map /* == nullptr*/) { +/*static*/ RaggedShape StackAxis0(int32_t num_srcs, RaggedShape **src, + Array1 *merge_map /* == nullptr*/) { NVTX_RANGE(K2_FUNC); if (num_srcs == 1) { if (merge_map) @@ -1128,7 +1132,7 @@ RaggedShape Stack(int32_t axis, int32_t num_srcs, RaggedShape **src, RaggedShape, 1,2,4 to construct the second output RaggedShape, 6 and a empty list to construct the third output RaggedShape. */ -static void SelectAxis0(RaggedShape &src, const Ragged &indexes, +/*static*/ void SelectAxis0(RaggedShape &src, const Ragged &indexes, std::vector *out, std::vector> *split_map) { NVTX_RANGE(K2_FUNC); ContextPtr &c = src.Context(); @@ -1475,8 +1479,8 @@ Ragged GetCountsPartitioned(Ragged &src, return Ragged(ans_ragged_shape, counts); } -static Array1 GetTransposeReorderingCpu(Ragged &src, - int32_t num_cols) { +/*static*/ Array1 GetTransposeReorderingCpu(Ragged &src, + int32_t num_cols) { NVTX_RANGE(K2_FUNC); std::vector> column_indexes(num_cols); // [column][row] const int32_t *values_data = src.values.Data(); @@ -1496,8 +1500,9 @@ static Array1 GetTransposeReorderingCpu(Ragged &src, return ans; } -static Array1 GetTransposeReorderingThreeAxesCuda(Ragged &src, - int32_t num_cols) { +#ifndef _MSC_VER +/*static*/ Array1 GetTransposeReorderingThreeAxesCuda( + Ragged &src, int32_t num_cols) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(src.NumAxes(), 3); ContextPtr &context = src.Context(); @@ -1541,6 +1546,7 @@ static Array1 GetTransposeReorderingThreeAxesCuda(Ragged &src, lambda_comp, *mgpu_context)); return ans; } +#endif /* @@ -1565,6 +1571,37 @@ Array1 GetTransposeReordering(Ragged &src, int32_t num_cols) { if (device_type == kCpu) return GetTransposeReorderingCpu(src, num_cols); K2_CHECK_EQ(device_type, kCuda); + +#ifdef _MSC_VER + // See https://github.com/k2-fsa/k2/pull/753 + // and + // https://github.com/k2-fsa/k2/pull/571 + int32_t num_buckets = num_cols; + int32_t num_elements = src.values.Dim(); + int32_t log_buckets = static_cast(ceilf(log2f(num_buckets))); + + Array1 ans = Range(context, num_elements, 0); + + cudaStream_t stream = context->GetCudaStream(); + + size_t temp_storage_bytes = 0; + K2_CUDA_SAFE_CALL(cub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, src.values.Data(), + static_cast(nullptr), ans.Data(), ans.Data(), num_elements, 0, + log_buckets, stream)); + + Array1 d_temp_storage( + context, temp_storage_bytes + num_elements * sizeof(int32_t)); + + K2_CUDA_SAFE_CALL(cub::DeviceRadixSort::SortPairs( + d_temp_storage.Data() + sizeof(int32_t) * num_elements, + temp_storage_bytes, src.values.Data(), + reinterpret_cast(d_temp_storage.Data()), ans.Data(), + ans.Data(), num_elements, 0, log_buckets, stream)); + + return ans; + +#else (void)GetTransposeReorderingThreeAxesCuda; // remove compiler warnings #if __CUDACC_VER_MAJOR__ > 10 || \ @@ -1599,7 +1636,7 @@ Array1 GetTransposeReordering(Ragged &src, int32_t num_cols) { // CheckGetTransposeReordering(src, ans); return ans; -#else +#else // __CUDACC_VER_MAJOR__ if (src.NumAxes() == 3) { Array1 ans = GetTransposeReorderingThreeAxesCuda(src, num_cols); // CheckGetTransposeReordering(src, ans); @@ -1638,6 +1675,7 @@ Array1 GetTransposeReordering(Ragged &src, int32_t num_cols) { // CheckGetTransposeReordering(src, ans); return ans; #endif +#endif // _MSC_VER } RaggedShape ChangeSublistSize(const RaggedShape &src, int32_t size_delta) { diff --git a/k2/csrc/rand_test.cu b/k2/csrc/rand_test.cu index 1370c8cc6..49d456dff 100644 --- a/k2/csrc/rand_test.cu +++ b/k2/csrc/rand_test.cu @@ -107,7 +107,7 @@ TEST(RandInt, CUDA) { } template -static void TestBounds(T low, T high) { +/*static*/ void TestBounds(T low, T high) { int32_t dim = 100000; ContextPtr cpu = GetCpuContext(); ContextPtr cuda = GetCudaContext(); diff --git a/k2/csrc/rm_epsilon.cu b/k2/csrc/rm_epsilon.cu index c20775c1a..e6cd5fe28 100644 --- a/k2/csrc/rm_epsilon.cu +++ b/k2/csrc/rm_epsilon.cu @@ -69,7 +69,7 @@ namespace k2 { @param [out] epsilon_closure_mapped_arc_map The arc map from `epsilon_closure_mapped` to `src`. */ -static void GetEpsilonClosureMapped( +/*static*/ void GetEpsilonClosureMapped( FsaVec &epsilon_fsa_closure, const Array1 &epsilon_closure_state_map, Ragged &epsilon_closure_arc_map, FsaVec &non_epsilon_fsa, @@ -139,7 +139,7 @@ static void GetEpsilonClosureMapped( foll_shape.RowSplits(1)[i] is the number of following arcs it is combined with. */ -static void DecideCombineWithFollowingOrPreceding( +/*static*/ void DecideCombineWithFollowingOrPreceding( FsaVec &epsilon_closure_mapped, FsaVec &non_epsilon_fsa, Renumbering *epsilon_prec_renumbering, RaggedShape *foll_shape) { NVTX_RANGE(K2_FUNC); @@ -237,7 +237,7 @@ static void DecideCombineWithFollowingOrPreceding( @param [out] combined_foll_arc_map The arc map of `combined_foll`, from arcs idx012 in `combined_foll` to the original Fsa. */ -static void CombineWithFollowingNonEpsilonArcs( +/*static*/ void CombineWithFollowingNonEpsilonArcs( FsaVec &epsilon_closure_mapped, Ragged &epsilon_closure_mapped_arc_map, FsaVec &non_epsilon_fsa, const Array1 &non_epsilon_arc_map, RaggedShape &foll_shape, @@ -341,7 +341,7 @@ static void CombineWithFollowingNonEpsilonArcs( `epsilon_closure_prec_arc_map`, user will get the complete arc map info for `combined_prec`. */ -static void CombineWithPrecedingNonEpsilonArcs( +/*static*/ void CombineWithPrecedingNonEpsilonArcs( FsaVec &epsilon_closure_prec, Ragged &epsilon_closure_prec_arc_map, FsaVec &non_epsilon_fsa, FsaVec *combined_prec, Ragged *epsilon_closure_prec_arc_map_prec, diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index db5e732dd..d5fe89432 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -159,8 +159,8 @@ void RnntDecodingStreams::GetContexts(RaggedShape *shape, int64_t state_value = states_values_data[state_idx01x], context_state = state_value / num_graph_states, exp = decoder_history_len - col, - state = context_state % (int64_t)pow(vocab_size, exp); - state = state / (int64_t)pow(vocab_size, exp - 1); + state = context_state % (int64_t)powf(vocab_size, exp); + state = state / (int64_t)powf(vocab_size, exp - 1); contexts_acc(row, col) = state; }); } @@ -540,7 +540,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { // can be done with `358 % 10^2`, then we append 6 to 58, that can be // done with `58 * 10 + 6`. context_state = this_context_state % - (int64_t)pow(vocab_size, decoder_history_len - 1); + (int64_t)powf(vocab_size, decoder_history_len - 1); context_state = context_state * vocab_size + arc.label; } diff --git a/k2/csrc/tensor_ops.cu b/k2/csrc/tensor_ops.cu index 94ab6c1c5..481107566 100644 --- a/k2/csrc/tensor_ops.cu +++ b/k2/csrc/tensor_ops.cu @@ -20,10 +20,11 @@ namespace k2 { template -static void CopyTensorElements2d(ContextPtr c, int32_t dim0, int32_t dim1, - const T *src_data, int32_t src_stride0, - int32_t src_stride1, T *dest_data, - int32_t dest_stride0, int32_t dest_stride1) { +/*static*/ void CopyTensorElements2d(ContextPtr c, int32_t dim0, int32_t dim1, + const T *src_data, int32_t src_stride0, + int32_t src_stride1, T *dest_data, + int32_t dest_stride0, + int32_t dest_stride1) { NVTX_RANGE(K2_FUNC); DeviceType d = c->GetDeviceType(); if (d == kCpu) { @@ -132,10 +133,11 @@ Tensor Cast(Tensor src, Dtype new_dtype) { // See the documentation of `Index`. template -static void Index1DImpl(ContextPtr context, const T *src_data, - int32_t src_stride, int32_t src_dim, - const int32_t *indexes_data, bool allow_minus_one, - int32_t ans_dim, T *ans_data, double default_value) { +/*static*/ void Index1DImpl(ContextPtr context, const T *src_data, + int32_t src_stride, int32_t src_dim, + const int32_t *indexes_data, bool allow_minus_one, + int32_t ans_dim, T *ans_data, + double default_value) { if (std::is_integral::value) { K2_CHECK_EQ(static_cast(default_value), default_value); } @@ -166,10 +168,11 @@ static void Index1DImpl(ContextPtr context, const T *src_data, // See the documentation of `Index`. template -static void Index2DImpl(ContextPtr context, const T *src_data, - int32_t src_stride, int32_t src_dim0, int32_t src_dim1, - const int32_t *indexes_data, bool allow_minus_one, - int32_t ans_dim, int32_t ans_stride, T *ans_data) { +/*static*/ void Index2DImpl(ContextPtr context, const T *src_data, + int32_t src_stride, int32_t src_dim0, + int32_t src_dim1, const int32_t *indexes_data, + bool allow_minus_one, int32_t ans_dim, + int32_t ans_stride, T *ans_data) { NVTX_RANGE(K2_FUNC); if (allow_minus_one) { if (context->GetDeviceType() == kCpu) { @@ -299,11 +302,11 @@ Tensor Index(Tensor &src, Array1 &indexes, bool allow_minus_one, } template -static void IndexAdd1DImpl(ContextPtr context, const T *src_data, - int32_t src_dim, int32_t src_stride, - const int32_t *indexes_data, bool allow_minus_one, - int32_t dest_dim, int32_t dest_stride, - T *dest_data) { +/*static*/ void IndexAdd1DImpl(ContextPtr context, const T *src_data, + int32_t src_dim, int32_t src_stride, + const int32_t *indexes_data, + bool allow_minus_one, int32_t dest_dim, + int32_t dest_stride, T *dest_data) { NVTX_RANGE(K2_FUNC); if (allow_minus_one) { K2_EVAL( @@ -330,12 +333,13 @@ static void IndexAdd1DImpl(ContextPtr context, const T *src_data, } template -static void IndexAdd2DImpl(ContextPtr context, const T *src_data, - int32_t src_dim0, int32_t src_dim1, - int32_t src_stride0, int32_t src_stride1, - const int32_t *indexes_data, bool allow_minus_one, - int32_t dest_dim, int32_t dest_stride0, - int32_t dest_stride1, T *dest_data) { +/*static*/ void IndexAdd2DImpl(ContextPtr context, const T *src_data, + int32_t src_dim0, int32_t src_dim1, + int32_t src_stride0, int32_t src_stride1, + const int32_t *indexes_data, + bool allow_minus_one, int32_t dest_dim, + int32_t dest_stride0, int32_t dest_stride1, + T *dest_data) { NVTX_RANGE(K2_FUNC); if (allow_minus_one) { K2_EVAL2( @@ -437,10 +441,9 @@ void IndexAdd(Tensor &src, Array1 &indexes, bool allow_minus_one, } template -static void SimpleRaggedIndexSelect1DImpl(ContextPtr context, const T *src_data, - int32_t src_stride, int32_t src_dim, - Ragged &indexes, - int32_t ans_dim, T *ans_data) { +/*static*/ void SimpleRaggedIndexSelect1DImpl( + ContextPtr context, const T *src_data, int32_t src_stride, int32_t src_dim, + Ragged &indexes, int32_t ans_dim, T *ans_data) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(indexes.NumAxes(), 2); int32_t indexes_dim0 = indexes.Dim0(), diff --git a/k2/csrc/tensor_ops_test.cu b/k2/csrc/tensor_ops_test.cu index f57636ba8..0aa7498c4 100644 --- a/k2/csrc/tensor_ops_test.cu +++ b/k2/csrc/tensor_ops_test.cu @@ -36,7 +36,7 @@ namespace k2 { @return Returns a 1-D tensor with the given `dim` and `stride`. */ template -static Tensor GenerateRandTensor1D(ContextPtr context, int32_t dim, +/*static*/ Tensor GenerateRandTensor1D(ContextPtr context, int32_t dim, int32_t stride) { K2_CHECK_GT(stride, 0); @@ -69,7 +69,7 @@ static Tensor GenerateRandTensor1D(ContextPtr context, int32_t dim, `stride`. */ template -static Tensor GenerateRandTensor2D(ContextPtr context, int32_t num_rows, +/*static*/ Tensor GenerateRandTensor2D(ContextPtr context, int32_t num_rows, int32_t num_cols, int32_t stride) { int32_t num_tensor_elements = num_rows * num_cols; K2_CHECK_GT(num_cols, 0); @@ -301,7 +301,7 @@ TEST(IndexAdd, IndexAdd2D) { } template -static void TestSimpleRaggedIndexSelect1D() { +/*static*/ void TestSimpleRaggedIndexSelect1D() { // test with simple case should be good enough for (auto &context : {GetCpuContext(), GetCudaContext()}) { // create src diff --git a/k2/csrc/test_utils.h b/k2/csrc/test_utils.h index 27774d5e8..05a816033 100644 --- a/k2/csrc/test_utils.h +++ b/k2/csrc/test_utils.h @@ -20,15 +20,16 @@ #ifndef K2_CSRC_TEST_UTILS_H_ #define K2_CSRC_TEST_UTILS_H_ -#include #include #include #include #include +#include "gtest/gtest.h" #include "k2/csrc/array.h" #include "k2/csrc/fsa.h" +#include "k2/csrc/log.h" namespace k2 { @@ -103,9 +104,9 @@ inline void ExpectEqual(const std::vector &expected, // check if `array` and `target` have the same values template void CheckArrayData(const Array1 &array, const Array1 &target, - T abs_error = (T)0.001) { + T abs_error = T(0.001)) { if (array.Dim() != target.Dim()) { - K2_LOG(ERROR) << "Dims mismatch " << array.Dim() << " vs. " << target.Dim(); + K2_LOG(FATAL) << "Dims mismatch " << array.Dim() << " vs. " << target.Dim(); } int32_t dim = array.Dim(); ContextPtr cpu = GetCpuContext(); diff --git a/k2/csrc/version.h.in b/k2/csrc/version.h.in index a894f979d..cfffaccf8 100644 --- a/k2/csrc/version.h.in +++ b/k2/csrc/version.h.in @@ -46,7 +46,7 @@ static constexpr const char *kPythonVersion = "@PYTHON_VERSION_MAJOR@.@PYTHON_VE static constexpr const char *kBuildType = "@CMAKE_BUILD_TYPE@"; // The operating system that is used to build k2, e.g., Ubuntu 16.04 LTS -static constexpr const char *kOS = "@K2_OS@"; +static constexpr const char *kOS = R"os(@K2_OS@)os"; // e.g., 3.18.0 static constexpr const char *kCMakeVersion = "@CMAKE_VERSION@"; @@ -55,10 +55,10 @@ static constexpr const char *kCMakeVersion = "@CMAKE_VERSION@"; static constexpr const char *kGCCVersion = "@CMAKE_CXX_COMPILER_VERSION@"; // CUDA flags used to compile k2 -static constexpr const char *kCMakeCudaFlags = "@CMAKE_CUDA_FLAGS@"; +static constexpr const char *kCMakeCudaFlags = R"cuda_flags(@CMAKE_CUDA_FLAGS@)cuda_flags"; // CXX flags used to compile k2 -static constexpr const char *kCMakeCxxFlags = "@CMAKE_CXX_FLAGS@"; +static constexpr const char *kCMakeCxxFlags = R"cxx_flags(@CMAKE_CXX_FLAGS@)cxx_flags"; // Which PyTorch version k2 is using, e.g., 1.6.0+cu101 static constexpr const char *kTorchVersion = "@TORCH_VERSION@"; diff --git a/k2/python/csrc/CMakeLists.txt b/k2/python/csrc/CMakeLists.txt index 68563a754..c9084af4d 100644 --- a/k2/python/csrc/CMakeLists.txt +++ b/k2/python/csrc/CMakeLists.txt @@ -19,8 +19,19 @@ if(NOT K2_WITH_CUDA) transform(OUTPUT_VARIABLE k2_srcs SRCS ${k2_srcs}) endif() -pybind11_add_module(_k2 ${k2_srcs} SHARED) +if(WIN32) + # It throws the following error on Windows + # nvcc fatal : A single input file is required for a non-link phase when an outputfile is specified + # because there is an option "/bigobj" in pybind11::windows_extra that cannot be recognized by NVCC. + # + # We clear it below. + set_property(TARGET pybind11::windows_extras PROPERTY INTERFACE_COMPILE_OPTIONS "") +endif() + + +pybind11_add_module(_k2 ${k2_srcs}) target_link_libraries(_k2 PRIVATE context) target_link_libraries(_k2 PRIVATE fsa) target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(_k2 PRIVATE ${CMAKE_BINARY_DIR}) +set_target_properties(_k2 PROPERTIES CUDA_SEPARABLE_COMPILATION ON) diff --git a/k2/python/csrc/torch.h b/k2/python/csrc/torch.h index a0f742196..e11827598 100644 --- a/k2/python/csrc/torch.h +++ b/k2/python/csrc/torch.h @@ -29,38 +29,6 @@ namespace pybind11 { namespace detail { -#if K2_TORCH_VERSION_MAJOR < 1 || \ - (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR < 9) -// Only for torch version < 1.9.0 - -// See https://github.com/pytorch/pytorch/pull/57292 - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(torch::Device, _("torch::Device")); - - // PYBIND11_TYPE_CASTER defines a member field called value. Since - // torch::Device cannot be default-initialized, we provide this constructor to - // explicitly initialize that field. The value doesn't matter as it will be - // overwritten after a successful call to load. - type_caster() : value(torch::kCPU) {} - - bool load(handle src, bool) { - PyObject *obj = src.ptr(); - if (THPDevice_Check(obj)) { - value = reinterpret_cast(obj)->device; - return true; - } - return false; - } - - static handle cast(const torch::Device &src, return_value_policy /* policy */, - handle /* parent */) { - return handle(THPDevice_New(src)); - } -}; -#endif template <> struct type_caster { diff --git a/k2/python/csrc/torch/fsa.cu b/k2/python/csrc/torch/fsa.cu index 372b853f0..7571667c5 100644 --- a/k2/python/csrc/torch/fsa.cu +++ b/k2/python/csrc/torch/fsa.cu @@ -491,7 +491,7 @@ static void PybindBackpropGetArcPost(py::module &m, const char *name) { @return It returns the gradient of scores of all arcs. */ template -static torch::Tensor GetTotScoresTropicalBackward( +/*static*/ torch::Tensor GetTotScoresTropicalBackward( FsaVec &fsas, const RaggedAny &best_path_arc_indexes, torch::Tensor tot_scores_grad) { DeviceGuard guard(fsas.Context()); @@ -542,7 +542,7 @@ static torch::Tensor GetTotScoresTropicalBackward( @return It returns the gradient of scores of all arcs. */ template -static torch::Tensor GetTotScoresLogBackward(FsaVec &fsas, +/*static*/ torch::Tensor GetTotScoresLogBackward(FsaVec &fsas, torch::Tensor arc_post, torch::Tensor tot_scores_grad) { DeviceGuard guard(fsas.Context()); diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index f4016695d..aa945c70d 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -59,7 +59,7 @@ static void PybindTopSort(py::module &m) { static void PybindLinearFsa(py::module &m) { m.def( "linear_fsa", - [](RaggedAny &labels, torch::optional = {}) -> FsaVec { + [](RaggedAny &labels, py::object = py::none()) -> FsaVec { DeviceGuard guard(labels.any.Context()); return LinearFsas(labels.any.Specialize()); }, @@ -68,48 +68,26 @@ static void PybindLinearFsa(py::module &m) { m.def( "linear_fsa", [](const std::vector &labels, - torch::optional device = {}) -> Fsa { - ContextPtr context = - GetContext(device.value_or(torch::Device(torch::kCPU))); + py::object device = py::str("cpu")) -> Fsa { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + ContextPtr context = GetContext(torch::Device(device_str)); DeviceGuard guard(context); Array1 array(context, labels); return LinearFsa(array); // }, - py::arg("labels"), py::arg("device") = py::none()); - - m.def( - "linear_fsa", - [](const std::vector &labels, - torch::optional device = {}) -> Fsa { - ContextPtr context = GetContext(torch::Device(device.value_or("cpu"))); - DeviceGuard guard(context); - Array1 array(context, labels); - return LinearFsa(array); // - }, - py::arg("labels"), py::arg("device") = py::none()); + py::arg("labels"), py::arg("device") = py::str("cpu")); m.def( "linear_fsa", [](const std::vector> &labels, - torch::optional device = {}) -> FsaVec { - ContextPtr context = - GetContext(device.value_or(torch::Device(torch::kCPU))); + py::object device = py::str("cpu")) -> FsaVec { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + ContextPtr context = GetContext(torch::Device(device_str)); DeviceGuard guard(context); Ragged ragged = CreateRagged2(labels).To(context); return LinearFsas(ragged); }, - py::arg("labels"), py::arg("device") = py::none()); - - m.def( - "linear_fsa", - [](const std::vector> &labels, - torch::optional device = {}) -> FsaVec { - ContextPtr context = GetContext(torch::Device(device.value_or("cpu"))); - DeviceGuard guard(context); - Ragged ragged = CreateRagged2(labels).To(context); - return LinearFsas(ragged); - }, - py::arg("labels"), py::arg("device") = py::none()); + py::arg("labels"), py::arg("device") = py::str("cpu")); } static void PybindIntersect(py::module &m) { @@ -481,7 +459,7 @@ static void PybindRemoveEpsilonSelfLoops(py::module &m) { py::arg("src"), py::arg("need_arc_map") = true); } -static void PybindExpandArcs(py::module &m) { +/*static*/ void PybindExpandArcs(py::module &m) { // See doc-string below. m.def( "expand_arcs", @@ -718,59 +696,34 @@ static void PybindCtcGraph(py::module &m) { static void PybindCtcTopo(py::module &m) { m.def( "ctc_topo", - [](int32_t max_token, torch::optional device = {}, + [](int32_t max_token, py::object device = py::str("cpu"), bool modified = false) -> std::pair { - ContextPtr context = GetContext(device.value_or(torch::Device("cpu"))); + std::string device_str = device.is_none() ? "cpu" : py::str(device); + ContextPtr context = GetContext(torch::Device(device_str)); DeviceGuard guard(context); Array1 aux_labels; Fsa fsa = CtcTopo(context, max_token, modified, &aux_labels); torch::Tensor tensor = ToTorch(aux_labels); return std::make_pair(fsa, tensor); }, - py::arg("max_token"), py::arg("device") = py::none(), - py::arg("modified") = false); - - m.def( - "ctc_topo", - [](int32_t max_token, torch::optional device = {}, - bool modified = false) -> std::pair { - ContextPtr context = GetContext(torch::Device(device.value_or("cpu"))); - DeviceGuard guard(context); - Array1 aux_labels; - Fsa fsa = CtcTopo(context, max_token, modified, &aux_labels); - torch::Tensor tensor = ToTorch(aux_labels); - return std::make_pair(fsa, tensor); - }, - py::arg("max_token"), py::arg("device") = py::none(), + py::arg("max_token"), py::arg("device") = py::str("cpu"), py::arg("modified") = false); } static void PybindTrivialGraph(py::module &m) { m.def( "trivial_graph", - [](int32_t max_token, torch::optional device = {}) - -> std::pair { - ContextPtr context = GetContext(device.value_or(torch::Device("cpu"))); - DeviceGuard guard(context); - Array1 aux_labels; - Fsa fsa = TrivialGraph(context, max_token, &aux_labels); - torch::Tensor tensor = ToTorch(aux_labels); - return std::make_pair(fsa, tensor); - }, - py::arg("max_token"), py::arg("device") = py::none()); - - m.def( - "trivial_graph", - [](int32_t max_token, torch::optional device = {}) - -> std::pair { - ContextPtr context = GetContext(torch::Device(device.value_or("cpu"))); + [](int32_t max_token, + py::object device = py::str("cpu")) -> std::pair { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + ContextPtr context = GetContext(torch::Device(device_str)); DeviceGuard guard(context); Array1 aux_labels; Fsa fsa = TrivialGraph(context, max_token, &aux_labels); torch::Tensor tensor = ToTorch(aux_labels); return std::make_pair(fsa, tensor); }, - py::arg("max_token"), py::arg("device") = py::none()); + py::arg("max_token"), py::arg("device") = py::str("cpu")); } static void PybindLevenshteinGraph(py::module &m) { diff --git a/k2/python/csrc/torch/ragged_ops.cu b/k2/python/csrc/torch/ragged_ops.cu index ec6018e13..d2d431b17 100644 --- a/k2/python/csrc/torch/ragged_ops.cu +++ b/k2/python/csrc/torch/ragged_ops.cu @@ -152,8 +152,9 @@ static void PybindNormalizePerSublist(py::module &m, const char *name) { (out.NumElements(),). */ template -static torch::Tensor NormalizePerSublistBackward(Ragged &out, bool use_log, - torch::Tensor out_grad) { +/*static*/ torch::Tensor NormalizePerSublistBackward(Ragged &out, + bool use_log, + torch::Tensor out_grad) { NVTX_RANGE(K2_FUNC); DeviceGuard guard(out.Context()); K2_CHECK_EQ(out_grad.dim(), 1) @@ -397,7 +398,7 @@ void PybindRaggedOps(py::module &m) { PybindArgMaxPerSublist(m); PybindCat(m); PybindCat(m); - PybindCat(m); + PybindCat(m); PybindCreateRagged2(m); PybindCreateRagged2(m); PybindGetLayer(m); diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 0c9f07b4a..74f163d23 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -40,24 +40,32 @@ void PybindRaggedAny(py::module &m) { // k2.ragged.Tensor methods //-------------------------------------------------- - any.def(py::init(), py::arg("data"), - py::arg("dtype") = py::none(), - py::arg("device") = torch::Device(torch::kCPU), - kRaggedAnyInitDataDeviceDoc); + any.def(py::init([](py::list data, py::object dtype = py::none(), + py::object device = + py::str("cpu")) -> std::unique_ptr { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + return std::make_unique(data, dtype, + torch::Device(device_str)); + }), + py::arg("data"), py::arg("dtype") = py::none(), + py::arg("device") = py::str("cpu"), kRaggedAnyInitDataDeviceDoc); any.def(py::init(), py::arg("data"), py::arg("dtype") = py::none(), py::arg("device") = "cpu", kRaggedAnyInitDataDeviceDoc); - any.def(py::init(), + any.def(py::init([](const std::string &s, py::object dtype = py::none(), + py::object device = + py::str("cpu")) -> std::unique_ptr { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + return std::make_unique(s, dtype, device_str); + }), py::arg("s"), py::arg("dtype") = py::none(), - py::arg("device") = torch::Device(torch::kCPU), - kRaggedAnyInitStrDeviceDoc); + py::arg("device") = py::str("cpu"), kRaggedAnyInitStrDeviceDoc); any.def(py::init(), py::arg("s"), py::arg("dtype") = py::none(), - py::arg("device") = torch::Device(torch::kCPU), - kRaggedAnyInitStrDeviceDoc); + py::arg("device") = "cpu", kRaggedAnyInitStrDeviceDoc); any.def(py::init(), py::arg("shape"), py::arg("value"), kRaggedInitFromShapeAndTensorDoc); @@ -110,7 +118,7 @@ void PybindRaggedAny(py::module &m) { any.def( "__getitem__", [](RaggedAny &self, const py::slice &slice) -> RaggedAny { - py::ssize_t start = 0, stop = 0, step = 0, slicelength = 0; + py::size_t start = 0, stop = 0, step = 0, slicelength = 0; if (!slice.compute(self.any.Dim0(), &start, &stop, &step, &slicelength)) throw py::error_already_set(); int32_t istart = static_cast(start); @@ -168,10 +176,13 @@ void PybindRaggedAny(py::module &m) { }, py::arg("src"), py::arg("indexes"), kRaggedAnyIndexAndSumDoc); - any.def("to", - static_cast( - &RaggedAny::To), - py::arg("device"), kRaggedAnyToDeviceDoc); + any.def( + "to", + [](RaggedAny &self, py::object device) -> RaggedAny { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + return self.To(torch::Device(device_str)); + }, + py::arg("device"), kRaggedAnyToDeviceDoc); any.def("to", static_cast( @@ -243,7 +254,8 @@ void PybindRaggedAny(py::module &m) { [](const RaggedAny &self) -> py::tuple { DeviceGuard guard(self.any.Context()); K2_CHECK(self.any.NumAxes() == 2 || self.any.NumAxes() == 3) - << "Only support Ragged with NumAxes() == 2 or 3 for now, given " + << "Only support Ragged with NumAxes() == 2 or 3 for now, " + "given " << self.any.NumAxes(); Array1 row_splits1 = self.any.RowSplits(1); Dtype t = self.any.GetDtype(); @@ -380,10 +392,8 @@ void PybindRaggedAny(py::module &m) { torch::Device device(device_type, self.any.Context()->GetDeviceId()); - PyObject *ptr = THPDevice_New(device); - - // takes ownership - return py::reinterpret_steal(ptr); + auto torch_device = py::module::import("torch").attr("device"); + return torch_device(device.str()); }, kRaggedAnyDeviceDoc); @@ -443,12 +453,12 @@ void PybindRaggedAny(py::module &m) { m.def( "create_ragged_tensor", [](py::list data, py::object dtype = py::none(), - torch::Device device = torch::kCPU) -> RaggedAny { - return RaggedAny(data, dtype, device); + py::object device = py::str("cpu")) -> RaggedAny { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + return RaggedAny(data, dtype, torch::Device(device_str)); }, py::arg("data"), py::arg("dtype") = py::none(), - py::arg("device") = torch::Device(torch::kCPU), - kCreateRaggedTensorDataDoc); + py::arg("device") = py::str("cpu"), kCreateRaggedTensorDataDoc); m.def( "create_ragged_tensor", @@ -462,12 +472,12 @@ void PybindRaggedAny(py::module &m) { m.def( "create_ragged_tensor", [](const std::string &s, py::object dtype = py::none(), - torch::Device device = torch::kCPU) -> RaggedAny { - return RaggedAny(s, dtype, device); + py::object device = py::str("cpu")) -> RaggedAny { + std::string device_str = device.is_none() ? "cpu" : py::str(device); + return RaggedAny(s, dtype, torch::Device(device_str)); }, py::arg("s"), py::arg("dtype") = py::none(), - py::arg("device") = torch::Device(torch::kCPU), - kCreateRaggedTensorStrDoc); + py::arg("device") = py::str("cpu"), kCreateRaggedTensorStrDoc); m.def( "create_ragged_tensor", diff --git a/k2/python/csrc/torch/v2/ragged_shape.cu b/k2/python/csrc/torch/v2/ragged_shape.cu index cb3bc8c13..f989800de 100644 --- a/k2/python/csrc/torch/v2/ragged_shape.cu +++ b/k2/python/csrc/torch/v2/ragged_shape.cu @@ -66,7 +66,9 @@ void PybindRaggedShape(py::module &m) { shape.def( "to", - [](const RaggedShape &self, torch::Device device) -> RaggedShape { + [](const RaggedShape &self, py::object _device) -> RaggedShape { + std::string device_str = _device.is_none() ? "cpu" : py::str(_device); + torch::Device device = torch::Device(device_str); DeviceGuard guard(self.Context()); if (device.type() == torch::kCPU) return self.To(GetCpuContext()); @@ -166,10 +168,8 @@ void PybindRaggedShape(py::module &m) { torch::Device device(device_type, self.Context()->GetDeviceId()); - PyObject *ptr = THPDevice_New(device); - - // takes ownership - return py::reinterpret_steal(ptr); + auto torch_device = py::module::import("torch").attr("device"); + return torch_device(device.str()); }, kRaggedShapeDeviceDoc); diff --git a/k2/python/host/k2host/fsa.py b/k2/python/host/k2host/fsa.py index 59196a422..ecb8a5b0e 100644 --- a/k2/python/host/k2host/fsa.py +++ b/k2/python/host/k2host/fsa.py @@ -30,9 +30,9 @@ def __init__(self, src_state: int, dest_state: int, label: int, super().__init__(src_state, dest_state, label, weight) def to_tensor(self): - # TODO(fangjun): weight will be truncted to an int. + # TODO(fangjun): weight will be truncated to an int. return torch.tensor( - [self.src_state, self.dest_state, self.label, self.weight], + [self.src_state, self.dest_state, self.label, int(self.weight)], dtype=torch.int32) @staticmethod diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py index 85d56cd5d..7e43d9f82 100644 --- a/k2/python/k2/rnnt_decode.py +++ b/k2/python/k2/rnnt_decode.py @@ -179,7 +179,7 @@ def format_output(self, num_frames: List[int]) -> Fsa: src = self.src_streams[i].fsa for name, value in src.named_tensor_attr(include_scores=False): if name not in tensor_attr_info: - filler = 0.0 + filler = 0 if isinstance(value, Tensor): filler = float(src.get_filler(name)) dtype = value.dtype diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index fa030a0a1..67ad28a57 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -471,7 +471,7 @@ def _adjust_pruning_lower_bound( # make the transformed tensor to be non-decreasing s_begin = k2.monotonic_lower_bound(s_begin) # make start symbol to be zero. - s_begin = torch.where(s_begin < 0, 0, s_begin) + s_begin = torch.clamp(s_begin, min=0) # do the magic transformation again to recover s_begin s_begin = -( s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) @@ -568,7 +568,7 @@ def get_rnnt_prune_ranges( s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 # handle the cases when `len(symbols) < s_range` - s_begin_padding = torch.where(s_begin_padding >= 0, s_begin_padding, 0) + s_begin_padding = torch.clamp(s_begin_padding, min=0) s_begin = torch.where(mask, s_begin, s_begin_padding) diff --git a/k2/python/tests/linear_fsa_with_self_loops_test.py b/k2/python/tests/linear_fsa_with_self_loops_test.py index 1e331bbbc..ec3654cb1 100644 --- a/k2/python/tests/linear_fsa_with_self_loops_test.py +++ b/k2/python/tests/linear_fsa_with_self_loops_test.py @@ -55,7 +55,7 @@ def test_multiple_fsa(self): expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1] expected_labels1 = [0, 1, 0, 2, 0, -1] expected_labels2 = [0, 3, 0, 2, 0, -1] - expected_labels = expected_labels0 + expected_labels1 + expected_labels2 + expected_labels = expected_labels0 + expected_labels1 + expected_labels2 # noqa assert dst.labels.tolist() == expected_labels diff --git a/k2/python/tests/mutual_information_test.py b/k2/python/tests/mutual_information_test.py index 11917f18f..cddd817e9 100644 --- a/k2/python/tests/mutual_information_test.py +++ b/k2/python/tests/mutual_information_test.py @@ -286,12 +286,12 @@ def get_boundary_row(): observed_delta = (delta_m * m_grad).sum().to("cpu") predicted_delta = (delta_px * px.grad).sum().to("cpu") - atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 - rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 + atol = 1.0e-01 + rtol = atol assert torch.allclose( observed_delta, predicted_delta, atol=atol, rtol=rtol - ) + ), (observed_delta, predicted_delta) delta_py = delta * torch.randn_like(py) m2 = k2.mutual_information_recursion( diff --git a/scripts/github_actions/generate_build_matrix.py b/scripts/github_actions/generate_build_matrix.py new file mode 100755 index 000000000..5899c19b0 --- /dev/null +++ b/scripts/github_actions/generate_build_matrix.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import json + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--enable-cuda", + action="store_true", + default=False, + help="True to enable CUDA", + ) + + parser.add_argument( + "--test-only-latest-torch", + action="store_true", + default=False, + help="""If True, we test only the latest PyTroch + to reduce CI running time.""", + ) + return parser.parse_args() + + +def generate_build_matrix(enable_cuda, test_only_latest_torch): + matrix = { + # there are issues in serializing ragged tensors in 1.5.0 and 1.5.1 + # "1.5.0": { + # "python-version": ["3.6", "3.7", "3.8"], + # "cuda": ["10.1", "10.2"], + # }, + # "1.5.1": { + # "python-version": ["3.6", "3.7", "3.8"], + # "cuda": ["10.1", "10.2"], + # }, + "1.6.0": { + "python-version": ["3.6", "3.7", "3.8"], + "cuda": ["10.1", "10.2"], + }, + "1.7.0": { + "python-version": ["3.6", "3.7", "3.8"], + "cuda": ["10.1", "10.2", "11.0"], + }, + "1.7.1": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.1", "10.2", "11.0"], + }, + "1.8.0": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.1", "10.2", "11.1"], + }, + "1.8.1": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.1", "10.2", "11.1"], + }, + "1.9.0": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.2", "11.1"], + }, + "1.9.1": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.2", "11.1"], + }, + "1.10.0": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.2", "11.1", "11.3"], + }, + "1.10.1": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.2", "11.1", "11.3"], + }, + "1.10.2": { + "python-version": ["3.6", "3.7", "3.8", "3.9"], + "cuda": ["10.2", "11.1", "11.3"], + }, + "1.11.0": { + "python-version": ["3.7", "3.8", "3.9", "3.10"], + "cuda": ["10.2", "11.3", "11.5"], + }, + } + if test_only_latest_torch: + latest = "1.11.0" + matrix = {latest: matrix[latest]} + + ans = [] + for torch, python_cuda in matrix.items(): + python_versions = python_cuda["python-version"] + cuda_versions = python_cuda["cuda"] + if enable_cuda: + for p in python_versions: + for c in cuda_versions: + ans.append({"torch": torch, "python-version": p, "cuda": c}) + else: + for p in python_versions: + ans.append({"torch": torch, "python-version": p}) + + print(json.dumps({"include": ans})) + + +def main(): + args = get_args() + generate_build_matrix( + enable_cuda=args.enable_cuda, + test_only_latest_torch=args.test_only_latest_torch, + ) + + +if __name__ == "__main__": + main() From a4d76d24808f274924a87b79209d599102f8014c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 Apr 2022 14:10:49 +0800 Subject: [PATCH 60/64] Fix nightly windows CPU build (#948) * Fix nightly building k2 for windows. * Run nightly build only if there are new commits. --- .github/workflows/nightly-cpu-macos.yml | 22 +++++- .github/workflows/nightly-cpu-ubuntu.yml | 22 +++++- .github/workflows/nightly-cpu-windows.yml | 75 ++++++++++++++------- .github/workflows/nightly-cuda-ubuntu.yml | 22 +++++- .github/workflows/wheel-cpu-macos.yml | 1 + .github/workflows/wheel-cpu-windows.yml | 1 + .github/workflows/wheel-cuda-ubuntu.yml | 1 + scripts/github_actions/run-nightly-build.py | 35 ++++++++++ 8 files changed, 149 insertions(+), 30 deletions(-) create mode 100755 scripts/github_actions/run-nightly-build.py diff --git a/.github/workflows/nightly-cpu-macos.yml b/.github/workflows/nightly-cpu-macos.yml index 0b6d773cc..0ce5354cb 100644 --- a/.github/workflows/nightly-cpu-macos.yml +++ b/.github/workflows/nightly-cpu-macos.yml @@ -23,14 +23,31 @@ on: # day of the month (1-31) # month (1-12) # day of the week (0-6) - # nightly build at 14:00 UTC time every day - - cron: "0 14 * * *" + # nightly build at 23:50 UTC time every day + - cron: "50 23 * * *" env: BUILD_TYPE: Release jobs: + enable_nightly_build: + runs-on: ubuntu-latest + outputs: + enabled: ${{ steps.set-enabled.outputs.enabled }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Set enabled + id: set-enabled + run: | + enabled=$(python scripts/github_actions/run-nightly-build.py) + echo "enabled: $enabled" + echo "::set-output name=enabled::${enabled}" + generate_build_matrix: + needs: enable_nightly_build + if: needs.enable_nightly_build.outputs.enabled == 'true' # see https://github.com/pytorch/pytorch/pull/50633 runs-on: ubuntu-latest outputs: @@ -102,6 +119,7 @@ jobs: path: dist/*.whl - name: Copy wheels to k2-fsa.org + if: ${{ github.repository_owner == 'k2-fsa' }} run: | user=${{ secrets.K2_USERNAME }} server=${{ secrets.K2_HOST }} diff --git a/.github/workflows/nightly-cpu-ubuntu.yml b/.github/workflows/nightly-cpu-ubuntu.yml index b371af8d5..b47074272 100644 --- a/.github/workflows/nightly-cpu-ubuntu.yml +++ b/.github/workflows/nightly-cpu-ubuntu.yml @@ -23,14 +23,31 @@ on: # day of the month (1-31) # month (1-12) # day of the week (0-6) - # nightly build at 14:00 UTC time every day - - cron: "0 14 * * *" + # nightly build at 23:50 UTC time every day + - cron: "50 23 * * *" env: BUILD_TYPE: Release jobs: + enable_nightly_build: + runs-on: ubuntu-latest + outputs: + enabled: ${{ steps.set-enabled.outputs.enabled }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Set enabled + id: set-enabled + run: | + enabled=$(python scripts/github_actions/run-nightly-build.py) + echo "enabled: $enabled" + echo "::set-output name=enabled::${enabled}" + generate_build_matrix: + needs: enable_nightly_build + if: needs.enable_nightly_build.outputs.enabled == 'true' # see https://github.com/pytorch/pytorch/pull/50633 runs-on: ubuntu-latest outputs: @@ -107,6 +124,7 @@ jobs: path: dist/*.whl - name: Copy wheels to k2-fsa.org + if: ${{ github.repository_owner == 'k2-fsa' }} run: | user=${{ secrets.K2_USERNAME }} server=${{ secrets.K2_HOST }} diff --git a/.github/workflows/nightly-cpu-windows.yml b/.github/workflows/nightly-cpu-windows.yml index f23e6c801..7074494d2 100644 --- a/.github/workflows/nightly-cpu-windows.yml +++ b/.github/workflows/nightly-cpu-windows.yml @@ -23,14 +23,31 @@ on: # day of the month (1-31) # month (1-12) # day of the week (0-6) - # nightly build at 14:00 UTC time every day - - cron: "0 14 * * *" + # nightly build at 23:50 UTC time every day + - cron: "50 23 * * *" env: BUILD_TYPE: Release jobs: + enable_nightly_build: + runs-on: ubuntu-latest + outputs: + enabled: ${{ steps.set-enabled.outputs.enabled }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Set enabled + id: set-enabled + run: | + enabled=$(python scripts/github_actions/run-nightly-build.py) + echo "enabled: $enabled" + echo "::set-output name=enabled::${enabled}" + generate_build_matrix: + needs: enable_nightly_build + if: needs.enable_nightly_build.outputs.enabled == 'true' runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} @@ -83,6 +100,31 @@ jobs: cmake --version cmake --help + - name: Build wheel + shell: bash + run: | + export K2_CMAKE_ARGS="-DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Release" + python3 setup.py bdist_wheel + ls -lh dist/ + pip install ./dist/*.whl + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-cpu + path: dist/*.whl + + - name: Copy wheels to k2-fsa.org + if: ${{ github.repository_owner == 'k2-fsa' }} + shell: bash + run: | + user=${{ secrets.K2_USERNAME }} + server=${{ secrets.K2_HOST }} + port=${{ secrets.K2_PORT }} + echo "${{ secrets.K2_KEY }}" > id_rsa && chmod 600 id_rsa + scp -P $port -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i id_rsa dist/*.whl $user@$server:~/nightly/whl + rm id_rsa + - name: Configure CMake shell: bash run: | @@ -94,37 +136,22 @@ jobs: cat CMakeCache.txt - name: Build k2 + shell: bash run: | cd build_release cmake --build . --target _k2 --config Release -- -m - ls -lh lib/*/* - ls -lh bin/*/* + cmake --build . --target ALL_BUILD --config Release - name: Display generated files shell: bash run: | cd build_release ls -lh lib/*/* + ls -lh bin/*/* - - name: Build wheel - shell: bash - run: | - export K2_CMAKE_ARGS="-DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Release" - python3 setup.py bdist_wheel - ls -lh dist/ - - - name: Upload Wheel - uses: actions/upload-artifact@v2 - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-cpu - path: dist/*.whl - - - name: Copy wheels to k2-fsa.org + - name: Run tests shell: bash run: | - user=${{ secrets.K2_USERNAME }} - server=${{ secrets.K2_HOST }} - port=${{ secrets.K2_PORT }} - echo "${{ secrets.K2_KEY }}" > id_rsa && chmod 600 id_rsa - scp -P $port -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i id_rsa dist/*.whl $user@$server:~/nightly/whl - rm id_rsa + cd build_release + # disable python tests for k2host + ctest -C Release --output-on-failure -E host diff --git a/.github/workflows/nightly-cuda-ubuntu.yml b/.github/workflows/nightly-cuda-ubuntu.yml index 7af8d657d..e14bb9d6a 100644 --- a/.github/workflows/nightly-cuda-ubuntu.yml +++ b/.github/workflows/nightly-cuda-ubuntu.yml @@ -7,14 +7,31 @@ on: # day of the month (1-31) # month (1-12) # day of the week (0-6) - # nightly build at 14:00 UTC time every day - - cron: "0 14 * * *" + # nightly build at 23:50 UTC time every day + - cron: "50 23 * * *" env: BUILD_TYPE: Release jobs: + enable_nightly_build: + runs-on: ubuntu-latest + outputs: + enabled: ${{ steps.set-enabled.outputs.enabled }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Set enabled + id: set-enabled + run: | + enabled=$(python scripts/github_actions/run-nightly-build.py) + echo "enabled: $enabled" + echo "::set-output name=enabled::${enabled}" + nightly: + needs: enable_nightly_build + if: needs.enable_nightly_build.outputs.enabled == 'true' runs-on: ubuntu-18.04 strategy: fail-fast: false @@ -98,6 +115,7 @@ jobs: ls -lh dist/ - name: Copy wheels to k2-fsa.org + if: ${{ github.repository_owner == 'k2-fsa' }} uses: horochx/deploy-via-scp@v1.0.1 with: host: ${{ secrets.K2_HOST }} diff --git a/.github/workflows/wheel-cpu-macos.yml b/.github/workflows/wheel-cpu-macos.yml index 74bdfb496..eb44821e4 100644 --- a/.github/workflows/wheel-cpu-macos.yml +++ b/.github/workflows/wheel-cpu-macos.yml @@ -60,6 +60,7 @@ jobs: path: dist/*.whl - name: Publish wheels to PyPI + if: ${{ github.repository_owner == 'k2-fsa' }} env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} diff --git a/.github/workflows/wheel-cpu-windows.yml b/.github/workflows/wheel-cpu-windows.yml index 40ce800d6..9d84b51cc 100644 --- a/.github/workflows/wheel-cpu-windows.yml +++ b/.github/workflows/wheel-cpu-windows.yml @@ -64,6 +64,7 @@ jobs: path: dist/*.whl - name: Publish wheels to PyPI + if: ${{ github.repository_owner == 'k2-fsa' }} env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} diff --git a/.github/workflows/wheel-cuda-ubuntu.yml b/.github/workflows/wheel-cuda-ubuntu.yml index 7888d9fa0..c4449028c 100644 --- a/.github/workflows/wheel-cuda-ubuntu.yml +++ b/.github/workflows/wheel-cuda-ubuntu.yml @@ -93,6 +93,7 @@ jobs: ls -lh dist/ - name: Publish wheels to PyPI + if: ${{ github.repository_owner == 'k2-fsa' }} env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} diff --git a/scripts/github_actions/run-nightly-build.py b/scripts/github_actions/run-nightly-build.py new file mode 100755 index 000000000..1e002fba3 --- /dev/null +++ b/scripts/github_actions/run-nightly-build.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) + +import subprocess +from datetime import datetime, timedelta + + +def get_last_commit_date() -> datetime: + date = ( + subprocess.check_output( + [ + "git", + "log", + "-1", + "--format=%ad", + "--date=unix", + ] + ) + .decode("ascii") + .strip() + ) + return datetime.utcfromtimestamp(int(date)) + + +def main(): + last_commit_date_utc = get_last_commit_date() + now_utc = datetime.utcnow() + if last_commit_date_utc + timedelta(days=1) > now_utc: + print("true") + else: + print("false") + + +if __name__ == "__main__": + main() From 4fb6b88661cca73e5f66f03df16e5a1d0c4886f8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Apr 2022 18:29:32 +0800 Subject: [PATCH 61/64] Check the versions of PyTorch and CUDA at the import time. (#949) * Check the versions of PyTorch and CUDA at the import time. --- .flake8 | 1 + .gitignore | 1 + CMakeLists.txt | 4 ++++ k2/python/k2/__init__.py | 21 +++++++++++++++++++++ k2/python/k2/torch_version.py.in | 17 +++++++++++++++++ 5 files changed, 44 insertions(+) create mode 100644 k2/python/k2/torch_version.py.in diff --git a/.flake8 b/.flake8 index c0ad0c420..f5f3f5e04 100644 --- a/.flake8 +++ b/.flake8 @@ -14,6 +14,7 @@ exclude = get_version.py build, k2/python/host, + k2/python/k2/__init__.py, k2/python/k2/ctc_loss.py, docs diff --git a/.gitignore b/.gitignore index bfdf9eb31..39235688d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +k2/python/k2/torch_version.py # Build folder **/build* diff --git a/CMakeLists.txt b/CMakeLists.txt index 7adbede8b..906ebddb0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,6 +264,10 @@ if(K2_USE_PYTORCH) add_definitions(-DK2_USE_PYTORCH) add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H) include(torch) + configure_file( + ${CMAKE_SOURCE_DIR}/k2/python/k2/torch_version.py.in + ${CMAKE_SOURCE_DIR}/k2/python/k2/torch_version.py @ONLY + ) endif() if(K2_WITH_CUDA) diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 930affb18..54102705b 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -1,9 +1,29 @@ import torch # noqa +from .torch_version import k2_torch_cuda_version +from .torch_version import k2_torch_version + +if torch.__version__.split("+")[0] != k2_torch_version.split("+")[0]: + raise ImportError( + f"k2 was built using PyTorch {k2_torch_version}\n" + f"But you are using PyTorch {torch.__version__} to run it" + ) + +if ( + k2_torch_cuda_version != "" + and torch.version.cuda is not None + and torch.version.cuda != k2_torch_cuda_version +): + raise ImportError( + f"k2 was built using CUDA {k2_torch_cuda_version}\n" + f"But you are using CUDA {torch.version.cuda} to run it." + ) + try: from _k2 import DeterminizeWeightPushingType from _k2 import simple_ragged_index_select except ImportError as e: import sys + major_v, minor_v = sys.version_info[:2] raise ImportError( str(e) + "\nNote: If you're using anaconda and importing k2 on MacOS," @@ -18,6 +38,7 @@ from . import dense_fsa_vec from . import fsa from . import utils + # from .autograd import intersect_dense from .autograd import intersect_dense_pruned diff --git a/k2/python/k2/torch_version.py.in b/k2/python/k2/torch_version.py.in new file mode 100644 index 000000000..30e83abc0 --- /dev/null +++ b/k2/python/k2/torch_version.py.in @@ -0,0 +1,17 @@ +# Auto generated by the toplevel CMakeLists.txt. +# +# DO NOT EDIT. + +# The torch version used to build k2. We will check it against the torch version +# that is used to run k2. If they are not the same, `import k2` will throw. +# +# Some example values are: +# - 1.10.0+cu102 +# - 1.5.0+cpu +k2_torch_version = "@TORCH_VERSION@" + +# The CUDA version used to build k2. +# Note: It is an empty string if you used a CPU version of PyTorch to build k2 +# +# An example value is "10.2". +k2_torch_cuda_version = "@TORCH_CUDA_VERSION@" From 9ebd757bb383343142c736ef84073585f4ccf84c Mon Sep 17 00:00:00 2001 From: "Nickolay V. Shmyrev" Date: Tue, 12 Apr 2022 01:53:49 +0300 Subject: [PATCH 62/64] More straightforward message when CUDA support is missing (#950) --- k2/csrc/log.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k2/csrc/log.h b/k2/csrc/log.h index 5018c5275..04624d66f 100644 --- a/k2/csrc/log.h +++ b/k2/csrc/log.h @@ -374,7 +374,7 @@ inline K2_CUDA_HOSTDEV LogLevel GetCurrentLogLevel() { #define K2_CHECK_CUDA_ERROR(x) \ K2_CHECK_EQ(x, cudaSuccess) << " Error: " << cudaGetErrorString(x) << ". " #else -#define K2_CHECK_CUDA_ERROR(...) K2_LOG(FATAL) << "Don't call me" +#define K2_CHECK_CUDA_ERROR(...) K2_LOG(FATAL) << "K2 compiled without CUDA support" #endif // The parameter of `K2_CUDA_SAFE_CALL` should be cuda function call or kernel @@ -396,7 +396,7 @@ inline K2_CUDA_HOSTDEV LogLevel GetCurrentLogLevel() { // Use a separate K2_CUDA_SAFE_CALL() for CPU // because the kernel invocation syntax <<< >>> // is not valid C++ -#define K2_CUDA_SAFE_CALL(...) K2_LOG(FATAL) << "Don't call me!" +#define K2_CUDA_SAFE_CALL(...) K2_LOG(FATAL) << "K2 compiled without CUDA support" #endif // ------------------------------------------------------------ From 3b83183234d0f1d8391872630551c5af7c491ed2 Mon Sep 17 00:00:00 2001 From: LvHang Date: Tue, 12 Apr 2022 08:26:41 +0800 Subject: [PATCH 63/64] Implement ArrayOfRagged (#927) * Implement ArrayOfRagged * Fix issues and pass tests * fix style * change few statements of functions and move the definiation of template Array1OfRagged to header file * add offsets test code --- k2/csrc/array_of_ragged.cu | 104 ++++++++++++++++++++++++---- k2/csrc/array_of_ragged.h | 119 ++++++++++++++++++++++---------- k2/csrc/array_of_ragged_test.cu | 34 +++++++++ 3 files changed, 206 insertions(+), 51 deletions(-) diff --git a/k2/csrc/array_of_ragged.cu b/k2/csrc/array_of_ragged.cu index cd93434d9..11f8e8ea4 100644 --- a/k2/csrc/array_of_ragged.cu +++ b/k2/csrc/array_of_ragged.cu @@ -1,5 +1,7 @@ /** - * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) + * Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) + * 2022 ASLP@NWPU (authors: Hang Lyu) + * * See LICENSE for clarification regarding multiple authors * @@ -20,35 +22,107 @@ namespace k2 { -Array1OfRaggedShape::Array1OfRaggedShape(RaggedShape *src, int32_t num_srcs) - : num_srcs_(num_srcs) { - K2_CHECK_GE(num_srcs, 1); - K2_CHECK(src); - num_axes_ = src[0].NumAxes(); - c_ = src[0].Context(); +Array1OfRaggedShape::Array1OfRaggedShape(RaggedShape *srcs, int32_t num_srcs) : + num_srcs_(num_srcs) { + K2_CHECK_GT(num_srcs, 0); + K2_CHECK(srcs); + + // Initialize context and num_axes_. + c_ = srcs[0].Context(); + num_axes_ = srcs[0].NumAxes(); + + // Check if they have same num-axes and compatible context. + for (int32_t i = 1; i < num_srcs_; ++i) { + K2_CHECK_EQ(num_axes_, srcs[i].NumAxes()); + K2_CHECK(c_->IsCompatible(*(srcs[i].Context()))); + } - row_splits_ = - Array2(GetCpuContext(), num_axes_ - 1, num_srcs_); + // Initialize row_splits__, row_ids_ and tot_sizes_. + // + // Notice: since the Data() function is a __host__ function, it cannot be + // called on GPU. It limits us to work on CPU so that the row_splits_ and + // row_ids_ are populated on CPU, although the operator() of Array2 is a + // __host__ and __device__ function. Bear in mind, we cannot access the + // GPU data on CPU. + row_splits_ = Array2(GetCpuContext(), + num_axes_ - 1, num_srcs_); row_ids_ = Array2(GetCpuContext(), num_axes_ - 1, num_srcs_); + + // Notice: no matter the return value of TotSize() is from 'cached_tot_size' + // or the Back() function (i.e. operator[]) of array1, it it a CPU value. tot_sizes_ = Array1(GetCpuContext(), num_axes_, 0); auto row_splits_acc = row_splits_.Accessor(), row_ids_acc = row_ids_.Accessor(); + // Bear in mind, when axis == 0, the TotSize() is row_splits.Dim() - 1. + // When 0 < axis < NumAxes(), the TotSize() is row_splits.Back(). int32_t *tot_sizes_data = tot_sizes_.Data(); for (int32_t i = 0; i < num_srcs_; ++i) { - K2_CHECK_EQ(src[i].NumAxes(), num_axes_); - K2_CHECK(c_->IsCompatible(*(src[i].Context()))); for (int32_t j = 1; j < num_axes_; ++j) { - row_splits_acc(j - 1, i) = src[i].RowSplits(j).Data(); - row_ids_acc(j - 1, i) = src[i].RowIds(j).Data(); - tot_sizes_data[j] += src[i].TotSize(j); + row_splits_acc(j - 1, i) = srcs[i].RowSplits(j).Data(); + row_ids_acc(j - 1, i) = srcs[i].RowIds(j).Data(); + tot_sizes_data[j] += srcs[i].TotSize(j); } - tot_sizes_data[0] += src[i].TotSize(0); + tot_sizes_data[0] += srcs[i].TotSize(0); } row_splits_ = row_splits_.To(c_); row_ids_ = row_ids_.To(c_); + tot_sizes_ = tot_sizes_.To(c_); + + + // Initialize meat_row_splits_ + // We populate this on CPU and transfer to GPU. + meta_row_splits_ = Array2(GetCpuContext(), num_axes_, num_srcs_ + 1); + offsets_ = Array2(GetCpuContext(), num_axes_ + 1, num_srcs_ + 1); + + auto meta_row_splits_acc = meta_row_splits_.Accessor(), + offsets_acc = offsets_.Accessor(); + + // Initialize the 1st row of offsets_, which contains 0,1,2,... + for (int32_t col = 0; col <= num_srcs_; ++col) { + offsets_acc(0, col) = col; + } + // Initialize the 1st col of meta_row_splits_ and offsets_ + for (int32_t row = 0; row < num_axes_; ++row) { + meta_row_splits_acc(row, 0) = 0; + offsets_acc(row + 1, 0) = 0; + } + + // The meta_row_splits_ is the cumulative sum of the tot-sizes of the + // individual arrays. + for (int32_t i = 0; i < num_axes_; ++i) { + for (int32_t j = 1; j <= num_srcs_; ++j) { + meta_row_splits_acc(i, j) = meta_row_splits_acc(i, j - 1) + + srcs[j - 1].TotSize(i); + offsets_acc(i + 1, j) = meta_row_splits_acc(i, j); + } + } + + // Initialize meta_row_ids_ + // Elements are in [0, NumSrcs() - 1] + meta_row_ids_.resize(num_axes_); + + for (int32_t axis = 0; axis < num_axes_; ++axis) { + // The length equals to TotSize(axis) + meta_row_ids_.at(axis) = Array1( + GetCpuContext(), meta_row_splits_acc(axis, num_srcs_)); + int32_t *meta_row_ids_data = meta_row_ids_[axis].Data(); + + int32_t cur_row_start = meta_row_splits_acc(axis, 0); + for (int32_t src = 0; src < num_srcs_; ++src) { + int32_t next_row_start = meta_row_splits_acc(axis, src + 1); + for (; cur_row_start < next_row_start; ++cur_row_start) { + meta_row_ids_data[cur_row_start] = src; + } + } + meta_row_ids_[axis] = meta_row_ids_[axis].To(c_); + } + + meta_row_splits_ = meta_row_splits_.To(c_); + offsets_ = offsets_.To(c_); } + } // namespace k2 diff --git a/k2/csrc/array_of_ragged.h b/k2/csrc/array_of_ragged.h index 31349cf91..facc02dc0 100644 --- a/k2/csrc/array_of_ragged.h +++ b/k2/csrc/array_of_ragged.h @@ -1,5 +1,6 @@ /** * Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) + * 2022 ASLP@NWPU (authors: Hang Lyu) * * See LICENSE for clarification regarding multiple authors * @@ -24,31 +25,48 @@ #include #include "k2/csrc/array.h" +#include "k2/csrc/array_ops.h" #include "k2/csrc/context.h" #include "k2/csrc/log.h" -#include "k2/csrc/ragged_ops.h" +#include "k2/csrc/ragged.h" namespace k2 { + +/* + Array1OfRagged is a 1-dimensional array of Ragged. + It is intended for situations where you want to do some operations on + arrays of ragged arrays, without explicitly concatenating them (e.g. to + save time). This is a fairly low-level interface, intended to + be used mostly by CUDA/C++ implementation code. It is a convenience + wrapper that saves you the trouble of creating arrays of pointers. + */ + + /* Array1OfRaggedShape is a convenience function that gives you easy access to pointers-of-pointers for an array of ragged shapes. */ class Array1OfRaggedShape { public: + // Default constructor. + Array1OfRaggedShape() = default; + /* Constructor. Args: - srcs: pointers to the source shapes, a CPU pointer - num_srcs: the number of source shapes. All shapes must have the - same NumAxes() and must be on the same device. + srcs: pointers to the source shapes, a CPU pointer + num_srcs: the number of source shapes. All shapes must have the + same NumAxes() and must be on the same device. TODO: we'll likely, later, add optional args which dictate which of the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should enable us to save kernels by combining certain operations across the axes. + */ - Array1OfRaggedShape(RaggedShape *srcs, int32_t num_srcs); - Array1OfRaggedShape() = default; + Array1OfRaggedShape(RaggedShape *srcs, + int32_t num_srcs); + int32_t NumSrcs() const { return num_srcs_; } int32_t NumAxes() const { return num_axes_; } @@ -63,23 +81,35 @@ class Array1OfRaggedShape { // Returns device-accessible vector of row-splits for a particular // axis, indexed by 0 <= src < num_srcs. const int32_t **RowSplits(int32_t axis) { - return row_splits_.Row(axis - 1).Data(); + K2_CHECK_LT(static_cast(axis), + static_cast(num_axes_)); + return row_splits_.Row(axis - 1).Data(); } // Returns device-accessible array of row-ids for the individual shapes // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this // Array2 is [NumAxes() - 1][NumSrcs()]. - const Array2 *RowIds() const { return &row_ids_; } + const Array2 *RowIds() const { return &row_ids_; } + // Returns device-accessible vector of row-splits for a particular // axis, indexed by 0 <= src < num_srcs. - const int32_t **RowIds(int32_t axis) { return row_ids_.Row(axis - 1).Data(); } + const int32_t **RowIds(int32_t axis) { + K2_CHECK_LT(static_cast(axis), + static_cast(num_axes_)); + return row_ids_.Row(axis - 1).Data(); + } + /* Return the total size on this axis, which is the sum of the TotSize() of the individual shapes. Requires 0 <= axis < NumAxes() and for axis=0 the returned value is the same as Dim0(). */ - int32_t TotSize(int32_t axis) const { return tot_sizes_[axis]; } + int32_t TotSize(int32_t axis) const { + K2_CHECK_LT(static_cast(axis), + static_cast(num_axes_)); + return tot_sizes_[axis]; + } // equivalent to TotSize(0). int32_t Dim0() const { return TotSize(0); } @@ -88,7 +118,7 @@ class Array1OfRaggedShape { along the src axis, of the tot-sizes of the individual arrays. This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src]; caution, the indexing is different from RowSplits(), there is no offset. - Also, the meta_row_splits0 is a thing, unlike with regular row-splits + Also, the meta_row_splits_ is a thing, unlike with regular row-splits which start from 1. Caution: the lengths of the arrays pointed to by the elements of this @@ -99,38 +129,47 @@ class Array1OfRaggedShape { to GPU, this will be faster than invoking an extra kernel in normal cases when the NumSrcs() is small. [Also: see GetRowInfoMulti()]. */ - // TODO: implement it... - Array2 MetaRowSplits(); + const Array2 &MetaRowSplits() const { return meta_row_splits_; } // could POSSIBLY add this so this code could be used in functions like // Stack(). would be like MetaRowSplits but with an extra 1st row containing // 0,1,2,... We could perhaps create it with 1 extra initial row so this is // always convenient to output. - // TODO: implement it... - Array2 Offsets(); + const Array2 &Offsets() const { return offsets_; } /* - Returns the meta-row-splits for a particular axis, with 0 <= axis < - NumAxes(); this is the cumulative sum of the TotSize(axis) for all of the - sources, with MetaRowSplits(axis).Dim() == NumSrcs() + 1. + Returns the meta-row-splits for a particular axis, with + 0 <= axis < NumAxes(); + this is the cumulative sum of the TotSize(axis) for all of the sources, + with MetaRowSplits(axis).Dim() == NumSrcs() + 1. - Note: in ragged_ops.cu we refer to this as composed_row_splits + Note: in ragged_opts.cu we refer to this as composed_row_splits */ - // TODO: implement it... - Array1 MetaRowSplits(int32_t axis); + Array1 MetaRowSplits(int32_t axis) { + K2_CHECK_LT(static_cast(axis), + static_cast(num_axes_)); + return meta_row_splits_.Row(axis); + } /* Return the device-accessible meta-row-ids, which are the row-ids corresponding to MetaRowSplits(); this tells us, for indexes into the - appended/concatenated array, which source array they belong to, i.e. - elements are in [0,NumSrcs()-1]. + appended/concatenated array, which source array they belong to, + i.e. elements are in [0,NumSrcs()-1]. This cannot be an Array2 because unlike the MetaRowSplits(), all the row-ids arrays are of different lengths. Note: in ragged_ops.cu we refer to this as composed_row_ids. */ - // TODO: implement it... - Array1 MetaRowIds(); + Array1 MetaRowIds() { + Array1 ans(GetCpuContext(), num_axes_); + const int32_t* *ans_data = ans.Data(); + for (int32_t i = 0; i < num_axes_; ++i) { + ans_data[i] = meta_row_ids_[i].Data(); + } + ans = ans.To(c_); + return ans; + } /* Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes(); @@ -140,18 +179,28 @@ class Array1OfRaggedShape { would tell us which source an idx012 with value 100 into axis 2 of concatenated array would come from. */ - // TODO: implement it... - Array1 MetaRowIds(int32_t axis); + const Array1 &MetaRowIds(int32_t axis) const { + K2_CHECK_LT(static_cast(axis), + static_cast(num_axes_)); + return meta_row_ids_[axis]; + } private: ContextPtr c_; int32_t num_srcs_; int32_t num_axes_; + Array2 row_splits_; // shape [num_axes_ - 1][num_srcs_] Array2 row_ids_; // shape [num_axes_ - 1][num_srcs_] - Array1 tot_sizes_; // dim num_axes_, this is on CPU + Array1 tot_sizes_; // dim num_axes_ + + Array2 meta_row_splits_; // shape [num_axes_][num_srcs_ + 1] + Array2 offsets_; // shape [num_axes_][num_srcs_ + 1] + std::vector > meta_row_ids_; // dim num_axes_ }; + + /* Array1OfRagged is a 1-dimensional array of Ragged. It is intended for situations where you want to do some operations on @@ -171,17 +220,14 @@ struct Array1OfRagged { int32_t NumSrcs() const { return values.Dim(); } ContextPtr &Context() { return shape.Context(); } + // Default constructor will not leave this a valid Array1OfRagged object, + // you shouldn't do anything with it. Both members will be initialized with + // default constructors. Array1OfRagged() = default; - /* - Constructor. - Args: - srcs: pointers to the source ragged tensors, a CPU pointer - num_srcs: the number of source ragged tensors. All ragged tensors must - have the same NumAxes() and must be on the same device. - */ + // The 'srcs' should have the same number of axes. Array1OfRagged(Ragged *srcs, int32_t num_srcs) { - K2_CHECK_GE(num_srcs, 1); + K2_CHECK_GT(num_srcs, 0); K2_CHECK(srcs); values = Array1(GetCpuContext(), num_srcs); T **values_data = values.Data(); @@ -195,6 +241,7 @@ struct Array1OfRagged { } }; + } // namespace k2 #endif // K2_CSRC_ARRAY_OF_RAGGED_H_ diff --git a/k2/csrc/array_of_ragged_test.cu b/k2/csrc/array_of_ragged_test.cu index 69b482315..4cb48bdb6 100644 --- a/k2/csrc/array_of_ragged_test.cu +++ b/k2/csrc/array_of_ragged_test.cu @@ -43,6 +43,7 @@ void TestArray1OfRaggedConstruct() { for (int32_t j = 1; j < num_axes; ++j) { const int32_t **row_splits = array_of_ragged.shape.RowSplits(j); const int32_t **row_ids = array_of_ragged.shape.RowIds(j); + Array1 expected_row_splits(GetCpuContext(), num_srcs); Array1 expected_row_ids(GetCpuContext(), num_srcs); int32_t **expected_row_splits_data = expected_row_splits.Data(); @@ -55,6 +56,7 @@ void TestArray1OfRaggedConstruct() { expected_row_ids = expected_row_ids.To(c); expected_row_splits_data = expected_row_splits.Data(); expected_row_ids_data = expected_row_ids.Data(); + Array1 flags(c, 2, 1); int32_t *flags_data = flags.Data(); K2_EVAL( @@ -67,6 +69,38 @@ void TestArray1OfRaggedConstruct() { for (int32_t i = 0; i < num_srcs; ++i) { K2_CHECK_EQ(array_of_ragged.values[i], raggeds[i].values.Data()); } + + for (int32_t j = 0; j < num_axes; ++j) { + Array1 meta_row_splits(array_of_ragged.shape.MetaRowSplits(j)); + Array1 meta_row_ids(array_of_ragged.shape.MetaRowIds(j)); + Array1 offsets( + array_of_ragged.shape.Offsets().RowArange(j + 1, j + 2).Row(0)); + + Array1 expected_meta_row_splits(GetCpuContext(), num_srcs + 1); + int32_t *expected_meta_row_splits_data = expected_meta_row_splits.Data(); + for (int32_t i = 0; i < num_srcs; ++i) { + expected_meta_row_splits_data[i] = raggeds[i].TotSize(j); + } + ExclusiveSum(expected_meta_row_splits, &expected_meta_row_splits); + expected_meta_row_splits = expected_meta_row_splits.To(c); + Array1 expected_meta_row_ids(c, + array_of_ragged.shape.TotSize(j)); + RowSplitsToRowIds(expected_meta_row_splits, &expected_meta_row_ids); + + K2_CHECK(Equal(meta_row_splits, expected_meta_row_splits)); + K2_CHECK(Equal(meta_row_ids, expected_meta_row_ids)); + K2_CHECK(Equal(offsets, expected_meta_row_splits)); + } + + Array1 expected_offsets_1st_row(GetCpuContext(), num_srcs + 1); + int32_t *expected_offsets_1st_row_data = expected_offsets_1st_row.Data(); + for (int32_t i = 0; i <= num_srcs; ++i) { + expected_offsets_1st_row_data[i] = i; + } + expected_offsets_1st_row = expected_offsets_1st_row.To(c); + Array1 offsets_1st_row( + array_of_ragged.shape.Offsets().RowArange(0, 1).Row(0)); + K2_CHECK(Equal(offsets_1st_row, expected_offsets_1st_row)); } } From 1b29f0a946f50186aaa82df46a59f492ade9692b Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 13 Apr 2022 08:46:49 +0800 Subject: [PATCH 64/64] Fix precision (#951) * Fix precision * Using different pow version for windows and *nix * Use int64_t pow * Minor fixes --- k2/csrc/math.h | 24 +++++++++++++++++++----- k2/csrc/rnnt_decode.cu | 6 +++--- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/k2/csrc/math.h b/k2/csrc/math.h index 65b6f8e91..3ebc8b406 100644 --- a/k2/csrc/math.h +++ b/k2/csrc/math.h @@ -27,6 +27,20 @@ namespace k2 { +// Currently, only used in k2/csrc/rnnt_decode.cu +// See https://github.com/k2-fsa/k2/pull/951#issuecomment-1096650842 +__host__ __device__ __forceinline__ int64_t Pow(int64_t base, + int64_t exponent) { + K2_CHECK_GE(exponent, 0); + int64_t exp = 0; + int64_t result = 1; + while (exp < exponent) { + result *= base; + exp++; + } + return result; +} + /* Returns index of highest bit set, in range -1..30. HighestBitSet(0) = -1, @@ -106,29 +120,29 @@ int32_t RandIntGeometric(int32_t min, int32_t max); type, but for types float and double it "fixes" the broken behavior of the C++ standard w.r.t. infinity allowing infinities to be parsed. */ -template struct InputFixer { +template +struct InputFixer { T t; // cast operator operator T() const { return t; } }; - namespace internal { template Real FixedRead(std::istream &is); } template -inline std::istream &operator >>(std::istream &is, InputFixer &f) { +inline std::istream &operator>>(std::istream &is, InputFixer &f) { return is >> f.t; } template <> -inline std::istream &operator >>(std::istream &is, InputFixer &f) { +inline std::istream &operator>>(std::istream &is, InputFixer &f) { f.t = internal::FixedRead(is); return is; } template <> -inline std::istream &operator >>(std::istream &is, InputFixer &f) { +inline std::istream &operator>>(std::istream &is, InputFixer &f) { f.t = internal::FixedRead(is); return is; } diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index d5fe89432..e86b2f7d5 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -159,8 +159,8 @@ void RnntDecodingStreams::GetContexts(RaggedShape *shape, int64_t state_value = states_values_data[state_idx01x], context_state = state_value / num_graph_states, exp = decoder_history_len - col, - state = context_state % (int64_t)powf(vocab_size, exp); - state = state / (int64_t)powf(vocab_size, exp - 1); + state = context_state % Pow(vocab_size, exp); + state = state / Pow(vocab_size, exp - 1); contexts_acc(row, col) = state; }); } @@ -540,7 +540,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { // can be done with `358 % 10^2`, then we append 6 to 58, that can be // done with `58 * 10 + 6`. context_state = this_context_state % - (int64_t)powf(vocab_size, decoder_history_len - 1); + Pow(vocab_size, decoder_history_len - 1); context_state = context_state * vocab_size + arc.label; }