diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 6e0ab6d7..16431c9b 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("num-threads", &num_threads, "Number of threads to run the neural network"); + po->Register("warm-up", &warm_up, + "Number of warm-up to run the onnxruntime" + "Valid vales are: zipformer2"); + po->Register("debug", &debug, "true to print model information while loading it."); @@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const { os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; + os << "warm_up=" << warm_up << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; os << "provider=\"" << provider << "\", "; os << "model_type=\"" << model_type << "\")"; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index bedabf11..d9616867 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -20,6 +20,7 @@ struct OnlineModelConfig { OnlineZipformer2CtcModelConfig zipformer2_ctc; std::string tokens; int32_t num_threads = 1; + int32_t warm_up = 0; bool debug = false; std::string provider = "cpu"; @@ -38,14 +39,17 @@ struct OnlineModelConfig { const OnlineParaformerModelConfig ¶former, const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineZipformer2CtcModelConfig &zipformer2_ctc, - const std::string &tokens, int32_t num_threads, bool debug, - const std::string &provider, const std::string &model_type) + const std::string &tokens, int32_t num_threads, + int32_t warm_up, bool debug, + const std::string &provider, + const std::string &model_type) : transducer(transducer), paraformer(paraformer), wenet_ctc(wenet_ctc), zipformer2_ctc(zipformer2_ctc), tokens(tokens), num_threads(num_threads), + warm_up(warm_up), debug(debug), provider(provider), model_type(model_type) {} diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index db07ffa5..72efedec 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -37,6 +37,12 @@ class OnlineRecognizerImpl { virtual bool IsReady(OnlineStream *s) const = 0; + virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { + // ToDo extending to other models + SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now."); + exit(-1); + } + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 0fa3acac..add0b85d 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -32,6 +32,7 @@ #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/utils.h" +#include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { @@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { s->NumFramesReady(); } + // Warmping up engine with wp: warm_up count and max-batch-size + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { + auto max_batch_size = mbs; + if (warmup <= 0 || warmup > 100) { + return; + } + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + int32_t feature_dim = 80; + std::vector results(max_batch_size); + std::vector features_vec(max_batch_size * chunk_size * feature_dim); + std::vector> states_vec(max_batch_size); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{max_batch_size, chunk_size, feature_dim}; + + for (int32_t i = 0; i != max_batch_size; ++i) { + states_vec[i] = model_->GetEncoderInitStates(); + results[i] = decoder_->GetEmptyResult(); + } + + for (int32_t i = 0; i != warmup; ++i) { + auto states = model_->StackStates(states_vec); + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + auto x_copy = Clone(model_->Allocator(), &x); + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(x_copy)); + decoder_->Decode(std::move(pair.first), &results); + } + } + void DecodeStreams(OnlineStream **ss, int32_t n) const override { int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 5d344565..8bd0c16a 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { return impl_->IsReady(s); } +void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { + if (warmup > 0) { + impl_->WarmpUpRecognizer(warmup, mbs); + } +} + void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { impl_->DecodeStreams(ss, n); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index e7f1b38d..c1d2e9a7 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -162,6 +162,15 @@ class OnlineRecognizer { DecodeStreams(ss, 1); } + /** + * Warmups up onnxruntime sessions by apply optimization and + * allocating memory prior + * + * @param warmup Number of warmups. + * @param mbs : max-batch-size Max batch size for the models + */ + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const; + /** Decode multiple streams in parallel * * @param ss Pointer array containing streams to be decoded. diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.cc b/sherpa-onnx/csrc/online-websocket-server-impl.cc index d02a4913..11651f07 100644 --- a/sherpa-onnx/csrc/online-websocket-server-impl.cc +++ b/sherpa-onnx/csrc/online-websocket-server-impl.cc @@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr c) { c->eof = true; } +void OnlineWebsocketDecoder::Warmup() const { + recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up, + config_.max_batch_size); +} + void OnlineWebsocketDecoder::Run() { timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); @@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) { server_.set_reuse_addr(true); server_.listen(asio::ip::tcp::v4(), port); server_.start_accept(); + auto recognizer_config = config_.decoder_config.recognizer_config; + int32_t warm_up = recognizer_config.model_config.warm_up; + const std::string &model_type = recognizer_config.model_config.model_type; + if (0 < warm_up && warm_up < 100) { + if (model_type == "zipformer2") { + decoder_.Warmup(); + SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up); + } else { + SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now."); + SHERPA_ONNX_LOGE("Given: %s", model_type.c_str()); + exit(0); + } + } else if (warm_up == 0) { + SHERPA_ONNX_LOGE("Starting without warmup!"); + } else { + SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100"); + exit(0); + } decoder_.Run(); } diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.h b/sherpa-onnx/csrc/online-websocket-server-impl.h index 9716c5c7..4e0582db 100644 --- a/sherpa-onnx/csrc/online-websocket-server-impl.h +++ b/sherpa-onnx/csrc/online-websocket-server-impl.h @@ -85,6 +85,8 @@ class OnlineWebsocketDecoder { // signal that there will be no more audio samples for a stream void InputFinished(std::shared_ptr c); + void Warmup() const; + void Run(); private: diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 9a847351..2b4a8776 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) { .def(py::init(), + const OnlineZipformer2CtcModelConfig &, + const std::string &, int32_t, int32_t, + bool, const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu", py::arg("model_type") = "") + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, + py::arg("debug") = false, py::arg("provider") = "cpu", + py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)