diff --git a/sherpa-onnx/csrc/context-graph-test.cc b/sherpa-onnx/csrc/context-graph-test.cc index 97d03443..0e7e9b5c 100644 --- a/sherpa-onnx/csrc/context-graph-test.cc +++ b/sherpa-onnx/csrc/context-graph-test.cc @@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) { auto context_graph = ContextGraph(contexts, 1); auto queries = std::map{ - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, - {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; for (const auto &iter : queries) { float total_scores = 0; diff --git a/sherpa-onnx/csrc/context-graph.cc b/sherpa-onnx/csrc/context-graph.cc index bc3a1e3e..05ca04f0 100644 --- a/sherpa-onnx/csrc/context-graph.cc +++ b/sherpa-onnx/csrc/context-graph.cc @@ -19,7 +19,7 @@ void ContextGraph::Build( bool is_end = j == token_ids[i].size() - 1; node->next[token] = std::make_unique( token, context_score_, node->node_score + context_score_, - is_end ? 0 : node->local_node_score + context_score_, is_end); + is_end ? node->node_score + context_score_ : 0, is_end); } node = node->next[token].get(); } @@ -34,7 +34,6 @@ std::pair ContextGraph::ForwardOneStep( if (1 == state->next.count(token)) { node = state->next.at(token).get(); score = node->token_score; - if (state->is_end) score += state->node_score; } else { node = state->fail; while (0 == node->next.count(token)) { @@ -44,24 +43,15 @@ std::pair ContextGraph::ForwardOneStep( if (1 == node->next.count(token)) { node = node->next.at(token).get(); } - score = node->node_score - state->local_node_score; + score = node->node_score - state->node_score; } SHERPA_ONNX_CHECK(nullptr != node); - float matched_score = 0; - auto output = node->output; - while (nullptr != output) { - matched_score += output->node_score; - output = output->output; - } - return std::make_pair(score + matched_score, node); + return std::make_pair(score + node->output_score, node); } std::pair ContextGraph::Finalize( const ContextState *state) const { float score = -state->node_score; - if (state->is_end) { - score = 0; - } return std::make_pair(score, root_.get()); } @@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const { } } kv.second->output = output; + kv.second->output_score += output == nullptr ? 0 : output->output_score; node_queue.push(kv.second.get()); } } diff --git a/sherpa-onnx/csrc/context-graph.h b/sherpa-onnx/csrc/context-graph.h index db16ce66..57010689 100644 --- a/sherpa-onnx/csrc/context-graph.h +++ b/sherpa-onnx/csrc/context-graph.h @@ -21,7 +21,7 @@ struct ContextState { int32_t token; float token_score; float node_score; - float local_node_score; + float output_score; bool is_end; std::unordered_map> next; const ContextState *fail = nullptr; @@ -29,11 +29,11 @@ struct ContextState { ContextState() = default; ContextState(int32_t token, float token_score, float node_score, - float local_node_score, bool is_end) + float output_score, bool is_end) : token(token), token_score(token_score), node_score(node_score), - local_node_score(local_node_score), + output_score(output_score), is_end(is_end) {} };