Skip to content

Commit

Permalink
Tensor ops implementation (k2-fsa#126)
Browse files Browse the repository at this point in the history
* Progress on ragged array implementation

* Progress on ragged array implementation

* Fix some syntax problems

* More ragged array progress

* Move to Fangjun's logger, clean up; ragged array progress

* Code to generate random ragged matrices, and print them.

* Changing Tensor to pimpl, and other progress

* Some more implementation of ragged, tensor.

* Implement more things...
  • Loading branch information
danpovey authored Sep 7, 2020
1 parent ffd21a4 commit 8d979dc
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 65 deletions.
6 changes: 4 additions & 2 deletions k2/csrc/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
namespace k2 {

DtypeTraits g_dtype_traits_array[] = {
{kFloatBase, 4}, {kFloatBase, 8}, {kIntBase, 1}, {kIntBase, 4},
{kIntBase, 8}, {kUintBase, 4}, {kUintBase, 8}};
{kFloatBase, 4, "float"}, {kFloatBase, 8, "double"}, {kIntBase, 1, "int8"},
{kIntBase, 2, "int16"}, {kIntBase, 4, "int32"},
{kIntBase, 8, "int64"}, {kUintBase, 4, "uint32"}, {kUintBase, 8, "uint64"}};

const Dtype DtypeOf<float>::dtype;
const Dtype DtypeOf<double>::dtype;
const Dtype DtypeOf<int8_t>::dtype;
const Dtype DtypeOf<int16_t>::dtype;
const Dtype DtypeOf<int32_t>::dtype;
const Dtype DtypeOf<int64_t>::dtype;
const Dtype DtypeOf<uint32_t>::dtype;
Expand Down
40 changes: 37 additions & 3 deletions k2/csrc/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define K2_CSRC_DTYPE_H_

#include <cstdint>
#include "k2/csrc/log.h"

namespace k2 {

Expand All @@ -28,12 +29,13 @@ class DtypeTraits {
int NumBytes() { return num_bytes_; }
BaseType GetBaseType() { return static_cast<BaseType>(base_type_); }

DtypeTraits(BaseType base_type, int num_bytes, int num_scalars = 1,
int misc = 0)
DtypeTraits(BaseType base_type, int num_bytes, const char *name,
int num_scalars = 1, int misc = 0)
: base_type_(static_cast<char>(base_type)),
num_scalars_(num_scalars),
misc_(misc),
num_bytes_(num_bytes) {}
num_bytes_(num_bytes),
name_(name) {}

private:
// We may add more
Expand All @@ -46,6 +48,7 @@ class DtypeTraits {
// scalar element is given by bytes_per_elem / num_scalars;
// we do it this way so that the stride in bytes is easily
// extractable.
const char *name_; // name, e.g. "float", "int8", "int32"
};

// We initialize this in dtype.cc
Expand All @@ -56,6 +59,7 @@ enum Dtype {
kFloatDtype,
kDoubleDtype,
kInt8Dtype,
kInt16Dtype,
kInt32Dtype,
kInt64Dtype,
kUint32Dtype,
Expand Down Expand Up @@ -84,6 +88,11 @@ struct DtypeOf<int8_t> {
static const Dtype dtype = kInt8Dtype;
};

template <>
struct DtypeOf<int16_t> {
static const Dtype dtype = kInt16Dtype;
};

template <>
struct DtypeOf<int32_t> {
static const Dtype dtype = kInt32Dtype;
Expand All @@ -104,5 +113,30 @@ struct DtypeOf<uint64_t> {
static const Dtype dtype = kUint64Dtype;
};

#define FOR_DTYPES(


/*
Evaluates Expr for TypeName being all dtypes. E.g.
FOR_ALL_DTYPES(t.GetDtype(), T, SomeFuncCall<T>(a,b,c..));
*/
#define FOR_ALL_DTYPES(DtypeValue, TypeName, Expr) \
do { switch (DtypeValue) { \
case kFloatDtype: { using TypeName = float; Expr; break; } \
case kDoubleDtype: { using TypeName = double; Expr; break; } \
case kInt8Dtype: { using TypeName = int8_t; Expr; break; } \
case kInt16Dtype: { using TypeName = int16_t; Expr; break; } \
case kInt32Dtype: { using TypeName = int32_t; Expr; break; } \
case kInt64Dtype: { using TypeName = int64_t; Expr; break; } \
case kUint32Dtype: { using TypeName = uint32_t; Expr; break; }\
case kUint64Dtype: { using TypeName = uint64_t; Expr; break; } \
default: K2_FATAL << "Dtype " << TratsOf(Dtype) \
<< " not covered in switch statement. p not supported for this type?"; \
} while(0)





} // namespace k2
#endif // K2_CSRC_DTYPE_H_
3 changes: 1 addition & 2 deletions k2/csrc/ragged.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ class RaggedShape {

// Convert to possibly different context.
RaggedShape<T> To(ContextPtr ctx);

private:
// TODO: could probably do away with the std::vector and have a max size and a
// fixed length array (more efficient)
Expand Down Expand Up @@ -241,7 +240,7 @@ class RaggedShapeIndexIterator {
and these are just the shapes of arrays..).
See also the version of Stack for class Ragged.
*/
RaggedShape Stack(int32_t src_size, const RaggedShape **src, int32_t axis);
RaggedShape Stack(int32_t axis, int32_t src_size, const RaggedShape **src);

/*
Insert a new axis at position `axis`, with 0 <= axis <= src.NumAxes(), for
Expand Down
21 changes: 10 additions & 11 deletions k2/csrc/ragged_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,25 @@
namespace k2 {

template <typename T>
Ragged<T> Stack(int32_t axis, int32_t src_size, const Ragged<T> *src) {
K2_CHECK_GT(src_size, 0); // can later relax this, maybe
std::vector<const RaggedShape *> src_shapes(src_size);
std::vector<const Array1<T> *> src_values(src_size);
for (int32_t i = 0; i < src_size; i++) {
Ragged<T> Stack(int32_t num_srcs, const Ragged<T> *src, int32_t axis) {
K2_CHECK_GT(num_srcs, 0); // can later relax this, maybe
std::vector<const RaggedShape *> src_shapes(num_srcs);
std::vector<const Array1<T> *> src_values(num_srcs);
for (int32_t i = 0; i < num_srcs; i++) {
src_shapes[i] = &(src[i]->shape);
src_values[i] = &(src[i]->values);
}
// TODO.

RaggedShape ans_shape = Stack(num_srcs, src_shapes, axis);
Array1<T> ans_values;
if (axis == 0) {
// return Ragged<T>(Stack(axis, src_size, &(src_shapes[0])),
// Append(src_size, src_values));
values = Append(num_srcs, rsc_values);
} else {
assert(0); // Have to figure out whether it makes sense to
// support this case here.
K2_LOG(FATAL) << "Axis != 0 not currently supported in Stack().";
}
}



// Recursive function that prints (part of) a ragged shape.
// 0 <= begin_pos <= end_pos < shape.TotSize(axis).
template <typename T>
Expand Down
32 changes: 16 additions & 16 deletions k2/csrc/tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@

namespace k2 {

Shape::Shape(const std::vector<int32_t> &dims) : ndim_(dims.size()) {
CHECK_LT(ndim_, kMaxDim);
Shape::Shape(const std::vector<int32_t> &dims) : num_axes_(dims.size()) {
CHECK_LT(num_axes_, kMaxDim);

std::copy(dims.begin(), dims.end(), dims_);

// compute strides_
if (ndim_ > 0) {
strides_[ndim_ - 1] = 1;
if (num_axes_ > 0) {
strides_[num_axes_ - 1] = 1;
}
for (int32_t i = ndim_ - 2; i >= 0; --i) {
for (int32_t i = num_axes_ - 2; i >= 0; --i) {
strides_[i] = strides_[i + 1] * dims_[i + 1];
}

Expand All @@ -39,9 +39,9 @@ Shape::Shape(const std::vector<int32_t> &dims) : ndim_(dims.size()) {

Shape::Shape(const std::vector<int32_t> &dims,
const std::vector<int32_t> strides)
: ndim_(dims.size()) {
CHECK_LT(ndim_, kMaxDim);
CHECK_EQ(strides.size(), ndim_);
: num_axes_(dims.size()) {
CHECK_LT(num_axes_, kMaxDim);
CHECK_EQ(strides.size(), num_axes_);
std::copy(dims.begin(), dims.end(), dims_);
std::copy(strides.begin(), strides.end(), strides_);
num_element_ = ComputeNumElement();
Expand All @@ -50,30 +50,30 @@ Shape::Shape(const std::vector<int32_t> &dims,
}

int32_t Shape::ComputeNumElement() {
if (ndim_ == 0) {
if (num_axes_ == 0) {
return 0;
}
int32_t elements = 1;
for (int32_t i = 0; i < ndim_; ++i) {
for (int32_t i = 0; i < num_axes_; ++i) {
elements *= dims_[i];
}
return elements;
}

int32_t Shape::ComputeStorageSize() {
if (ndim_ == 0) {
if (num_axes_ == 0) {
return 0;
}
int32_t size = 1;
for (int32_t i = 0; i < ndim_; ++i) {
for (int32_t i = 0; i < num_axes_; ++i) {
size += (dims_[i] - 1) * strides_[i];
}
return size;
}

bool Shape::CheckContiguous() {
int32_t z = 1;
for (int32_t i = ndim_ - 1; i >= 0; --i) {
for (int32_t i = num_axes_ - 1; i >= 0; --i) {
CHECK_GE(strides_[i], z);
if (dims_[i] != 1) {
if (strides_[i] != z) return false;
Expand Down Expand Up @@ -102,11 +102,11 @@ Tensor::Tensor(Dtype type, const Shape &shape, RegionPtr region,
}

TensorPtr Tensor::Index(int32_t axis, int32_t index) const {
CHECK_LT(axis, shape_.Ndim());
CHECK_LT(axis, shape_.NumAxes());
CHECK_LT(index, shape_.Dim(axis));
std::vector<int32_t> dims(shape_.Dims(), shape_.Dims() + shape_.Ndim());
std::vector<int32_t> dims(shape_.Dims(), shape_.Dims() + shape_.NumAxes());
std::vector<int32_t> strides(shape_.Strides(),
shape_.Strides() + shape_.Ndim());
shape_.Strides() + shape_.NumAxes());
dims.erase(dims.begin() + axis);
strides.erase(strides.begin() + axis);
Shape shape(dims, strides);
Expand Down
46 changes: 38 additions & 8 deletions k2/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@
namespace k2 {
class Shape {
public:
int32_t Ndim() const { return ndim_; }
int32_t NumAxes() const { return num_axes_; }

const int32_t *Dims() const { return dims_; }

const int32_t *Strides() const { return strides_; }

int32_t Dim(int32_t i) const {
CHECK_LT(static_cast<uint32_t>(i), static_cast<uint32_t>(ndim_));
CHECK_LT(static_cast<uint32_t>(i), static_cast<uint32_t>(num_axes_));
return dims_[i];
}

int32_t Stride(int32_t i) const {
CHECK_LT(static_cast<uint32_t>(i), static_cast<uint32_t>(ndim_));
CHECK_LT(static_cast<uint32_t>(i), static_cast<uint32_t>(num_axes_));
return strides_[i];
}

Expand All @@ -47,7 +47,7 @@ class Shape {
// Returns true if the two shapes have the same dims (but not necessarily strides).
bool SameDims(Shape &other);

Shape() : ndim_(0), num_element_(0), is_contiguous_(true) {}
Shape() : num_axes_(0), num_element_(0), is_contiguous_(true) {}

explicit Shape(const std::vector<int32_t> &dims);

Expand All @@ -59,12 +59,12 @@ class Shape {
private:
static const int32_t kMaxDim = 4; // Will increase this as needed

int32_t ndim_; // Must be >= 0
int32_t num_axes_; // Must be >= 0
int32_t num_element_;
int32_t storage_size_;
bool is_contiguous_;

// elements of dims_ and strides_ >= ndim_ are currently not set;
// elements of dims_ and strides_ >= num_axes_ are currently not set;
// in future we may change this.
int32_t dims_[kMaxDim];
int32_t strides_[kMaxDim]; // Strides in elements
Expand Down Expand Up @@ -130,7 +130,7 @@ class Tensor {

// Return the result of indexing one of the axes, which will result in a
// Tensor with one fewer axis.
TensorPtr Index(int32_t axis, int32_t index) const;
Tensor Index(int32_t axis, int32_t index) const;

Dtype GetDtype() const { return impl_->dtype; }
const Shape &GetShape() const { return impl_->shape; }
Expand All @@ -139,7 +139,37 @@ class Tensor {

// Forward some funtions from the shape. Will forward more later.
inline bool SameDim(const Tensor &other) const { return other->impl_.shape.SameDim(shape); }
inline bool Ndim() const { return impl_->shape.Ndim(); }
inline bool NumAxes() const { return impl_->shape.NumAxes(); }
inline int32_t Dim(int32_t i) { return impl_->shape.Dim(i); }


/*
Convert to possibly-different context, may require CPU/GPU transfer.
The returned value may share the same underlying `data` memory as *this.
This should work even for tensors with empty data.
If dim_ == 0 and region_ is NULL, this will return a direct copy of *this
(i.e. with region_ also NULL)
If dim == 0 and region_ is non-NULL, it will return a copy of *this with an
empty region with the supplied context (if different from current region's
context).
Note: the answer will always be contiguous, i.e. there is a possibility that
it will have a different memory layout than the input. [Internally it will
call `Contiguous()`.
*/
Tensor To(ContextPtr ctx);


Dtype GetDtype() const { return impl_->dtype; }
const Shape &GetShape() const { return impl_->shape; }
int32_t ByteOffset() const { return impl_->bytes_offset; }
std::shared_ptr<Region> &GetRegion() { return impl_->data; }

// Forward some funtions from the shape. Will forward more later.
inline bool SameDim(const Tensor &other) const { return other->impl_.shape.SameDim(shape); }
inline bool NumAxes() const { return impl_->shape.NumAxes(); }
inline int32_t Dim(int32_t i) { return impl_->shape.Dim(i); }


Expand Down
Loading

0 comments on commit 8d979dc

Please sign in to comment.