Support batch greedy search decoding (#30)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -10,6 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||
@@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value OnlineLstmTransducerModel::StackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
fprintf(stderr, "implement me: %s:%d!\n", __func__,
|
||||
static_cast<int>(__LINE__));
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
int64_t a;
|
||||
std::array<int64_t, 3> x_shape{1, 1, 1};
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);
|
||||
return x;
|
||||
std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
|
||||
std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
|
||||
std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,
|
||||
rnn_hidden_size_};
|
||||
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
|
||||
float *dst_h = h.GetTensorMutableData<float>();
|
||||
float *dst_c = c.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *src_h =
|
||||
states[i][0].GetTensorData<float>() + layer * d_model_;
|
||||
|
||||
const float *src_c =
|
||||
states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;
|
||||
|
||||
std::copy(src_h, src_h + d_model_, dst_h);
|
||||
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
|
||||
|
||||
dst_h += d_model_;
|
||||
dst_c += rnn_hidden_size_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(h));
|
||||
ans.push_back(std::move(c));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(
|
||||
Ort::Value states) const {
|
||||
fprintf(stderr, "implement me: %s:%d!\n", __func__,
|
||||
static_cast<int>(__LINE__));
|
||||
return {};
|
||||
std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
|
||||
// allocate space
|
||||
std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};
|
||||
std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_};
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
ans[i].push_back(std::move(h));
|
||||
ans[i].push_back(std::move(c));
|
||||
}
|
||||
|
||||
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *src_h = states[0].GetTensorData<float>() +
|
||||
layer * batch_size * d_model_ + i * d_model_;
|
||||
const float *src_c = states[1].GetTensorData<float>() +
|
||||
layer * batch_size * rnn_hidden_size_ +
|
||||
i * rnn_hidden_size_;
|
||||
|
||||
float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;
|
||||
float *dst_c =
|
||||
ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;
|
||||
|
||||
std::copy(src_h, src_h + d_model_, dst_h);
|
||||
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||
@@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
||||
}
|
||||
|
||||
Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
|
||||
const std::vector<int64_t> &hyp) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) {
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
std::array<int64_t, 2> shape{batch_size, context_size_};
|
||||
Ort::Value decoder_input =
|
||||
Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
std::array<int64_t, 2> shape{1, context_size_};
|
||||
for (const auto &r : results) {
|
||||
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
|
||||
const int64_t *end = r.tokens.data() + r.tokens.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size_;
|
||||
}
|
||||
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info,
|
||||
const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),
|
||||
context_size_, shape.data(), shape.size());
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
|
||||
|
||||
@@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
public:
|
||||
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
||||
|
||||
Ort::Value StackStates(const std::vector<Ort::Value> &states) const override;
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
std::vector<Ort::Value> UnStackStates(Ort::Value states) const override;
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const override;
|
||||
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> &states) override;
|
||||
|
||||
Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override;
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) override;
|
||||
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||
|
||||
@@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
int32_t ChunkShift() const override { return decode_chunk_len_; }
|
||||
|
||||
int32_t VocabSize() const override { return vocab_size_; }
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
@@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
private:
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
@@ -64,39 +65,50 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
if (n != 1) {
|
||||
fprintf(stderr, "only n == 1 is implemented\n");
|
||||
exit(-1);
|
||||
}
|
||||
OnlineStream *s = ss[0];
|
||||
assert(IsReady(s));
|
||||
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feature_dim = s->FeatureDim();
|
||||
int32_t feature_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
||||
std::vector<OnlineTransducerDecoderResult> results(n);
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size);
|
||||
|
||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
std::copy(features.begin(), features.end(),
|
||||
features_vec.data() + i * chunk_size * feature_dim);
|
||||
|
||||
results[i] = std::move(ss[i]->GetResult());
|
||||
states_vec[i] = std::move(ss[i]->GetStates());
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<float> features =
|
||||
s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
|
||||
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
|
||||
|
||||
s->GetNumProcessedFrames() += chunk_shift;
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||
x_shape.data(), x_shape.size());
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), s->GetStates());
|
||||
|
||||
s->SetStates(std::move(pair.second));
|
||||
std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
|
||||
auto pair = model_->RunEncoder(std::move(x), states);
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
s->SetResult(results[0]);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(pair.second);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
ss[i]->SetResult(results[i]);
|
||||
ss[i]->SetStates(std::move(next_states[i]));
|
||||
}
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
||||
|
||||
@@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
||||
encoder_out_dim, shape.data(), shape.size());
|
||||
}
|
||||
|
||||
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
int32_t n) {
|
||||
if (n == 1) {
|
||||
return std::move(*cur_encoder_out);
|
||||
}
|
||||
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]};
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
|
||||
const float *src = cur_encoder_out->GetTensorData<float>();
|
||||
float *dst = ans.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
||||
dst += cur_encoder_out_shape[1];
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
@@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (result->size() != 1) {
|
||||
fprintf(stderr, "only batch size == 1 is implemented. Given: %d",
|
||||
static_cast<int32_t>(result->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto &hyp = (*result)[0].tokens;
|
||||
|
||||
int32_t num_frames = encoder_out_shape[1];
|
||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(hyp);
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
||||
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
||||
const float *p_logit = logit.GetTensorData<float>();
|
||||
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
if (y != 0) {
|
||||
hyp.push_back(y);
|
||||
decoder_input = model_->BuildDecoderInput(hyp);
|
||||
bool emitted = false;
|
||||
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
if (y != 0) {
|
||||
emitted = true;
|
||||
(*result)[i].tokens.push_back(y);
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
decoder_input = model_->BuildDecoderInput(*result);
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineTransducerDecoderResult;
|
||||
|
||||
class OnlineTransducerModel {
|
||||
public:
|
||||
virtual ~OnlineTransducerModel() = default;
|
||||
@@ -27,8 +29,8 @@ class OnlineTransducerModel {
|
||||
* @param states states[i] contains the state for the i-th utterance.
|
||||
* @return Return a single value representing the batched state.
|
||||
*/
|
||||
virtual Ort::Value StackStates(
|
||||
const std::vector<Ort::Value> &states) const = 0;
|
||||
virtual std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const = 0;
|
||||
|
||||
/** Unstack a batch state into a list of individual states.
|
||||
*
|
||||
@@ -37,7 +39,8 @@ class OnlineTransducerModel {
|
||||
* @param states A batched state.
|
||||
* @return ans[i] contains the state for the i-th utterance.
|
||||
*/
|
||||
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
|
||||
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const = 0;
|
||||
|
||||
/** Get the initial encoder states.
|
||||
*
|
||||
@@ -58,7 +61,8 @@ class OnlineTransducerModel {
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> &states) = 0; // NOLINT
|
||||
|
||||
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
|
||||
virtual Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
|
||||
|
||||
/** Run the decoder network.
|
||||
*
|
||||
@@ -111,6 +115,7 @@ class OnlineTransducerModel {
|
||||
virtual int32_t VocabSize() const = 0;
|
||||
|
||||
virtual int32_t SubsamplingFactor() const { return 4; }
|
||||
virtual OrtAllocator *Allocator() = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user