diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index f55c804a..20876791 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -3,6 +3,7 @@ // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" +#include #include #include #include @@ -10,6 +11,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ @@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { } } -Ort::Value OnlineLstmTransducerModel::StackStates( - const std::vector &states) const { - fprintf(stderr, "implement me: %s:%d!\n", __func__, - static_cast(__LINE__)); - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - int64_t a; - std::array x_shape{1, 1, 1}; - Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0); - return x; +std::vector OnlineLstmTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::array h_shape{num_encoder_layers_, batch_size, d_model_}; + Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), + h_shape.size()); + + std::array c_shape{num_encoder_layers_, batch_size, + rnn_hidden_size_}; + + Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), + c_shape.size()); + + float *dst_h = h.GetTensorMutableData(); + float *dst_c = c.GetTensorMutableData(); + + for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { + for (int32_t i = 0; i != batch_size; ++i) { + const float *src_h = + states[i][0].GetTensorData() + layer * d_model_; + + const float *src_c = + states[i][1].GetTensorData() + layer * rnn_hidden_size_; + + std::copy(src_h, src_h + d_model_, dst_h); + std::copy(src_c, src_c + rnn_hidden_size_, dst_c); + + dst_h += d_model_; + dst_c += rnn_hidden_size_; + } + } + + std::vector ans; + + ans.reserve(2); + ans.push_back(std::move(h)); + ans.push_back(std::move(c)); + + return ans; } -std::vector OnlineLstmTransducerModel::UnStackStates( - Ort::Value states) const { - fprintf(stderr, "implement me: %s:%d!\n", __func__, - static_cast(__LINE__)); - return {}; +std::vector> OnlineLstmTransducerModel::UnStackStates( + const std::vector &states) const { + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + + std::vector> ans(batch_size); + + // allocate space + std::array h_shape{num_encoder_layers_, 1, d_model_}; + std::array c_shape{num_encoder_layers_, 1, rnn_hidden_size_}; + + for (int32_t i = 0; i != batch_size; ++i) { + Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), + h_shape.size()); + Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), + c_shape.size()); + ans[i].push_back(std::move(h)); + ans[i].push_back(std::move(c)); + } + + for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { + for (int32_t i = 0; i != batch_size; ++i) { + const float *src_h = states[0].GetTensorData() + + layer * batch_size * d_model_ + i * d_model_; + const float *src_c = states[1].GetTensorData() + + layer * batch_size * rnn_hidden_size_ + + i * rnn_hidden_size_; + + float *dst_h = ans[i][0].GetTensorMutableData() + layer * d_model_; + float *dst_c = + ans[i][1].GetTensorMutableData() + layer * rnn_hidden_size_; + + std::copy(src_h, src_h + d_model_, dst_h); + std::copy(src_c, src_c + rnn_hidden_size_, dst_c); + } + } + + return ans; } std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { @@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, } Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( - const std::vector &hyp) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + const std::vector &results) { + int32_t batch_size = static_cast(results.size()); + std::array shape{batch_size, context_size_}; + Ort::Value decoder_input = + Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); + int64_t *p = decoder_input.GetTensorMutableData(); - std::array shape{1, context_size_}; + for (const auto &r : results) { + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; + const int64_t *end = r.tokens.data() + r.tokens.size(); + std::copy(begin, end, p); + p += context_size_; + } - return Ort::Value::CreateTensor( - memory_info, - const_cast(hyp.data() + hyp.size() - context_size_), - context_size_, shape.data(), shape.size()); + return decoder_input; } Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index 6dc03d8e..5fc23260 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { public: explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); - Ort::Value StackStates(const std::vector &states) const override; + std::vector StackStates( + const std::vector> &states) const override; - std::vector UnStackStates(Ort::Value states) const override; + std::vector> UnStackStates( + const std::vector &states) const override; std::vector GetEncoderInitStates() override; std::pair> RunEncoder( Ort::Value features, std::vector &states) override; - Ort::Value BuildDecoderInput(const std::vector &hyp) override; + Ort::Value BuildDecoderInput( + const std::vector &results) override; Ort::Value RunDecoder(Ort::Value decoder_input) override; @@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { int32_t ChunkShift() const override { return decode_chunk_len_; } int32_t VocabSize() const override { return vocab_size_; } + OrtAllocator *Allocator() override { return allocator_; } private: void InitEncoder(const std::string &encoder_filename); @@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { private: Ort::Env env_; Ort::SessionOptions sess_opts_; - Ort::AllocatorWithDefaultOptions allocator_; std::unique_ptr encoder_sess_; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 3a7b42cf..2f0fdbf4 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -6,6 +6,7 @@ #include +#include #include #include #include @@ -64,39 +65,50 @@ class OnlineRecognizer::Impl { } void DecodeStreams(OnlineStream **ss, int32_t n) { - if (n != 1) { - fprintf(stderr, "only n == 1 is implemented\n"); - exit(-1); - } - OnlineStream *s = ss[0]; - assert(IsReady(s)); - int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); - int32_t feature_dim = s->FeatureDim(); + int32_t feature_dim = ss[0]->FeatureDim(); - std::array x_shape{1, chunk_size, feature_dim}; + std::vector results(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> states_vec(n); + + for (int32_t i = 0; i != n; ++i) { + std::vector features = + ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size); + + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + results[i] = std::move(ss[i]->GetResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + } auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::vector features = - s->GetFrames(s->GetNumProcessedFrames(), chunk_size); + std::array x_shape{n, chunk_size, feature_dim}; - s->GetNumProcessedFrames() += chunk_shift; + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); - Ort::Value x = - Ort::Value::CreateTensor(memory_info, features.data(), features.size(), - x_shape.data(), x_shape.size()); + auto states = model_->StackStates(states_vec); - auto pair = model_->RunEncoder(std::move(x), s->GetStates()); - - s->SetStates(std::move(pair.second)); - std::vector results = {s->GetResult()}; + auto pair = model_->RunEncoder(std::move(x), states); decoder_->Decode(std::move(pair.first), &results); - s->SetResult(results[0]); + + std::vector> next_states = + model_->UnStackStates(pair.second); + + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetResult(results[i]); + ss[i]->SetStates(std::move(next_states[i])); + } } OnlineRecognizerResult GetResult(OnlineStream *s) { diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index e1aafbca..e628cc7c 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { encoder_out_dim, shape.data(), shape.size()); } +static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, + int32_t n) { + if (n == 1) { + return std::move(*cur_encoder_out); + } + + std::vector cur_encoder_out_shape = + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); + + std::array ans_shape{n, cur_encoder_out_shape[1]}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + + const float *src = cur_encoder_out->GetTensorData(); + float *dst = ans.GetTensorMutableData(); + for (int32_t i = 0; i != n; ++i) { + std::copy(src, src + cur_encoder_out_shape[1], dst); + dst += cur_encoder_out_shape[1]; + } + + return ans; +} + OnlineTransducerDecoderResult OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); @@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode( exit(-1); } - if (result->size() != 1) { - fprintf(stderr, "only batch size == 1 is implemented. Given: %d", - static_cast(result->size())); - exit(-1); - } - - auto &hyp = (*result)[0].tokens; - - int32_t num_frames = encoder_out_shape[1]; + int32_t batch_size = static_cast(encoder_out_shape[0]); + int32_t num_frames = static_cast(encoder_out_shape[1]); int32_t vocab_size = model_->VocabSize(); - Ort::Value decoder_input = model_->BuildDecoderInput(hyp); + Ort::Value decoder_input = model_->BuildDecoderInput(*result); Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); for (int32_t t = 0; t != num_frames; ++t) { Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); + cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); const float *p_logit = logit.GetTensorData(); - auto y = static_cast(std::distance( - static_cast(p_logit), - std::max_element(static_cast(p_logit), - static_cast(p_logit) + vocab_size))); - if (y != 0) { - hyp.push_back(y); - decoder_input = model_->BuildDecoderInput(hyp); + bool emitted = false; + for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + if (y != 0) { + emitted = true; + (*result)[i].tokens.push_back(y); + } + } + if (emitted) { + decoder_input = model_->BuildDecoderInput(*result); decoder_out = model_->RunDecoder(std::move(decoder_input)); } } diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index dfcf9452..8f33b818 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -13,6 +13,8 @@ namespace sherpa_onnx { +class OnlineTransducerDecoderResult; + class OnlineTransducerModel { public: virtual ~OnlineTransducerModel() = default; @@ -27,8 +29,8 @@ class OnlineTransducerModel { * @param states states[i] contains the state for the i-th utterance. * @return Return a single value representing the batched state. */ - virtual Ort::Value StackStates( - const std::vector &states) const = 0; + virtual std::vector StackStates( + const std::vector> &states) const = 0; /** Unstack a batch state into a list of individual states. * @@ -37,7 +39,8 @@ class OnlineTransducerModel { * @param states A batched state. * @return ans[i] contains the state for the i-th utterance. */ - virtual std::vector UnStackStates(Ort::Value states) const = 0; + virtual std::vector> UnStackStates( + const std::vector &states) const = 0; /** Get the initial encoder states. * @@ -58,7 +61,8 @@ class OnlineTransducerModel { Ort::Value features, std::vector &states) = 0; // NOLINT - virtual Ort::Value BuildDecoderInput(const std::vector &hyp) = 0; + virtual Ort::Value BuildDecoderInput( + const std::vector &results) = 0; /** Run the decoder network. * @@ -111,6 +115,7 @@ class OnlineTransducerModel { virtual int32_t VocabSize() const = 0; virtual int32_t SubsamplingFactor() const { return 4; } + virtual OrtAllocator *Allocator() = 0; }; } // namespace sherpa_onnx