Fix context graph (#292)

This commit is contained in:
Wei Kang
2023-08-28 19:39:22 +08:00
committed by GitHub
parent 49ec7e8f57
commit 2b0152d2a2
3 changed files with 10 additions and 18 deletions

View File

@@ -19,7 +19,7 @@ void ContextGraph::Build(
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? 0 : node->local_node_score + context_score_, is_end);
is_end ? node->node_score + context_score_ : 0, is_end);
}
node = node->next[token].get();
}
@@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
if (state->is_end) score += state->node_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
@@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->local_node_score;
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
float matched_score = 0;
auto output = node->output;
while (nullptr != output) {
matched_score += output->node_score;
output = output->output;
}
return std::make_pair(score + matched_score, node);
return std::make_pair(score + node->output_score, node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
if (state->is_end) {
score = 0;
}
return std::make_pair(score, root_.get());
}
@@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const {
}
}
kv.second->output = output;
kv.second->output_score += output == nullptr ? 0 : output->output_score;
node_queue.push(kv.second.get());
}
}