Fix context graph (#292)
This commit is contained in:
@@ -19,7 +19,7 @@ void ContextGraph::Build(
|
||||
bool is_end = j == token_ids[i].size() - 1;
|
||||
node->next[token] = std::make_unique<ContextState>(
|
||||
token, context_score_, node->node_score + context_score_,
|
||||
is_end ? 0 : node->local_node_score + context_score_, is_end);
|
||||
is_end ? node->node_score + context_score_ : 0, is_end);
|
||||
}
|
||||
node = node->next[token].get();
|
||||
}
|
||||
@@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
if (1 == state->next.count(token)) {
|
||||
node = state->next.at(token).get();
|
||||
score = node->token_score;
|
||||
if (state->is_end) score += state->node_score;
|
||||
} else {
|
||||
node = state->fail;
|
||||
while (0 == node->next.count(token)) {
|
||||
@@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
if (1 == node->next.count(token)) {
|
||||
node = node->next.at(token).get();
|
||||
}
|
||||
score = node->node_score - state->local_node_score;
|
||||
score = node->node_score - state->node_score;
|
||||
}
|
||||
SHERPA_ONNX_CHECK(nullptr != node);
|
||||
float matched_score = 0;
|
||||
auto output = node->output;
|
||||
while (nullptr != output) {
|
||||
matched_score += output->node_score;
|
||||
output = output->output;
|
||||
}
|
||||
return std::make_pair(score + matched_score, node);
|
||||
return std::make_pair(score + node->output_score, node);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
const ContextState *state) const {
|
||||
float score = -state->node_score;
|
||||
if (state->is_end) {
|
||||
score = 0;
|
||||
}
|
||||
return std::make_pair(score, root_.get());
|
||||
}
|
||||
|
||||
@@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const {
|
||||
}
|
||||
}
|
||||
kv.second->output = output;
|
||||
kv.second->output_score += output == nullptr ? 0 : output->output_score;
|
||||
node_queue.push(kv.second.get());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user