Skip to content

Commit

Permalink
Merged PR 12959: minor fixes from my old ONNX code
Browse files Browse the repository at this point in the history
These are minor comments/fixes I found when doing my ONNX prototype, would be good to get them out of the way
  • Loading branch information
frankseide committed May 16, 2020
1 parent 1d2a137 commit 9842e7b
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 28 deletions.
4 changes: 4 additions & 0 deletions src/common/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ void saveItemsNpz(const std::string& fileName, const std::vector<Item>& items) {
type = cnpy::map_type(typeid(double));
else if(item.type == Type::int8)
type = cnpy::map_type(typeid(char));
else if(item.type == Type::int32)
type = cnpy::map_type(typeid(int32_t));
else if (item.type == Type::uint32)
type = cnpy::map_type(typeid(uint32_t));
else
ABORT("Other types not supported yet");

Expand Down
10 changes: 4 additions & 6 deletions src/common/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,13 @@ struct Shape {
inline int& dim(int i) {
if(i >= 0) {
ABORT_IF(i >= (int)size(),
"Index {} is out of bounds, shape has {} dimension",
i,
size());
"Index {} is out of bounds, shape {} has {} dimension",
i, std::string(*this), size());
return shape_[i];
} else {
ABORT_IF((int)size() + i < 0,
"Negative index {} is out of bounds, shape has {} dimension",
i,
size());
"Negative index {} is out of bounds, shape {} has {} dimension",
i, std::string(*this), size());
return shape_[size() + i];
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/data/corpus_base.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class CorpusBatch : public Batch {
size_t wordsTrg() const override { return subBatches_.back()->batchWords(); };

/**
* @brief The width of the target mini-batch. Num words + padded?
* @brief The target width (=max length) of the mini-batch.
*/
size_t widthTrg() const override { return subBatches_.back()->batchWidth(); };

Expand Down
42 changes: 24 additions & 18 deletions src/graph/node_operators_binary.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ class DotNodeOp : public NaryNodeOp {

auto shapeB = b->shape();
if(transB) {
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]); // @TODO: why not use negative indices?
shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}

Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match");
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}

Expand Down Expand Up @@ -187,7 +187,7 @@ class AffineNodeOp : public NaryNodeOp {
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match");
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}

Expand Down Expand Up @@ -351,7 +351,7 @@ class DotBatchedNodeOp : public NaryNodeOp {
Shape outShape = shapeA;
outShape.set(-1, shapeB[-1]);
ABORT_IF(shapeA[-1] != shapeB[-2],
"Batched matrix product requires inner dimensions to match");
"Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}

Expand Down Expand Up @@ -674,12 +674,12 @@ struct GatherNodeOp : public NaryNodeOp {

NodeOps forwardOps() override {
return {NodeOp(
Select(val_, child(0)->val(), child(1)->val(), axis_))};
Select(val_, child(0)->val(), child(1)->val(), axis_))};
}

NodeOps backwardOps() override {
return {NodeOp(
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
}

Shape newShape(Expr a, int axis, Expr indices) {
Expand Down Expand Up @@ -722,6 +722,7 @@ struct GatherNodeOp : public NaryNodeOp {
return true;
}

private:
int axis_;
};

Expand Down Expand Up @@ -817,7 +818,7 @@ struct MultNodeOp : public ElementBinaryNodeOp {
NodeOp(Add(_1 * _2, child(1)->grad(), adj_, child(0)->val()))};
}

const std::string type() override { return "×"; }
const std::string type() override { return "*"; }
};

struct DivNodeOp : public ElementBinaryNodeOp {
Expand All @@ -842,7 +843,7 @@ struct DivNodeOp : public ElementBinaryNodeOp {
child(1)->val()))};
}

const std::string type() override { return "÷"; }
const std::string type() override { return "/"; }
};

