Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
This commit is contained in:
Wei Kang
2023-06-16 14:26:36 +08:00
committed by GitHub
parent 1a1b9fd236
commit 8562711252
23 changed files with 515 additions and 29 deletions

View File

@@ -54,7 +54,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
shell: bash shell: bash
run: | run: |
python3 -m pip install --upgrade pip numpy python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96
- name: Install sherpa-onnx - name: Install sherpa-onnx
shell: bash shell: bash

View File

@@ -43,9 +43,10 @@ import argparse
import time import time
import wave import wave
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import List, Tuple
import numpy as np import numpy as np
import sentencepiece as spm
import sherpa_onnx import sherpa_onnx
@@ -60,6 +61,47 @@ def get_args():
help="Path to tokens.txt", help="Path to tokens.txt",
) )
parser.add_argument(
"--bpe-model",
type=str,
default="",
help="""
Path to bpe.model,
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
""",
)
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
""",
)
parser.add_argument(
"--context-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
""",
)
parser.add_argument( parser.add_argument(
"--encoder", "--encoder",
default="", default="",
@@ -153,6 +195,24 @@ def assert_file_exists(filename: str):
) )
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
""" """
Args: Args:
@@ -182,10 +242,17 @@ def main():
args = get_args() args = get_args()
assert_file_exists(args.tokens) assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads assert args.num_threads > 0, args.num_threads
contexts_list = []
if args.encoder: if args.encoder:
assert len(args.paraformer) == 0, args.paraformer assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc assert len(args.nemo_ctc) == 0, args.nemo_ctc
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
assert_file_exists(args.decoder) assert_file_exists(args.decoder)
assert_file_exists(args.joiner) assert_file_exists(args.joiner)
@@ -199,6 +266,7 @@ def main():
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
feature_dim=args.feature_dim, feature_dim=args.feature_dim,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
context_score=args.context_score,
debug=args.debug, debug=args.debug,
) )
elif args.paraformer: elif args.paraformer:
@@ -238,8 +306,12 @@ def main():
samples, sample_rate = read_wave(wave_filename) samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate duration = len(samples) / sample_rate
total_duration += duration total_duration += duration
if contexts_list:
s = recognizer.create_stream() assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples) s.accept_waveform(sample_rate, samples)
streams.append(s) streams.append(s)

View File

@@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
install_requires = [ install_requires = [
"numpy", "numpy",
"sentencepiece==0.1.96",
] ]

View File

@@ -12,6 +12,7 @@ endif()
set(sources set(sources
cat.cc cat.cc
context-graph.cc
endpoint.cc endpoint.cc
features.cc features.cc
file-utils.cc file-utils.cc
@@ -248,6 +249,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TESTS) if(SHERPA_ONNX_ENABLE_TESTS)
set(sherpa_onnx_test_srcs set(sherpa_onnx_test_srcs
cat-test.cc cat-test.cc
context-graph-test.cc
packed-sequence-test.cc packed-sequence-test.cc
pad-sequence-test.cc pad-sequence-test.cc
slice-test.cc slice-test.cc

View File

@@ -0,0 +1,43 @@
// sherpa-onnx/csrc/context-graph-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <map>
#include <string>
#include <vector>
#include "gtest/gtest.h"
namespace sherpa_onnx {
TEST(ContextGraph, TestBasic) {
std::vector<std::string> contexts_str(
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < contexts_str.size(); ++i) {
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
}
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
for (const auto &iter : queries) {
float total_scores = 0;
auto state = context_graph.Root();
for (auto q : iter.first) {
auto res = context_graph.ForwardOneStep(state, q);
total_scores += res.first;
state = res.second;
}
auto res = context_graph.Finalize(state);
EXPECT_EQ(res.second->token, -1);
total_scores += res.first;
EXPECT_EQ(total_scores, iter.second);
}
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,105 @@
// sherpa-onnx/csrc/context-graph.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <cassert>
#include <queue>
#include <utility>
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? 0 : node->local_node_score + context_score_, is_end);
}
node = node->next[token].get();
}
}
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
if (state->is_end) score += state->node_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
node = node->fail;
if (-1 == node->token) break; // root
}
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->local_node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
float matched_score = 0;
auto output = node->output;
while (nullptr != output) {
matched_score += output->node_score;
output = output->output;
}
return std::make_pair(score + matched_score, node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
if (state->is_end) {
score = 0;
}
return std::make_pair(score, root_.get());
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
kv.second->fail = root_.get();
node_queue.push(kv.second.get());
}
while (!node_queue.empty()) {
auto current_node = node_queue.front();
node_queue.pop();
for (auto &kv : current_node->next) {
auto fail = current_node->fail;
if (1 == fail->next.count(kv.first)) {
fail = fail->next.at(kv.first).get();
} else {
fail = fail->fail;
while (0 == fail->next.count(kv.first)) {
fail = fail->fail;
if (-1 == fail->token) break;
}
if (1 == fail->next.count(kv.first))
fail = fail->next.at(kv.first).get();
}
kv.second->fail = fail;
// fill the output arc
auto output = fail;
while (!output->is_end) {
output = output->fail;
if (-1 == output->token) {
output = nullptr;
break;
}
}
kv.second->output = output;
node_queue.push(kv.second.get());
}
}
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,66 @@
// sherpa-onnx/csrc/context-graph.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
namespace sherpa_onnx {
class ContextGraph;
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
struct ContextState {
int32_t token;
float token_score;
float node_score;
float local_node_score;
bool is_end;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;
ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float local_node_score, bool is_end)
: token(token),
token_score(token_score),
node_score(node_score),
local_node_score(local_node_score),
is_end(is_end) {}
};
class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score)
: context_score_(context_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
root_->fail = root_.get();
Build(token_ids);
}
std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;
const ContextState *Root() const { return root_.get(); }
private:
float context_score_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void FillFailOutput() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_

