Allow modify model config at decode time for ASR (#1124)

This commit is contained in:
ivan provalov
2024-07-13 07:30:47 -07:00
committed by GitHub
parent ab71c3976d
commit de04b3b9bf
15 changed files with 121 additions and 13 deletions

View File

@@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.

View File

@@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
return text;
}
void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
config_ = config;
}
} // namespace sherpa_onnx

View File

@@ -48,6 +48,10 @@ class OfflineRecognizerImpl {
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
virtual void SetConfig(const OfflineRecognizerConfig &config);
virtual OfflineRecognizerConfig GetConfig() const = 0;
std::string ApplyInverseTextNormalization(std::string text) const;
private:

View File

@@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
private:
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
int32_t lfr_window_size = model_->LfrWindowSize();

View File

@@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
void InitHotwords() {
// each line in hotwords_file contains space-separated words

View File

@@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
private:
void PostInit() {
config_.feat_config.nemo_normalize_type =

View File

@@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
}
r.text = text;
r.lang = src.lang;
return r;
}
@@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}
void SetConfig(const OfflineRecognizerConfig &config) override {
config_.model_config.whisper = config.model_config.whisper;
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
private:
void DecodeStream(OfflineStream *s) const {
decoder_->SetConfig(config_.model_config.whisper);
int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

View File

@@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
impl_->SetConfig(config);
}
OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
return impl_->GetConfig();
}
} // namespace sherpa_onnx

View File

@@ -119,6 +119,15 @@ class OfflineRecognizer {
*/
void DecodeStreams(OfflineStream **ss, int32_t n) const;
/** Onnxruntime Session objects are not affected by this method.
* The exact behavior can be defined by a specific recognizer impl.
* For instance, for the whisper recognizer, you can retrieve the language and task from
* the config and ignore any remaining fields in `config`.
*/
void SetConfig(const OfflineRecognizerConfig &config);
OfflineRecognizerConfig GetConfig() const;
private:
std::unique_ptr<OfflineRecognizerImpl> impl_;
};

View File

@@ -26,7 +26,9 @@ struct OfflineRecognitionResult {
// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;
/// timestamps.size() == tokens.size()
std::string lang;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;

View File

@@ -6,14 +6,17 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#include <vector>
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
struct OfflineWhisperDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
std::string lang;
};
class OfflineWhisperDecoder {
@@ -31,6 +34,9 @@ class OfflineWhisperDecoder {
*/
virtual std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
};
} // namespace sherpa_onnx

View File

@@ -12,6 +12,10 @@
namespace sherpa_onnx {
void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
config_ = config;
}
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
@@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std::vector<OfflineWhisperDecoderResult> ans(1);
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(initial_tokens[1])) {
ans[0].lang = id2lang.at(initial_tokens[1]);
} else {
ans[0].lang = "";
}
ans[0].tokens = std::move(predicted_tokens);
return ans;

View File

@@ -8,7 +8,6 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace sherpa_onnx {
@@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
void SetConfig(const OfflineWhisperModelConfig &config) override;
private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned