Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aux labels plus notes on Python interface #29

Merged
merged 7 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions k2/csrc/determinize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// k2/csrc/determinize.cc

// Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey [email protected], Haowen Qiu [email protected])

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/fsa_algo.h"

#include <utility>
#include <vector>

namespace k2 {


struct DetStateElement {
// Element of the doubly linked list whose start/end are
// members 'head' and 'tail' of DetState.
// We can trace back the `parent` links, which will take
// us backward along a path in the original FSA.
DetStateElement *parent = nullptr;
int32_t arc_index; // Index of most recent arc in path to the dest-state.
// This data-structure represents a path through the FSA,
// with this arc being the most recent arc on that path.
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
// (copied here for convenience).

double weight; // Weight from reference state to this state, along
// the path taken by following the 'parent' links
// (the path would have `seq_len` arcs in it).
// Note: by "this state" we mean the destination-state of
// the arc at `arc_index`.

// `prev` and `next` form the doubly linked list of DetStateElement
DetStateElement *prev = nullptr;
DetStateElement *next = nullptr;

// This comparator function compares the weights, but is careful in case of
// ties to ensure deterministic behavior.
bool operator < (const DetStateElement &other) const {
if (weight < other.weight) return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the return value if weight == other.weight ??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't realize I had included this code, it is not finished. I have to get back to this.

else if (weight > other.weight) return false;
// TODO.
}

};





/*
Conceptually a determinized state in weighted FSA determinization would normally
be a weighted subset of states in the input FSA, with the weights normalized
somehow (e.g. subtracting the sum of the weights).

Two determinized states are equal if the states and weights are the same. To
ensure differentiability, our assumption is that in general no two arcs in the
input FSA have identical weights. We argue that two determinized states can
always be represented as a base-state and a symbol sequence. Imagine that we
follow arcs with that symbol sequence from the base-state, and then in case we
reach the same states in the different ways we always select the best path
from the base-state. That process gives us a set of states and weights. We
argue that this representation is unique. (If not, it won't matter actually;
it will just give us an output that's less minimal than it could be).


*/
struct DetState {
// `base_state` is a state in the input FSA.
int32_t base_state;
// seq_len is the length of symbol sequence that we follow from state `base_state`.
// The sequence of symbols can be found by tracing back one of the DetStateElements
// in the doubly linked list (it doesn't matter which you pick, the result will be the
// same.
int32_t seq_len;

bool normalized { false };

DetState *parent; // Maybe not needed!

DetStateElement *head;
DetStateElement *tail;

double forward_backward_weight;

/*
Normalizes this DetState and sets forward_backward_weight.

By 'normalize' what we mean is the following:

- Remove duplicates.

If the DLL of DetStateElements contains duplicate elements (i.e.
elements whose paths end in the same state) it removes whichever has the
smallest weight. (Remember, a determinized state is, conceptually, a
weighted subset of elements; we are implementing determinization in a
tropical-like semiring where we take the best weight.

In case of ties on the weights, we carefully re-examine the paths to
make sure that the tie was not due to numerical roundoffi; and if it
was still a tie, we disambiguate using a lexical order on state
sequences. The reason it's important to have deterministic behavior in
case of ties on weights, is that a failure here could lead to
situations where we didn't advance the base state where we could,
leading the number of determinized states to be larger than it could
be.

- Advance the base state if possible. Each DetState can be represented
as a base state and a sequence of symbols from that base state, but
if some initial subsequence of that symbol sequence takes us to
a unique state then we say the DetState is not normalized. In that
case we need to advance the base state and reduced `seq_len`.
If this happens, then the arc sequence which takes us to the new
base state will be output to `leftover_arcs`. When this is done,
the 'weight' components of the DetStateElement members also need
to be adjusted to remove the weight contribution from those arcs.

The forward_backward_weight is the weight on the best path through the
output determinized FSA that will include this DetState. It will determine
the order of expansion of DetStates and also whether the states are
expanded at all (if the pruning beam `beam` is finite).
forward_backward_weight is the sum of the forward weight of the base state,
plus (the greatest over the DetStateElements, of its `weight` element,
plus the backward weight in the input FSA of the state that corresponds
to it).


worked outobtained from

*/
void Normalize(std::vector<int32_t> *leftover_arcs);
};


void DetState::Normalize(std::vector<int32_t> *input_arcs) {

}


class DetStateMap {
public:
/*
Outputs the output state-id corresponding to a specific DetState structure.
This does not store any pointers to the DetState or its contents, so
you can delete the DetState without affecting this object's ability to map
an equivalent DetState to the same state-id.

@param [in] a The DetState that we're looking up
@param [out] state_id The state-index in the output FSA
corresponding to this DetState (will
be freshly allocated if an equivalent of
this DetState did not already exist.
@return Returns true if this was a NEWLY CREATED state,
false otherwise.
*/
bool GetOutputState(const DetState &a, int32_t *state_id) {
std::pair<uint64_t, uint64_t> compact;
DetStateToCompact(a, &compact);
auto p = map_.insert({compact, cur_output_state));
bool inserted = p.second;
if (inserted) {
*state_id = cur_output_state_++;
return true;
} else {
*state_id = p.first->second;
return false;
}
}

int32_t size() const { return cur_output_state_; }

private:

int32_t cur_output_state_ { 0 };
std::unordered_map<std::pair<uint64_t, uint64_t>, int32_t, DetStateVectorHasher> map_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DetStateHasher?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, might do that. This code needs to be finished.


/* Turns DetState into a compact form of 128 bits. Technically there
could be collisions, which would be fatal for the algorithm, but this
is one of those lifetime-of-the-universe type of things (kind of like
the theoretical potential for git hash collision) that we ignore.

The normalized form

*/
void DetStateToCompact(const DetState &d,
std::pair<uint64_t, uint64_t> *vec) {
assert(d.normalized);

uint64_t a = d.base_state + 17489 * d.seq_len,
b = d.base_state * 103979 + d.seq_len;

// We choose an arbitrary DetStateElement (the first one in the list) to
// read the symbol sequence from; the symbol sequence will be the same no
// matter which element we choose to trace back.
DetStateElement *elem = d.head;
int32_t seq_len = d.seq_len;
for (int32_t i = 0; i < seq_len; ++i) {
a = elem->symbol + 102299 * a;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an error inside the for loop?

elem is never updated.


It should be ++i, not i++.

b = elem->symbol + 102983 * b;
elem = elem->parent
}
vec->first = a;
vec->second = b;
}

struct DetStateHasher {
size_t operator () (const std::pair<uint64_t, uint64_t> &p) const {
return p.first;
}
};



};



void DeterminizeMax(const WfsaWithFbWeights &a,
float beam,
Fsa *b,
std::vector<std::vector<int32_t> > *arc_map) {
// TODO: use glog stuff.
assert(IsValid(a) && IsEpsilonFree(a) && IsTopSortedAndAcyclic(a));
if (a.arc_indexes.empty()) {
b->Clear();
return;
}
float cutoff = a.backward_state_weights[0] - beam;
// TODO.

}


} // namespace k2
48 changes: 36 additions & 12 deletions k2/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

