Support zipformer CTC ASR with whisper features. (#2319)
This commit is contained in:
@@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
|
||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||
|
||||
bool UseWhisperFeature() const { return use_whisper_feature_; }
|
||||
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
// Return a vector containing 3 tensors
|
||||
@@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||
|
||||
std::string feature_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", "");
|
||||
if (feature_type == "whisper") {
|
||||
use_whisper_feature_ = true;
|
||||
}
|
||||
|
||||
{
|
||||
auto shape =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
|
||||
@@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
int32_t T_ = 0;
|
||||
int32_t decode_chunk_len_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
|
||||
// for models from
|
||||
// https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
|
||||
bool use_whisper_feature_ = false;
|
||||
};
|
||||
|
||||
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
|
||||
@@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
|
||||
bool OnlineZipformer2CtcModel::UseWhisperFeature() const {
|
||||
return impl_->UseWhisperFeature();
|
||||
}
|
||||
|
||||
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user