Skip to content

Commit

Permalink
feat(dictionary): predict word
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Mar 3, 2024
1 parent 8b7f6b7 commit 95cb5fe
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
19 changes: 12 additions & 7 deletions src/rime/dict/dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ bool compare_chunk_by_head_element(const Chunk& a, const Chunk& b) {
size_t match_extra_code(const table::Code* extra_code,
size_t depth,
const SyllableGraph& syll_graph,
size_t current_pos) {
size_t current_pos,
bool predict_word) {
if (!extra_code || depth >= extra_code->size)
return current_pos; // success
if (current_pos >= syll_graph.interpreted_length)
return 0; // failure (possibly success for completion in the future)
if (current_pos >= syll_graph.interpreted_length) {
return predict_word ? syll_graph.interpreted_length // word completion
: 0; // failure
}
auto index = syll_graph.indices.find(current_pos);
if (index == syll_graph.indices.end())
return 0;
Expand All @@ -74,8 +77,8 @@ size_t match_extra_code(const table::Code* extra_code,
return 0;
size_t best_match = 0;
for (const SpellingProperties* props : spellings->second) {
size_t match_end_pos =
match_extra_code(extra_code, depth + 1, syll_graph, props->end_pos);
size_t match_end_pos = match_extra_code(extra_code, depth + 1, syll_graph,
props->end_pos, predict_word);
if (!match_end_pos)
continue;
if (match_end_pos > best_match)
Expand Down Expand Up @@ -199,6 +202,7 @@ static void lookup_table(Table* table,
DictEntryCollector* collector,
const SyllableGraph& syllable_graph,
size_t start_pos,
bool predict_word,
double initial_credibility) {
TableQueryResult result;
if (!table->Query(syllable_graph, start_pos, &result)) {
Expand All @@ -212,7 +216,7 @@ static void lookup_table(Table* table,
if (a.extra_code()) {
do {
size_t actual_end_pos = dictionary::match_extra_code(
a.extra_code(), 0, syllable_graph, end_pos);
a.extra_code(), 0, syllable_graph, end_pos, predict_word);
if (actual_end_pos == 0)
continue;
(*collector)[actual_end_pos].AddChunk(
Expand All @@ -227,6 +231,7 @@ static void lookup_table(Table* table,

an<DictEntryCollector> Dictionary::Lookup(const SyllableGraph& syllable_graph,
size_t start_pos,
bool predict_word,
double initial_credibility) {
if (!loaded())
return nullptr;
Expand All @@ -235,7 +240,7 @@ an<DictEntryCollector> Dictionary::Lookup(const SyllableGraph& syllable_graph,
if (!table->IsOpen())
continue;
lookup_table(table.get(), collector.get(), syllable_graph, start_pos,
initial_credibility);
predict_word, initial_credibility);
}
if (collector->empty())
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions src/rime/dict/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Dictionary : public Class<Dictionary, const Ticket&> {

RIME_API an<DictEntryCollector> Lookup(const SyllableGraph& syllable_graph,
size_t start_pos,
bool predict_word = false,
double initial_credibility = 0.0);
// if predictive is true, do an expand search with limit,
// otherwise do an exact match.
Expand Down
11 changes: 8 additions & 3 deletions src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ class ScriptTranslation : public Translation {
Corrector* corrector,
Poet* poet,
const string& input,
size_t start)
size_t start,
size_t end_of_input)
: translator_(translator),
poet_(poet),
start_(start),
end_of_input_(end_of_input),
syllabifier_(
New<ScriptSyllabifier>(translator, corrector, input, start)),
enable_correction_(corrector) {
Expand All @@ -137,6 +139,7 @@ class ScriptTranslation : public Translation {
ScriptTranslator* translator_;
Poet* poet_;
size_t start_;
size_t end_of_input_;
an<ScriptSyllabifier> syllabifier_;

an<DictEntryCollector> phrase_;
Expand Down Expand Up @@ -189,9 +192,10 @@ an<Translation> ScriptTranslator::Query(const string& input,
bool enable_user_dict =
user_dict_ && user_dict_->loaded() && !IsUserDictDisabledFor(input);

size_t end_of_input = engine_->context()->input().length();
// the translator should survive translations it creates
auto result = New<ScriptTranslation>(this, corrector_.get(), poet_.get(),
input, segment.start);
input, segment.start, end_of_input);
if (!result || !result->Evaluate(
dict_.get(), enable_user_dict ? user_dict_.get() : NULL)) {
return nullptr;
Expand Down Expand Up @@ -343,8 +347,9 @@ string ScriptSyllabifier::GetOriginalSpelling(const Phrase& cand) const {
bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {
size_t consumed = syllabifier_->BuildSyllableGraph(*dict->prism());
const auto& syllable_graph = syllabifier_->syllable_graph();
bool predict_word = start_ + consumed == end_of_input_;

phrase_ = dict->Lookup(syllable_graph, 0);
phrase_ = dict->Lookup(syllable_graph, 0, predict_word);
if (user_dict) {
user_phrase_ = user_dict->Lookup(syllable_graph, 0);
}
Expand Down

0 comments on commit 95cb5fe

Please sign in to comment.