Add C++ runtime for parakeet-tdt-0.6b-v2. (#2181)
This commit is contained in:
@@ -138,6 +138,12 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
|
||||
private:
|
||||
void PostInit() {
|
||||
int32_t feat_dim = model_->FeatureDim();
|
||||
|
||||
if (feat_dim > 0) {
|
||||
config_.feat_config.feature_dim = feat_dim;
|
||||
}
|
||||
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
|
||||
@@ -164,6 +164,8 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
|
||||
bool IsGigaAM() const { return is_giga_am_; }
|
||||
|
||||
int32_t FeatureDim() const { return feat_dim_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
@@ -201,6 +203,7 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
|
||||
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(feat_dim_, "feat_dim", -1);
|
||||
|
||||
if (normalize_type_ == "NA") {
|
||||
normalize_type_ = "";
|
||||
@@ -263,6 +266,11 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
int32_t pred_rnn_layers_ = -1;
|
||||
int32_t pred_hidden_ = -1;
|
||||
int32_t is_giga_am_ = 0;
|
||||
|
||||
// giga am uses 64
|
||||
// parakeet-tdt-0.6b-v2 uses 128
|
||||
// others use 80
|
||||
int32_t feat_dim_ = -1; // -1 means to use default values.
|
||||
};
|
||||
|
||||
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||
@@ -317,6 +325,10 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
|
||||
|
||||
bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); }
|
||||
|
||||
int32_t OfflineTransducerNeMoModel::FeatureDim() const {
|
||||
return impl_->FeatureDim();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
|
||||
@@ -88,6 +88,8 @@ class OfflineTransducerNeMoModel {
|
||||
|
||||
bool IsGigaAM() const;
|
||||
|
||||
int32_t FeatureDim() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
Reference in New Issue
Block a user