View File

@@ -14,6 +14,7 @@
#include <vector> #include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
@@ -39,11 +40,18 @@ struct Hypothesis {
// the nn lm states // the nn lm states
std::vector<CopyableOrtValue> nn_lm_states; std::vector<CopyableOrtValue> nn_lm_states;
const ContextState *context_state;
// TODO(fangjun): Make it configurable
// the minimum of tokens in a chunk for streaming RNN LM
int32_t lm_rescore_min_chunk = 2; // a const
int32_t num_trailing_blanks = 0; int32_t num_trailing_blanks = 0;
Hypothesis() = default; Hypothesis() = default;
Hypothesis(const std::vector<int64_t> &ys, double log_prob) Hypothesis(const std::vector<int64_t> &ys, double log_prob,
: ys(ys), log_prob(log_prob) {} const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}
double TotalLogProb() const { return log_prob + lm_log_prob; } double TotalLogProb() const { return log_prob + lm_log_prob; }

View File

@@ -6,7 +6,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
#include <memory> #include <memory>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h" #include "sherpa-onnx/csrc/offline-stream.h"
@@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
virtual ~OfflineRecognizerImpl() = default; virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;

View File

@@ -10,6 +10,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
} }
} }
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(context_list, config_.context_score);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}
std::unique_ptr<OfflineStream> CreateStream() const override { std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config); return std::make_unique<OfflineStream>(config_.feat_config);
} }
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
-23.025850929940457f); -23.025850929940457f);
auto t = model_->RunEncoder(std::move(x), std::move(x_length)); auto t = model_->RunEncoder(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); auto results =
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
int32_t frame_shift_ms = 10; int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) { for (int32_t i = 0; i != n; ++i) {

View File

@@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths, po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search"); "Used only when decoding_method is modified_beam_search");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
} }
bool OfflineRecognizerConfig::Validate() const { bool OfflineRecognizerConfig::Validate() const {
@@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "model_config=" << model_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", "; os << "lm_config=" << lm_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", "; os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ")"; os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ")";
return os.str(); return os.str();
} }
@@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
OfflineRecognizer::~OfflineRecognizer() = default; OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
return impl_->CreateStream(); return impl_->CreateStream();
} }

View File

@@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
std::string decoding_method = "greedy_search"; std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4; int32_t max_active_paths = 4;
float context_score = 1.5;
// only greedy_search is implemented // only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search // TODO(fangjun): Implement modified_beam_search
@@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config, const OfflineModelConfig &model_config,
const OfflineLMConfig &lm_config, const OfflineLMConfig &lm_config,
const std::string &decoding_method, const std::string &decoding_method,
int32_t max_active_paths) int32_t max_active_paths, float context_score)
: feat_config(feat_config), : feat_config(feat_config),
model_config(model_config), model_config(model_config),
lm_config(lm_config), lm_config(lm_config),
decoding_method(decoding_method), decoding_method(decoding_method),
max_active_paths(max_active_paths) {} max_active_paths(max_active_paths),
context_score(context_score) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;
@@ -58,6 +60,10 @@ class OfflineRecognizer {
/// Create a stream for decoding. /// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream() const; std::unique_ptr<OfflineStream> CreateStream() const;
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/** Decode a single stream /** Decode a single stream
* *
* @param s The stream to decode. * @param s The stream to decode.

View File

@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
class OfflineStream::Impl { class OfflineStream::Impl {
public: public:
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) { explicit Impl(const OfflineFeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = 0; opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false; opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate; opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
const OfflineRecognitionResult &GetResult() const { return r_; } const OfflineRecognitionResult &GetResult() const { return r_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private: private:
void NemoNormalizeFeatures(float *p, int32_t num_frames, void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const { int32_t feature_dim) const {
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
std::unique_ptr<knf::OnlineFbank> fbank_; std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_; knf::FbankOptions opts_;
OfflineRecognitionResult r_; OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
}; };
OfflineStream::OfflineStream( OfflineStream::OfflineStream(
const OfflineFeatureExtractorConfig &config /*= {}*/) const OfflineFeatureExtractorConfig &config /*= {}*/,
: impl_(std::make_unique<Impl>(config)) {} ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::~OfflineStream() = default; OfflineStream::~OfflineStream() = default;
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
impl_->SetResult(r); impl_->SetResult(r);
} }
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
const OfflineRecognitionResult &OfflineStream::GetResult() const { const OfflineRecognitionResult &OfflineStream::GetResult() const {
return impl_->GetResult(); return impl_->GetResult();
} }

