Skip to content

Commit

Permalink
Added support for split and merging
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellegedin committed Jul 27, 2016
1 parent 617c5af commit b04b96f
Show file tree
Hide file tree
Showing 12 changed files with 533 additions and 30 deletions.
94 changes: 65 additions & 29 deletions kuromoji-core/src/main/java/com/atilika/kuromoji/TokenizerBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.atilika.kuromoji.dict.UserDictionary;
import com.atilika.kuromoji.fst.FST;
import com.atilika.kuromoji.util.ResourceResolver;
import com.atilika.kuromoji.viterbi.MultiSearchMerger;
import com.atilika.kuromoji.viterbi.MultiSearchResult;
import com.atilika.kuromoji.viterbi.TokenFactory;
import com.atilika.kuromoji.viterbi.ViterbiBuilder;
Expand Down Expand Up @@ -117,6 +118,7 @@ public List<? extends TokenBase> tokenize(String text) {
}

public <T extends TokenBase> List<List<T>> multiTokenize(String text, int maxCount, int costSlack) {

return createMultiTokenList(text, maxCount, costSlack);
}

Expand Down Expand Up @@ -178,7 +180,65 @@ protected <T extends TokenBase> List<T> createTokenList(String text) {
* @return list of Token, not null
*/
protected <T extends TokenBase> List<List<T>> createMultiTokenList(String text, int maxCount, int costSlack) {
return createMultiTokenList(0, text, maxCount, costSlack);

if (!split) {
return convertMultiSearchResultToList(createMultiSearchResult(text, maxCount, costSlack));
}

List<Integer> splitPositions = getSplitPositions(text);

if (splitPositions.size() == 0) {
return convertMultiSearchResultToList(createMultiSearchResult(text, maxCount, costSlack));
}

List<MultiSearchResult> results = new ArrayList<>();
int offset = 0;

for (int position : splitPositions) {
results.add(createMultiSearchResult(text.substring(offset, position + 1), maxCount, costSlack));
offset = position + 1;
}

if (offset < text.length()) {
results.add(createMultiSearchResult(text.substring(offset), maxCount, costSlack));
}

System.out.println("Merging...");

MultiSearchMerger merger = new MultiSearchMerger(maxCount, costSlack);
MultiSearchResult mergedResult = merger.merge(results);

System.out.println("Done");

return convertMultiSearchResultToList(mergedResult);
}

private <T extends TokenBase> List<List<T>> convertMultiSearchResultToList(MultiSearchResult multiSearchResult) {
List<List<T>> result = new ArrayList<>();

List<List<ViterbiNode>> paths = multiSearchResult.getTokenizedResultsList();

for (List<ViterbiNode> path : paths) {
ArrayList<T> tokens = new ArrayList<>();
for (ViterbiNode node : path) {
int wordId = node.getWordId();
if (node.getType() == ViterbiNode.Type.KNOWN && wordId == -1) { // Do not include BOS/EOS
continue;
}
@SuppressWarnings("unchecked")
T token = (T) tokenFactory.createToken(
wordId,
node.getSurface(),
node.getType(),
node.getStartIndex(),
dictionaryMap.get(node.getType())
);
tokens.add(token);
}
result.add(tokens);
}

return result;
}

/**
Expand Down Expand Up @@ -289,40 +349,16 @@ private <T extends TokenBase> List<T> createTokenList(int offset, String text) {
/**
* Tokenize input sentence. Up to maxCount different paths of cost at most OPT + costSlack are returned ordered in ascending order by cost, where OPT is the optimal solution.
*
* @param offset offset of sentence in original input text
* @param text sentence to tokenize
* @param maxCount maximum number of paths
* @param costSlack maximum cost slack of a path
* @return list of Token
* @return instance of MultiSearchResult containing the tokenizations
*/
private <T extends TokenBase> List<List<T>> createMultiTokenList(int offset, String text, int maxCount, int costSlack) {
List<List<T>> result = new ArrayList<>();

private MultiSearchResult createMultiSearchResult(String text, int maxCount, int costSlack) {
System.out.println("Searching text of size: " + text.length());

This comment has been minimized.

Copy link
@cmoen

cmoen Jul 27, 2016

Member

Let's remove this println()

ViterbiLattice lattice = viterbiBuilder.build(text);
MultiSearchResult multiSearchResult = viterbiSearcher.searchMultiple(lattice, maxCount, costSlack);
List<List<ViterbiNode>> paths = multiSearchResult.getTokenizedResultsList();

for (List<ViterbiNode> path : paths) {
ArrayList<T> tokens = new ArrayList<>();
for (ViterbiNode node : path) {
int wordId = node.getWordId();
if (node.getType() == ViterbiNode.Type.KNOWN && wordId == -1) { // Do not include BOS/EOS
continue;
}
@SuppressWarnings("unchecked")
T token = (T) tokenFactory.createToken(
wordId,
node.getSurface(),
node.getType(),
offset + node.getStartIndex(),
dictionaryMap.get(node.getType())
);
tokens.add(token);
}
result.add(tokens);
}

return result;
return multiSearchResult;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package com.atilika.kuromoji.viterbi;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;

public class MultiSearchMerger {

private int baseCost;
private List<Integer> suffixCostLowerBounds;
private int maxCount;
private int costSlack;

public MultiSearchMerger(int maxCount, int costSlack) {
this.maxCount = maxCount;
this.costSlack = costSlack;
}

public MultiSearchResult merge(List<MultiSearchResult> results) {
if (results.isEmpty()) {
return new MultiSearchResult();
}

suffixCostLowerBounds = new ArrayList<>();
for (int i = 0; i < results.size(); i++) {
suffixCostLowerBounds.add(0);
}
suffixCostLowerBounds.add(suffixCostLowerBounds.size() - 1, results.get(results.size() - 1).getCost(0));
for (int i = results.size() - 2; i >= 0; i--) {
suffixCostLowerBounds.add(i, results.get(i).getCost(0) + suffixCostLowerBounds.get(i + 1));
}
baseCost = suffixCostLowerBounds.get(0);

MultiSearchResult ret = new MultiSearchResult();
List<MergeBuilder> builders = new ArrayList<>();
for (int i = 0; i < results.get(0).size(); i++) {
if (getCostLowerBound(results.get(0).getCost(i), 0) > baseCost + costSlack || i == maxCount) {
break;
}

MergeBuilder newBuilder = new MergeBuilder(results);
newBuilder.add(i);
builders.add(newBuilder);
}

for (int i = 1; i < results.size(); i++) {
builders = mergeStep(builders, results, i);
}

for (MergeBuilder builder : builders) {
ret.add(builder.buildList(), builder.getCost());
}

return ret;
}

private List<MergeBuilder> mergeStep(List<MergeBuilder> builders, List<MultiSearchResult> results, int currentIndex) {
MultiSearchResult nextResult = results.get(currentIndex);
PriorityQueue<MergePair> pairHeap = new PriorityQueue<>();
List<MergeBuilder> ret = new ArrayList<>();

if (builders.isEmpty() || nextResult.size() == 0) {
return ret;
}

pairHeap.add(new MergePair(0, 0, builders.get(0).getCost() + nextResult.getCost(0)));

Set<Integer> visited = new HashSet<>();

while (ret.size() < maxCount && pairHeap.size() > 0) {
MergePair top = pairHeap.poll();

if (getCostLowerBound(top.getCost(), currentIndex) > baseCost + costSlack) {
break;
}

int i = top.getLeftIndex(), j = top.getRightIndex();

MergeBuilder nextBuilder = new MergeBuilder(results, builders.get(i).getIndices());
nextBuilder.add(j);
ret.add(nextBuilder);

if (i + 1 < builders.size()) {
MergePair newMergePair = new MergePair(i + 1, j, builders.get(i + 1).getCost() + nextResult.getCost(j));
int positionValue = getPositionValue(i + 1, j);
if (!visited.contains(positionValue)) {
pairHeap.add(newMergePair);
visited.add(positionValue);
}
}
if (j + 1 < nextResult.size()) {
MergePair newMergePair = new MergePair(i, j + 1, builders.get(i).getCost() + nextResult.getCost(j + 1));
int positionValue = getPositionValue(i, j + 1);
if (!visited.contains(positionValue)) {
pairHeap.add(newMergePair);
visited.add(positionValue);
}
}
}

return ret;
}

private int getPositionValue(int i, int j) {
return (maxCount + 1) * i + j;
}

private int getCostLowerBound(int currentCost, int index) {
if (index + 1 < suffixCostLowerBounds.size()) {
return currentCost + suffixCostLowerBounds.get(index + 1);
}
return currentCost;
}

private class MergeBuilder implements Comparable<MergeBuilder> {
private int cost;
private List<Integer> indices;
private List<MultiSearchResult> results;

public MergeBuilder(List<MultiSearchResult> results) {
this.results = results;
cost = 0;
indices = new ArrayList<>();
}

public MergeBuilder(List<MultiSearchResult> results, List<Integer> indices) {
this(results);
for (Integer index : indices) {
add(index);
}
}

public List<ViterbiNode> buildList() {
List<ViterbiNode> ret = new ArrayList<>();
for (int i = 0; i < indices.size(); i++) {
ret.addAll(results.get(i).getTokenizedResult(indices.get(i)));
}
return ret;
}

public void add(int index) {
indices.add(index);
cost += results.get(indices.size() - 1).getCost(index);
}

public int getCost() {
return cost;
}

public List<Integer> getIndices() {
return indices;
}

public int compareTo(MergeBuilder o) {
return cost - o.cost;
}
}

private class MergePair implements Comparable<MergePair> {
private int leftIndex;
private int rightIndex;
private int cost;

public MergePair(int leftIndex, int rightIndex, int cost) {
this.leftIndex = leftIndex;
this.rightIndex = rightIndex;
this.cost = cost;
}

public int getLeftIndex() {
return leftIndex;
}

public int getRightIndex() {
return rightIndex;
}

public int getCost() {
return cost;
}

public int compareTo(MergePair o) {
return cost - o.getCost();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class ViterbiSearcher {

private final TokenizerBase.Mode mode;

private MultiSearcher multiSearcher;

public ViterbiSearcher(TokenizerBase.Mode mode,
ConnectionCosts costs,
UnknownDictionary unknownDictionary,
Expand All @@ -51,6 +53,7 @@ public ViterbiSearcher(TokenizerBase.Mode mode,
this.mode = mode;
this.costs = costs;
this.unknownDictionary = unknownDictionary;
multiSearcher = new MultiSearcher(costs, mode, this);
}

/**
Expand All @@ -76,7 +79,6 @@ public List<ViterbiNode> search(ViterbiLattice lattice) {
*/
public MultiSearchResult searchMultiple(ViterbiLattice lattice, int maxCount, int costSlack) {
calculatePathCosts(lattice);
MultiSearcher multiSearcher = new MultiSearcher(costs, mode, this);
MultiSearchResult result = multiSearcher.getShortestPaths(lattice, maxCount, costSlack);
return result;
}
Expand Down
10 changes: 10 additions & 0 deletions kuromoji-core/src/test/java/com/atilika/kuromoji/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ public static void assertCanTokenizeString(String input, TokenizerBase tokenizer
}
}

public static void assertCanMultiTokenizeString(String input, int maxCount, int costSlack, TokenizerBase tokenizer) {
List<List<TokenBase>> tokens = tokenizer.multiTokenize(input, maxCount, costSlack);

if (input.length() > 0) {
assertFalse(tokens.isEmpty());
} else {
assertTrue(tokens.isEmpty());
}
}

public static void assertTokenizedStreamEquals(InputStream tokenizedInput,
InputStream untokenizedInput,
TokenizerBase tokenizer) throws IOException {
Expand Down
Loading

0 comments on commit b04b96f

Please sign in to comment.