Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
2
.github/workflows/run-python-test.yaml
vendored
2
.github/workflows/run-python-test.yaml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy
|
||||
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
|
||||
@@ -43,9 +43,10 @@ import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
@@ -60,6 +61,47 @@ def get_args():
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
Path to bpe.model,
|
||||
Used only when --decoding-method=modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="char",
|
||||
help="""
|
||||
The type of modeling unit.
|
||||
Valid values are bpe, bpe+char, char.
|
||||
Note: the char here means characters in CJK languages.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--contexts",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The context list, it is a string containing some words/phrases separated
|
||||
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="""
|
||||
The context score of each token for biasing word/phrase. Used only if
|
||||
--contexts is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="",
|
||||
@@ -153,6 +195,24 @@ def assert_file_exists(filename: str):
|
||||
)
|
||||
|
||||
|
||||
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
||||
sp = None
|
||||
if "bpe" in args.modeling_unit:
|
||||
assert_file_exists(args.bpe_model)
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
tokens = {}
|
||||
with open(args.tokens, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
toks = line.strip().split()
|
||||
assert len(toks) == 2, len(toks)
|
||||
assert toks[0] not in tokens, f"Duplicate token: {toks} "
|
||||
tokens[toks[0]] = int(toks[1])
|
||||
return sherpa_onnx.encode_contexts(
|
||||
modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
@@ -182,10 +242,17 @@ def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
|
||||
contexts_list = []
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
|
||||
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
||||
if contexts:
|
||||
print(f"Contexts list: {contexts}")
|
||||
contexts_list = encode_contexts(args, contexts)
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
@@ -199,6 +266,7 @@ def main():
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
context_score=args.context_score,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.paraformer:
|
||||
@@ -238,8 +306,12 @@ def main():
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
|
||||
s = recognizer.create_stream()
|
||||
if contexts_list:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
s = recognizer.create_stream(contexts_list=contexts_list)
|
||||
else:
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
streams.append(s)
|
||||
|
||||
1
setup.py
1
setup.py
@@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
|
||||
|
||||
install_requires = [
|
||||
"numpy",
|
||||
"sentencepiece==0.1.96",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ endif()
|
||||
|
||||
set(sources
|
||||
cat.cc
|
||||
context-graph.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
file-utils.cc
|
||||
@@ -248,6 +249,7 @@ endif()
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
set(sherpa_onnx_test_srcs
|
||||
cat-test.cc
|
||||
context-graph-test.cc
|
||||
packed-sequence-test.cc
|
||||
pad-sequence-test.cc
|
||||
slice-test.cc
|
||||
|
||||
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
@@ -0,0 +1,43 @@
|
||||
// sherpa-onnx/csrc/context-graph-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
std::vector<std::string> contexts_str(
|
||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||
}
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
|
||||
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
|
||||
for (const auto &iter : queries) {
|
||||
float total_scores = 0;
|
||||
auto state = context_graph.Root();
|
||||
for (auto q : iter.first) {
|
||||
auto res = context_graph.ForwardOneStep(state, q);
|
||||
total_scores += res.first;
|
||||
state = res.second;
|
||||
}
|
||||
auto res = context_graph.Finalize(state);
|
||||
EXPECT_EQ(res.second->token, -1);
|
||||
total_scores += res.first;
|
||||
EXPECT_EQ(total_scores, iter.second);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
105
sherpa-onnx/csrc/context-graph.cc
Normal file
105
sherpa-onnx/csrc/context-graph.cc
Normal file
@@ -0,0 +1,105 @@
|
||||
// sherpa-onnx/csrc/context-graph.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
void ContextGraph::Build(
|
||||
const std::vector<std::vector<int32_t>> &token_ids) const {
|
||||
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||
auto node = root_.get();
|
||||
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||
int32_t token = token_ids[i][j];
|
||||
if (0 == node->next.count(token)) {
|
||||
bool is_end = j == token_ids[i].size() - 1;
|
||||
node->next[token] = std::make_unique<ContextState>(
|
||||
token, context_score_, node->node_score + context_score_,
|
||||
is_end ? 0 : node->local_node_score + context_score_, is_end);
|
||||
}
|
||||
node = node->next[token].get();
|
||||
}
|
||||
}
|
||||
FillFailOutput();
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
const ContextState *state, int32_t token) const {
|
||||
const ContextState *node;
|
||||
float score;
|
||||
if (1 == state->next.count(token)) {
|
||||
node = state->next.at(token).get();
|
||||
score = node->token_score;
|
||||
if (state->is_end) score += state->node_score;
|
||||
} else {
|
||||
node = state->fail;
|
||||
while (0 == node->next.count(token)) {
|
||||
node = node->fail;
|
||||
if (-1 == node->token) break; // root
|
||||
}
|
||||
if (1 == node->next.count(token)) {
|
||||
node = node->next.at(token).get();
|
||||
}
|
||||
score = node->node_score - state->local_node_score;
|
||||
}
|
||||
SHERPA_ONNX_CHECK(nullptr != node);
|
||||
float matched_score = 0;
|
||||
auto output = node->output;
|
||||
while (nullptr != output) {
|
||||
matched_score += output->node_score;
|
||||
output = output->output;
|
||||
}
|
||||
return std::make_pair(score + matched_score, node);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
const ContextState *state) const {
|
||||
float score = -state->node_score;
|
||||
if (state->is_end) {
|
||||
score = 0;
|
||||
}
|
||||
return std::make_pair(score, root_.get());
|
||||
}
|
||||
|
||||
void ContextGraph::FillFailOutput() const {
|
||||
std::queue<const ContextState *> node_queue;
|
||||
for (auto &kv : root_->next) {
|
||||
kv.second->fail = root_.get();
|
||||
node_queue.push(kv.second.get());
|
||||
}
|
||||
while (!node_queue.empty()) {
|
||||
auto current_node = node_queue.front();
|
||||
node_queue.pop();
|
||||
for (auto &kv : current_node->next) {
|
||||
auto fail = current_node->fail;
|
||||
if (1 == fail->next.count(kv.first)) {
|
||||
fail = fail->next.at(kv.first).get();
|
||||
} else {
|
||||
fail = fail->fail;
|
||||
while (0 == fail->next.count(kv.first)) {
|
||||
fail = fail->fail;
|
||||
if (-1 == fail->token) break;
|
||||
}
|
||||
if (1 == fail->next.count(kv.first))
|
||||
fail = fail->next.at(kv.first).get();
|
||||
}
|
||||
kv.second->fail = fail;
|
||||
// fill the output arc
|
||||
auto output = fail;
|
||||
while (!output->is_end) {
|
||||
output = output->fail;
|
||||
if (-1 == output->token) {
|
||||
output = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
kv.second->output = output;
|
||||
node_queue.push(kv.second.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace sherpa_onnx
|
||||
66
sherpa-onnx/csrc/context-graph.h
Normal file
66
sherpa-onnx/csrc/context-graph.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// sherpa-onnx/csrc/context-graph.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class ContextGraph;
|
||||
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
|
||||
|
||||
struct ContextState {
|
||||
int32_t token;
|
||||
float token_score;
|
||||
float node_score;
|
||||
float local_node_score;
|
||||
bool is_end;
|
||||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||
const ContextState *fail = nullptr;
|
||||
const ContextState *output = nullptr;
|
||||
|
||||
ContextState() = default;
|
||||
ContextState(int32_t token, float token_score, float node_score,
|
||||
float local_node_score, bool is_end)
|
||||
: token(token),
|
||||
token_score(token_score),
|
||||
node_score(node_score),
|
||||
local_node_score(local_node_score),
|
||||
is_end(is_end) {}
|
||||
};
|
||||
|
||||
class ContextGraph {
|
||||
public:
|
||||
ContextGraph() = default;
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score)
|
||||
: context_score_(context_score) {
|
||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
|
||||
root_->fail = root_.get();
|
||||
Build(token_ids);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id) const;
|
||||
std::pair<float, const ContextState *> Finalize(
|
||||
const ContextState *state) const;
|
||||
|
||||
const ContextState *Root() const { return root_.get(); }
|
||||
|
||||
private:
|
||||
float context_score_;
|
||||
std::unique_ptr<ContextState> root_;
|
||||
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
|
||||
void FillFailOutput() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -39,11 +40,18 @@ struct Hypothesis {
|
||||
// the nn lm states
|
||||
std::vector<CopyableOrtValue> nn_lm_states;
|
||||
|
||||
const ContextState *context_state;
|
||||
|
||||
// TODO(fangjun): Make it configurable
|
||||
// the minimum of tokens in a chunk for streaming RNN LM
|
||||
int32_t lm_rescore_min_chunk = 2; // a const
|
||||
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||
: ys(ys), log_prob(log_prob) {}
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
|
||||
const ContextState *context_state = nullptr)
|
||||
: ys(ys), log_prob(log_prob), context_state(context_state) {}
|
||||
|
||||
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
|
||||
@@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
|
||||
|
||||
virtual ~OfflineRecognizerImpl() = default;
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(context_list, config_.context_score);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
-23.025850929940457f);
|
||||
|
||||
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
auto results =
|
||||
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
|
||||
@@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register("context-score", &context_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
}
|
||||
|
||||
bool OfflineRecognizerConfig::Validate() const {
|
||||
@@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "max_active_paths=" << max_active_paths << ")";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
|
||||
|
||||
OfflineRecognizer::~OfflineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
int32_t max_active_paths = 4;
|
||||
float context_score = 1.5;
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
@@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
|
||||
const OfflineModelConfig &model_config,
|
||||
const OfflineLMConfig &lm_config,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
int32_t max_active_paths, float context_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
max_active_paths(max_active_paths),
|
||||
context_score(context_score) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
@@ -58,6 +60,10 @@ class OfflineRecognizer {
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||
|
||||
/** Decode a single stream
|
||||
*
|
||||
* @param s The stream to decode.
|
||||
|
||||
@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
|
||||
|
||||
class OfflineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: config_(config), context_graph_(context_graph) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
|
||||
|
||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||
int32_t feature_dim) const {
|
||||
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
|
||||
impl_->SetResult(r);
|
||||
}
|
||||
|
||||
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
const OfflineRecognitionResult &OfflineStream::GetResult() const {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
@@ -96,6 +98,9 @@ class OfflineStream {
|
||||
/** Get the recognition result of this stream */
|
||||
const OfflineRecognitionResult &GetResult() const;
|
||||
|
||||
/** Get the ContextGraph of this stream */
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -16,7 +16,9 @@ namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
||||
Ort::Value encoder_out_length) {
|
||||
Ort::Value encoder_out_length,
|
||||
OfflineStream **ss /*= nullptr*/,
|
||||
int32_t n /*= 0*/) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||
: model_(model) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
|
||||
@@ -8,7 +8,9 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||
#include "sherpa-onnx/csrc/slice.h"
|
||||
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) {
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
int32_t batch_size =
|
||||
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||
|
||||
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
std::vector<int64_t> blanks(context_size, 0);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
|
||||
std::deque<Hypotheses> finalized;
|
||||
std::vector<Hypotheses> cur(batch_size, blank_hyp);
|
||||
std::vector<Hypotheses> cur;
|
||||
std::vector<Hypothesis> prev;
|
||||
|
||||
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
|
||||
|
||||
for (int32_t i = 0; i < batch_size; ++i) {
|
||||
const ContextState *context_state;
|
||||
if (ss != nullptr) {
|
||||
context_graphs[i] =
|
||||
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
|
||||
if (context_graphs[i] != nullptr)
|
||||
context_state = context_graphs[i]->Root();
|
||||
}
|
||||
Hypotheses blank_hyp({{blanks, 0, context_state}});
|
||||
cur.emplace_back(std::move(blank_hyp));
|
||||
}
|
||||
|
||||
int32_t start = 0;
|
||||
int32_t t = 0;
|
||||
for (auto n : packed_encoder_out.batch_sizes) {
|
||||
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
int32_t new_token = k % vocab_size;
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
if (new_token != 0) {
|
||||
// blank id is fixed to 0
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t);
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
}
|
||||
}
|
||||
|
||||
new_hyp.log_prob = p_logprob[k];
|
||||
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
cur.push_back(std::move(h));
|
||||
}
|
||||
|
||||
// Finalize context biasing matching..
|
||||
for (int32_t i = 0; i < cur.size(); ++i) {
|
||||
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->Finalize(iter->second.context_state);
|
||||
iter->second.log_prob += context_res.first;
|
||||
iter->second.context_state = context_res.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lm_) {
|
||||
// use LM for rescoring
|
||||
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
||||
|
||||
@@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
||||
lm_scale_(lm_scale) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
|
||||
@@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||
const std::string &, int32_t>(),
|
||||
const std::string &, int32_t, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4)
|
||||
py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("context_score", &PyClass::context_score)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) {
|
||||
using PyClass = OfflineRecognizer;
|
||||
py::class_<PyClass>(*m, "OfflineRecognizer")
|
||||
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
|
||||
.def("create_stream", &PyClass::CreateStream)
|
||||
.def("create_stream",
|
||||
[](const PyClass &self) { return self.CreateStream(); })
|
||||
.def(
|
||||
"create_stream",
|
||||
[](PyClass &self,
|
||||
const std::vector<std::vector<int32_t>> &contexts_list) {
|
||||
return self.CreateStream(contexts_list);
|
||||
},
|
||||
py::arg("contexts_list"))
|
||||
.def("decode_stream", &PyClass::DecodeStream)
|
||||
.def("decode_streams",
|
||||
[](PyClass &self, std::vector<OfflineStream *> ss) {
|
||||
[](const PyClass &self, std::vector<OfflineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from _sherpa_onnx import Display
|
||||
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
from .online_recognizer import OnlineStream
|
||||
from .offline_recognizer import OfflineRecognizer
|
||||
|
||||
from .utils import encode_contexts
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 by manyeyes
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
OfflineFeatureExtractorConfig,
|
||||
@@ -39,6 +39,7 @@ class OfflineRecognizer(object):
|
||||
sample_rate: int = 16000,
|
||||
feature_dim: int = 80,
|
||||
decoding_method: str = "greedy_search",
|
||||
context_score: float = 1.5,
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
@@ -96,6 +97,7 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
context_score=context_score,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
return self
|
||||
@@ -216,8 +218,11 @@ class OfflineRecognizer(object):
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
return self
|
||||
|
||||
def create_stream(self):
|
||||
return self.recognizer.create_stream()
|
||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||
if contexts_list is None:
|
||||
return self.recognizer.create_stream()
|
||||
else:
|
||||
return self.recognizer.create_stream(contexts_list)
|
||||
|
||||
def decode_stream(self, s: OfflineStream):
|
||||
self.recognizer.decode_stream(s)
|
||||
|
||||
74
sherpa-onnx/python/sherpa_onnx/utils.py
Normal file
74
sherpa-onnx/python/sherpa_onnx/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def encode_contexts(
|
||||
modeling_unit: str,
|
||||
contexts: List[str],
|
||||
sp: Optional["SentencePieceProcessor"] = None,
|
||||
tokens_table: Optional[Dict[str, int]] = None,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Encode the given contexts (a list of string) to a list of a list of token ids.
|
||||
|
||||
Args:
|
||||
modeling_unit:
|
||||
The valid values are bpe, char, bpe+char.
|
||||
Note: char here means characters in CJK languages, not English like languages.
|
||||
contexts:
|
||||
The given contexts list (a list of string).
|
||||
sp:
|
||||
An instance of SentencePieceProcessor.
|
||||
tokens_table:
|
||||
The tokens_table containing the tokens and the corresponding ids.
|
||||
Returns:
|
||||
Return the contexts_list, it is a list of a list of token ids.
|
||||
"""
|
||||
contexts_list = []
|
||||
if "bpe" in modeling_unit:
|
||||
assert sp is not None
|
||||
if "char" in modeling_unit:
|
||||
assert tokens_table is not None
|
||||
assert len(tokens_table) > 0, len(tokens_table)
|
||||
|
||||
if "char" == modeling_unit:
|
||||
for context in contexts:
|
||||
assert ' ' not in context
|
||||
ids = [
|
||||
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
|
||||
for txt in context
|
||||
]
|
||||
contexts_list.append(ids)
|
||||
elif "bpe" == modeling_unit:
|
||||
contexts_list = sp.encode(contexts, out_type=int)
|
||||
else:
|
||||
assert modeling_unit == "bpe+char", modeling_unit
|
||||
|
||||
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
for context in contexts:
|
||||
# Example:
|
||||
# txt = "你好 ITS'S OKAY 的"
|
||||
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
||||
chars = pattern.split(context.upper())
|
||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||
ids = []
|
||||
for ch_or_w in mix_chars:
|
||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||
if pattern.fullmatch(ch_or_w) is not None:
|
||||
ids.append(
|
||||
tokens_table[ch_or_w]
|
||||
if ch_or_w in tokens_table
|
||||
else tokens_table["<unk>"]
|
||||
)
|
||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||
# encode ch_or_w using bpe_model.
|
||||
else:
|
||||
for p in sp.encode_as_pieces(ch_or_w):
|
||||
ids.append(
|
||||
tokens_table[p]
|
||||
if p in tokens_table
|
||||
else tokens_table["<unk>"]
|
||||
)
|
||||
contexts_list.append(ids)
|
||||
return contexts_list
|
||||
Reference in New Issue
Block a user