Fix nemo streaming transducer greedy search (#944)
This commit is contained in:
@@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
|
||||
: config_(config),
|
||||
@@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
std::vector<Ort::Value> states) {
|
||||
Ort::Value &cache_last_channel = states[0];
|
||||
Ort::Value &cache_last_time = states[1];
|
||||
Ort::Value &cache_last_channel_len = states[2];
|
||||
@@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
std::move(features), View(&length), std::move(cache_last_channel),
|
||||
std::move(cache_last_time), std::move(cache_last_channel_len)};
|
||||
|
||||
auto out =
|
||||
encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||
auto out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||
// out[0]: logit
|
||||
// out[1] logit_length
|
||||
// out[2:] states_next
|
||||
@@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
Ort::Value targets, std::vector<Ort::Value> states) {
|
||||
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
Ort::MemoryInfo memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
// Create the tensor with a single int32_t value of 1
|
||||
int32_t length_value = 1;
|
||||
std::vector<int64_t> length_shape = {1};
|
||||
auto shape = targets.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = static_cast<int32_t>(shape[0]);
|
||||
|
||||
std::vector<int64_t> length_shape = {batch_size};
|
||||
std::vector<int32_t> length_value(batch_size, 1);
|
||||
|
||||
Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>(
|
||||
memory_info, &length_value, 1, length_shape.data(), length_shape.size()
|
||||
);
|
||||
|
||||
memory_info, length_value.data(), batch_size, length_shape.data(),
|
||||
length_shape.size());
|
||||
|
||||
std::vector<Ort::Value> decoder_inputs;
|
||||
decoder_inputs.reserve(2 + states.size());
|
||||
|
||||
@@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
|
||||
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||
std::move(decoder_out)};
|
||||
auto logit =
|
||||
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
|
||||
joiner_input.size(), joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
|
||||
joiner_input.data(), joiner_input.size(),
|
||||
joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
|
||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||
s0_shape.size());
|
||||
std::vector<Ort::Value> GetDecoderInitStates() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(View(&lstm0_));
|
||||
ans.push_back(View(&lstm1_));
|
||||
|
||||
Fill<float>(&s0, 0);
|
||||
|
||||
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
|
||||
Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
|
||||
s1_shape.size());
|
||||
|
||||
Fill<float>(&s1, 0);
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
states.reserve(2);
|
||||
states.push_back(std::move(s0));
|
||||
states.push_back(std::move(s1));
|
||||
|
||||
return states;
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t ChunkSize() const { return window_size_; }
|
||||
@@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
int32_t ChunkShift() const { return chunk_shift_; }
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
@@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
// - cache_last_channel
|
||||
// - cache_last_time_
|
||||
// - cache_last_channel_len
|
||||
std::vector<Ort::Value> GetInitStates() {
|
||||
std::vector<Ort::Value> GetEncoderInitStates() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(3);
|
||||
ans.push_back(View(&cache_last_channel_));
|
||||
@@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
if (batch_size == 1) {
|
||||
return std::move(states[0]);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
|
||||
// stack cache_last_channel
|
||||
std::vector<const Ort::Value *> buf(batch_size);
|
||||
|
||||
// there are 3 states to be stacked
|
||||
for (int32_t i = 0; i != 3; ++i) {
|
||||
buf.clear();
|
||||
buf.reserve(batch_size);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
assert(states[b].size() == 3);
|
||||
buf.push_back(&states[b][i]);
|
||||
}
|
||||
|
||||
Ort::Value c{nullptr};
|
||||
if (i == 2) {
|
||||
c = Cat<int64_t>(allocator_, buf, 0);
|
||||
} else {
|
||||
c = Cat(allocator_, buf, 0);
|
||||
}
|
||||
|
||||
ans.push_back(std::move(c));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
assert(states.size() == 3);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
|
||||
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = shape[0];
|
||||
ans.resize(batch_size);
|
||||
|
||||
if (batch_size == 1) {
|
||||
ans[0] = std::move(states);
|
||||
return ans;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != 3; ++i) {
|
||||
std::vector<Ort::Value> v;
|
||||
if (i == 2) {
|
||||
v = Unbind<int64_t>(allocator_, &states[i], 0);
|
||||
} else {
|
||||
v = Unbind(allocator_, &states[i], 0);
|
||||
}
|
||||
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
ans[b].push_back(std::move(v[b]));
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
@@ -276,10 +332,10 @@ private:
|
||||
normalize_type_ = "";
|
||||
}
|
||||
|
||||
InitStates();
|
||||
InitEncoderStates();
|
||||
}
|
||||
|
||||
void InitStates() {
|
||||
|
||||
void InitEncoderStates() {
|
||||
std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
|
||||
cache_last_channel_dim2_,
|
||||
cache_last_channel_dim3_};
|
||||
@@ -313,7 +369,25 @@ private:
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
&decoder_output_names_ptr_);
|
||||
|
||||
InitDecoderStates();
|
||||
}
|
||||
|
||||
void InitDecoderStates() {
|
||||
int32_t batch_size = 1;
|
||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
lstm0_ = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||
s0_shape.size());
|
||||
|
||||
Fill<float>(&lstm0_, 0);
|
||||
|
||||
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
|
||||
lstm1_ = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
|
||||
s1_shape.size());
|
||||
|
||||
Fill<float>(&lstm1_, 0);
|
||||
}
|
||||
|
||||
void InitJoiner(void *model_data, size_t model_data_length) {
|
||||
@@ -324,7 +398,7 @@ private:
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
&joiner_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -363,6 +437,7 @@ private:
|
||||
int32_t pred_rnn_layers_ = -1;
|
||||
int32_t pred_hidden_ = -1;
|
||||
|
||||
// encoder states
|
||||
int32_t cache_last_channel_dim1_;
|
||||
int32_t cache_last_channel_dim2_;
|
||||
int32_t cache_last_channel_dim3_;
|
||||
@@ -370,9 +445,14 @@ private:
|
||||
int32_t cache_last_time_dim2_;
|
||||
int32_t cache_last_time_dim3_;
|
||||
|
||||
// init encoder states
|
||||
Ort::Value cache_last_channel_{nullptr};
|
||||
Ort::Value cache_last_time_{nullptr};
|
||||
Ort::Value cache_last_channel_len_{nullptr};
|
||||
|
||||
// init decoder states
|
||||
Ort::Value lstm0_{nullptr};
|
||||
Ort::Value lstm1_{nullptr};
|
||||
};
|
||||
|
||||
OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
|
||||
@@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
|
||||
|
||||
OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default;
|
||||
|
||||
std::vector<Ort::Value>
|
||||
OnlineTransducerNeMoModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->RunEncoder(std::move(features), std::move(states));
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states) const {
|
||||
return impl_->RunEncoder(std::move(features), std::move(states));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
@@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets,
|
||||
return impl_->RunDecoder(std::move(targets), std::move(states));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates(
|
||||
int32_t batch_size) const {
|
||||
return impl_->GetDecoderInitStates(batch_size);
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates()
|
||||
const {
|
||||
return impl_->GetDecoderInitStates();
|
||||
}
|
||||
|
||||
Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
|
||||
@@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
|
||||
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::ChunkSize() const {
|
||||
return impl_->ChunkSize();
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::ChunkSize() const {
|
||||
return impl_->ChunkSize();
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
int32_t OnlineTransducerNeMoModel::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
|
||||
int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
@@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
|
||||
return impl_->FeatureNormalizationMethod();
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const {
|
||||
return impl_->GetInitStates();
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetEncoderInitStates()
|
||||
const {
|
||||
return impl_->GetEncoderInitStates();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
std::vector<Ort::Value> OnlineTransducerNeMoModel::StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
return impl_->StackStates(std::move(states));
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> OnlineTransducerNeMoModel::UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->UnStackStates(std::move(states));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user