Skip to content

Commit

Permalink
Merge pull request #6 from csukuangfj/fangjun-fsa-util
Browse files Browse the repository at this point in the history
implement GetEnteringArcs.
  • Loading branch information
danpovey authored Apr 24, 2020
2 parents 474e352 + 0b25697 commit f75a615
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 88 deletions.
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Language: Cpp
Cpp11BracedListStyle: true
Standard: Cpp11
DerivePointerAlignment: false
PointerAlignment: Right
---
20 changes: 20 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ add_library(properties properties.cc)
target_include_directories(properties PUBLIC ${CMAKE_SOURCE_DIR})
target_compile_features(properties PUBLIC cxx_std_11)

add_library(fsa_util fsa_util.cc)
target_include_directories(fsa_util PUBLIC ${CMAKE_SOURCE_DIR})
target_compile_features(fsa_util PUBLIC cxx_std_11)

add_executable(properties_test properties_test.cc)

target_link_libraries(properties_test
Expand All @@ -15,3 +19,19 @@ add_test(NAME Test.properties_test
COMMAND
$<TARGET_FILE:properties_test>
)

add_executable(fsa_util_test fsa_util_test.cc)

target_link_libraries(fsa_util_test
PRIVATE
fsa_util
gtest
gtest_main
)

add_test(NAME Test.fsa_util_test
COMMAND
$<TARGET_FILE:fsa_util_test>
)

# TODO(fangjun): write some helper functions to create targets.
11 changes: 5 additions & 6 deletions k2/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define K2_CSRC_FSA_H_

#include <cstdint>
#include <utility>
#include <vector>

namespace k2 {
Expand Down Expand Up @@ -48,7 +49,7 @@ struct Arc {
};

struct ArcLabelCompare {
bool operator()(const Arc& a, const Arc& b) const {
bool operator()(const Arc &a, const Arc &b) const {
return a.label < b.label;
}
};
Expand Down Expand Up @@ -92,11 +93,9 @@ struct Fsa {
more state). For 0 <= t < T, we have an arc with symbol n on it for
each 0 <= n < N, from state t to state t+1, with weight equal to
weights[t,n].
*/
struct DenseFsa {
Weight* weights; // Would typically be a log-prob or unnormalized log-prob
Weight *weights; // Would typically be a log-prob or unnormalized log-prob
int32_t T; // The number of time steps == rows in the matrix `weights`;
// this FSA has T + 2 states, see explanation above.
int32_t num_symbols; // The number of symbols == columns in the matrix
Expand All @@ -110,7 +109,7 @@ struct DenseFsa {
CAUTION: we may later enforce that stride == num_symbols, in order to
be able to know the layout of a phantom matrix of arcs. (?)
*/
DenseFsa(Weight* data, int32_t T, int32_t num_symbols, int32_t stride);
DenseFsa(Weight *data, int32_t T, int32_t num_symbols, int32_t stride);
};

/*
Expand All @@ -120,7 +119,7 @@ struct DenseFsa {
*/
struct VecOfVec {
std::vector<Range> ranges;
std::vector<int32_t> values;
std::vector<std::pair<Label, StateId>> values;
};

struct Fst {
Expand Down
52 changes: 26 additions & 26 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace k2 {
so the output will be topologically sorted if the input
was.
*/
void ConnectCore(const Fsa& fsa, std::vector<int32>* state_map);
void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map);

/*
Removes states that are not accessible (from the start state) or are not
Expand All @@ -44,13 +44,13 @@ void ConnectCore(const Fsa& fsa, std::vector<int32>* state_map);
Notes:
- If `a` admitted a topological sorting, b will be topologically
sorted. TODO: maybe just leave in the same order as a??
sorted. TODO(Dan): maybe just leave in the same order as a??
- If `a` was deterministic, `b` will be deterministic; same for
epsilon free, obviously.
- `b` will be arc-sorted (arcs sorted by label)
- `b` will (obviously) be connected
*/
void Connect(const Fsa& a, Fsa* b, std::vector<int32>* arc_map = nullptr);
void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map = nullptr);

/**
Output an Fsa that is equivalent to the input but which has no epsilons.
Expand All @@ -61,10 +61,10 @@ void Connect(const Fsa& a, Fsa* b, std::vector<int32>* arc_map = nullptr);
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
the arc-indexes in `a` that contributed to that arc
(e.g. its cost would be a sum of their costs).
TODO: make it a VecOfVec, maybe?
TODO(Dan): make it a VecOfVec, maybe?
*/
void RmEpsilons(const Fsa& a, Fsa* b,
std::vector<std::vector>* arc_map = nullptr);
void RmEpsilons(const Fsa &a, Fsa *b,
std::vector<std::vector> *arc_map = nullptr);

/**
Pruned version of RmEpsilons, which also uses a pruning beam.
Expand All @@ -77,12 +77,12 @@ void RmEpsilons(const Fsa& a, Fsa* b,
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
the arc-indexes in `a` that contributed to that arc
(e.g. its cost would be a sum of their costs).
TODO: make it a VecOfVec, maybe?
TODO(Dan): make it a VecOfVec, maybe?
*/
void RmEpsilonsPruned(const Fsa& a, const float* a_state_forward_costs,
const float* a_state_backward_costs,
const float* a_arc_costs, float cutoff, Fsa* b,
std::vector<std::vector>* arc_map = nullptr);
void RmEpsilonsPruned(const Fsa &a, const float *a_state_forward_costs,
const float *a_state_backward_costs,
const float *a_arc_costs, float cutoff, Fsa *b,
std::vector<std::vector> *arc_map = nullptr);

/*
Compute the intersection of two FSAs; this is the equivalent of composition
Expand All @@ -104,16 +104,16 @@ void RmEpsilonsPruned(const Fsa& a, const float* a_state_forward_costs,
size c->arcs.size(), saying for each arc in
`c` what the source arc in `b` was.
*/
void Intersect(const Fsa& a, const Fsa& b, Fsa* c,
std::vector<int32>* arc_map_a = nullptr,
std::vector<int32>* arc_map_b = nullptr);
void Intersect(const Fsa &a, const Fsa &b, Fsa *c,
std::vector<int32_t> *arc_map_a = nullptr,
std::vector<int32_t> *arc_map_b = nullptr);

/*
Version of Intersect where `a` is dense?
*/
void Intersect(const DenseFsa& a, const Fsa& b, Fsa* c,
std::vector<int32>* arc_map_a = nullptr,
std::vector<int32>* arc_map_b = nullptr);
void Intersect(const DenseFsa &a, const Fsa &b, Fsa *c,
std::vector<int32_t> *arc_map_a = nullptr,
std::vector<int32_t> *arc_map_b = nullptr);

/*
Version of Intersect where `a` is dense, pruned with pruning beam `beam`.
Expand All @@ -124,9 +124,9 @@ void Intersect(const DenseFsa& a, const Fsa& b, Fsa* c,
This is the same as time-synchronous Viterbi beam pruning.
*/
void IntersectPruned(const DenseFsa& a, const Fsa& b, float beam, Fsa* c,
std::vector<int32>* arc_map_a = nullptr,
std::vector<int32>* arc_map_b = nullptr);
void IntersectPruned(const DenseFsa &a, const Fsa &b, float beam, Fsa *c,
std::vector<int32_t> *arc_map_a = nullptr,
std::vector<int32_t> *arc_map_b = nullptr);

/**
Intersection of two weighted FSA's: the same as Intersect(), but it prunes
Expand All @@ -152,13 +152,13 @@ void IntersectPruned(const DenseFsa& a, const Fsa& b, float beam, Fsa* c,
@param [out] state_map_b Maps from arc-index in c to the corresponding
arc-index in b
*/
void IntersectPruned2(const Fsa& a, const float* a_cost, const Fsa& b,
const float* b_cost, float cutoff, Fsa* c,
std::vector<int32>* state_map_a,
std::vector<int32>* state_map_b);
void IntersectPruned2(const Fsa &a, const float *a_cost, const Fsa &b,
const float *b_cost, float cutoff, Fsa *c,
std::vector<int32_t> *state_map_a,
std::vector<int32_t> *state_map_b);

void RandomPath(const Fsa& a, const float* a_cost, Fsa* b,
std::vector<int32>* state_map = nullptr);
void RandomPath(const Fsa &a, const float *a_cost, Fsa *b,
std::vector<int32_t> *state_map = nullptr);

} // namespace k2

Expand Down
43 changes: 43 additions & 0 deletions k2/csrc/fsa_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// k2/csrc/fsa_util.cc

// Copyright (c) 2020 Fangjun Kuang ([email protected])

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/fsa_util.h"

#include <utility>
#include <vector>

namespace k2 {

void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs) {
// CHECK(CheckProperties(fsa, KTopSorted));

int num_states = fsa.NumStates();
std::vector<std::vector<std::pair<Label, StateId>>> vec(num_states);
int num_arcs = 0;
for (const auto &arc : fsa.arcs) {
auto src_state = arc.src_state;
auto dest_state = arc.dest_state;
auto label = arc.label;
vec[dest_state].emplace_back(label, src_state);
++num_arcs;
}

auto &ranges = entering_arcs->ranges;
auto &values = entering_arcs->values;
ranges.reserve(num_states);
values.reserve(num_arcs);

int32_t start = 0;
int32_t end = 0;
for (const auto &label_state : vec) {
values.insert(values.end(), label_state.begin(), label_state.end());
start = end;
end += static_cast<int32_t>(label_state.size());
ranges.push_back({start, end});
}
}

} // namespace k2
2 changes: 1 addition & 1 deletion k2/csrc/fsa_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace k2 {
Requires that `fsa` be valid and top-sorted, i.e.
CheckProperties(fsa, KTopSorted) == true.
*/
void GetEnteringArcs(const Fsa& fsa, VecOfVec* entering_arcs);
void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs);

} // namespace k2

Expand Down
68 changes: 68 additions & 0 deletions k2/csrc/fsa_util_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// k2/csrc/fsa_util_test.cc

// Copyright (c) 2020 Fangjun Kuang ([email protected])

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/fsa_util.h"

#include <utility>
#include <vector>

#include "gtest/gtest.h"

namespace k2 {

TEST(FsaUtil, GetEnteringArcs) {
std::vector<Arc> arcs = {
{0, 1, 2}, {0, 2, 1}, {1, 2, 0}, {1, 3, 5}, {2, 3, 6},
};
std::vector<Range> leaving_arcs = {
{0, 2}, {2, 4}, {4, 5}, {0, 0}, // the last state has no entering arcs
};

Fsa fsa;
fsa.leaving_arcs = std::move(leaving_arcs);
fsa.arcs = std::move(arcs);

VecOfVec entering_arcs;
GetEnteringArcs(fsa, &entering_arcs);

const auto &ranges = entering_arcs.ranges;
const auto &values = entering_arcs.values;
EXPECT_EQ(ranges.size(), 4u); // there are 4 states
EXPECT_EQ(values.size(), 5u); // there are 5 arcs

// state 0, no entering arcs
EXPECT_EQ(ranges[0].begin, ranges[0].end);

// state 1 has one entering arc from state 0 with label 2
EXPECT_EQ(ranges[1].begin, 0);
EXPECT_EQ(ranges[1].end, 1);
EXPECT_EQ(values[0].first, 2); // label is 2
EXPECT_EQ(values[0].second, 0); // state is 0

// state 2 has two entering arcs
// the first one: from state 0 with label 1
// the second one: from state 1 with label 0
EXPECT_EQ(ranges[2].begin, 1);
EXPECT_EQ(ranges[2].end, 3);
EXPECT_EQ(values[1].first, 1); // label is 1
EXPECT_EQ(values[1].second, 0); // state is 0

EXPECT_EQ(values[2].first, 0); // label is 0
EXPECT_EQ(values[2].second, 1); // state is 1

// state 3 has two entering arcs
// the first one: from state 1 with label 5
// the second one: from state 2 with label 6
EXPECT_EQ(ranges[3].begin, 3);
EXPECT_EQ(ranges[3].end, 5);
EXPECT_EQ(values[3].first, 5); // label is 5
EXPECT_EQ(values[3].second, 1); // state is 1

EXPECT_EQ(values[4].first, 6); // label is 6
EXPECT_EQ(values[4].second, 2); // state is 2
}

} // namespace k2
Loading

0 comments on commit f75a615

Please sign in to comment.