Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
@@ -12,6 +12,7 @@ endif()
|
||||
|
||||
set(sources
|
||||
cat.cc
|
||||
context-graph.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
file-utils.cc
|
||||
@@ -248,6 +249,7 @@ endif()
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
set(sherpa_onnx_test_srcs
|
||||
cat-test.cc
|
||||
context-graph-test.cc
|
||||
packed-sequence-test.cc
|
||||
pad-sequence-test.cc
|
||||
slice-test.cc
|
||||
|
||||
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
@@ -0,0 +1,43 @@
|
||||
// sherpa-onnx/csrc/context-graph-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
std::vector<std::string> contexts_str(
|
||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||
}
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
|
||||
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
|
||||
for (const auto &iter : queries) {
|
||||
float total_scores = 0;
|
||||
auto state = context_graph.Root();
|
||||
for (auto q : iter.first) {
|
||||
auto res = context_graph.ForwardOneStep(state, q);
|
||||
total_scores += res.first;
|
||||
state = res.second;
|
||||
}
|
||||
auto res = context_graph.Finalize(state);
|
||||
EXPECT_EQ(res.second->token, -1);
|
||||
total_scores += res.first;
|
||||
EXPECT_EQ(total_scores, iter.second);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
105
sherpa-onnx/csrc/context-graph.cc
Normal file
105
sherpa-onnx/csrc/context-graph.cc
Normal file
@@ -0,0 +1,105 @@
|
||||
// sherpa-onnx/csrc/context-graph.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
void ContextGraph::Build(
|
||||
const std::vector<std::vector<int32_t>> &token_ids) const {
|
||||
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||
auto node = root_.get();
|
||||
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||
int32_t token = token_ids[i][j];
|
||||
if (0 == node->next.count(token)) {
|
||||
bool is_end = j == token_ids[i].size() - 1;
|
||||
node->next[token] = std::make_unique<ContextState>(
|
||||
token, context_score_, node->node_score + context_score_,
|
||||
is_end ? 0 : node->local_node_score + context_score_, is_end);
|
||||
}
|
||||
node = node->next[token].get();
|
||||
}
|
||||
}
|
||||
FillFailOutput();
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
const ContextState *state, int32_t token) const {
|
||||
const ContextState *node;
|
||||
float score;
|
||||
if (1 == state->next.count(token)) {
|
||||
node = state->next.at(token).get();
|
||||
score = node->token_score;
|
||||
if (state->is_end) score += state->node_score;
|
||||
} else {
|
||||
node = state->fail;
|
||||
while (0 == node->next.count(token)) {
|
||||
node = node->fail;
|
||||
if (-1 == node->token) break; // root
|
||||
}
|
||||
if (1 == node->next.count(token)) {
|
||||
node = node->next.at(token).get();
|
||||
}
|
||||
score = node->node_score - state->local_node_score;
|
||||
}
|
||||
SHERPA_ONNX_CHECK(nullptr != node);
|
||||
float matched_score = 0;
|
||||
auto output = node->output;
|
||||
while (nullptr != output) {
|
||||
matched_score += output->node_score;
|
||||
output = output->output;
|
||||
}
|
||||
return std::make_pair(score + matched_score, node);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
const ContextState *state) const {
|
||||
float score = -state->node_score;
|
||||
if (state->is_end) {
|
||||
score = 0;
|
||||
}
|
||||
return std::make_pair(score, root_.get());
|
||||
}
|
||||
|
||||
void ContextGraph::FillFailOutput() const {
|
||||
std::queue<const ContextState *> node_queue;
|
||||
for (auto &kv : root_->next) {
|
||||
kv.second->fail = root_.get();
|
||||
node_queue.push(kv.second.get());
|
||||
}
|
||||
while (!node_queue.empty()) {
|
||||
auto current_node = node_queue.front();
|
||||
node_queue.pop();
|
||||
for (auto &kv : current_node->next) {
|
||||
auto fail = current_node->fail;
|
||||
if (1 == fail->next.count(kv.first)) {
|
||||
fail = fail->next.at(kv.first).get();
|
||||
} else {
|
||||
fail = fail->fail;
|
||||
while (0 == fail->next.count(kv.first)) {
|
||||
fail = fail->fail;
|
||||
if (-1 == fail->token) break;
|
||||
}
|
||||
if (1 == fail->next.count(kv.first))
|
||||
fail = fail->next.at(kv.first).get();
|
||||
}
|
||||
kv.second->fail = fail;
|
||||
// fill the output arc
|
||||
auto output = fail;
|
||||
while (!output->is_end) {
|
||||
output = output->fail;
|
||||
if (-1 == output->token) {
|
||||
output = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
kv.second->output = output;
|
||||
node_queue.push(kv.second.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace sherpa_onnx
|
||||
66
sherpa-onnx/csrc/context-graph.h
Normal file
66
sherpa-onnx/csrc/context-graph.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// sherpa-onnx/csrc/context-graph.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class ContextGraph;
|
||||
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
|
||||
|
||||
struct ContextState {
|
||||
int32_t token;
|
||||
float token_score;
|
||||
float node_score;
|
||||
float local_node_score;
|
||||
bool is_end;
|
||||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||
const ContextState *fail = nullptr;
|
||||
const ContextState *output = nullptr;
|
||||
|
||||
ContextState() = default;
|
||||
ContextState(int32_t token, float token_score, float node_score,
|
||||
float local_node_score, bool is_end)
|
||||
: token(token),
|
||||
token_score(token_score),
|
||||
node_score(node_score),
|
||||
local_node_score(local_node_score),
|
||||
is_end(is_end) {}
|
||||
};
|
||||
|
||||
class ContextGraph {
|
||||
public:
|
||||
ContextGraph() = default;
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score)
|
||||
: context_score_(context_score) {
|
||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
|
||||
root_->fail = root_.get();
|
||||
Build(token_ids);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id) const;
|
||||
std::pair<float, const ContextState *> Finalize(
|
||||
const ContextState *state) const;
|
||||
|
||||
const ContextState *Root() const { return root_.get(); }
|
||||
|
||||
private:
|
||||
float context_score_;
|
||||
std::unique_ptr<ContextState> root_;
|
||||
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
|
||||
void FillFailOutput() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -39,11 +40,18 @@ struct Hypothesis {
|
||||
// the nn lm states
|
||||
std::vector<CopyableOrtValue> nn_lm_states;
|
||||
|
||||
const ContextState *context_state;
|
||||
|
||||
// TODO(fangjun): Make it configurable
|
||||
// the minimum of tokens in a chunk for streaming RNN LM
|
||||
int32_t lm_rescore_min_chunk = 2; // a const
|
||||
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||
: ys(ys), log_prob(log_prob) {}
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
|
||||
const ContextState *context_state = nullptr)
|
||||
: ys(ys), log_prob(log_prob), context_state(context_state) {}
|
||||
|
||||
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
|
||||
@@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
|
||||
|
||||
virtual ~OfflineRecognizerImpl() = default;
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(context_list, config_.context_score);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
-23.025850929940457f);
|
||||
|
||||
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
auto results =
|
||||
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
|
||||
@@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register("context-score", &context_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
}
|
||||
|
||||
bool OfflineRecognizerConfig::Validate() const {
|
||||
@@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "max_active_paths=" << max_active_paths << ")";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
|
||||
|
||||
OfflineRecognizer::~OfflineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
int32_t max_active_paths = 4;
|
||||
float context_score = 1.5;
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
@@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
|
||||
const OfflineModelConfig &model_config,
|
||||
const OfflineLMConfig &lm_config,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
int32_t max_active_paths, float context_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
max_active_paths(max_active_paths),
|
||||
context_score(context_score) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
@@ -58,6 +60,10 @@ class OfflineRecognizer {
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||
|
||||
/** Decode a single stream
|
||||
*
|
||||
* @param s The stream to decode.
|
||||
|
||||
@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
|
||||
|
||||
class OfflineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: config_(config), context_graph_(context_graph) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
|
||||
|
||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||
int32_t feature_dim) const {
|
||||
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
|
||||
impl_->SetResult(r);
|
||||
}
|
||||
|
||||
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
const OfflineRecognitionResult &OfflineStream::GetResult() const {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
@@ -96,6 +98,9 @@ class OfflineStream {
|
||||
/** Get the recognition result of this stream */
|
||||
const OfflineRecognitionResult &GetResult() const;
|
||||
|
||||
/** Get the ContextGraph of this stream */
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -16,7 +16,9 @@ namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
||||
Ort::Value encoder_out_length) {
|
||||
Ort::Value encoder_out_length,
|
||||
OfflineStream **ss /*= nullptr*/,
|
||||
int32_t n /*= 0*/) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||
: model_(model) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
|
||||
@@ -8,7 +8,9 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||
#include "sherpa-onnx/csrc/slice.h"
|
||||
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) {
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
int32_t batch_size =
|
||||
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||
|
||||
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
std::vector<int64_t> blanks(context_size, 0);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
|
||||
std::deque<Hypotheses> finalized;
|
||||
std::vector<Hypotheses> cur(batch_size, blank_hyp);
|
||||
std::vector<Hypotheses> cur;
|
||||
std::vector<Hypothesis> prev;
|
||||
|
||||
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
|
||||
|
||||
for (int32_t i = 0; i < batch_size; ++i) {
|
||||
const ContextState *context_state;
|
||||
if (ss != nullptr) {
|
||||
context_graphs[i] =
|
||||
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
|
||||
if (context_graphs[i] != nullptr)
|
||||
context_state = context_graphs[i]->Root();
|
||||
}
|
||||
Hypotheses blank_hyp({{blanks, 0, context_state}});
|
||||
cur.emplace_back(std::move(blank_hyp));
|
||||
}
|
||||
|
||||
int32_t start = 0;
|
||||
int32_t t = 0;
|
||||
for (auto n : packed_encoder_out.batch_sizes) {
|
||||
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
int32_t new_token = k % vocab_size;
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
if (new_token != 0) {
|
||||
// blank id is fixed to 0
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t);
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
}
|
||||
}
|
||||
|
||||
new_hyp.log_prob = p_logprob[k];
|
||||
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
cur.push_back(std::move(h));
|
||||
}
|
||||
|
||||
// Finalize context biasing matching..
|
||||
for (int32_t i = 0; i < cur.size(); ++i) {
|
||||
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->Finalize(iter->second.context_state);
|
||||
iter->second.log_prob += context_res.first;
|
||||
iter->second.context_state = context_res.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lm_) {
|
||||
// use LM for rescoring
|
||||
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
||||
|
||||
@@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
||||
lm_scale_(lm_scale) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
|
||||
Reference in New Issue
Block a user