View File

@@ -10,6 +10,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
class OfflineStream { class OfflineStream {
public: public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}); explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
~OfflineStream(); ~OfflineStream();
/** /**
@@ -96,6 +98,9 @@ class OfflineStream {
/** Get the recognition result of this stream */ /** Get the recognition result of this stream */
const OfflineRecognitionResult &GetResult() const; const OfflineRecognitionResult &GetResult() const;
/** Get the ContextGraph of this stream */
const ContextGraphPtr &GetContextGraph() const;
private: private:
class Impl; class Impl;
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;

View File

@@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-stream.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
* @return Return a vector of size `N` containing the decoded results. * @return Return a vector of size `N` containing the decoded results.
*/ */
virtual std::vector<OfflineTransducerDecoderResult> Decode( virtual std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0; Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -16,7 +16,9 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult> std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
Ort::Value encoder_out_length) { Ort::Value encoder_out_length,
OfflineStream **ss /*= nullptr*/,
int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence( PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length); model_->Allocator(), &encoder_out, &encoder_out_length);

View File

@@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
: model_(model) {} : model_(model) {}
std::vector<OfflineTransducerDecoderResult> Decode( std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override; Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private: private:
OfflineTransducerModel *model_; // Not owned OfflineTransducerModel *model_; // Not owned

View File

@@ -8,7 +8,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h" #include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h" #include "sherpa-onnx/csrc/slice.h"
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult> std::vector<OfflineTransducerDecoderResult>
OfflineTransducerModifiedBeamSearchDecoder::Decode( OfflineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) { Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence( PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length); model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size = int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size()); static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
int32_t vocab_size = model_->VocabSize(); int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize(); int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, 0); std::vector<int64_t> blanks(context_size, 0);
Hypotheses blank_hyp({{blanks, 0}});
std::deque<Hypotheses> finalized; std::deque<Hypotheses> finalized;
std::vector<Hypotheses> cur(batch_size, blank_hyp); std::vector<Hypotheses> cur;
std::vector<Hypothesis> prev; std::vector<Hypothesis> prev;
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
for (int32_t i = 0; i < batch_size; ++i) {
const ContextState *context_state;
if (ss != nullptr) {
context_graphs[i] =
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
if (context_graphs[i] != nullptr)
context_state = context_graphs[i]->Root();
}
Hypotheses blank_hyp({{blanks, 0, context_state}});
cur.emplace_back(std::move(blank_hyp));
}
int32_t start = 0; int32_t start = 0;
int32_t t = 0; int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) { for (auto n : packed_encoder_out.batch_sizes) {
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
int32_t new_token = k % vocab_size; int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index]; Hypothesis new_hyp = prev[hyp_index];
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) { if (new_token != 0) {
// blank id is fixed to 0 // blank id is fixed to 0
new_hyp.ys.push_back(new_token); new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t); new_hyp.timestamps.push_back(t);
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
} }
new_hyp.log_prob = p_logprob[k]; new_hyp.log_prob = p_logprob[k] + context_score;
hyps.Add(std::move(new_hyp)); hyps.Add(std::move(new_hyp));
} // for (auto k : topk) } // for (auto k : topk)
p_logprob += (end - start) * vocab_size; p_logprob += (end - start) * vocab_size;
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(h)); cur.push_back(std::move(h));
} }
// Finalize context biasing matching..
for (int32_t i = 0; i < cur.size(); ++i) {
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->Finalize(iter->second.context_state);
iter->second.log_prob += context_res.first;
iter->second.context_state = context_res.second;
}
}
}
if (lm_) { if (lm_) {
// use LM for rescoring // use LM for rescoring
lm_->ComputeLMScore(lm_scale_, context_size, &cur); lm_->ComputeLMScore(lm_scale_, context_size, &cur);

View File

@@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
lm_scale_(lm_scale) {} lm_scale_(lm_scale) {}
std::vector<OfflineTransducerDecoderResult> Decode( std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override; Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private: private:
OfflineTransducerModel *model_; // Not owned OfflineTransducerModel *model_; // Not owned

View File

@@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineRecognizerConfig") py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &, .def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &, const OfflineModelConfig &, const OfflineLMConfig &,
const std::string &, int32_t>(), const std::string &, int32_t, float>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(), py::arg("lm_config") = OfflineLMConfig(),
py::arg("decoding_method") = "greedy_search", py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4) py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }
@@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) {
using PyClass = OfflineRecognizer; using PyClass = OfflineRecognizer;
py::class_<PyClass>(*m, "OfflineRecognizer") py::class_<PyClass>(*m, "OfflineRecognizer")
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config")) .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", &PyClass::CreateStream) .def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
},
py::arg("contexts_list"))
.def("decode_stream", &PyClass::DecodeStream) .def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams", .def("decode_streams",
[](PyClass &self, std::vector<OfflineStream *> ss) { [](const PyClass &self, std::vector<OfflineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size()); self.DecodeStreams(ss.data(), ss.size());
}); });
} }

