Reduce model initialization time for online speech recognition (#215)
* Reduce model initialization time for online speech recognition * Fixed Styling --------- Co-authored-by: w11wo <wilsowong961@gmail.com>
This commit is contained in:
@@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
const OnlineTransducerModelConfig &config) {
|
||||
if (!config.model_type.empty()) {
|
||||
const auto &model_type = config.model_type;
|
||||
if (model_type == "conformer") {
|
||||
return std::make_unique<OnlineConformerTransducerModel>(config);
|
||||
} else if (model_type == "lstm") {
|
||||
return std::make_unique<OnlineLstmTransducerModel>(config);
|
||||
} else if (model_type == "zipformer") {
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||
} else if (model_type == "zipformer2") {
|
||||
return std::make_unique<OnlineZipformer2TransducerModel>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
|
||||
{
|
||||
@@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput(
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
|
||||
if (!config.model_type.empty()) {
|
||||
const auto &model_type = config.model_type;
|
||||
if (model_type == "conformer") {
|
||||
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
|
||||
} else if (model_type == "lstm") {
|
||||
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
|
||||
} else if (model_type == "zipformer") {
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||
} else if (model_type == "zipformer2") {
|
||||
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
auto buffer = ReadFile(mgr, config.encoder_filename);
|
||||
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user