// struct PowNodeOp : public ElementBinaryNodeOp {
Expand Down Expand Up @@ -1047,19 +1048,19 @@ struct ConcatenateNodeOp : public NaryNodeOp {
ABORT_IF(nodes.empty(), "No child nodes given");

Shape shape = nodes[0]->shape();
ax_ = shape.axis(ax);
axis_ = shape.axis(ax);

int sum = 0;
auto checkShape = shape;
for(auto child : nodes) {
checkShape.set(ax_, child->shape()[ax_]); // don't abort on different sizes on axis dim.
checkShape.set(axis_, child->shape()[axis_]); // don't abort on different sizes on axis dim.
ABORT_IF(checkShape != child->shape(),
"Child shapes {} and {} cannot be concatenated along axis {}",
shape, child->shape(), ax);

sum += child->shape()[ax_];
sum += child->shape()[axis_];
}
shape.set(ax_, sum);
shape.set(axis_, sum);

return shape;
}
Expand All @@ -1068,7 +1069,7 @@ struct ConcatenateNodeOp : public NaryNodeOp {
std::vector<Tensor> concatenees;
for(size_t i = 0; i < children_.size(); ++i)
concatenees.push_back(child(i)->val());
Concatenate(val_, concatenees, ax_);
Concatenate(val_, concatenees, axis_);
}

void backward() override {
Expand All @@ -1078,12 +1079,12 @@ struct ConcatenateNodeOp : public NaryNodeOp {
childPtr->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly
deconcatenees.push_back(childPtr->grad());
}
Deconcatenate(deconcatenees, adj_, ax_);
Deconcatenate(deconcatenees, adj_, axis_);
}

virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, ax_);
util::hash_combine(seed, axis_);
return seed;
}

Expand All @@ -1093,20 +1094,24 @@ struct ConcatenateNodeOp : public NaryNodeOp {
auto cnode = std::dynamic_pointer_cast<ConcatenateNodeOp>(node);
if(!cnode)
return false;
if(ax_ != cnode->ax_)
if(axis_ != cnode->axis_)
return false;
return true;
}

const std::string type() override { return "concat"; }

int ax_;
private:
int axis_;
};

// layer norm along last axis
struct LayerNormalizationOp : public NaryNodeOp {
public:
LayerNormalizationOp(const std::vector<Expr>& nodes, float eps = 1e-9)
: NaryNodeOp(nodes), eps_(eps) {}
: NaryNodeOp(nodes), eps_(eps) {
// @TODO: dimension check
}

NodeOps forwardOps() override {
return {NodeOp(
Expand All @@ -1117,6 +1122,7 @@ struct LayerNormalizationOp : public NaryNodeOp {
eps_))};
}

// @BUGBUG: backward has not been tested for broadcasting gamma/beta
NodeOps backwardOps() override {
return {NodeOp(
LayerNormalizationGrad(
Expand Down
3 changes: 2 additions & 1 deletion src/graph/node_operators_unary.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ struct NegNodeOp : public UnaryNodeOp {
return {NodeOp(Add(-_1, child(0)->grad(), adj_))};
}

const std::string type() override { return "-"; }
const std::string type() override { return "negate"; }
};

struct TransposeNodeOp : public UnaryNodeOp {
Expand Down Expand Up @@ -1002,6 +1002,7 @@ struct ShiftNodeOp : public UnaryNodeOp {
return true;
}

private:
Shape shift_; // shift offsets in each dimension
float padValue_; // what value to shift in
};
Expand Down
4 changes: 2 additions & 2 deletions src/models/decoder.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class DecoderBase : public EncoderDecoderLayerBase {

ABORT_IF(shortlist_, "How did a shortlist make it into training?");

auto yShifted = shift(y, {1, 0, 0});
auto yDelayed = shift(y, {1, 0, 0}); // insert zero at front; first word gets predicted from a target embedding of 0

state->setTargetHistoryEmbeddings(yShifted);
state->setTargetHistoryEmbeddings(yDelayed);
state->setTargetMask(yMask);

const Words& data = subBatch->data();
Expand Down

0 comments on commit 9842e7b

Please sign in to comment.