Adding warm up for Zipformer2 (#766)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
This commit is contained in:
@@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("num-threads", &num_threads,
|
po->Register("num-threads", &num_threads,
|
||||||
"Number of threads to run the neural network");
|
"Number of threads to run the neural network");
|
||||||
|
|
||||||
|
po->Register("warm-up", &warm_up,
|
||||||
|
"Number of warm-up to run the onnxruntime"
|
||||||
|
"Valid vales are: zipformer2");
|
||||||
|
|
||||||
po->Register("debug", &debug,
|
po->Register("debug", &debug,
|
||||||
"true to print model information while loading it.");
|
"true to print model information while loading it.");
|
||||||
|
|
||||||
@@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const {
|
|||||||
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
|
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
|
os << "warm_up=" << warm_up << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
os << "provider=\"" << provider << "\", ";
|
os << "provider=\"" << provider << "\", ";
|
||||||
os << "model_type=\"" << model_type << "\")";
|
os << "model_type=\"" << model_type << "\")";
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ struct OnlineModelConfig {
|
|||||||
OnlineZipformer2CtcModelConfig zipformer2_ctc;
|
OnlineZipformer2CtcModelConfig zipformer2_ctc;
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
int32_t num_threads = 1;
|
int32_t num_threads = 1;
|
||||||
|
int32_t warm_up = 0;
|
||||||
bool debug = false;
|
bool debug = false;
|
||||||
std::string provider = "cpu";
|
std::string provider = "cpu";
|
||||||
|
|
||||||
@@ -38,14 +39,17 @@ struct OnlineModelConfig {
|
|||||||
const OnlineParaformerModelConfig ¶former,
|
const OnlineParaformerModelConfig ¶former,
|
||||||
const OnlineWenetCtcModelConfig &wenet_ctc,
|
const OnlineWenetCtcModelConfig &wenet_ctc,
|
||||||
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
|
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads,
|
||||||
const std::string &provider, const std::string &model_type)
|
int32_t warm_up, bool debug,
|
||||||
|
const std::string &provider,
|
||||||
|
const std::string &model_type)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
paraformer(paraformer),
|
paraformer(paraformer),
|
||||||
wenet_ctc(wenet_ctc),
|
wenet_ctc(wenet_ctc),
|
||||||
zipformer2_ctc(zipformer2_ctc),
|
zipformer2_ctc(zipformer2_ctc),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
|
warm_up(warm_up),
|
||||||
debug(debug),
|
debug(debug),
|
||||||
provider(provider),
|
provider(provider),
|
||||||
model_type(model_type) {}
|
model_type(model_type) {}
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ class OnlineRecognizerImpl {
|
|||||||
|
|
||||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||||
|
|
||||||
|
virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
||||||
|
// ToDo extending to other models
|
||||||
|
SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||||
|
|
||||||
virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0;
|
virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0;
|
||||||
|
|||||||
@@ -32,6 +32,7 @@
|
|||||||
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
#include "sherpa-onnx/csrc/utils.h"
|
#include "sherpa-onnx/csrc/utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
s->NumFramesReady();
|
s->NumFramesReady();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Warmping up engine with wp: warm_up count and max-batch-size
|
||||||
|
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
||||||
|
auto max_batch_size = mbs;
|
||||||
|
if (warmup <= 0 || warmup > 100) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
int32_t feature_dim = 80;
|
||||||
|
std::vector<OnlineTransducerDecoderResult> results(max_batch_size);
|
||||||
|
std::vector<float> features_vec(max_batch_size * chunk_size * feature_dim);
|
||||||
|
std::vector<std::vector<Ort::Value>> states_vec(max_batch_size);
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{max_batch_size, chunk_size, feature_dim};
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != max_batch_size; ++i) {
|
||||||
|
states_vec[i] = model_->GetEncoderInitStates();
|
||||||
|
results[i] = decoder_->GetEmptyResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != warmup; ++i) {
|
||||||
|
auto states = model_->StackStates(states_vec);
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||||
|
features_vec.size(), x_shape.data(),
|
||||||
|
x_shape.size());
|
||||||
|
auto x_copy = Clone(model_->Allocator(), &x);
|
||||||
|
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||||
|
std::move(x_copy));
|
||||||
|
decoder_->Decode(std::move(pair.first), &results);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||||
int32_t chunk_size = model_->ChunkSize();
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
int32_t chunk_shift = model_->ChunkShift();
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
|||||||
@@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
|||||||
return impl_->IsReady(s);
|
return impl_->IsReady(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
||||||
|
if (warmup > 0) {
|
||||||
|
impl_->WarmpUpRecognizer(warmup, mbs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||||
impl_->DecodeStreams(ss, n);
|
impl_->DecodeStreams(ss, n);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,6 +162,15 @@ class OnlineRecognizer {
|
|||||||
DecodeStreams(ss, 1);
|
DecodeStreams(ss, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Warmups up onnxruntime sessions by apply optimization and
|
||||||
|
* allocating memory prior
|
||||||
|
*
|
||||||
|
* @param warmup Number of warmups.
|
||||||
|
* @param mbs : max-batch-size Max batch size for the models
|
||||||
|
*/
|
||||||
|
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
|
||||||
|
|
||||||
/** Decode multiple streams in parallel
|
/** Decode multiple streams in parallel
|
||||||
*
|
*
|
||||||
* @param ss Pointer array containing streams to be decoded.
|
* @param ss Pointer array containing streams to be decoded.
|
||||||
|
|||||||
@@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
|
|||||||
c->eof = true;
|
c->eof = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnlineWebsocketDecoder::Warmup() const {
|
||||||
|
recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up,
|
||||||
|
config_.max_batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
void OnlineWebsocketDecoder::Run() {
|
void OnlineWebsocketDecoder::Run() {
|
||||||
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
|
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
|
||||||
|
|
||||||
@@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) {
|
|||||||
server_.set_reuse_addr(true);
|
server_.set_reuse_addr(true);
|
||||||
server_.listen(asio::ip::tcp::v4(), port);
|
server_.listen(asio::ip::tcp::v4(), port);
|
||||||
server_.start_accept();
|
server_.start_accept();
|
||||||
|
auto recognizer_config = config_.decoder_config.recognizer_config;
|
||||||
|
int32_t warm_up = recognizer_config.model_config.warm_up;
|
||||||
|
const std::string &model_type = recognizer_config.model_config.model_type;
|
||||||
|
if (0 < warm_up && warm_up < 100) {
|
||||||
|
if (model_type == "zipformer2") {
|
||||||
|
decoder_.Warmup();
|
||||||
|
SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now.");
|
||||||
|
SHERPA_ONNX_LOGE("Given: %s", model_type.c_str());
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
} else if (warm_up == 0) {
|
||||||
|
SHERPA_ONNX_LOGE("Starting without warmup!");
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100");
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
decoder_.Run();
|
decoder_.Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -85,6 +85,8 @@ class OnlineWebsocketDecoder {
|
|||||||
// signal that there will be no more audio samples for a stream
|
// signal that there will be no more audio samples for a stream
|
||||||
void InputFinished(std::shared_ptr<Connection> c);
|
void InputFinished(std::shared_ptr<Connection> c);
|
||||||
|
|
||||||
|
void Warmup() const;
|
||||||
|
|
||||||
void Run();
|
void Run();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) {
|
|||||||
.def(py::init<const OnlineTransducerModelConfig &,
|
.def(py::init<const OnlineTransducerModelConfig &,
|
||||||
const OnlineParaformerModelConfig &,
|
const OnlineParaformerModelConfig &,
|
||||||
const OnlineWenetCtcModelConfig &,
|
const OnlineWenetCtcModelConfig &,
|
||||||
const OnlineZipformer2CtcModelConfig &, const std::string &,
|
const OnlineZipformer2CtcModelConfig &,
|
||||||
int32_t, bool, const std::string &, const std::string &>(),
|
const std::string &, int32_t, int32_t,
|
||||||
|
bool, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||||
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||||
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
|
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
|
||||||
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
|
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
|
||||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
py::arg("debug") = false, py::arg("provider") = "cpu",
|
||||||
|
py::arg("model_type") = "")
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||||
|
|||||||
Reference in New Issue
Block a user