Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
2
.github/workflows/run-python-test.yaml
vendored
2
.github/workflows/run-python-test.yaml
vendored
@@ -54,7 +54,7 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip numpy
|
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96
|
||||||
|
|
||||||
- name: Install sherpa-onnx
|
- name: Install sherpa-onnx
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -43,9 +43,10 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sentencepiece as spm
|
||||||
import sherpa_onnx
|
import sherpa_onnx
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +61,47 @@ def get_args():
|
|||||||
help="Path to tokens.txt",
|
help="Path to tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
Path to bpe.model,
|
||||||
|
Used only when --decoding-method=modified_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modeling-unit",
|
||||||
|
type=str,
|
||||||
|
default="char",
|
||||||
|
help="""
|
||||||
|
The type of modeling unit.
|
||||||
|
Valid values are bpe, bpe+char, char.
|
||||||
|
Note: the char here means characters in CJK languages.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--contexts",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The context list, it is a string containing some words/phrases separated
|
||||||
|
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="""
|
||||||
|
The context score of each token for biasing word/phrase. Used only if
|
||||||
|
--contexts is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder",
|
"--encoder",
|
||||||
default="",
|
default="",
|
||||||
@@ -153,6 +195,24 @@ def assert_file_exists(filename: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
||||||
|
sp = None
|
||||||
|
if "bpe" in args.modeling_unit:
|
||||||
|
assert_file_exists(args.bpe_model)
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
tokens = {}
|
||||||
|
with open(args.tokens, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
toks = line.strip().split()
|
||||||
|
assert len(toks) == 2, len(toks)
|
||||||
|
assert toks[0] not in tokens, f"Duplicate token: {toks} "
|
||||||
|
tokens[toks[0]] = int(toks[1])
|
||||||
|
return sherpa_onnx.encode_contexts(
|
||||||
|
modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -182,10 +242,17 @@ def main():
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
assert_file_exists(args.tokens)
|
assert_file_exists(args.tokens)
|
||||||
assert args.num_threads > 0, args.num_threads
|
assert args.num_threads > 0, args.num_threads
|
||||||
|
|
||||||
|
contexts_list = []
|
||||||
if args.encoder:
|
if args.encoder:
|
||||||
assert len(args.paraformer) == 0, args.paraformer
|
assert len(args.paraformer) == 0, args.paraformer
|
||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
|
||||||
|
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
||||||
|
if contexts:
|
||||||
|
print(f"Contexts list: {contexts}")
|
||||||
|
contexts_list = encode_contexts(args, contexts)
|
||||||
|
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
assert_file_exists(args.joiner)
|
assert_file_exists(args.joiner)
|
||||||
@@ -199,6 +266,7 @@ def main():
|
|||||||
sample_rate=args.sample_rate,
|
sample_rate=args.sample_rate,
|
||||||
feature_dim=args.feature_dim,
|
feature_dim=args.feature_dim,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
|
context_score=args.context_score,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
elif args.paraformer:
|
elif args.paraformer:
|
||||||
@@ -238,8 +306,12 @@ def main():
|
|||||||
samples, sample_rate = read_wave(wave_filename)
|
samples, sample_rate = read_wave(wave_filename)
|
||||||
duration = len(samples) / sample_rate
|
duration = len(samples) / sample_rate
|
||||||
total_duration += duration
|
total_duration += duration
|
||||||
|
if contexts_list:
|
||||||
s = recognizer.create_stream()
|
assert len(args.paraformer) == 0, args.paraformer
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
s = recognizer.create_stream(contexts_list=contexts_list)
|
||||||
|
else:
|
||||||
|
s = recognizer.create_stream()
|
||||||
s.accept_waveform(sample_rate, samples)
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
streams.append(s)
|
streams.append(s)
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
|
|||||||
|
|
||||||
install_requires = [
|
install_requires = [
|
||||||
"numpy",
|
"numpy",
|
||||||
|
"sentencepiece==0.1.96",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ endif()
|
|||||||
|
|
||||||
set(sources
|
set(sources
|
||||||
cat.cc
|
cat.cc
|
||||||
|
context-graph.cc
|
||||||
endpoint.cc
|
endpoint.cc
|
||||||
features.cc
|
features.cc
|
||||||
file-utils.cc
|
file-utils.cc
|
||||||
@@ -248,6 +249,7 @@ endif()
|
|||||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||||
set(sherpa_onnx_test_srcs
|
set(sherpa_onnx_test_srcs
|
||||||
cat-test.cc
|
cat-test.cc
|
||||||
|
context-graph-test.cc
|
||||||
packed-sequence-test.cc
|
packed-sequence-test.cc
|
||||||
pad-sequence-test.cc
|
pad-sequence-test.cc
|
||||||
slice-test.cc
|
slice-test.cc
|
||||||
|
|||||||
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
// sherpa-onnx/csrc/context-graph-test.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
TEST(ContextGraph, TestBasic) {
|
||||||
|
std::vector<std::string> contexts_str(
|
||||||
|
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||||
|
std::vector<std::vector<int32_t>> contexts;
|
||||||
|
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||||
|
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||||
|
}
|
||||||
|
auto context_graph = ContextGraph(contexts, 1);
|
||||||
|
|
||||||
|
auto queries = std::map<std::string, float>{
|
||||||
|
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
|
||||||
|
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||||
|
|
||||||
|
for (const auto &iter : queries) {
|
||||||
|
float total_scores = 0;
|
||||||
|
auto state = context_graph.Root();
|
||||||
|
for (auto q : iter.first) {
|
||||||
|
auto res = context_graph.ForwardOneStep(state, q);
|
||||||
|
total_scores += res.first;
|
||||||
|
state = res.second;
|
||||||
|
}
|
||||||
|
auto res = context_graph.Finalize(state);
|
||||||
|
EXPECT_EQ(res.second->token, -1);
|
||||||
|
total_scores += res.first;
|
||||||
|
EXPECT_EQ(total_scores, iter.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
105
sherpa-onnx/csrc/context-graph.cc
Normal file
105
sherpa-onnx/csrc/context-graph.cc
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// sherpa-onnx/csrc/context-graph.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <queue>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
void ContextGraph::Build(
|
||||||
|
const std::vector<std::vector<int32_t>> &token_ids) const {
|
||||||
|
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||||
|
auto node = root_.get();
|
||||||
|
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||||
|
int32_t token = token_ids[i][j];
|
||||||
|
if (0 == node->next.count(token)) {
|
||||||
|
bool is_end = j == token_ids[i].size() - 1;
|
||||||
|
node->next[token] = std::make_unique<ContextState>(
|
||||||
|
token, context_score_, node->node_score + context_score_,
|
||||||
|
is_end ? 0 : node->local_node_score + context_score_, is_end);
|
||||||
|
}
|
||||||
|
node = node->next[token].get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FillFailOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||||
|
const ContextState *state, int32_t token) const {
|
||||||
|
const ContextState *node;
|
||||||
|
float score;
|
||||||
|
if (1 == state->next.count(token)) {
|
||||||
|
node = state->next.at(token).get();
|
||||||
|
score = node->token_score;
|
||||||
|
if (state->is_end) score += state->node_score;
|
||||||
|
} else {
|
||||||
|
node = state->fail;
|
||||||
|
while (0 == node->next.count(token)) {
|
||||||
|
node = node->fail;
|
||||||
|
if (-1 == node->token) break; // root
|
||||||
|
}
|
||||||
|
if (1 == node->next.count(token)) {
|
||||||
|
node = node->next.at(token).get();
|
||||||
|
}
|
||||||
|
score = node->node_score - state->local_node_score;
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_CHECK(nullptr != node);
|
||||||
|
float matched_score = 0;
|
||||||
|
auto output = node->output;
|
||||||
|
while (nullptr != output) {
|
||||||
|
matched_score += output->node_score;
|
||||||
|
output = output->output;
|
||||||
|
}
|
||||||
|
return std::make_pair(score + matched_score, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||||
|
const ContextState *state) const {
|
||||||
|
float score = -state->node_score;
|
||||||
|
if (state->is_end) {
|
||||||
|
score = 0;
|
||||||
|
}
|
||||||
|
return std::make_pair(score, root_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContextGraph::FillFailOutput() const {
|
||||||
|
std::queue<const ContextState *> node_queue;
|
||||||
|
for (auto &kv : root_->next) {
|
||||||
|
kv.second->fail = root_.get();
|
||||||
|
node_queue.push(kv.second.get());
|
||||||
|
}
|
||||||
|
while (!node_queue.empty()) {
|
||||||
|
auto current_node = node_queue.front();
|
||||||
|
node_queue.pop();
|
||||||
|
for (auto &kv : current_node->next) {
|
||||||
|
auto fail = current_node->fail;
|
||||||
|
if (1 == fail->next.count(kv.first)) {
|
||||||
|
fail = fail->next.at(kv.first).get();
|
||||||
|
} else {
|
||||||
|
fail = fail->fail;
|
||||||
|
while (0 == fail->next.count(kv.first)) {
|
||||||
|
fail = fail->fail;
|
||||||
|
if (-1 == fail->token) break;
|
||||||
|
}
|
||||||
|
if (1 == fail->next.count(kv.first))
|
||||||
|
fail = fail->next.at(kv.first).get();
|
||||||
|
}
|
||||||
|
kv.second->fail = fail;
|
||||||
|
// fill the output arc
|
||||||
|
auto output = fail;
|
||||||
|
while (!output->is_end) {
|
||||||
|
output = output->fail;
|
||||||
|
if (-1 == output->token) {
|
||||||
|
output = nullptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kv.second->output = output;
|
||||||
|
node_queue.push(kv.second.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace sherpa_onnx
|
||||||
66
sherpa-onnx/csrc/context-graph.h
Normal file
66
sherpa-onnx/csrc/context-graph.h
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
// sherpa-onnx/csrc/context-graph.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/log.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class ContextGraph;
|
||||||
|
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
|
||||||
|
|
||||||
|
struct ContextState {
|
||||||
|
int32_t token;
|
||||||
|
float token_score;
|
||||||
|
float node_score;
|
||||||
|
float local_node_score;
|
||||||
|
bool is_end;
|
||||||
|
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||||
|
const ContextState *fail = nullptr;
|
||||||
|
const ContextState *output = nullptr;
|
||||||
|
|
||||||
|
ContextState() = default;
|
||||||
|
ContextState(int32_t token, float token_score, float node_score,
|
||||||
|
float local_node_score, bool is_end)
|
||||||
|
: token(token),
|
||||||
|
token_score(token_score),
|
||||||
|
node_score(node_score),
|
||||||
|
local_node_score(local_node_score),
|
||||||
|
is_end(is_end) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ContextGraph {
|
||||||
|
public:
|
||||||
|
ContextGraph() = default;
|
||||||
|
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||||
|
float context_score)
|
||||||
|
: context_score_(context_score) {
|
||||||
|
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
|
||||||
|
root_->fail = root_.get();
|
||||||
|
Build(token_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<float, const ContextState *> ForwardOneStep(
|
||||||
|
const ContextState *state, int32_t token_id) const;
|
||||||
|
std::pair<float, const ContextState *> Finalize(
|
||||||
|
const ContextState *state) const;
|
||||||
|
|
||||||
|
const ContextState *Root() const { return root_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
float context_score_;
|
||||||
|
std::unique_ptr<ContextState> root_;
|
||||||
|
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
|
||||||
|
void FillFailOutput() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||||
@@ -14,6 +14,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
#include "sherpa-onnx/csrc/math.h"
|
#include "sherpa-onnx/csrc/math.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
@@ -39,11 +40,18 @@ struct Hypothesis {
|
|||||||
// the nn lm states
|
// the nn lm states
|
||||||
std::vector<CopyableOrtValue> nn_lm_states;
|
std::vector<CopyableOrtValue> nn_lm_states;
|
||||||
|
|
||||||
|
const ContextState *context_state;
|
||||||
|
|
||||||
|
// TODO(fangjun): Make it configurable
|
||||||
|
// the minimum of tokens in a chunk for streaming RNN LM
|
||||||
|
int32_t lm_rescore_min_chunk = 2; // a const
|
||||||
|
|
||||||
int32_t num_trailing_blanks = 0;
|
int32_t num_trailing_blanks = 0;
|
||||||
|
|
||||||
Hypothesis() = default;
|
Hypothesis() = default;
|
||||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
|
||||||
: ys(ys), log_prob(log_prob) {}
|
const ContextState *context_state = nullptr)
|
||||||
|
: ys(ys), log_prob(log_prob), context_state(context_state) {}
|
||||||
|
|
||||||
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,9 @@
|
|||||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||||
|
|
||||||
@@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
|
|||||||
|
|
||||||
virtual ~OfflineRecognizerImpl() = default;
|
virtual ~OfflineRecognizerImpl() = default;
|
||||||
|
|
||||||
|
virtual std::unique_ptr<OfflineStream> CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||||
|
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||||
|
|
||||||
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||||
|
// We create context_graph at this level, because we might have default
|
||||||
|
// context_graph(will be added later if needed) that belongs to the whole
|
||||||
|
// model rather than each stream.
|
||||||
|
auto context_graph =
|
||||||
|
std::make_shared<ContextGraph>(context_list, config_.context_score);
|
||||||
|
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||||
}
|
}
|
||||||
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
-23.025850929940457f);
|
-23.025850929940457f);
|
||||||
|
|
||||||
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
auto results =
|
||||||
|
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
|
||||||
|
|
||||||
int32_t frame_shift_ms = 10;
|
int32_t frame_shift_ms = 10;
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
po->Register("max-active-paths", &max_active_paths,
|
po->Register("max-active-paths", &max_active_paths,
|
||||||
"Used only when decoding_method is modified_beam_search");
|
"Used only when decoding_method is modified_beam_search");
|
||||||
|
po->Register("context-score", &context_score,
|
||||||
|
"The bonus score for each token in context word/phrase. "
|
||||||
|
"Used only when decoding_method is modified_beam_search");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OfflineRecognizerConfig::Validate() const {
|
bool OfflineRecognizerConfig::Validate() const {
|
||||||
@@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
|||||||
os << "model_config=" << model_config.ToString() << ", ";
|
os << "model_config=" << model_config.ToString() << ", ";
|
||||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||||
os << "max_active_paths=" << max_active_paths << ")";
|
os << "max_active_paths=" << max_active_paths << ", ";
|
||||||
|
os << "context_score=" << context_score << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
@@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
|
|||||||
|
|
||||||
OfflineRecognizer::~OfflineRecognizer() = default;
|
OfflineRecognizer::~OfflineRecognizer() = default;
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||||
|
return impl_->CreateStream(context_list);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
|
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
|
||||||
return impl_->CreateStream();
|
return impl_->CreateStream();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
|
|||||||
|
|
||||||
std::string decoding_method = "greedy_search";
|
std::string decoding_method = "greedy_search";
|
||||||
int32_t max_active_paths = 4;
|
int32_t max_active_paths = 4;
|
||||||
|
float context_score = 1.5;
|
||||||
// only greedy_search is implemented
|
// only greedy_search is implemented
|
||||||
// TODO(fangjun): Implement modified_beam_search
|
// TODO(fangjun): Implement modified_beam_search
|
||||||
|
|
||||||
@@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
|
|||||||
const OfflineModelConfig &model_config,
|
const OfflineModelConfig &model_config,
|
||||||
const OfflineLMConfig &lm_config,
|
const OfflineLMConfig &lm_config,
|
||||||
const std::string &decoding_method,
|
const std::string &decoding_method,
|
||||||
int32_t max_active_paths)
|
int32_t max_active_paths, float context_score)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
lm_config(lm_config),
|
lm_config(lm_config),
|
||||||
decoding_method(decoding_method),
|
decoding_method(decoding_method),
|
||||||
max_active_paths(max_active_paths) {}
|
max_active_paths(max_active_paths),
|
||||||
|
context_score(context_score) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
@@ -58,6 +60,10 @@ class OfflineRecognizer {
|
|||||||
/// Create a stream for decoding.
|
/// Create a stream for decoding.
|
||||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||||
|
|
||||||
|
/// Create a stream for decoding.
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||||
|
|
||||||
/** Decode a single stream
|
/** Decode a single stream
|
||||||
*
|
*
|
||||||
* @param s The stream to decode.
|
* @param s The stream to decode.
|
||||||
|
|||||||
@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
|
|||||||
|
|
||||||
class OfflineStream::Impl {
|
class OfflineStream::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
|
explicit Impl(const OfflineFeatureExtractorConfig &config,
|
||||||
|
ContextGraphPtr context_graph)
|
||||||
|
: config_(config), context_graph_(context_graph) {
|
||||||
opts_.frame_opts.dither = 0;
|
opts_.frame_opts.dither = 0;
|
||||||
opts_.frame_opts.snip_edges = false;
|
opts_.frame_opts.snip_edges = false;
|
||||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||||
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
|
|||||||
|
|
||||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||||
|
|
||||||
|
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||||
int32_t feature_dim) const {
|
int32_t feature_dim) const {
|
||||||
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
|
|||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
knf::FbankOptions opts_;
|
knf::FbankOptions opts_;
|
||||||
OfflineRecognitionResult r_;
|
OfflineRecognitionResult r_;
|
||||||
|
ContextGraphPtr context_graph_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OfflineStream::OfflineStream(
|
OfflineStream::OfflineStream(
|
||||||
const OfflineFeatureExtractorConfig &config /*= {}*/)
|
const OfflineFeatureExtractorConfig &config /*= {}*/,
|
||||||
: impl_(std::make_unique<Impl>(config)) {}
|
ContextGraphPtr context_graph /*= nullptr*/)
|
||||||
|
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||||
|
|
||||||
OfflineStream::~OfflineStream() = default;
|
OfflineStream::~OfflineStream() = default;
|
||||||
|
|
||||||
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
|
|||||||
impl_->SetResult(r);
|
impl_->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
|
||||||
|
return impl_->GetContextGraph();
|
||||||
|
}
|
||||||
|
|
||||||
const OfflineRecognitionResult &OfflineStream::GetResult() const {
|
const OfflineRecognitionResult &OfflineStream::GetResult() const {
|
||||||
return impl_->GetResult();
|
return impl_->GetResult();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
#include "sherpa-onnx/csrc/parse-options.h"
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
|
|||||||
|
|
||||||
class OfflineStream {
|
class OfflineStream {
|
||||||
public:
|
public:
|
||||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
|
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||||
|
ContextGraphPtr context_graph = nullptr);
|
||||||
~OfflineStream();
|
~OfflineStream();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -96,6 +98,9 @@ class OfflineStream {
|
|||||||
/** Get the recognition result of this stream */
|
/** Get the recognition result of this stream */
|
||||||
const OfflineRecognitionResult &GetResult() const;
|
const OfflineRecognitionResult &GetResult() const;
|
||||||
|
|
||||||
|
/** Get the ContextGraph of this stream */
|
||||||
|
const ContextGraphPtr &GetContextGraph() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
|
|||||||
* @return Return a vector of size `N` containing the decoded results.
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
*/
|
*/
|
||||||
virtual std::vector<OfflineTransducerDecoderResult> Decode(
|
virtual std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult>
|
std::vector<OfflineTransducerDecoderResult>
|
||||||
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
||||||
Ort::Value encoder_out_length) {
|
Ort::Value encoder_out_length,
|
||||||
|
OfflineStream **ss /*= nullptr*/,
|
||||||
|
int32_t n /*= 0*/) {
|
||||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
|||||||
: model_(model) {}
|
: model_(model) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineTransducerModel *model_; // Not owned
|
OfflineTransducerModel *model_; // Not owned
|
||||||
|
|||||||
@@ -8,7 +8,9 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
|
#include "sherpa-onnx/csrc/log.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/packed-sequence.h"
|
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||||
#include "sherpa-onnx/csrc/slice.h"
|
#include "sherpa-onnx/csrc/slice.h"
|
||||||
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult>
|
std::vector<OfflineTransducerDecoderResult>
|
||||||
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length) {
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
|
||||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||||
|
|
||||||
int32_t batch_size =
|
int32_t batch_size =
|
||||||
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||||
|
|
||||||
|
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
|
||||||
|
|
||||||
int32_t vocab_size = model_->VocabSize();
|
int32_t vocab_size = model_->VocabSize();
|
||||||
int32_t context_size = model_->ContextSize();
|
int32_t context_size = model_->ContextSize();
|
||||||
|
|
||||||
std::vector<int64_t> blanks(context_size, 0);
|
std::vector<int64_t> blanks(context_size, 0);
|
||||||
Hypotheses blank_hyp({{blanks, 0}});
|
|
||||||
|
|
||||||
std::deque<Hypotheses> finalized;
|
std::deque<Hypotheses> finalized;
|
||||||
std::vector<Hypotheses> cur(batch_size, blank_hyp);
|
std::vector<Hypotheses> cur;
|
||||||
std::vector<Hypothesis> prev;
|
std::vector<Hypothesis> prev;
|
||||||
|
|
||||||
|
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < batch_size; ++i) {
|
||||||
|
const ContextState *context_state;
|
||||||
|
if (ss != nullptr) {
|
||||||
|
context_graphs[i] =
|
||||||
|
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
|
||||||
|
if (context_graphs[i] != nullptr)
|
||||||
|
context_state = context_graphs[i]->Root();
|
||||||
|
}
|
||||||
|
Hypotheses blank_hyp({{blanks, 0, context_state}});
|
||||||
|
cur.emplace_back(std::move(blank_hyp));
|
||||||
|
}
|
||||||
|
|
||||||
int32_t start = 0;
|
int32_t start = 0;
|
||||||
int32_t t = 0;
|
int32_t t = 0;
|
||||||
for (auto n : packed_encoder_out.batch_sizes) {
|
for (auto n : packed_encoder_out.batch_sizes) {
|
||||||
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
int32_t new_token = k % vocab_size;
|
int32_t new_token = k % vocab_size;
|
||||||
Hypothesis new_hyp = prev[hyp_index];
|
Hypothesis new_hyp = prev[hyp_index];
|
||||||
|
|
||||||
|
float context_score = 0;
|
||||||
|
auto context_state = new_hyp.context_state;
|
||||||
if (new_token != 0) {
|
if (new_token != 0) {
|
||||||
// blank id is fixed to 0
|
// blank id is fixed to 0
|
||||||
new_hyp.ys.push_back(new_token);
|
new_hyp.ys.push_back(new_token);
|
||||||
new_hyp.timestamps.push_back(t);
|
new_hyp.timestamps.push_back(t);
|
||||||
|
if (context_graphs[i] != nullptr) {
|
||||||
|
auto context_res =
|
||||||
|
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||||
|
context_score = context_res.first;
|
||||||
|
new_hyp.context_state = context_res.second;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
new_hyp.log_prob = p_logprob[k];
|
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||||
hyps.Add(std::move(new_hyp));
|
hyps.Add(std::move(new_hyp));
|
||||||
} // for (auto k : topk)
|
} // for (auto k : topk)
|
||||||
p_logprob += (end - start) * vocab_size;
|
p_logprob += (end - start) * vocab_size;
|
||||||
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
cur.push_back(std::move(h));
|
cur.push_back(std::move(h));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finalize context biasing matching..
|
||||||
|
for (int32_t i = 0; i < cur.size(); ++i) {
|
||||||
|
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
|
||||||
|
if (context_graphs[i] != nullptr) {
|
||||||
|
auto context_res =
|
||||||
|
context_graphs[i]->Finalize(iter->second.context_state);
|
||||||
|
iter->second.log_prob += context_res.first;
|
||||||
|
iter->second.context_state = context_res.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (lm_) {
|
if (lm_) {
|
||||||
// use LM for rescoring
|
// use LM for rescoring
|
||||||
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
|||||||
lm_scale_(lm_scale) {}
|
lm_scale_(lm_scale) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineTransducerModel *model_; // Not owned
|
OfflineTransducerModel *model_; // Not owned
|
||||||
|
|||||||
@@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||||
const std::string &, int32_t>(),
|
const std::string &, int32_t, float>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OfflineLMConfig(),
|
py::arg("lm_config") = OfflineLMConfig(),
|
||||||
py::arg("decoding_method") = "greedy_search",
|
py::arg("decoding_method") = "greedy_search",
|
||||||
py::arg("max_active_paths") = 4)
|
py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||||
|
.def_readwrite("context_score", &PyClass::context_score)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) {
|
|||||||
using PyClass = OfflineRecognizer;
|
using PyClass = OfflineRecognizer;
|
||||||
py::class_<PyClass>(*m, "OfflineRecognizer")
|
py::class_<PyClass>(*m, "OfflineRecognizer")
|
||||||
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
|
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
|
||||||
.def("create_stream", &PyClass::CreateStream)
|
.def("create_stream",
|
||||||
|
[](const PyClass &self) { return self.CreateStream(); })
|
||||||
|
.def(
|
||||||
|
"create_stream",
|
||||||
|
[](PyClass &self,
|
||||||
|
const std::vector<std::vector<int32_t>> &contexts_list) {
|
||||||
|
return self.CreateStream(contexts_list);
|
||||||
|
},
|
||||||
|
py::arg("contexts_list"))
|
||||||
.def("decode_stream", &PyClass::DecodeStream)
|
.def("decode_stream", &PyClass::DecodeStream)
|
||||||
.def("decode_streams",
|
.def("decode_streams",
|
||||||
[](PyClass &self, std::vector<OfflineStream *> ss) {
|
[](const PyClass &self, std::vector<OfflineStream *> ss) {
|
||||||
self.DecodeStreams(ss.data(), ss.size());
|
self.DecodeStreams(ss.data(), ss.size());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from _sherpa_onnx import Display
|
from _sherpa_onnx import Display
|
||||||
|
|
||||||
from .online_recognizer import OnlineRecognizer
|
from .online_recognizer import OnlineRecognizer
|
||||||
from .online_recognizer import OnlineStream
|
from .online_recognizer import OnlineStream
|
||||||
from .offline_recognizer import OfflineRecognizer
|
from .offline_recognizer import OfflineRecognizer
|
||||||
|
|
||||||
|
from .utils import encode_contexts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2023 by manyeyes
|
# Copyright (c) 2023 by manyeyes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
OfflineFeatureExtractorConfig,
|
OfflineFeatureExtractorConfig,
|
||||||
@@ -39,6 +39,7 @@ class OfflineRecognizer(object):
|
|||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
|
context_score: float = 1.5,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
):
|
):
|
||||||
@@ -96,6 +97,7 @@ class OfflineRecognizer(object):
|
|||||||
feat_config=feat_config,
|
feat_config=feat_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
|
context_score=context_score,
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
return self
|
return self
|
||||||
@@ -216,8 +218,11 @@ class OfflineRecognizer(object):
|
|||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def create_stream(self):
|
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||||
return self.recognizer.create_stream()
|
if contexts_list is None:
|
||||||
|
return self.recognizer.create_stream()
|
||||||
|
else:
|
||||||
|
return self.recognizer.create_stream(contexts_list)
|
||||||
|
|
||||||
def decode_stream(self, s: OfflineStream):
|
def decode_stream(self, s: OfflineStream):
|
||||||
self.recognizer.decode_stream(s)
|
self.recognizer.decode_stream(s)
|
||||||
|
|||||||
74
sherpa-onnx/python/sherpa_onnx/utils.py
Normal file
74
sherpa-onnx/python/sherpa_onnx/utils.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def encode_contexts(
|
||||||
|
modeling_unit: str,
|
||||||
|
contexts: List[str],
|
||||||
|
sp: Optional["SentencePieceProcessor"] = None,
|
||||||
|
tokens_table: Optional[Dict[str, int]] = None,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Encode the given contexts (a list of string) to a list of a list of token ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modeling_unit:
|
||||||
|
The valid values are bpe, char, bpe+char.
|
||||||
|
Note: char here means characters in CJK languages, not English like languages.
|
||||||
|
contexts:
|
||||||
|
The given contexts list (a list of string).
|
||||||
|
sp:
|
||||||
|
An instance of SentencePieceProcessor.
|
||||||
|
tokens_table:
|
||||||
|
The tokens_table containing the tokens and the corresponding ids.
|
||||||
|
Returns:
|
||||||
|
Return the contexts_list, it is a list of a list of token ids.
|
||||||
|
"""
|
||||||
|
contexts_list = []
|
||||||
|
if "bpe" in modeling_unit:
|
||||||
|
assert sp is not None
|
||||||
|
if "char" in modeling_unit:
|
||||||
|
assert tokens_table is not None
|
||||||
|
assert len(tokens_table) > 0, len(tokens_table)
|
||||||
|
|
||||||
|
if "char" == modeling_unit:
|
||||||
|
for context in contexts:
|
||||||
|
assert ' ' not in context
|
||||||
|
ids = [
|
||||||
|
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
|
||||||
|
for txt in context
|
||||||
|
]
|
||||||
|
contexts_list.append(ids)
|
||||||
|
elif "bpe" == modeling_unit:
|
||||||
|
contexts_list = sp.encode(contexts, out_type=int)
|
||||||
|
else:
|
||||||
|
assert modeling_unit == "bpe+char", modeling_unit
|
||||||
|
|
||||||
|
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||||
|
for context in contexts:
|
||||||
|
# Example:
|
||||||
|
# txt = "你好 ITS'S OKAY 的"
|
||||||
|
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
||||||
|
chars = pattern.split(context.upper())
|
||||||
|
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||||
|
ids = []
|
||||||
|
for ch_or_w in mix_chars:
|
||||||
|
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||||
|
if pattern.fullmatch(ch_or_w) is not None:
|
||||||
|
ids.append(
|
||||||
|
tokens_table[ch_or_w]
|
||||||
|
if ch_or_w in tokens_table
|
||||||
|
else tokens_table["<unk>"]
|
||||||
|
)
|
||||||
|
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||||
|
# encode ch_or_w using bpe_model.
|
||||||
|
else:
|
||||||
|
for p in sp.encode_as_pieces(ch_or_w):
|
||||||
|
ids.append(
|
||||||
|
tokens_table[p]
|
||||||
|
if p in tokens_table
|
||||||
|
else tokens_table["<unk>"]
|
||||||
|
)
|
||||||
|
contexts_list.append(ids)
|
||||||
|
return contexts_list
|
||||||
Reference in New Issue
Block a user