Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -20,9 +20,10 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sentencepiece as spm
|
||||||
import sherpa_onnx
|
import sherpa_onnx
|
||||||
|
|
||||||
|
|
||||||
@@ -69,6 +70,59 @@ def get_args():
|
|||||||
help="Valid values are greedy_search and modified_beam_search",
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-active-paths",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is modified_beam_search.
|
||||||
|
It specifies number of active paths to keep during decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
Path to bpe.model, it will be used to tokenize contexts biasing phrases.
|
||||||
|
Used only when --decoding-method=modified_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modeling-unit",
|
||||||
|
type=str,
|
||||||
|
default="char",
|
||||||
|
help="""
|
||||||
|
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
|
||||||
|
Valid values are bpe, bpe+char, char.
|
||||||
|
Note: the char here means characters in CJK languages.
|
||||||
|
Used only when --decoding-method=modified_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--contexts",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The context list, it is a string containing some words/phrases separated
|
||||||
|
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
|
||||||
|
Used only when --decoding-method=modified_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="""
|
||||||
|
The context score of each token for biasing word/phrase. Used only if
|
||||||
|
--contexts is given.
|
||||||
|
Used only when --decoding-method=modified_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
|||||||
return samples_float32, f.getframerate()
|
return samples_float32, f.getframerate()
|
||||||
|
|
||||||
|
|
||||||
|
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
||||||
|
sp = None
|
||||||
|
if "bpe" in args.modeling_unit:
|
||||||
|
assert_file_exists(args.bpe_model)
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
tokens = {}
|
||||||
|
with open(args.tokens, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
toks = line.strip().split()
|
||||||
|
assert len(toks) == 2, len(toks)
|
||||||
|
assert toks[0] not in tokens, f"Duplicate token: {toks} "
|
||||||
|
tokens[toks[0]] = int(toks[1])
|
||||||
|
return sherpa_onnx.encode_contexts(
|
||||||
|
modeling_unit=args.modeling_unit,
|
||||||
|
contexts=contexts,
|
||||||
|
sp=sp,
|
||||||
|
tokens_table=tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
@@ -132,11 +207,20 @@ def main():
|
|||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
feature_dim=80,
|
feature_dim=80,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
|
max_active_paths=args.max_active_paths,
|
||||||
|
context_score=args.context_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Started!")
|
print("Started!")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
contexts_list = []
|
||||||
|
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
||||||
|
if contexts:
|
||||||
|
print(f"Contexts list: {contexts}")
|
||||||
|
contexts_list = encode_contexts(args, contexts)
|
||||||
|
|
||||||
|
|
||||||
streams = []
|
streams = []
|
||||||
total_duration = 0
|
total_duration = 0
|
||||||
for wave_filename in args.sound_files:
|
for wave_filename in args.sound_files:
|
||||||
@@ -145,7 +229,11 @@ def main():
|
|||||||
duration = len(samples) / sample_rate
|
duration = len(samples) / sample_rate
|
||||||
total_duration += duration
|
total_duration += duration
|
||||||
|
|
||||||
s = recognizer.create_stream()
|
if contexts_list:
|
||||||
|
s = recognizer.create_stream(contexts_list=contexts_list)
|
||||||
|
else:
|
||||||
|
s = recognizer.create_stream()
|
||||||
|
|
||||||
s.accept_waveform(sample_rate, samples)
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
"True to enable endpoint detection. False to disable it.");
|
"True to enable endpoint detection. False to disable it.");
|
||||||
po->Register("max-active-paths", &max_active_paths,
|
po->Register("max-active-paths", &max_active_paths,
|
||||||
"beam size used in modified beam search.");
|
"beam size used in modified beam search.");
|
||||||
|
po->Register("context-score", &context_score,
|
||||||
|
"The bonus score for each token in context word/phrase. "
|
||||||
|
"Used only when decoding_method is modified_beam_search");
|
||||||
po->Register("decoding-method", &decoding_method,
|
po->Register("decoding-method", &decoding_method,
|
||||||
"decoding method,"
|
"decoding method,"
|
||||||
"now support greedy_search and modified_beam_search.");
|
"now support greedy_search and modified_beam_search.");
|
||||||
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
|||||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||||
os << "max_active_paths=" << max_active_paths << ", ";
|
os << "max_active_paths=" << max_active_paths << ", ";
|
||||||
|
os << "context_score=" << context_score << ", ";
|
||||||
os << "decoding_method=\"" << decoding_method << "\")";
|
os << "decoding_method=\"" << decoding_method << "\")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void InitOnlineStream(OnlineStream *stream) const {
|
||||||
|
auto r = decoder_->GetEmptyResult();
|
||||||
|
|
||||||
|
if (config_.decoding_method == "modified_beam_search" &&
|
||||||
|
nullptr != stream->GetContextGraph()) {
|
||||||
|
// r.hyps has only one element.
|
||||||
|
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
||||||
|
it->second.context_state = stream->GetContextGraph()->Root();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream->SetResult(r);
|
||||||
|
stream->SetStates(model_->GetEncoderInitStates());
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||||
stream->SetResult(decoder_->GetEmptyResult());
|
InitOnlineStream(stream.get());
|
||||||
stream->SetStates(model_->GetEncoderInitStates());
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &contexts) const {
|
||||||
|
// We create context_graph at this level, because we might have default
|
||||||
|
// context_graph(will be added later if needed) that belongs to the whole
|
||||||
|
// model rather than each stream.
|
||||||
|
auto context_graph =
|
||||||
|
std::make_shared<ContextGraph>(contexts, config_.context_score);
|
||||||
|
auto stream =
|
||||||
|
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
|
||||||
|
InitOnlineStream(stream.get());
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
|
|||||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||||
std::vector<int64_t> all_processed_frames(n);
|
std::vector<int64_t> all_processed_frames(n);
|
||||||
|
bool has_context_graph = false;
|
||||||
|
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
if (!has_context_graph && ss[i]->GetContextGraph())
|
||||||
|
has_context_graph = true;
|
||||||
|
|
||||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||||
std::vector<float> features =
|
std::vector<float> features =
|
||||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||||
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
|
|||||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||||
std::move(processed_frames));
|
std::move(processed_frames));
|
||||||
|
|
||||||
decoder_->Decode(std::move(pair.first), &results);
|
if (has_context_graph) {
|
||||||
|
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||||
|
} else {
|
||||||
|
decoder_->Decode(std::move(pair.first), &results);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> next_states =
|
std::vector<std::vector<Ort::Value>> next_states =
|
||||||
model_->UnStackStates(pair.second);
|
model_->UnStackStates(pair.second);
|
||||||
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
|||||||
return impl_->CreateStream();
|
return impl_->CreateStream();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||||
|
return impl_->CreateStream(context_list);
|
||||||
|
}
|
||||||
|
|
||||||
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||||
return impl_->IsReady(s);
|
return impl_->IsReady(s);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
|
|||||||
std::string decoding_method = "greedy_search";
|
std::string decoding_method = "greedy_search";
|
||||||
// now support modified_beam_search and greedy_search
|
// now support modified_beam_search and greedy_search
|
||||||
|
|
||||||
int32_t max_active_paths = 4; // used only for modified_beam_search
|
// used only for modified_beam_search
|
||||||
|
int32_t max_active_paths = 4;
|
||||||
|
/// used only for modified_beam_search
|
||||||
|
float context_score = 1.5;
|
||||||
|
|
||||||
OnlineRecognizerConfig() = default;
|
OnlineRecognizerConfig() = default;
|
||||||
|
|
||||||
@@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
|
|||||||
const EndpointConfig &endpoint_config,
|
const EndpointConfig &endpoint_config,
|
||||||
bool enable_endpoint,
|
bool enable_endpoint,
|
||||||
const std::string &decoding_method,
|
const std::string &decoding_method,
|
||||||
int32_t max_active_paths)
|
int32_t max_active_paths, float context_score)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
endpoint_config(endpoint_config),
|
endpoint_config(endpoint_config),
|
||||||
enable_endpoint(enable_endpoint),
|
enable_endpoint(enable_endpoint),
|
||||||
decoding_method(decoding_method),
|
decoding_method(decoding_method),
|
||||||
max_active_paths(max_active_paths) {}
|
max_active_paths(max_active_paths),
|
||||||
|
context_score(context_score) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
@@ -112,6 +116,10 @@ class OnlineRecognizer {
|
|||||||
/// Create a stream for decoding.
|
/// Create a stream for decoding.
|
||||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||||
|
|
||||||
|
// Create a stream with context phrases
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream(
|
||||||
|
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return true if the given stream has enough frames for decoding.
|
* Return true if the given stream has enough frames for decoding.
|
||||||
* Return false otherwise
|
* Return false otherwise
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OnlineStream::Impl {
|
class OnlineStream::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const FeatureExtractorConfig &config)
|
explicit Impl(const FeatureExtractorConfig &config,
|
||||||
: feat_extractor_(config) {}
|
ContextGraphPtr context_graph)
|
||||||
|
: feat_extractor_(config), context_graph_(context_graph) {}
|
||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||||
@@ -54,16 +55,21 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
std::vector<Ort::Value> &GetStates() { return states_; }
|
std::vector<Ort::Value> &GetStates() { return states_; }
|
||||||
|
|
||||||
|
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FeatureExtractor feat_extractor_;
|
FeatureExtractor feat_extractor_;
|
||||||
|
/// For contextual-biasing
|
||||||
|
ContextGraphPtr context_graph_;
|
||||||
int32_t num_processed_frames_ = 0; // before subsampling
|
int32_t num_processed_frames_ = 0; // before subsampling
|
||||||
int32_t start_frame_index_ = 0; // never reset
|
int32_t start_frame_index_ = 0; // never reset
|
||||||
OnlineTransducerDecoderResult result_;
|
OnlineTransducerDecoderResult result_;
|
||||||
std::vector<Ort::Value> states_;
|
std::vector<Ort::Value> states_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||||
: impl_(std::make_unique<Impl>(config)) {}
|
ContextGraphPtr context_graph /*= nullptr */)
|
||||||
|
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||||
|
|
||||||
OnlineStream::~OnlineStream() = default;
|
OnlineStream::~OnlineStream() = default;
|
||||||
|
|
||||||
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
|
|||||||
return impl_->GetStates();
|
return impl_->GetStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||||
|
return impl_->GetContextGraph();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
|
|
||||||
@@ -16,7 +17,8 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OnlineStream {
|
class OnlineStream {
|
||||||
public:
|
public:
|
||||||
explicit OnlineStream(const FeatureExtractorConfig &config = {});
|
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||||
|
ContextGraphPtr context_graph = nullptr);
|
||||||
~OnlineStream();
|
~OnlineStream();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -71,6 +73,13 @@ class OnlineStream {
|
|||||||
void SetStates(std::vector<Ort::Value> states);
|
void SetStates(std::vector<Ort::Value> states);
|
||||||
std::vector<Ort::Value> &GetStates();
|
std::vector<Ort::Value> &GetStates();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the context graph corresponding to this stream.
|
||||||
|
*
|
||||||
|
* @return Return the context graph for this stream.
|
||||||
|
*/
|
||||||
|
const ContextGraphPtr &GetContextGraph() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
|
|||||||
OnlineTransducerDecoderResult &&other);
|
OnlineTransducerDecoderResult &&other);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class OnlineStream;
|
||||||
class OnlineTransducerDecoder {
|
class OnlineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
virtual ~OnlineTransducerDecoder() = default;
|
virtual ~OnlineTransducerDecoder() = default;
|
||||||
@@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
|
|||||||
virtual void Decode(Ort::Value encoder_out,
|
virtual void Decode(Ort::Value encoder_out,
|
||||||
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||||
|
|
||||||
|
/** Run transducer beam search given the output from the encoder model.
|
||||||
|
*
|
||||||
|
* Note: Currently this interface is for contextual-biasing feature which
|
||||||
|
* needs a ContextGraph owned by the OnlineStream.
|
||||||
|
*
|
||||||
|
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
|
||||||
|
* @param ss A list of OnlineStreams.
|
||||||
|
* @param result It is modified in-place.
|
||||||
|
*
|
||||||
|
* @note There is no need to pass encoder_out_length here since for the
|
||||||
|
* online decoding case, each utterance has the same number of frames
|
||||||
|
* and there are no paddings.
|
||||||
|
*/
|
||||||
|
virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||||
|
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
// used for endpointing. We need to keep decoder_out after reset
|
// used for endpointing. We need to keep decoder_out after reset
|
||||||
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/log.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
|||||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||||
Ort::Value encoder_out,
|
Ort::Value encoder_out,
|
||||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||||
|
Decode(std::move(encoder_out), nullptr, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||||
|
Ort::Value encoder_out, OnlineStream **ss,
|
||||||
|
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||||
std::vector<int64_t> encoder_out_shape =
|
std::vector<int64_t> encoder_out_shape =
|
||||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||||
|
|
||||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||||
int32_t vocab_size = model_->VocabSize();
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
|
||||||
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
|
|
||||||
Hypothesis new_hyp = prev[hyp_index];
|
Hypothesis new_hyp = prev[hyp_index];
|
||||||
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
||||||
|
float context_score = 0;
|
||||||
|
auto context_state = new_hyp.context_state;
|
||||||
|
|
||||||
if (new_token != 0) {
|
if (new_token != 0) {
|
||||||
new_hyp.ys.push_back(new_token);
|
new_hyp.ys.push_back(new_token);
|
||||||
new_hyp.timestamps.push_back(t + frame_offset);
|
new_hyp.timestamps.push_back(t + frame_offset);
|
||||||
new_hyp.num_trailing_blanks = 0;
|
new_hyp.num_trailing_blanks = 0;
|
||||||
|
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||||
|
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||||
|
context_state, new_token);
|
||||||
|
context_score = context_res.first;
|
||||||
|
new_hyp.context_state = context_res.second;
|
||||||
|
}
|
||||||
if (lm_) {
|
if (lm_) {
|
||||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
++new_hyp.num_trailing_blanks;
|
++new_hyp.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
new_hyp.log_prob =
|
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||||
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
|
prev_lm_log_prob; // log_prob only includes the
|
||||||
// score of the transducer
|
// score of the transducer
|
||||||
hyps.Add(std::move(new_hyp));
|
hyps.Add(std::move(new_hyp));
|
||||||
} // for (auto k : topk)
|
} // for (auto k : topk)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-lm.h"
|
#include "sherpa-onnx/csrc/online-lm.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
@@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
void Decode(Ort::Value encoder_out,
|
void Decode(Ort::Value encoder_out,
|
||||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||||
|
|
||||||
|
void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||||
|
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||||
|
|
||||||
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
|
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||||
.def(py::init<const FeatureExtractorConfig &,
|
.def(py::init<const FeatureExtractorConfig &,
|
||||||
const OnlineTransducerModelConfig &, const OnlineLMConfig &,
|
const OnlineTransducerModelConfig &, const OnlineLMConfig &,
|
||||||
const EndpointConfig &, bool, const std::string &,
|
const EndpointConfig &, bool, const std::string &, int32_t,
|
||||||
int32_t>(),
|
float>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||||
py::arg("max_active_paths"))
|
py::arg("max_active_paths"), py::arg("context_score"))
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||||
|
.def_readwrite("context_score", &PyClass::context_score)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) {
|
|||||||
using PyClass = OnlineRecognizer;
|
using PyClass = OnlineRecognizer;
|
||||||
py::class_<PyClass>(*m, "OnlineRecognizer")
|
py::class_<PyClass>(*m, "OnlineRecognizer")
|
||||||
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
|
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
|
||||||
.def("create_stream", &PyClass::CreateStream)
|
.def("create_stream",
|
||||||
|
[](const PyClass &self) { return self.CreateStream(); })
|
||||||
|
.def(
|
||||||
|
"create_stream",
|
||||||
|
[](PyClass &self,
|
||||||
|
const std::vector<std::vector<int32_t>> &contexts_list) {
|
||||||
|
return self.CreateStream(contexts_list);
|
||||||
|
},
|
||||||
|
py::arg("contexts_list"))
|
||||||
.def("is_ready", &PyClass::IsReady)
|
.def("is_ready", &PyClass::IsReady)
|
||||||
.def("decode_stream", &PyClass::DecodeStream)
|
.def("decode_stream", &PyClass::DecodeStream)
|
||||||
.def("decode_streams",
|
.def("decode_streams",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2023 Xiaomi Corporation
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
EndpointConfig,
|
EndpointConfig,
|
||||||
@@ -39,6 +39,7 @@ class OnlineRecognizer(object):
|
|||||||
rule3_min_utterance_length: float = 20.0,
|
rule3_min_utterance_length: float = 20.0,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
max_active_paths: int = 4,
|
max_active_paths: int = 4,
|
||||||
|
context_score: float = 1.5,
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -124,13 +125,17 @@ class OnlineRecognizer(object):
|
|||||||
enable_endpoint=enable_endpoint_detection,
|
enable_endpoint=enable_endpoint_detection,
|
||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
max_active_paths=max_active_paths,
|
max_active_paths=max_active_paths,
|
||||||
|
context_score=context_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
|
|
||||||
def create_stream(self):
|
def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
|
||||||
return self.recognizer.create_stream()
|
if contexts_list is None:
|
||||||
|
return self.recognizer.create_stream()
|
||||||
|
else:
|
||||||
|
return self.recognizer.create_stream(contexts_list)
|
||||||
|
|
||||||
def decode_stream(self, s: OnlineStream):
|
def decode_stream(self, s: OnlineStream):
|
||||||
self.recognizer.decode_stream(s)
|
self.recognizer.decode_stream(s)
|
||||||
|
|||||||
Reference in New Issue
Block a user