Skip to content

Commit

Permalink
SubsetRagged & PruneRagged (#919)
Browse files Browse the repository at this point in the history
* Extend interface of SubsampleRagged.

* Add interface for pruning ragged tensor.

* Draft of new RNN-T decoding method

* Implements SubsampleRaggedShape

* Implements PruneRagged

* Rename subsample-> subset

* Minor fixes

* Fix comments

Co-authored-by: Daniel Povey <[email protected]>
  • Loading branch information
pkufool and danpovey authored Feb 20, 2022
1 parent 56edc82 commit 854b792
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 187 deletions.
31 changes: 26 additions & 5 deletions k2/csrc/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,48 @@ class Renumbering {
(pre-renumbering) indexes. Its dimension is the number of
new indexes (i.e. the number of 1 in keep_), but internally
it has one extra element which contains the number of old
elements, so it's OK to read one past the end. (We may
later make it possible to access the array with the one-larger
dimension).
elements, so it's OK to read one past the end.
*/
Array1<int32_t> &New2Old() {
NVTX_RANGE(K2_FUNC);
if (!new2old_.IsValid()) ComputeNew2Old();
return new2old_;
}

/* Return a mapping from new index to old index, with one extra element
containing the total number of kept elements if extra_element == true.
If Keep() can be interpreted as a tails vector, i.e. with 1 at the end
of sub-lists of elements, then New2Old(true) would corresponds to a
row-splits array and Old2New(false) would correspond to a row-ids
array.
*/
Array1<int32_t> New2Old(bool extra_element) {
Array1<int32_t> &new2old_part = New2Old();
if (!extra_element) {
return new2old_part;
} else {
// This is a little perverse, using low-level interfaces to increase the
// dimension of the array; but we know it does have one more element.
// Because we normally use New2Old() with no arg (equivalent to false),
// the overloaded version of this function returns a reference for
// efficiency.
return Array1<int32_t>(new2old_part.Dim() + 1,
new2old_part.GetRegion(), 0);
}
}

/* Return a mapping from old index to new index. This is created on demand
(must only be called after the Keep() array has been populated).
@param [in] extra_element If true, will return the array of size
NumOldElems() + 1, which includes one more element;
otherwise it will return an array of size NumOldElems().
@return Returns an array mapping the old indexes to the new indexes.
This array is just the exclusive sum of Keep().
It gives the mapping for indexes that are kept; element
i is kept if `Old2New()[i+1] > Old2New()[i]`.
@return Returns an array mapping the old indexes to the new indexes.
*/
Array1<int32_t> Old2New(bool extra_element = false) {
NVTX_RANGE(K2_FUNC);
Expand Down
9 changes: 9 additions & 0 deletions k2/csrc/algorithms_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ TEST(AlgorithmsTest, TestRenumbering) {
Array1<int32_t> new2old = numbering.New2Old();
EXPECT_EQ(new2old.Dim(), 0);
EXPECT_EQ(numbering.NumNewElems(), 0);
new2old = numbering.New2Old(true);
EXPECT_EQ(new2old.Dim(), 1);
EXPECT_EQ(new2old.Back(), 0);
}

{
Expand All @@ -67,6 +70,9 @@ TEST(AlgorithmsTest, TestRenumbering) {
Array1<int32_t> new2old = numbering.New2Old();
EXPECT_EQ(new2old.Dim(), 0);
EXPECT_EQ(numbering.NumNewElems(), 0);
new2old = numbering.New2Old(true);
EXPECT_EQ(new2old.Dim(), 1);
EXPECT_EQ(new2old.Back(), 5);
}

{
Expand All @@ -93,6 +99,9 @@ TEST(AlgorithmsTest, TestRenumbering) {
std::vector<int32_t> cpu_new2old(new2old.Data(),
new2old.Data() + new2old.Dim());
EXPECT_THAT(cpu_new2old, ::testing::ElementsAre(0, 2, 3, 6));
new2old = numbering.New2Old(true);
EXPECT_EQ(new2old.Dim(), 5);
EXPECT_EQ(new2old.Back(), 7);
}
}
}
Expand Down
25 changes: 10 additions & 15 deletions k2/csrc/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ RaggedShape Index(RaggedShape &src, int32_t axis,
if (axis == 0) {
return IndexAxis0(src, indexes, elem_indexes);
} else if (axis == src.NumAxes() - 1) {
// This code is related to SubsampleRaggedShape(). `indexes` corresponds
// This code is related to SubsetRaggedShape(). `indexes` corresponds
// to `new2old`.
Array1<int32_t> last_row_ids = src.RowIds(num_axes - 1)[indexes];
#ifndef NDEBUG
Expand Down Expand Up @@ -1944,21 +1944,16 @@ Ragged<int32_t> AddPrefixToRagged(Ragged<int32_t> &src,
return Ragged<int32_t>(dst_shape, dst_values);
}

RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering) {
RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering,
int32_t axis, Array1<int32_t> *elems_new2old) {
NVTX_RANGE(K2_FUNC);
K2_CHECK_EQ(renumbering.NumOldElems(), src.NumElements());

// Make sure final row-ids are populated.
src.RowIds(src.NumAxes() - 1);
std::vector<RaggedShapeLayer> axes = src.Layers();
axes.back().row_ids = axes.back().row_ids[renumbering.New2Old()];
axes.back().row_splits = renumbering.Old2New()[axes.back().row_splits];
axes.back().cached_tot_size = axes.back().row_ids.Dim();
return RaggedShape(axes);
axis = axis < 0 ? src.NumAxes() + axis : axis;
K2_CHECK_EQ(renumbering.NumOldElems(), src.TotSize(axis));
return Index(src, axis, renumbering.New2Old(), elems_new2old);
}

RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &r_before_last,
Renumbering &r_last) {
RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &r_before_last,
Renumbering &r_last) {
NVTX_RANGE(K2_FUNC);
K2_CHECK_EQ(r_before_last.NumOldElems(), src.TotSize(src.NumAxes() - 2));
K2_CHECK_EQ(r_last.NumOldElems(), src.NumElements());
Expand Down Expand Up @@ -2103,7 +2098,7 @@ RaggedShape RemoveEmptyLists(RaggedShape &src_shape, int32_t axis,
Renumbering r_temp;
if (!renumbering_out) renumbering_out = &r_temp;
bottom_shape = RemoveEmptyListsAxis0(bottom_shape, renumbering_out);
top_shape = SubsampleRaggedShape(top_shape, *renumbering_out);
top_shape = SubsetRaggedShape(top_shape, *renumbering_out);
return ComposeRaggedShapes(top_shape, bottom_shape);
}

Expand All @@ -2117,7 +2112,7 @@ RaggedShape RemoveSomeEmptyLists(RaggedShape &src_shape, int32_t axis,
DecomposeRaggedShape(src_shape, axis, &top_shape, &bottom_shape);

bottom_shape = RenumberAxis0Simple(bottom_shape, renumbering);
top_shape = SubsampleRaggedShape(top_shape, renumbering);
top_shape = SubsetRaggedShape(top_shape, renumbering);
return ComposeRaggedShapes(top_shape, bottom_shape);
}

Expand Down
120 changes: 101 additions & 19 deletions k2/csrc/ragged_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,32 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false,
int32_t max_num_elements = 2000);

/*
Return ragged shape with only a subset of the bottom-level elements kept.
Require renumbering.NumOldElems() == src.NumElements(). Note: all
dimensions and tot-sizes preceding the final axis will remain the same, which
might give rise to empty lists.
Return ragged shape with only a subset of the elements or sub-lists
on the specified axis kept. (This is not regular sampling, it is
irregular subsampling with specified elements kept).
@param [in] src The ragged shape that we are subsampling
@param [in] renumbering The renumbering object that dictates
which elements of `src` we keep; we require
renumbering.NumOldElems() == src.TotSize(axis2)
where axis2 = (axis < 0 ? src.NumAxes() + axis : axis).
@param [in] axis The axis to subsample; if negative, will be
interpreted as an offset from src.NumAxes().
@param [out] elems_new2old If supplied, this function will
output to this location a new2old vector that
dictates how the elements of a ragged tensor
with shape `src` would be renumbered.
@return Returns the subsampled shape. All dimensions and tot-sizes
preceding the axis `axis` will remain the same, which might give
rise to empty lists on those axes; these can be removed if
necessary with RemoveEmptyLists().
Notice the other version of this function below.
*/
RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering);
RaggedShape SubsetRaggedShape(RaggedShape &src,
Renumbering &renumbering,
int32_t axis = -1,
Array1<int32_t> *elems_new2old = nullptr);

/*
Return ragged shape with only a subset of the elements on the last
Expand All @@ -777,9 +795,9 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering);
Note: all dimensions and tot-sizes preceding the last two axes will remain the
same, which might give rise to empty lists.
*/
RaggedShape SubsampleRaggedShape(RaggedShape &src,
Renumbering &renumbering_before_last,
Renumbering &renumbering_last);
RaggedShape SubsetRaggedShape(RaggedShape &src,
Renumbering &renumbering_before_last,
Renumbering &renumbering_last);

/*
Removes empty lists on a particular axis (not last axis) of a RaggedShape,
Expand Down Expand Up @@ -866,17 +884,82 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape,


/*
Return ragged array with only a subset of the bottom-level elements kept.
Require renumbering.NumOldElems() == src.NumElements(). Note: all
dimensions and tot-sizes preceding the final axis will remain the same, which
might give rise to empty lists.
Return ragged array with only a subset of the elements or sub-lists
on the specified axis kept. (This is not regular sampling, it is
irregular subsampling with specified elements kept).
@param [in] src The ragged shape that we are subsampling
@param [in] renumbering The renumbering object that dictates
which elements of `src` we keep; we require
renumbering.NumOldElems() == src.TotSize(axis2)
where axis2 = (axis < 0 ? src.NumAxes() - axis : axis).
@param [in] axis The axis to subsample; if negative, will be
interpreted as an offset from src.NumAxes().
@param [out] elems_new2old If supplied, this function will
output to this location a new2old array that
dictates how the elements of a ragged tensor
with shape `src` would be renumbered.
@return Returns the subsampled shape. All dimensions and tot-sizes
preceding the axis `axis` will remain the same, which might give
rise to empty lists on those axes; these can be removed if
necessary with RemoveEmptyLists().
*/
template <typename T>
Ragged<T> SubsampleRagged(Ragged<T> &src, Renumbering &renumbering) {
return Ragged<T>(SubsampleRaggedShape(src.shape, renumbering),
src.values[renumbering.New2Old()]);
Ragged<T> SubsetRagged(Ragged<T> &src, Renumbering &renumbering,
int32_t axis = -1,
Array1<int32_t> *elems_new2old = nullptr) {
Array1<int32_t> tmp;
if (elems_new2old == nullptr)
elems_new2old = &tmp;
RaggedShape shape = SubsetRaggedShape(src.shape, renumbering,
axis, elems_new2old);
return Ragged<T>(shape, src.values[*elems_new2old]);
}

/*
This function creates a Renumbering object that can be used to obtain subsets
of ragged arrays via SubsetRaggedShape(). It implements beam pruning as
used in pruned Viterbi search and similar algorithms, where there is both a
beam and a max-active (`max_elems`) constraint. T will probably be float or
double, interpreted as a "positive-is-better" sense, i.e. as scores.
@param [in] src The ragged object to be subsampled.
@param [in] axis The axis to be subsampled, must satisfy
0 <= axis < src.NumAxes(). The axis before `axis`, if axis > 0,
will be interpreted as a "batch" axis.
@param [in] beam The main pruning beam. The sub-lists of elements on axis
`axis` will be removed if their maximum element (or the element
itself, if axis + 1 == src.NumAxes()) is less than
this_best_elem - beam, where this_best_elem
is the maximum element taken over axis `axis-1` (or over the
entire array, if axis == 0). Think of axis `axis-1`, if
axis > 0, as the "batch" axis, and axis `axis` as the axis that we
actually remove elements or sub-lists on. Empty sub-lists on axis
`axis` will always be pruned, as their score would be treated
as -infinity.
@param [in] max_elems If max_elems > 0, it is the maximum number of
sub-lists or elements that are allowed within any sub-list
on axis `axis-1` (or the maximum number of top-level sub-lists
after subsampling, if axis == 0). We keep the best ones.
If max_elems <= 0, there is no such constraint.
@return Returns the renumbering object to be used to actually
prune/subsample the specified axis.
Example:
PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 1, 5.0, 3)
would create a Renumbering object that would prune the
ragged tensor to [ [0 -1 -2], [ -10 ], [ ] ]
PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0)
would create a Renumbering object that would prune the
ragged tensor to [ [0 -1 -2 -3] ]
*/
template <typename T>
Renumbering PruneRagged(Ragged<T> &src,
int32_t axis,
T beam,
int32_t max_elems);

/*
Stack a list of Ragged arrays to create a Ragged array with one more axis.
Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All
Expand Down Expand Up @@ -974,8 +1057,7 @@ void Unstack(Ragged<T> src, int32_t axis, std::vector<Ragged<T>> *out,
/*
Concatenate a list of Ragged<T> to form a single Ragged<T>.
@param [in] axis Axis to append them on. Currently
we only support axis == 0 or axis == 1.
@param [in] axis Axis to append them on.
Previous axes must
have the same shape, i.e. if axis == 1
then `src[i]->Dim0()` must all have the
Expand Down Expand Up @@ -1368,7 +1450,7 @@ Ragged<T> Merge(int32_t num_srcs, Ragged<T> **src,
/*
Returns a ragged tensor after removing all 'values' that were <= a provided
cutoff. Leaves all layers of the shape except for the last one unaffected.
Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] <=
Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] <=
cutoff).
*/
template <typename T>
Expand All @@ -1377,7 +1459,7 @@ Ragged<T> RemoveValuesLeq(Ragged<T> &src, T cutoff);
/*
Returns a ragged tensor after removing all 'values' that equal a provided
target. Leaves all layers of the shape except for the last one unaffected.
Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] ==
Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] ==
target).
*/
template <typename T>
Expand Down
Loading

0 comments on commit 854b792

Please sign in to comment.