Fix nemo streaming transducer greedy search (#944)
This commit is contained in:
39
.github/scripts/test-online-transducer.sh
vendored
39
.github/scripts/test-online-transducer.sh
vendored
@@ -15,6 +15,45 @@ echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run NeMo transducer (English)"
|
||||
log "------------------------------------------------------------"
|
||||
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
|
||||
curl -SL -O $repo_url
|
||||
tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
|
||||
rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
|
||||
repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms
|
||||
|
||||
log "Start testing ${repo_url}"
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/8k.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder.onnx \
|
||||
--decoder=$repo/decoder.onnx \
|
||||
--joiner=$repo/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
$wave
|
||||
done
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder.onnx \
|
||||
--decoder=$repo/decoder.onnx \
|
||||
--joiner=$repo/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run LSTM transducer (English)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
@@ -196,7 +196,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
mkdir -p aarch64
|
||||
|
||||
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
|
||||
|
||||
@@ -187,7 +187,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
mkdir -p aarch64
|
||||
|
||||
cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64
|
||||
|
||||
1
.github/workflows/android.yaml
vendored
1
.github/workflows/android.yaml
vendored
@@ -124,7 +124,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
|
||||
cp -v ../sherpa-onnx-*-android.tar.bz2 ./
|
||||
|
||||
|
||||
1
.github/workflows/arm-linux-gnueabihf.yaml
vendored
1
.github/workflows/arm-linux-gnueabihf.yaml
vendored
@@ -209,7 +209,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
mkdir -p arm32
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./arm32
|
||||
|
||||
1
.github/workflows/build-xcframework.yaml
vendored
1
.github/workflows/build-xcframework.yaml
vendored
@@ -138,7 +138,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./
|
||||
|
||||
|
||||
1
.github/workflows/riscv64-linux.yaml
vendored
1
.github/workflows/riscv64-linux.yaml
vendored
@@ -242,7 +242,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
mkdir -p riscv64
|
||||
|
||||
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64
|
||||
|
||||
1
.github/workflows/windows-x64.yaml
vendored
1
.github/workflows/windows-x64.yaml
vendored
@@ -219,7 +219,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
mkdir -p win64
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./win64
|
||||
|
||||
1
.github/workflows/windows-x86.yaml
vendored
1
.github/workflows/windows-x86.yaml
vendored
@@ -221,7 +221,6 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs pull
|
||||
mkdir -p win32
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./win32
|
||||
|
||||
@@ -14,19 +14,18 @@ namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
const OnlineRecognizerConfig &config) {
|
||||
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
|
||||
|
||||
auto decoder_model = ReadFile(config.model_config.transducer.decoder);
|
||||
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(
|
||||
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
|
||||
|
||||
size_t node_count = sess->GetOutputCount();
|
||||
|
||||
|
||||
if (node_count == 1) {
|
||||
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
|
||||
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
|
||||
}
|
||||
}
|
||||
@@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
|
||||
|
||||
auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
|
||||
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(
|
||||
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
|
||||
|
||||
size_t node_count = sess->GetOutputCount();
|
||||
|
||||
|
||||
if (node_count == 1) {
|
||||
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
|
||||
} else {
|
||||
|
||||
@@ -35,18 +35,15 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t segment,
|
||||
int32_t frames_since_start) {
|
||||
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms, int32_t subsampling_factor,
|
||||
int32_t segment, int32_t frames_since_start) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
|
||||
for (auto i : src.tokens) {
|
||||
if (i == -1) continue;
|
||||
auto sym = sym_table[i];
|
||||
|
||||
r.text.append(sym);
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
@@ -32,23 +33,20 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ./online-recognizer-transducer-impl.h
|
||||
// static may or may not be here? TODDOs
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t segment,
|
||||
int32_t frames_since_start);
|
||||
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms, int32_t subsampling_factor,
|
||||
int32_t segment, int32_t frames_since_start);
|
||||
|
||||
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
public:
|
||||
public:
|
||||
explicit OnlineRecognizerTransducerNeMoImpl(
|
||||
const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config),
|
||||
model_(std::make_unique<OnlineTransducerNeMoModel>(
|
||||
config.model_config)) {
|
||||
model_(
|
||||
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
@@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
model_.get(), config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetStates(model_->GetInitStates());
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
@@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
|
||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 8;
|
||||
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
int32_t subsampling_factor = model_->SubsamplingFactor();
|
||||
return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
|
||||
subsampling_factor, s->GetCurrentSegment(),
|
||||
s->GetNumFramesSinceStart());
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const override {
|
||||
@@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
// frame shift is 10 milliseconds
|
||||
float frame_shift_in_seconds = 0.01;
|
||||
|
||||
// subsampling factor is 8
|
||||
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
|
||||
int32_t trailing_silence_frames =
|
||||
s->GetResult().num_trailing_blanks * model_->SubsamplingFactor();
|
||||
|
||||
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||
frame_shift_in_seconds);
|
||||
@@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
// segment is incremented only when the last
|
||||
// result is not empty
|
||||
const auto &r = s->GetResult();
|
||||
if (!r.tokens.empty() && r.tokens.back() != 0) {
|
||||
if (!r.tokens.empty()) {
|
||||
s->GetCurrentSegment() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// we keep the decoder_out
|
||||
decoder_->UpdateDecoderOut(&s->GetResult());
|
||||
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
|
||||
s->SetResult({});
|
||||
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
s->SetResult(r);
|
||||
s->GetResult().decoder_out = std::move(decoder_out);
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
|
||||
s->SetNeMoDecoderStates(model_->GetDecoderInitStates());
|
||||
|
||||
// Note: We only update counters. The underlying audio samples
|
||||
// are not discarded.
|
||||
@@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
|
||||
int32_t feature_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::vector<OnlineTransducerDecoderResult> result(n);
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> encoder_states(n);
|
||||
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
@@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
std::copy(features.begin(), features.end(),
|
||||
features_vec.data() + i * chunk_size * feature_dim);
|
||||
|
||||
result[i] = std::move(ss[i]->GetResult());
|
||||
encoder_states[i] = std::move(ss[i]->GetStates());
|
||||
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
@@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
// Batch size is 1
|
||||
auto states = std::move(encoder_states[0]);
|
||||
int32_t num_states = states.size(); // num_states = 3
|
||||
auto states = model_->StackStates(std::move(encoder_states));
|
||||
int32_t num_states = states.size(); // num_states = 3
|
||||
auto t = model_->RunEncoder(std::move(x), std::move(states));
|
||||
// t[0] encoder_out, float tensor, (batch_size, dim, T)
|
||||
// t[1] next states
|
||||
|
||||
|
||||
std::vector<Ort::Value> out_states;
|
||||
out_states.reserve(num_states);
|
||||
|
||||
|
||||
for (int32_t k = 1; k != num_states + 1; ++k) {
|
||||
out_states.push_back(std::move(t[k]));
|
||||
}
|
||||
|
||||
auto unstacked_states = model_->UnStackStates(std::move(out_states));
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
ss[i]->SetStates(std::move(unstacked_states[i]));
|
||||
}
|
||||
|
||||
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
|
||||
|
||||
// defined in online-transducer-greedy-search-nemo-decoder.h
|
||||
// get intial states of decoder.
|
||||
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();
|
||||
|
||||
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
|
||||
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
|
||||
decoder_states = decoder_->Decode(std::move(encoder_out),
|
||||
std::move(decoder_states),
|
||||
&result, ss, n);
|
||||
|
||||
ss[0]->SetResult(result[0]);
|
||||
|
||||
ss[0]->SetStates(std::move(out_states));
|
||||
decoder_->Decode(std::move(encoder_out), ss, n);
|
||||
}
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
// set encoder states
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
|
||||
stream->SetResult(r);
|
||||
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
|
||||
// set decoder states
|
||||
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates());
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
symbol_table_.NumSymbols(), vocab_size);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
std::unique_ptr<OnlineTransducerNeMoModel> model_;
|
||||
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
|
||||
Endpoint endpoint_;
|
||||
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
|
||||
@@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
|
||||
return impl_->GetStates();
|
||||
}
|
||||
|
||||
void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
|
||||
void OnlineStream::SetNeMoDecoderStates(
|
||||
std::vector<Ort::Value> decoder_states) {
|
||||
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
|
||||
}
|
||||
|
||||
|
||||
@@ -91,8 +91,8 @@ class OnlineStream {
|
||||
void SetStates(std::vector<Ort::Value> states);
|
||||
std::vector<Ort::Value> &GetStates();
|
||||
|
||||
void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
|
||||
std::vector<Ort::Value> &GetNeMoDecoderStates();
|
||||
void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
|
||||
std::vector<Ort::Value> &GetNeMoDecoderStates();
|
||||
|
||||
/**
|
||||
* Get the context graph corresponding to this stream.
|
||||
|
||||
@@ -10,103 +10,64 @@
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
|
||||
int32_t token, OrtAllocator *allocator) {
|
||||
static Ort::Value BuildDecoderInput(int32_t token, OrtAllocator *allocator) {
|
||||
std::array<int64_t, 2> shape{1, 1};
|
||||
|
||||
Ort::Value decoder_input =
|
||||
Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
|
||||
|
||||
std::array<int64_t, 1> length_shape{1};
|
||||
Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
|
||||
allocator, length_shape.data(), length_shape.size());
|
||||
|
||||
int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
|
||||
|
||||
int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
|
||||
|
||||
p[0] = token;
|
||||
|
||||
p_length[0] = 1;
|
||||
|
||||
return {std::move(decoder_input), std::move(decoder_input_length)};
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = 8;
|
||||
int32_t blank_id = 0; // always 0
|
||||
OnlineTransducerDecoderResult r;
|
||||
r.tokens.resize(context_size, -1);
|
||||
r.tokens.back() = blank_id;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
static void UpdateCachedDecoderOut(
|
||||
OrtAllocator *allocator, const Ort::Value *decoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
std::array<int64_t, 2> v_shape{1, shape[1]};
|
||||
|
||||
const float *src = decoder_out->GetTensorData<float>();
|
||||
for (auto &r : *result) {
|
||||
if (!r.decoder_out) {
|
||||
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
|
||||
v_shape.size());
|
||||
}
|
||||
|
||||
float *dst = r.decoder_out.GetTensorMutableData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
src += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> DecodeOne(
|
||||
const float *encoder_out, int32_t num_rows, int32_t num_cols,
|
||||
OnlineTransducerNeMoModel *model, float blank_penalty,
|
||||
std::vector<Ort::Value>& decoder_states,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
|
||||
static void DecodeOne(const float *encoder_out, int32_t num_rows,
|
||||
int32_t num_cols, OnlineTransducerNeMoModel *model,
|
||||
float blank_penalty, OnlineStream *s) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
// OnlineTransducerDecoderResult result;
|
||||
int32_t vocab_size = model->VocabSize();
|
||||
int32_t blank_id = vocab_size - 1;
|
||||
|
||||
auto &r = (*result)[0];
|
||||
|
||||
auto &r = s->GetResult();
|
||||
|
||||
Ort::Value decoder_out{nullptr};
|
||||
|
||||
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
|
||||
// decoder_input_pair[0]: decoder_input
|
||||
// decoder_input_pair[1]: decoder_input_length (discarded)
|
||||
auto decoder_input = BuildDecoderInput(
|
||||
r.tokens.empty() ? blank_id : r.tokens.back(), model->Allocator());
|
||||
|
||||
std::vector<Ort::Value> &last_decoder_states = s->GetNeMoDecoderStates();
|
||||
|
||||
std::vector<Ort::Value> tmp_decoder_states;
|
||||
tmp_decoder_states.reserve(last_decoder_states.size());
|
||||
for (auto &v : last_decoder_states) {
|
||||
tmp_decoder_states.push_back(View(&v));
|
||||
}
|
||||
|
||||
// decoder_output_pair.second returns the next decoder state
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
|
||||
model->RunDecoder(std::move(decoder_input_pair.first),
|
||||
std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN
|
||||
model->RunDecoder(std::move(decoder_input),
|
||||
std::move(tmp_decoder_states));
|
||||
|
||||
std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
|
||||
|
||||
decoder_states = std::move(decoder_output_pair.second);
|
||||
bool emitted = false;
|
||||
|
||||
// TODO: Inside this loop, I need to framewise decoding.
|
||||
for (int32_t t = 0; t != num_rows; ++t) {
|
||||
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols,
|
||||
encoder_shape.data(), encoder_shape.size());
|
||||
|
||||
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
|
||||
View(&decoder_output_pair.first));
|
||||
View(&decoder_output_pair.first));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
if (blank_penalty > 0) {
|
||||
@@ -117,82 +78,52 @@ std::vector<Ort::Value> DecodeOne(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
SHERPA_ONNX_LOGE("y=%d", y);
|
||||
|
||||
if (y != blank_id) {
|
||||
emitted = true;
|
||||
r.tokens.push_back(y);
|
||||
r.timestamps.push_back(t + r.frame_offset);
|
||||
r.num_trailing_blanks = 0;
|
||||
|
||||
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
|
||||
decoder_input = BuildDecoderInput(y, model->Allocator());
|
||||
|
||||
// last decoder state becomes the current state for the first chunk
|
||||
decoder_output_pair =
|
||||
model->RunDecoder(std::move(decoder_input_pair.first),
|
||||
std::move(decoder_states));
|
||||
|
||||
// Update the decoder states for the next chunk
|
||||
decoder_states = std::move(decoder_output_pair.second);
|
||||
decoder_output_pair = model->RunDecoder(
|
||||
std::move(decoder_input), std::move(decoder_output_pair.second));
|
||||
} else {
|
||||
++r.num_trailing_blanks;
|
||||
}
|
||||
}
|
||||
|
||||
decoder_out = std::move(decoder_output_pair.first);
|
||||
// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);
|
||||
|
||||
// Update frame_offset
|
||||
for (auto &r : *result) {
|
||||
r.frame_offset += num_rows;
|
||||
if (emitted) {
|
||||
s->SetNeMoDecoderStates(std::move(decoder_output_pair.second));
|
||||
}
|
||||
|
||||
return std::move(decoder_states);
|
||||
r.frame_offset += num_rows;
|
||||
}
|
||||
|
||||
|
||||
std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<Ort::Value> decoder_states,
|
||||
std::vector<OnlineTransducerDecoderResult> *result,
|
||||
OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
|
||||
|
||||
void OnlineTransducerGreedySearchNeMoDecoder::Decode(Ort::Value encoder_out,
|
||||
OnlineStream **ss,
|
||||
int32_t n) const {
|
||||
auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1
|
||||
|
||||
if (shape[0] != result->size()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d",
|
||||
static_cast<int32_t>(shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
if (batch_size != n) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! encoder_out.size(0) %d, n: %d",
|
||||
static_cast<int32_t>(shape[0]), n);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1
|
||||
int32_t dim1 = static_cast<int32_t>(shape[1]); // 2
|
||||
int32_t dim2 = static_cast<int32_t>(shape[2]); // 512
|
||||
int32_t dim1 = static_cast<int32_t>(shape[1]); // T
|
||||
int32_t dim2 = static_cast<int32_t>(shape[2]); // encoder_out_dim
|
||||
|
||||
// Define and initialize encoder_out_length
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
int64_t length_value = 1;
|
||||
std::vector<int64_t> length_shape = {1};
|
||||
|
||||
Ort::Value encoder_out_length = Ort::Value::CreateTensor<int64_t>(
|
||||
memory_info, &length_value, 1, length_shape.data(), length_shape.size()
|
||||
);
|
||||
|
||||
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
|
||||
const float *p = encoder_out.GetTensorData<float>();
|
||||
|
||||
// std::vector<OnlineTransducerDecoderResult> ans(batch_size);
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *this_p = p + dim1 * dim2 * i;
|
||||
int32_t this_len = p_length[i];
|
||||
|
||||
// outputs the decoder state from last chunk.
|
||||
auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result);
|
||||
// ans[i] = decode_result_pair.first;
|
||||
decoder_states = std::move(last_decoder_states);
|
||||
DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, ss[i]);
|
||||
}
|
||||
|
||||
return decoder_states;
|
||||
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -7,27 +7,22 @@
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream;
|
||||
|
||||
class OnlineTransducerGreedySearchNeMoDecoder {
|
||||
public:
|
||||
OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model,
|
||||
float blank_penalty)
|
||||
: model_(model),
|
||||
blank_penalty_(blank_penalty) {}
|
||||
: model_(model), blank_penalty_(blank_penalty) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const;
|
||||
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||
void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {}
|
||||
|
||||
std::vector<Ort::Value> Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<Ort::Value> decoder_states,
|
||||
std::vector<OnlineTransducerDecoderResult> *result,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0);
|
||||
// @param n number of elements in ss
|
||||
void Decode(Ort::Value encoder_out, OnlineStream **ss, int32_t n) const;
|
||||
|
||||
private:
|
||||
OnlineTransducerNeMoModel *model_; // Not owned
|
||||
@@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder {
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
|
||||
: config_(config),
|
||||
@@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
std::vector<Ort::Value> states) {
|
||||
Ort::Value &cache_last_channel = states[0];
|
||||
Ort::Value &cache_last_time = states[1];
|
||||
Ort::Value &cache_last_channel_len = states[2];
|
||||
@@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
std::move(features), View(&length), std::move(cache_last_channel),
|
||||
std::move(cache_last_time), std::move(cache_last_channel_len)};
|
||||
|
||||
auto out =
|
||||
encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||
auto out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||
// out[0]: logit
|
||||
// out[1] logit_length
|
||||
// out[2:] states_next
|
||||
@@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
Ort::Value targets, std::vector<Ort::Value> states) {
|
||||
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
Ort::MemoryInfo memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
// Create the tensor with a single int32_t value of 1
|
||||
int32_t length_value = 1;
|
||||
std::vector<int64_t> length_shape = {1};
|
||||
auto shape = targets.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = static_cast<int32_t>(shape[0]);
|
||||
|
||||
std::vector<int64_t> length_shape = {batch_size};
|
||||
std::vector<int32_t> length_value(batch_size, 1);
|
||||
|
||||
Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>(
|
||||
memory_info, &length_value, 1, length_shape.data(), length_shape.size()
|
||||
);
|
||||
|
||||
memory_info, length_value.data(), batch_size, length_shape.data(),
|
||||
length_shape.size());
|
||||
|
||||
std::vector<Ort::Value> decoder_inputs;
|
||||
decoder_inputs.reserve(2 + states.size());
|
||||
|
||||
@@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
|
||||
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||
std::move(decoder_out)};
|
||||
auto logit =
|
||||
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
|
||||
joiner_input.size(), joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
|
||||
joiner_input.data(), joiner_input.size(),
|
||||
joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
|
||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||
s0_shape.size());
|
||||
std::vector<Ort::Value> GetDecoderInitStates() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(View(&lstm0_));
|
||||
ans.push_back(View(&lstm1_));
|
||||
|
||||
Fill<float>(&s0, 0);
|
||||
|
||||
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
|
||||
Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
|
||||
s1_shape.size());
|
||||
|
||||
Fill<float>(&s1, 0);
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
states.reserve(2);
|
||||
states.push_back(std::move(s0));
|
||||
states.push_back(std::move(s1));
|
||||
|
||||
return states;
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t ChunkSize() const { return window_size_; }
|
||||
@@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
int32_t ChunkShift() const { return chunk_shift_; }
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
@@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
// - cache_last_channel
|
||||
// - cache_last_time_
|
||||
// - cache_last_channel_len
|
||||
std::vector<Ort::Value> GetInitStates() {
|
||||
std::vector<Ort::Value> GetEncoderInitStates() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(3);
|
||||
ans.push_back(View(&cache_last_channel_));
|
||||
@@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
if (batch_size == 1) {
|
||||
return std::move(states[0]);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
|
||||
// stack cache_last_channel
|
||||
std::vector<const Ort::Value *> buf(batch_size);
|
||||
|
||||
// there are 3 states to be stacked
|
||||
for (int32_t i = 0; i != 3; ++i) {
|
||||
buf.clear();
|
||||
buf.reserve(batch_size);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
assert(states[b].size() == 3);
|
||||
buf.push_back(&states[b][i]);
|
||||
}
|
||||
|
||||
Ort::Value c{nullptr};
|
||||
if (i == 2) {
|
||||
c = Cat<int64_t>(allocator_, buf, 0);
|
||||
} else {
|
||||
c = Cat(allocator_, buf, 0);
|
||||
}
|
||||
|
||||
ans.push_back(std::move(c));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
assert(states.size() == 3);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
|
||||
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = shape[0];
|
||||
ans.resize(batch_size);
|
||||
|
||||
if (batch_size == 1) {
|
||||
ans[0] = std::move(states);
|
||||
return ans;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != 3; ++i) {
|
||||
std::vector<Ort::Value> v;
|
||||
if (i == 2) {
|
||||
v = Unbind<int64_t>(allocator_, &states[i], 0);
|
||||
} else {
|
||||
v = Unbind(allocator_, &states[i], 0);
|
||||
}
|
||||
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
ans[b].push_back(std::move(v[b]));
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
@@ -276,10 +332,10 @@ private:
|
||||
normalize_type_ = "";
|
||||
}
|
||||
|
||||
InitStates();
|
||||
InitEncoderStates();
|
||||
}
|
||||
|
||||
void InitStates() {
|
||||
|
||||
void InitEncoderStates() {
|
||||
std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
|
||||
cache_last_channel_dim2_,
|
||||
cache_last_channel_dim3_};
|
||||
@@ -313,7 +369,25 @@ private:
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
&decoder_output_names_ptr_);
|
||||
|
||||
InitDecoderStates();
|
||||
}
|
||||
|
||||
void InitDecoderStates() {
|
||||
int32_t batch_size = 1;
|
||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
lstm0_ = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||
s0_shape.size());
|
||||
|
||||
Fill<float>(&lstm0_, 0);
|
||||
|
||||
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
|
||||
lstm1_ = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
|
||||
s1_shape.size());
|
||||
|
||||
Fill<float>(&lstm1_, 0);
|
||||
}
|
||||
|
||||
void InitJoiner(void *model_data, size_t model_data_length) {
|
||||
@@ -324,7 +398,7 @@ private:
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
&joiner_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -363,6 +437,7 @@ private:
|
||||
int32_t pred_rnn_layers_ = -1;
|
||||
int32_t pred_hidden_ = -1;
|
||||
|
||||
// encoder states
|
||||
int32_t cache_last_channel_dim1_;
|
||||
int32_t cache_last_channel_dim2_;
|
||||
int32_t cache_last_channel_dim3_;
|
||||
@@ -370,9 +445,14 @@ private:
|
||||
int32_t cache_last_time_dim2_;
|
||||
int32_t cache_last_time_dim3_;
|
||||
|
||||
// init encoder states
|
||||
Ort::Value cache_last_channel_{nullptr};
|
||||
Ort::Value cache_last_time_{nullptr};
|
||||
Ort::Value cache_last_channel_len_{nullptr};
|
||||
|
||||
// init decoder states
|
||||
Ort::Value lstm0_{nullptr};
|
||||
Ort::Value lstm1_{nullptr};
|
||||
};
|
||||
|
||||
OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
|
||||
@@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
|
||||
|
||||
OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default;
|
||||
|
||||
std::vector<Ort::Value>
|
||||
OnlineTransducerNeMoModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->RunEncoder(std::move(features), std::move(states));
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states) const {
|
||||
return impl_->RunEncoder(std::move(features), std::move(states));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
@@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets,
|
||||
return impl_->RunDecoder(std::move(targets), std::move(states));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates(
|
||||
int32_t batch_size) const {
|
||||
return impl_->GetDecoderInitStates(batch_size);
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates()
|
||||
const {
|
||||
return impl_->GetDecoderInitStates();
|
||||
}
|
||||
|
||||
Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
|
||||
@@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
|
||||
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::ChunkSize() const {
|
||||
return impl_->ChunkSize();
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::ChunkSize() const {
|
||||
return impl_->ChunkSize();
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
int32_t OnlineTransducerNeMoModel::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
@@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
|
||||
return impl_->FeatureNormalizationMethod();
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const {
|
||||
return impl_->GetInitStates();
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetEncoderInitStates()
|
||||
const {
|
||||
return impl_->GetEncoderInitStates();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
return impl_->StackStates(std::move(states));
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> OnlineTransducerNeMoModel::UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->UnStackStates(std::move(states));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -32,22 +32,31 @@ class OnlineTransducerNeMoModel {
|
||||
OnlineTransducerNeMoModel(AAssetManager *mgr,
|
||||
const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
|
||||
~OnlineTransducerNeMoModel();
|
||||
// A list of 3 tensors:
|
||||
// A list of 3 tensors:
|
||||
// - cache_last_channel
|
||||
// - cache_last_time
|
||||
// - cache_last_channel_len
|
||||
std::vector<Ort::Value> GetInitStates() const;
|
||||
std::vector<Ort::Value> GetEncoderInitStates() const;
|
||||
|
||||
// stack encoder states
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const;
|
||||
|
||||
// unstack encoder states
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const;
|
||||
|
||||
/** Run the encoder.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param states It is from GetInitStates() or returned from this method.
|
||||
*
|
||||
* @param states It is from GetEncoderInitStates() or returned from this
|
||||
* method.
|
||||
*
|
||||
* @return Return a tuple containing:
|
||||
* - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim)
|
||||
* - ans[1:]: contains next states
|
||||
* - ans[0]: encoder_out, a tensor of shape (N, encoder_out_dim, T')
|
||||
* - ans[1:]: contains next states
|
||||
*/
|
||||
std::vector<Ort::Value> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT
|
||||
@@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel {
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
Ort::Value targets, std::vector<Ort::Value> states) const;
|
||||
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
|
||||
std::vector<Ort::Value> GetDecoderInitStates() const;
|
||||
|
||||
/** Run the joint network.
|
||||
*
|
||||
@@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel {
|
||||
* @param decoder_out Output of the decoder network.
|
||||
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
|
||||
*/
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out,
|
||||
Ort::Value decoder_out) const;
|
||||
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const;
|
||||
|
||||
/** We send this number of feature frames to the encoder at a time. */
|
||||
int32_t ChunkSize() const;
|
||||
@@ -114,10 +121,10 @@ class OnlineTransducerNeMoModel {
|
||||
// for details
|
||||
std::string FeatureNormalizationMethod() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
|
||||
Reference in New Issue
Block a user