Support streaming zipformer CTC (#496)
* Support streaming zipformer CTC * test online zipformer2 CTC * Update doc of sherpa-onnx.cc * Add Python APIs for streaming zipformer2 ctc * Add Python API examples for streaming zipformer2 ctc * Swift API for streaming zipformer2 CTC * NodeJS API for streaming zipformer2 CTC * Kotlin API for streaming zipformer2 CTC * Golang API for streaming zipformer2 CTC * C# API for streaming zipformer2 CTC * Release v1.9.6
This commit is contained in:
@@ -96,8 +96,67 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
if (n == 1 || !model_->SupportBatchProcessing()) {
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// batch processing
|
||||
int32_t chunk_length = model_->ChunkLength();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feat_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::vector<OnlineCtcDecoderResult> results(n);
|
||||
std::vector<float> features_vec(n * chunk_length * feat_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_length);
|
||||
|
||||
// Question: should num_processed_frames include chunk_shift?
|
||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
std::copy(features.begin(), features.end(),
|
||||
features_vec.data() + i * chunk_length * feat_dim);
|
||||
|
||||
results[i] = std::move(ss[i]->GetCtcResult());
|
||||
states_vec[i] = std::move(ss[i]->GetStates());
|
||||
all_processed_frames[i] = num_processed_frames;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{n, chunk_length, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
auto states = model_->StackStates(std::move(states_vec));
|
||||
int32_t num_states = states.size();
|
||||
auto out = model_->Forward(std::move(x), std::move(states));
|
||||
std::vector<Ort::Value> out_states;
|
||||
out_states.reserve(num_states);
|
||||
|
||||
for (int32_t k = 1; k != num_states + 1; ++k) {
|
||||
out_states.push_back(std::move(out[k]));
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
ss[k]->SetStates(std::move(next_states[k]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user