namespace k2 {

using Label = int32_t;
using StateId = int32_t;
using Weight = float;

enum {
kFinalSymbol = -1, // final-costs are represented as arcs with
// kFinalSymbol as their label, to the final
Expand All @@ -30,9 +26,9 @@ enum {
};

struct Arc {
StateId src_state;
StateId dest_state;
Label label; // 'label' as in a finite state acceptor.
int32_t src_state;
int32_t dest_state;
int32_t label; // 'label' as in a finite state acceptor.
// For FSTs, the other label will be present in the
// aux_label array. Which of the two represents the input
// vs. the output can be decided by the user; in general,
Expand Down Expand Up @@ -112,8 +108,8 @@ struct Fsa {
arc_indexes.push_back(index);
}

StateId NumStates() const {
return !arc_indexes.empty() ? (static_cast<StateId>(arc_indexes.size()) - 1)
int32_t NumStates() const {
return !arc_indexes.empty() ? (static_cast<int32_t>(arc_indexes.size()) - 1)
: 0;
}
};
Expand All @@ -134,7 +130,7 @@ struct Fsa {
weights[t,n].
*/
struct DenseFsa {
Weight *weights; // Would typically be a log-prob or unnormalized log-prob
float *weights; // Would typically be a log-prob or unnormalized log-prob
int32_t T; // The number of time steps == rows in the matrix `weights`;
// this FSA has T + 2 states, see explanation above.
int32_t num_symbols; // The number of symbols == columns in the matrix
Expand All @@ -148,15 +144,43 @@ struct DenseFsa {
CAUTION: we may later enforce that stride == num_symbols, in order to
be able to know the layout of a phantom matrix of arcs. (?)
*/
DenseFsa(Weight *data, int32_t T, int32_t num_symbols, int32_t stride);
DenseFsa(float *data, int32_t T, int32_t num_symbols, int32_t stride);
};

struct Fst {
Fsa core;
std::vector<int32_t> aux_label;
};

using StatePair = std::pair<StateId, StateId>;
/*
This demonstrates an interface for a deterministic FSA or FST; it's similar
to Kaldi's DeterministicOnDemandFst class. It can be used for things like
language models. Actually we'll template on types like this. There is no
need to actually inherit from this class. */
class DeterministicGenericFsa {
public:
int32_t Start();


bool LookupArc(int32_t cur_state,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there no const modifiers for Get ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is an interface for a possibly dynamic object.

int32_t label,
int32_t *arc_index);


float GetWeightForArc(int32_t arc_index);

int32_t Getint32_tForArc(int32_t arc_index);

int32_t GetPrevStateForArc(int32_t arc_index);

int32_t GetNextStateForArc(int32_t arc_index);

// Specific subclasses of this may have additional functions, e.g.
int32_t GetOlabelForArc(int32_t arc_index);

};


using FsaVec = std::vector<Fsa>;
using FstVec = std::vector<Fst>;
using DenseFsaVec = std::vector<DenseFsa>;
Expand Down
12 changes: 7 additions & 5 deletions k2/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ struct DfsState {
int32_t arc_end; // end of the arc index of the visiting node
};

using StatePair = std::pair<int32_t, int32_t>;

inline int32_t InsertIntersectionState(
const k2::StatePair &new_state, int32_t *state_index_c,
std::queue<k2::StatePair> *qstates,
std::unordered_map<k2::StatePair, int32_t, k2::PairHash> *state_pair_map) {
const StatePair &new_state, int32_t *state_index_c,
std::queue<StatePair> *qstates,
std::unordered_map<StatePair, int32_t, k2::PairHash> *state_pair_map) {
auto result = state_pair_map->insert({new_state, *state_index_c + 1});
if (result.second) {
// we have not visited `new_state` before.
Expand Down Expand Up @@ -411,8 +413,8 @@ void ArcSort(const Fsa &a, Fsa *b,
const auto arc_begin_iter = a.arcs.begin();
const auto index_begin_iter = indexes.begin();
// we will not process the final state as it has no arcs leaving it.
StateId final_state = a.NumStates() - 1;
for (StateId state = 0; state < final_state; ++state) {
int32_t final_state = a.NumStates() - 1;
for (int32_t state = 0; state < final_state; ++state) {
int32_t begin = a.arc_indexes[state];
// as non-empty fsa `a` contains at least two states,
// we can always access `state + 1` validly.
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ bool TopSort(const Fsa& a, Fsa* b, std::vector<int32_t>* state_map = nullptr);

*/
void Determinize(const Fsa &a, Fsa *b,
std::vector<std::vector<StateId>> *state_map);
std::vector<std::vector<int32_t>> *state_map);

} // namespace k2

Expand Down
9 changes: 4 additions & 5 deletions k2/csrc/fsa_renderer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ std::string GenerateEpilogue() { return "}"; }

using k2::Arc;
using k2::Fsa;
using k2::Label;
using k2::StateId;


std::string ProcessState(const Fsa &fsa, int32_t state) {
std::ostringstream os;
Expand All @@ -44,9 +43,9 @@ std::string ProcessState(const Fsa &fsa, int32_t state) {

for (; begin != end; ++begin) {
const auto &arc = fsa.arcs[begin];
StateId src = arc.src_state;
StateId dest = arc.dest_state;
Label label = arc.label;
int32_t src = arc.src_state;
int32_t dest = arc.dest_state;
int32_t label = arc.label;
os << " " << src << " -> " << dest << " [label = \"" << label
<< "\", fontsize = 14];"
<< "\n";
Expand Down
Loading