View File

@@ -1,5 +1,12 @@
from typing import Dict, List, Optional
from _sherpa_onnx import Display from _sherpa_onnx import Display
from .online_recognizer import OnlineRecognizer from .online_recognizer import OnlineRecognizer
from .online_recognizer import OnlineStream from .online_recognizer import OnlineStream
from .offline_recognizer import OfflineRecognizer from .offline_recognizer import OfflineRecognizer
from .utils import encode_contexts

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 by manyeyes # Copyright (c) 2023 by manyeyes
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Optional
from _sherpa_onnx import ( from _sherpa_onnx import (
OfflineFeatureExtractorConfig, OfflineFeatureExtractorConfig,
@@ -39,6 +39,7 @@ class OfflineRecognizer(object):
sample_rate: int = 16000, sample_rate: int = 16000,
feature_dim: int = 80, feature_dim: int = 80,
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
context_score: float = 1.5,
debug: bool = False, debug: bool = False,
provider: str = "cpu", provider: str = "cpu",
): ):
@@ -96,6 +97,7 @@ class OfflineRecognizer(object):
feat_config=feat_config, feat_config=feat_config,
model_config=model_config, model_config=model_config,
decoding_method=decoding_method, decoding_method=decoding_method,
context_score=context_score,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
return self return self
@@ -216,8 +218,11 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
return self return self
def create_stream(self): def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
return self.recognizer.create_stream() if contexts_list is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
def decode_stream(self, s: OfflineStream): def decode_stream(self, s: OfflineStream):
self.recognizer.decode_stream(s) self.recognizer.decode_stream(s)

View File

@@ -0,0 +1,74 @@
from typing import Dict, List, Optional
def encode_contexts(
modeling_unit: str,
contexts: List[str],
sp: Optional["SentencePieceProcessor"] = None,
tokens_table: Optional[Dict[str, int]] = None,
) -> List[List[int]]:
"""
Encode the given contexts (a list of string) to a list of a list of token ids.
Args:
modeling_unit:
The valid values are bpe, char, bpe+char.
Note: char here means characters in CJK languages, not English like languages.
contexts:
The given contexts list (a list of string).
sp:
An instance of SentencePieceProcessor.
tokens_table:
The tokens_table containing the tokens and the corresponding ids.
Returns:
Return the contexts_list, it is a list of a list of token ids.
"""
contexts_list = []
if "bpe" in modeling_unit:
assert sp is not None
if "char" in modeling_unit:
assert tokens_table is not None
assert len(tokens_table) > 0, len(tokens_table)
if "char" == modeling_unit:
for context in contexts:
assert ' ' not in context
ids = [
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
for txt in context
]
contexts_list.append(ids)
elif "bpe" == modeling_unit:
contexts_list = sp.encode(contexts, out_type=int)
else:
assert modeling_unit == "bpe+char", modeling_unit
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for context in contexts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(context.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
ids = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
ids.append(
tokens_table[ch_or_w]
if ch_or_w in tokens_table
else tokens_table["<unk>"]
)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
ids.append(
tokens_table[p]
if p in tokens_table
else tokens_table["<unk>"]
)
contexts_list.append(ids)
return contexts_list