Adding warm up for Zipformer2 (#766)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
This commit is contained in:
@@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("warm-up", &warm_up,
|
||||
"Number of warm-up to run the onnxruntime"
|
||||
"Valid vales are: zipformer2");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
@@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const {
|
||||
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "warm_up=" << warm_up << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
|
||||
@@ -20,6 +20,7 @@ struct OnlineModelConfig {
|
||||
OnlineZipformer2CtcModelConfig zipformer2_ctc;
|
||||
std::string tokens;
|
||||
int32_t num_threads = 1;
|
||||
int32_t warm_up = 0;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
@@ -38,14 +39,17 @@ struct OnlineModelConfig {
|
||||
const OnlineParaformerModelConfig ¶former,
|
||||
const OnlineWenetCtcModelConfig &wenet_ctc,
|
||||
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type)
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
int32_t warm_up, bool debug,
|
||||
const std::string &provider,
|
||||
const std::string &model_type)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
wenet_ctc(wenet_ctc),
|
||||
zipformer2_ctc(zipformer2_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
warm_up(warm_up),
|
||||
debug(debug),
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
|
||||
@@ -37,6 +37,12 @@ class OnlineRecognizerImpl {
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
||||
// ToDo extending to other models
|
||||
SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||
|
||||
virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0;
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
// Warmping up engine with wp: warm_up count and max-batch-size
|
||||
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
||||
auto max_batch_size = mbs;
|
||||
if (warmup <= 0 || warmup > 100) {
|
||||
return;
|
||||
}
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
int32_t feature_dim = 80;
|
||||
std::vector<OnlineTransducerDecoderResult> results(max_batch_size);
|
||||
std::vector<float> features_vec(max_batch_size * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(max_batch_size);
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{max_batch_size, chunk_size, feature_dim};
|
||||
|
||||
for (int32_t i = 0; i != max_batch_size; ++i) {
|
||||
states_vec[i] = model_->GetEncoderInitStates();
|
||||
results[i] = decoder_->GetEmptyResult();
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != warmup; ++i) {
|
||||
auto states = model_->StackStates(states_vec);
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
auto x_copy = Clone(model_->Allocator(), &x);
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(x_copy));
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
@@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
||||
if (warmup > 0) {
|
||||
impl_->WarmpUpRecognizer(warmup, mbs);
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
@@ -162,6 +162,15 @@ class OnlineRecognizer {
|
||||
DecodeStreams(ss, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Warmups up onnxruntime sessions by apply optimization and
|
||||
* allocating memory prior
|
||||
*
|
||||
* @param warmup Number of warmups.
|
||||
* @param mbs : max-batch-size Max batch size for the models
|
||||
*/
|
||||
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
|
||||
|
||||
/** Decode multiple streams in parallel
|
||||
*
|
||||
* @param ss Pointer array containing streams to be decoded.
|
||||
|
||||
@@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
|
||||
c->eof = true;
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::Warmup() const {
|
||||
recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up,
|
||||
config_.max_batch_size);
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::Run() {
|
||||
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
|
||||
|
||||
@@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) {
|
||||
server_.set_reuse_addr(true);
|
||||
server_.listen(asio::ip::tcp::v4(), port);
|
||||
server_.start_accept();
|
||||
auto recognizer_config = config_.decoder_config.recognizer_config;
|
||||
int32_t warm_up = recognizer_config.model_config.warm_up;
|
||||
const std::string &model_type = recognizer_config.model_config.model_type;
|
||||
if (0 < warm_up && warm_up < 100) {
|
||||
if (model_type == "zipformer2") {
|
||||
decoder_.Warmup();
|
||||
SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now.");
|
||||
SHERPA_ONNX_LOGE("Given: %s", model_type.c_str());
|
||||
exit(0);
|
||||
}
|
||||
} else if (warm_up == 0) {
|
||||
SHERPA_ONNX_LOGE("Starting without warmup!");
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100");
|
||||
exit(0);
|
||||
}
|
||||
decoder_.Run();
|
||||
}
|
||||
|
||||
|
||||
@@ -85,6 +85,8 @@ class OnlineWebsocketDecoder {
|
||||
// signal that there will be no more audio samples for a stream
|
||||
void InputFinished(std::shared_ptr<Connection> c);
|
||||
|
||||
void Warmup() const;
|
||||
|
||||
void Run();
|
||||
|
||||
private:
|
||||
|
||||
@@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
.def(py::init<const OnlineTransducerModelConfig &,
|
||||
const OnlineParaformerModelConfig &,
|
||||
const OnlineWenetCtcModelConfig &,
|
||||
const OnlineZipformer2CtcModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
const OnlineZipformer2CtcModelConfig &,
|
||||
const std::string &, int32_t, int32_t,
|
||||
bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
|
||||
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu",
|
||||
py::arg("model_type") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||
|
||||
Reference in New Issue
Block a user