Fix context graph (#292)
This commit is contained in:
@@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) {
|
|||||||
auto context_graph = ContextGraph(contexts, 1);
|
auto context_graph = ContextGraph(contexts, 1);
|
||||||
|
|
||||||
auto queries = std::map<std::string, float>{
|
auto queries = std::map<std::string, float>{
|
||||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
|
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||||
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||||
|
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||||
|
|
||||||
for (const auto &iter : queries) {
|
for (const auto &iter : queries) {
|
||||||
float total_scores = 0;
|
float total_scores = 0;
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ void ContextGraph::Build(
|
|||||||
bool is_end = j == token_ids[i].size() - 1;
|
bool is_end = j == token_ids[i].size() - 1;
|
||||||
node->next[token] = std::make_unique<ContextState>(
|
node->next[token] = std::make_unique<ContextState>(
|
||||||
token, context_score_, node->node_score + context_score_,
|
token, context_score_, node->node_score + context_score_,
|
||||||
is_end ? 0 : node->local_node_score + context_score_, is_end);
|
is_end ? node->node_score + context_score_ : 0, is_end);
|
||||||
}
|
}
|
||||||
node = node->next[token].get();
|
node = node->next[token].get();
|
||||||
}
|
}
|
||||||
@@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
|||||||
if (1 == state->next.count(token)) {
|
if (1 == state->next.count(token)) {
|
||||||
node = state->next.at(token).get();
|
node = state->next.at(token).get();
|
||||||
score = node->token_score;
|
score = node->token_score;
|
||||||
if (state->is_end) score += state->node_score;
|
|
||||||
} else {
|
} else {
|
||||||
node = state->fail;
|
node = state->fail;
|
||||||
while (0 == node->next.count(token)) {
|
while (0 == node->next.count(token)) {
|
||||||
@@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
|||||||
if (1 == node->next.count(token)) {
|
if (1 == node->next.count(token)) {
|
||||||
node = node->next.at(token).get();
|
node = node->next.at(token).get();
|
||||||
}
|
}
|
||||||
score = node->node_score - state->local_node_score;
|
score = node->node_score - state->node_score;
|
||||||
}
|
}
|
||||||
SHERPA_ONNX_CHECK(nullptr != node);
|
SHERPA_ONNX_CHECK(nullptr != node);
|
||||||
float matched_score = 0;
|
return std::make_pair(score + node->output_score, node);
|
||||||
auto output = node->output;
|
|
||||||
while (nullptr != output) {
|
|
||||||
matched_score += output->node_score;
|
|
||||||
output = output->output;
|
|
||||||
}
|
|
||||||
return std::make_pair(score + matched_score, node);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||||
const ContextState *state) const {
|
const ContextState *state) const {
|
||||||
float score = -state->node_score;
|
float score = -state->node_score;
|
||||||
if (state->is_end) {
|
|
||||||
score = 0;
|
|
||||||
}
|
|
||||||
return std::make_pair(score, root_.get());
|
return std::make_pair(score, root_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
kv.second->output = output;
|
kv.second->output = output;
|
||||||
|
kv.second->output_score += output == nullptr ? 0 : output->output_score;
|
||||||
node_queue.push(kv.second.get());
|
node_queue.push(kv.second.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ struct ContextState {
|
|||||||
int32_t token;
|
int32_t token;
|
||||||
float token_score;
|
float token_score;
|
||||||
float node_score;
|
float node_score;
|
||||||
float local_node_score;
|
float output_score;
|
||||||
bool is_end;
|
bool is_end;
|
||||||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||||
const ContextState *fail = nullptr;
|
const ContextState *fail = nullptr;
|
||||||
@@ -29,11 +29,11 @@ struct ContextState {
|
|||||||
|
|
||||||
ContextState() = default;
|
ContextState() = default;
|
||||||
ContextState(int32_t token, float token_score, float node_score,
|
ContextState(int32_t token, float token_score, float node_score,
|
||||||
float local_node_score, bool is_end)
|
float output_score, bool is_end)
|
||||||
: token(token),
|
: token(token),
|
||||||
token_score(token_score),
|
token_score(token_score),
|
||||||
node_score(node_score),
|
node_score(node_score),
|
||||||
local_node_score(local_node_score),
|
output_score(output_score),
|
||||||
is_end(is_end) {}
|
is_end(is_end) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user