Support batch greedy search decoding (#30)
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -10,6 +11,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||||
@@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value OnlineLstmTransducerModel::StackStates(
|
std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
|
||||||
const std::vector<Ort::Value> &states) const {
|
const std::vector<std::vector<Ort::Value>> &states) const {
|
||||||
fprintf(stderr, "implement me: %s:%d!\n", __func__,
|
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||||
static_cast<int>(__LINE__));
|
|
||||||
auto memory_info =
|
std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||||
int64_t a;
|
h_shape.size());
|
||||||
std::array<int64_t, 3> x_shape{1, 1, 1};
|
|
||||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);
|
std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,
|
||||||
return x;
|
rnn_hidden_size_};
|
||||||
|
|
||||||
|
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||||
|
c_shape.size());
|
||||||
|
|
||||||
|
float *dst_h = h.GetTensorMutableData<float>();
|
||||||
|
float *dst_c = c.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
|
||||||
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
|
const float *src_h =
|
||||||
|
states[i][0].GetTensorData<float>() + layer * d_model_;
|
||||||
|
|
||||||
|
const float *src_c =
|
||||||
|
states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;
|
||||||
|
|
||||||
|
std::copy(src_h, src_h + d_model_, dst_h);
|
||||||
|
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
|
||||||
|
|
||||||
|
dst_h += d_model_;
|
||||||
|
dst_c += rnn_hidden_size_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(
|
std::vector<Ort::Value> ans;
|
||||||
Ort::Value states) const {
|
|
||||||
fprintf(stderr, "implement me: %s:%d!\n", __func__,
|
ans.reserve(2);
|
||||||
static_cast<int>(__LINE__));
|
ans.push_back(std::move(h));
|
||||||
return {};
|
ans.push_back(std::move(c));
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
|
||||||
|
const std::vector<Ort::Value> &states) const {
|
||||||
|
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
|
||||||
|
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||||
|
|
||||||
|
// allocate space
|
||||||
|
std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};
|
||||||
|
std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_};
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
|
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||||
|
h_shape.size());
|
||||||
|
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||||
|
c_shape.size());
|
||||||
|
ans[i].push_back(std::move(h));
|
||||||
|
ans[i].push_back(std::move(c));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
|
||||||
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
|
const float *src_h = states[0].GetTensorData<float>() +
|
||||||
|
layer * batch_size * d_model_ + i * d_model_;
|
||||||
|
const float *src_c = states[1].GetTensorData<float>() +
|
||||||
|
layer * batch_size * rnn_hidden_size_ +
|
||||||
|
i * rnn_hidden_size_;
|
||||||
|
|
||||||
|
float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;
|
||||||
|
float *dst_c =
|
||||||
|
ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;
|
||||||
|
|
||||||
|
std::copy(src_h, src_h + d_model_, dst_h);
|
||||||
|
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||||
@@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
|
Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
|
||||||
const std::vector<int64_t> &hyp) {
|
const std::vector<OnlineTransducerDecoderResult> &results) {
|
||||||
auto memory_info =
|
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
std::array<int64_t, 2> shape{batch_size, context_size_};
|
||||||
|
Ort::Value decoder_input =
|
||||||
|
Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
|
||||||
|
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||||
|
|
||||||
std::array<int64_t, 2> shape{1, context_size_};
|
for (const auto &r : results) {
|
||||||
|
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
|
||||||
|
const int64_t *end = r.tokens.data() + r.tokens.size();
|
||||||
|
std::copy(begin, end, p);
|
||||||
|
p += context_size_;
|
||||||
|
}
|
||||||
|
|
||||||
return Ort::Value::CreateTensor(
|
return decoder_input;
|
||||||
memory_info,
|
|
||||||
const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),
|
|
||||||
context_size_, shape.data(), shape.size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
|
Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
|
||||||
|
|||||||
@@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
|||||||
public:
|
public:
|
||||||
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
||||||
|
|
||||||
Ort::Value StackStates(const std::vector<Ort::Value> &states) const override;
|
std::vector<Ort::Value> StackStates(
|
||||||
|
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||||
|
|
||||||
std::vector<Ort::Value> UnStackStates(Ort::Value states) const override;
|
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||||
|
const std::vector<Ort::Value> &states) const override;
|
||||||
|
|
||||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||||
Ort::Value features, std::vector<Ort::Value> &states) override;
|
Ort::Value features, std::vector<Ort::Value> &states) override;
|
||||||
|
|
||||||
Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override;
|
Ort::Value BuildDecoderInput(
|
||||||
|
const std::vector<OnlineTransducerDecoderResult> &results) override;
|
||||||
|
|
||||||
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||||
|
|
||||||
@@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
|||||||
int32_t ChunkShift() const override { return decode_chunk_len_; }
|
int32_t ChunkShift() const override { return decode_chunk_len_; }
|
||||||
|
|
||||||
int32_t VocabSize() const override { return vocab_size_; }
|
int32_t VocabSize() const override { return vocab_size_; }
|
||||||
|
OrtAllocator *Allocator() override { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitEncoder(const std::string &encoder_filename);
|
void InitEncoder(const std::string &encoder_filename);
|
||||||
@@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
|||||||
private:
|
private:
|
||||||
Ort::Env env_;
|
Ort::Env env_;
|
||||||
Ort::SessionOptions sess_opts_;
|
Ort::SessionOptions sess_opts_;
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator_;
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -64,39 +65,50 @@ class OnlineRecognizer::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||||
if (n != 1) {
|
|
||||||
fprintf(stderr, "only n == 1 is implemented\n");
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
OnlineStream *s = ss[0];
|
|
||||||
assert(IsReady(s));
|
|
||||||
|
|
||||||
int32_t chunk_size = model_->ChunkSize();
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
int32_t chunk_shift = model_->ChunkShift();
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
|
||||||
int32_t feature_dim = s->FeatureDim();
|
int32_t feature_dim = ss[0]->FeatureDim();
|
||||||
|
|
||||||
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
std::vector<OnlineTransducerDecoderResult> results(n);
|
||||||
|
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||||
|
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
std::vector<float> features =
|
||||||
|
ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size);
|
||||||
|
|
||||||
|
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||||
|
|
||||||
|
std::copy(features.begin(), features.end(),
|
||||||
|
features_vec.data() + i * chunk_size * feature_dim);
|
||||||
|
|
||||||
|
results[i] = std::move(ss[i]->GetResult());
|
||||||
|
states_vec[i] = std::move(ss[i]->GetStates());
|
||||||
|
}
|
||||||
|
|
||||||
auto memory_info =
|
auto memory_info =
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
std::vector<float> features =
|
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
|
||||||
s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
|
|
||||||
|
|
||||||
s->GetNumProcessedFrames() += chunk_shift;
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||||
|
features_vec.size(), x_shape.data(),
|
||||||
|
x_shape.size());
|
||||||
|
|
||||||
Ort::Value x =
|
auto states = model_->StackStates(states_vec);
|
||||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
|
||||||
x_shape.data(), x_shape.size());
|
|
||||||
|
|
||||||
auto pair = model_->RunEncoder(std::move(x), s->GetStates());
|
auto pair = model_->RunEncoder(std::move(x), states);
|
||||||
|
|
||||||
s->SetStates(std::move(pair.second));
|
|
||||||
std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
|
|
||||||
|
|
||||||
decoder_->Decode(std::move(pair.first), &results);
|
decoder_->Decode(std::move(pair.first), &results);
|
||||||
s->SetResult(results[0]);
|
|
||||||
|
std::vector<std::vector<Ort::Value>> next_states =
|
||||||
|
model_->UnStackStates(pair.second);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
ss[i]->SetResult(results[i]);
|
||||||
|
ss[i]->SetStates(std::move(next_states[i]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
||||||
|
|||||||
@@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
|||||||
encoder_out_dim, shape.data(), shape.size());
|
encoder_out_dim, shape.data(), shape.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||||
|
int32_t n) {
|
||||||
|
if (n == 1) {
|
||||||
|
return std::move(*cur_encoder_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> cur_encoder_out_shape =
|
||||||
|
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]};
|
||||||
|
|
||||||
|
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||||
|
ans_shape.size());
|
||||||
|
|
||||||
|
const float *src = cur_encoder_out->GetTensorData<float>();
|
||||||
|
float *dst = ans.GetTensorMutableData<float>();
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
||||||
|
dst += cur_encoder_out_shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult
|
OnlineTransducerDecoderResult
|
||||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||||
int32_t context_size = model_->ContextSize();
|
int32_t context_size = model_->ContextSize();
|
||||||
@@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result->size() != 1) {
|
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||||
fprintf(stderr, "only batch size == 1 is implemented. Given: %d",
|
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||||
static_cast<int32_t>(result->size()));
|
|
||||||
exit(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &hyp = (*result)[0].tokens;
|
|
||||||
|
|
||||||
int32_t num_frames = encoder_out_shape[1];
|
|
||||||
int32_t vocab_size = model_->VocabSize();
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
|
||||||
Ort::Value decoder_input = model_->BuildDecoderInput(hyp);
|
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||||
|
|
||||||
for (int32_t t = 0; t != num_frames; ++t) {
|
for (int32_t t = 0; t != num_frames; ++t) {
|
||||||
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
||||||
|
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
|
||||||
Ort::Value logit =
|
Ort::Value logit =
|
||||||
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
const float *p_logit = logit.GetTensorData<float>();
|
||||||
|
|
||||||
|
bool emitted = false;
|
||||||
|
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
static_cast<const float *>(p_logit) + vocab_size)));
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
if (y != 0) {
|
if (y != 0) {
|
||||||
hyp.push_back(y);
|
emitted = true;
|
||||||
decoder_input = model_->BuildDecoderInput(hyp);
|
(*result)[i].tokens.push_back(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (emitted) {
|
||||||
|
decoder_input = model_->BuildDecoderInput(*result);
|
||||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineTransducerDecoderResult;
|
||||||
|
|
||||||
class OnlineTransducerModel {
|
class OnlineTransducerModel {
|
||||||
public:
|
public:
|
||||||
virtual ~OnlineTransducerModel() = default;
|
virtual ~OnlineTransducerModel() = default;
|
||||||
@@ -27,8 +29,8 @@ class OnlineTransducerModel {
|
|||||||
* @param states states[i] contains the state for the i-th utterance.
|
* @param states states[i] contains the state for the i-th utterance.
|
||||||
* @return Return a single value representing the batched state.
|
* @return Return a single value representing the batched state.
|
||||||
*/
|
*/
|
||||||
virtual Ort::Value StackStates(
|
virtual std::vector<Ort::Value> StackStates(
|
||||||
const std::vector<Ort::Value> &states) const = 0;
|
const std::vector<std::vector<Ort::Value>> &states) const = 0;
|
||||||
|
|
||||||
/** Unstack a batch state into a list of individual states.
|
/** Unstack a batch state into a list of individual states.
|
||||||
*
|
*
|
||||||
@@ -37,7 +39,8 @@ class OnlineTransducerModel {
|
|||||||
* @param states A batched state.
|
* @param states A batched state.
|
||||||
* @return ans[i] contains the state for the i-th utterance.
|
* @return ans[i] contains the state for the i-th utterance.
|
||||||
*/
|
*/
|
||||||
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
|
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||||
|
const std::vector<Ort::Value> &states) const = 0;
|
||||||
|
|
||||||
/** Get the initial encoder states.
|
/** Get the initial encoder states.
|
||||||
*
|
*
|
||||||
@@ -58,7 +61,8 @@ class OnlineTransducerModel {
|
|||||||
Ort::Value features,
|
Ort::Value features,
|
||||||
std::vector<Ort::Value> &states) = 0; // NOLINT
|
std::vector<Ort::Value> &states) = 0; // NOLINT
|
||||||
|
|
||||||
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
|
virtual Ort::Value BuildDecoderInput(
|
||||||
|
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
|
||||||
|
|
||||||
/** Run the decoder network.
|
/** Run the decoder network.
|
||||||
*
|
*
|
||||||
@@ -111,6 +115,7 @@ class OnlineTransducerModel {
|
|||||||
virtual int32_t VocabSize() const = 0;
|
virtual int32_t VocabSize() const = 0;
|
||||||
|
|
||||||
virtual int32_t SubsamplingFactor() const { return 4; }
|
virtual int32_t SubsamplingFactor() const { return 4; }
|
||||||
|
virtual OrtAllocator *Allocator() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
Reference in New Issue
Block a user