Reduce model initialization time for online speech recognition (#215)

* Reduce model initialization time for online speech recognition

* Fixed Styling

---------

Co-authored-by: w11wo <wilsowong961@gmail.com>
This commit is contained in:
Wilson Wongso
2023-07-14 20:20:10 +07:00
committed by GitHub
parent fe0630fe1f
commit 5a6b55c5a7
7 changed files with 69 additions and 8 deletions

View File

@@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
const OnlineTransducerModelConfig &config) {
if (!config.model_type.empty()) {
const auto &model_type = config.model_type;
if (model_type == "conformer") {
return std::make_unique<OnlineConformerTransducerModel>(config);
} else if (model_type == "lstm") {
return std::make_unique<OnlineLstmTransducerModel>(config);
} else if (model_type == "zipformer") {
return std::make_unique<OnlineZipformerTransducerModel>(config);
} else if (model_type == "zipformer2") {
return std::make_unique<OnlineZipformer2TransducerModel>(config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
model_type.c_str());
}
}
ModelType model_type = ModelType::kUnkown;
{
@@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput(
#if __ANDROID_API__ >= 9
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
if (!config.model_type.empty()) {
const auto &model_type = config.model_type;
if (model_type == "conformer") {
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
} else if (model_type == "lstm") {
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
} else if (model_type == "zipformer") {
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
} else if (model_type == "zipformer2") {
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
model_type.c_str());
}
}
auto buffer = ReadFile(mgr, config.encoder_filename);
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);