Skip to content

Commit

Permalink
Bunch of changes... (k2-fsa#99)
Browse files Browse the repository at this point in the history
* Various progress..

* Further implementation  progress...

* Further progress...

* Various progress (in the middle of something...

* Partial work...
  • Loading branch information
danpovey authored Sep 1, 2020
1 parent d6d947e commit 25d73ef
Show file tree
Hide file tree
Showing 18 changed files with 2,096 additions and 189 deletions.
35 changes: 30 additions & 5 deletions k2/csrc/cuda/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,39 @@ class Renumbering {
Array1<char> &Keep(); // dim is NumOldElems(). 0 if not kept, 1 if kept
// (user will write to here).

Array1<int32_t> &New2Old(); // dim is NumNewElems().

Array1<int32_t> &Old2New(); // dim is NumOldElems(). exclusive-sum of Keep().
/* Return a mapping from new index to old index. This is created on
demand (must only be called after the Keep() array has been populated).
@param include_final_value If true the dimension of the result
will be NumNewElems(), the number of new elements
in the renumbering (the last element will be
NumOldElems(). If false, the last element
is omitted.
@return Returns an array mapping the new indexes to the old
(pre-renumbering) indexes.
*/
Array1<int32_t> New2Old(bool include_final_value = true);


/* Return a mapping from old index to new index (this is the exclusive-sum of
`Keep()`). This is created on demand (must only be called after the Keep()
array has been populated).
@param include_final_value If true the dimension of the result
will be NumNewElems(), the number of new elements
in the renumbering (the last element will be
NumOldElems(). If false, the last element
is omitted.
@return Returns an array mapping the new indexes to the old
(pre-renumbering) indexes.
*/
Array1<int32_t> Old2New(bool include_final_value = true);

private:
Array1<char> kept;
Array1<int32_t> new2old;
Array1<int32_t> old2new;
Array1<char> keep_;
Array1<int32_t> new2old_;
Array1<int32_t> old2new_;

};

Expand Down
78 changes: 71 additions & 7 deletions k2/csrc/cuda/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace k2 {
template <typename T>
class Array1 {
public:
int32_t Dim() const { return size_; } // dimension of only axis (axis 0)
int32_t Dim() const { return dim_; } // dimension of only axis (axis 0)

// Returns pointer to 1st elem. Could be a GPU or CPU pointer,
// depending on the context.
Expand Down Expand Up @@ -46,6 +46,19 @@ class Array1 {

Array1(ContextPtr ctx, int32_t size) { Init(ctx, size); }


// Creates an array that is not valid, e.g. you cannot call Context() on it.
Array1(): dim_(0), byte_offset_(0), region_(0) { }

Array1(ContextPtr ctx, int32_t size, T elem) {
Init(ctx, size);
T *data = Data();
auto lambda = [=] __host__ __device__ (int32_t i) -> void {
data[i] = elem;
};
Eval(ctx, dim_, lambda);
}

/* Return sub-part of this array
@param [in] start First element to cover, 0 <= start < size()
@param [in] size Number of elements to include, 0 < size < size()-start
Expand Down Expand Up @@ -96,6 +109,16 @@ class Array1 {
pointer. */
T operator[](int32_t i);


/* Setting all elements to a scalar */
void operator = (const T t) {
T *data = Data();
auto lambda_set_values = [=] __host__ __device__ (int32_t i) -> void {
data[i] = t;
};
Eval(Context(), dim_, lambda_set_values);
}

Array1 operator[](const Array1<int32_t> &indexes) {
ContextPtr c = Context();
assert(c->IsCompatible(*indexes.Context()));
Expand All @@ -111,22 +134,50 @@ class Array1 {
return ans;
}

// constructor from CPU array (transfers to GPU if necessary)
Array1(ContextPtr ctx, const std::vector<T> &src);

Array1(const Array1 &other) = default;
private:
int32_t size_;
int32_t dim_;
int32_t byte_offset_;
RegionPtr region_; // Region that `data` is a part of. Device
// type is stored here. For an Array1 with
// zero size (e.g. created using empty
// constructor), will point to an empty
// Region.
RegionPtr region_; // Region that `data` is a part of. Device type is stored
// here. Will be NULL if Array1 was created with default
// constructor (invalid array!) but may still be non-NULL
// if dim_ == 0; this allows it to keep track of the
// context.

void Init(DeviceType d, int32_t size) {
// .. takes care of allocation etc.
}
};

// Could possibly introduce a debug mode to this that would do bounds checking.
template <typename T>
struct Array2Accessor {
T *data;
int32_t elem_stride0;
__host__ __device__ T &operator () (int32_t i, int32_t j) {
return data[i * elem_stride0 + j];
}
Array2Accessor(T *data, int32_t elem_stride0):
data(data), elem_stride0(elem_stride0) { }
__host__ __device__ Array2Accessor(const Array2Accessor &other) = default;
};

template <typename T>
struct ConstArray2Accessor {
const T *data;
int32_t elem_stride0;
__host__ __device__ T operator () (int32_t i, int32_t j) {
return data[i * elem_stride0 + j];
}
Array2Accessor(const T *data, int32_t elem_stride0):
data(data), elem_stride0(elem_stride0) { }
__host__ __device__ Array2Accessor(const Array2Accessor &other) = default;
};


/*
Array2 is a 2-dimensional array (== matrix), that is contiguous in the
2nd dimension, i.e. a row-major marix.
Expand Down Expand Up @@ -178,6 +229,16 @@ class Array2 {
byte_offset_);
}

// Note: array1 doesn't need an accessor because its Data() pointer functions
// as one already.
Array2Accessor<T> Accessor() {
return Array2Accessor<T>(Data(), elem_stride0_);
}

ConstArray2Accessor<T> Accessor() const {
return Array2Accessor<T>(Data(), elem_stride0_);
}

/* Construct from Tensor. Required to have 2 axes; will copy if the tensor
did not have stride on 2nd axis == sizeof(T)
@param [in] t Input tensor, must have 2 axes and dtype == T
Expand Down Expand Up @@ -205,6 +266,9 @@ class Array2 {
// Region.
};




} // namespace k2

#endif // K2_CSRC_CUDA_ARRAY_H_
14 changes: 7 additions & 7 deletions k2/csrc/cuda/compose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class MultiGraphDenseIntersect {
}

void FormatOutput(FsaVec *ofsa,
Array<int32_t> *arc_map_a,
Array<int32_t> *arc_map_b) {
Array1<int32_t> *arc_map_a,
Array1<int32_t> *arc_map_b) {

Context c_cpu = c_->CpuContext();
int32_t T = a_fsas.MaxSize1();
Expand Down Expand Up @@ -173,11 +173,11 @@ class MultiGraphDenseIntersect {


int32_t tot_arcs_pruned = oshape_pruned_.TotSize3();
arc_map_a = Array<int32_t>(c_, tot_arcs_pruned);
arc_map_b = Array<int32_t>(c_, tot_arcs_pruned);
arc_map_a = Array1<int32_t>(c_, tot_arcs_pruned);
arc_map_b = Array1<int32_t>(c_, tot_arcs_pruned);
int32_t *arc_map_a_data = arc_map_a.Data(),
*arc_map_b_data = arc_map_b.Data();
Array<Arc> arcs_out(c_, tot_arcs_pruned);
Array1<Arc> arcs_out(c_, tot_arcs_pruned);
Arc *arcs_out_data = arcs.Data();
const Arc *a_fsas_arcs = a_fsas_.values.Data();
int32_t b_fsas_num_cols = b_fsas_.scores.Dim1();
Expand Down Expand Up @@ -373,7 +373,7 @@ class MultiGraphDenseIntersect {
return num_arcs_x1x;
};
// `num_arcs` gives the num-arcs for each state in `states`.
Array<int32_t> num_arcs(c_, states.values.Dim(), num_arcs_lambda);
Array1<int32_t> num_arcs(c_, states.values.Dim(), num_arcs_lambda);

// initialize shape of array that will hold arcs leaving the active states.
// Its shape is [fsa_index][state][arc]; the top two levels are shared with
Expand Down Expand Up @@ -662,7 +662,7 @@ class MultiGraphDenseIntersect {
arc. Indexing is [frame_state_index][arc_index], where frame_state_index
and arc_index are respectively idx01 and ind2 w.r.t. frames_[t]->arcs. */
Ragged<int32_t> arc_backward_prob(cur_frame->arcs.RemoveAxis(0),
Array<int32_t>(c_, num_arcs));
Array1<int32_t>(c_, num_arcs));
int32_t *arc_backward_prob_data = arc_backward_prob.values.Data();


Expand Down
8 changes: 4 additions & 4 deletions k2/csrc/cuda/compose.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace k2 {

// Note: b is FsaVec<Arc>.
void Intersect(const DenseFsa &a, const FsaVec &b, Fsa *c,
Array<int32_t> *arc_map_a = nullptr,
Array<int32_t> *arc_map_b = nullptr);
Array1<int32_t> *arc_map_a = nullptr,
Array1<int32_t> *arc_map_b = nullptr);



Expand All @@ -29,8 +29,8 @@ void IntersectDensePruned(Array3<Arc> &a_fsas,
float beam,
int32_t max_states,
FsaVec *ofsa,
Array<int> *arc_map_a,
Array<int> *arc_map_b);
Array1<int> *arc_map_a,
Array1<int> *arc_map_b);

} // namespace k2

Expand Down
Loading

0 comments on commit 25d73ef

Please sign in to comment.