Allow modify model config at decode time for ASR (#1124)
This commit is contained in:
@@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream {
|
|||||||
: impl(std::move(p)) {}
|
: impl(std::move(p)) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
||||||
|
const SherpaOnnxOfflineRecognizerConfig *config);
|
||||||
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
||||||
const SherpaOnnxOfflineRecognizerConfig *config) {
|
const SherpaOnnxOfflineRecognizerConfig *config) {
|
||||||
|
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
|
||||||
|
convertConfig(config);
|
||||||
|
|
||||||
|
if (!recognizer_config.Validate()) {
|
||||||
|
SHERPA_ONNX_LOGE("Errors in config");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;
|
||||||
|
|
||||||
|
recognizer->impl =
|
||||||
|
std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);
|
||||||
|
|
||||||
|
return recognizer;
|
||||||
|
}
|
||||||
|
sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
||||||
|
const SherpaOnnxOfflineRecognizerConfig *config) {
|
||||||
sherpa_onnx::OfflineRecognizerConfig recognizer_config;
|
sherpa_onnx::OfflineRecognizerConfig recognizer_config;
|
||||||
|
|
||||||
recognizer_config.feat_config.sampling_rate =
|
recognizer_config.feat_config.sampling_rate =
|
||||||
@@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
|||||||
SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str());
|
SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!recognizer_config.Validate()) {
|
return recognizer_config;
|
||||||
SHERPA_ONNX_LOGE("Errors in config");
|
}
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;
|
void SherpaOnnxOfflineRecognizerSetConfig(
|
||||||
|
const SherpaOnnxOfflineRecognizer *recognizer,
|
||||||
recognizer->impl =
|
const SherpaOnnxOfflineRecognizerConfig *config){
|
||||||
std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);
|
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
|
||||||
|
convertConfig(config);
|
||||||
return recognizer;
|
recognizer->impl->SetConfig(recognizer_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
|
void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
|
||||||
@@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
|
|||||||
pText[text.size()] = 0;
|
pText[text.size()] = 0;
|
||||||
r->text = pText;
|
r->text = pText;
|
||||||
|
|
||||||
|
//lang
|
||||||
|
const auto &lang = result.lang;
|
||||||
|
char *c_lang = new char[lang.size() + 1];
|
||||||
|
std::copy(lang.begin(), lang.end(), c_lang);
|
||||||
|
c_lang[lang.size()] = '\0';
|
||||||
|
r->lang = c_lang;
|
||||||
|
|
||||||
// copy json
|
// copy json
|
||||||
std::string json = result.AsJsonString();
|
std::string json = result.AsJsonString();
|
||||||
char *pJson = new char[json.size() + 1];
|
char *pJson = new char[json.size() + 1];
|
||||||
@@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult(
|
|||||||
delete[] r->tokens;
|
delete[] r->tokens;
|
||||||
delete[] r->tokens_arr;
|
delete[] r->tokens_arr;
|
||||||
delete[] r->json;
|
delete[] r->json;
|
||||||
|
delete[] r->lang;
|
||||||
delete r;
|
delete r;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream;
|
|||||||
SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
||||||
const SherpaOnnxOfflineRecognizerConfig *config);
|
const SherpaOnnxOfflineRecognizerConfig *config);
|
||||||
|
|
||||||
|
/// @param config Config for the recognizer.
|
||||||
|
SHERPA_ONNX_API void SherpaOnnxOfflineRecognizerSetConfig(
|
||||||
|
const SherpaOnnxOfflineRecognizer *recognizer,
|
||||||
|
const SherpaOnnxOfflineRecognizerConfig *config);
|
||||||
|
|
||||||
/// Free a pointer returned by CreateOfflineRecognizer()
|
/// Free a pointer returned by CreateOfflineRecognizer()
|
||||||
///
|
///
|
||||||
/// @param p A pointer returned by CreateOfflineRecognizer()
|
/// @param p A pointer returned by CreateOfflineRecognizer()
|
||||||
@@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
|||||||
* }
|
* }
|
||||||
*/
|
*/
|
||||||
const char *json;
|
const char *json;
|
||||||
|
|
||||||
|
//return recognized language
|
||||||
|
const char *lang;
|
||||||
|
|
||||||
} SherpaOnnxOfflineRecognizerResult;
|
} SherpaOnnxOfflineRecognizerResult;
|
||||||
|
|
||||||
/// Get the result of the offline stream.
|
/// Get the result of the offline stream.
|
||||||
|
|||||||
@@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override {
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Decode a single stream.
|
// Decode a single stream.
|
||||||
// Some models do not support batch size > 1, e.g., WeNet CTC models.
|
// Some models do not support batch size > 1, e.g., WeNet CTC models.
|
||||||
|
|||||||
@@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
|||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
|
||||||
|
config_ = config;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ class OfflineRecognizerImpl {
|
|||||||
|
|
||||||
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
||||||
|
|
||||||
|
virtual void SetConfig(const OfflineRecognizerConfig &config);
|
||||||
|
|
||||||
|
virtual OfflineRecognizerConfig GetConfig() const = 0;
|
||||||
|
|
||||||
std::string ApplyInverseTextNormalization(std::string text) const;
|
std::string ApplyInverseTextNormalization(std::string text) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override {
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
||||||
int32_t lfr_window_size = model_->LfrWindowSize();
|
int32_t lfr_window_size = model_->LfrWindowSize();
|
||||||
|
|||||||
@@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override {
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void InitHotwords() {
|
void InitHotwords() {
|
||||||
// each line in hotwords_file contains space-separated words
|
// each line in hotwords_file contains space-separated words
|
||||||
|
|
||||||
|
|||||||
@@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override {
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void PostInit() {
|
void PostInit() {
|
||||||
config_.feat_config.nemo_normalize_type =
|
config_.feat_config.nemo_normalize_type =
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.text = text;
|
r.text = text;
|
||||||
|
r.lang = src.lang;
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
@@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetConfig(const OfflineRecognizerConfig &config) override {
|
||||||
|
config_.model_config.whisper = config.model_config.whisper;
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override {
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void DecodeStream(OfflineStream *s) const {
|
void DecodeStream(OfflineStream *s) const {
|
||||||
|
decoder_->SetConfig(config_.model_config.whisper);
|
||||||
|
|
||||||
int32_t max_num_frames = 3000;
|
int32_t max_num_frames = 3000;
|
||||||
auto memory_info =
|
auto memory_info =
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|||||||
@@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
|
|||||||
impl_->DecodeStreams(ss, n);
|
impl_->DecodeStreams(ss, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
|
||||||
|
impl_->SetConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
|
||||||
|
return impl_->GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -119,6 +119,15 @@ class OfflineRecognizer {
|
|||||||
*/
|
*/
|
||||||
void DecodeStreams(OfflineStream **ss, int32_t n) const;
|
void DecodeStreams(OfflineStream **ss, int32_t n) const;
|
||||||
|
|
||||||
|
/** Onnxruntime Session objects are not affected by this method.
|
||||||
|
* The exact behavior can be defined by a specific recognizer impl.
|
||||||
|
* For instance, for the whisper recognizer, you can retrieve the language and task from
|
||||||
|
* the config and ignore any remaining fields in `config`.
|
||||||
|
*/
|
||||||
|
void SetConfig(const OfflineRecognizerConfig &config);
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<OfflineRecognizerImpl> impl_;
|
std::unique_ptr<OfflineRecognizerImpl> impl_;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ struct OfflineRecognitionResult {
|
|||||||
// For instance, for BPE-based models it consists of a list of BPE tokens.
|
// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||||
std::vector<std::string> tokens;
|
std::vector<std::string> tokens;
|
||||||
|
|
||||||
|
std::string lang;
|
||||||
|
|
||||||
/// timestamps.size() == tokens.size()
|
/// timestamps.size() == tokens.size()
|
||||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||||
std::vector<float> timestamps;
|
std::vector<float> timestamps;
|
||||||
|
|||||||
@@ -6,14 +6,17 @@
|
|||||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
struct OfflineWhisperDecoderResult {
|
struct OfflineWhisperDecoderResult {
|
||||||
/// The decoded token IDs
|
/// The decoded token IDs
|
||||||
std::vector<int32_t> tokens;
|
std::vector<int32_t> tokens;
|
||||||
|
std::string lang;
|
||||||
};
|
};
|
||||||
|
|
||||||
class OfflineWhisperDecoder {
|
class OfflineWhisperDecoder {
|
||||||
@@ -31,6 +34,9 @@ class OfflineWhisperDecoder {
|
|||||||
*/
|
*/
|
||||||
virtual std::vector<OfflineWhisperDecoderResult> Decode(
|
virtual std::vector<OfflineWhisperDecoderResult> Decode(
|
||||||
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
||||||
|
|
||||||
|
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -12,6 +12,10 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
|
||||||
|
config_ = config;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult>
|
std::vector<OfflineWhisperDecoderResult>
|
||||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) {
|
Ort::Value cross_v) {
|
||||||
@@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult> ans(1);
|
std::vector<OfflineWhisperDecoderResult> ans(1);
|
||||||
|
|
||||||
|
const auto &id2lang = model_->GetID2Lang();
|
||||||
|
if (id2lang.count(initial_tokens[1])) {
|
||||||
|
ans[0].lang = id2lang.at(initial_tokens[1]);
|
||||||
|
} else {
|
||||||
|
ans[0].lang = "";
|
||||||
|
}
|
||||||
|
|
||||||
ans[0].tokens = std::move(predicted_tokens);
|
ans[0].tokens = std::move(predicted_tokens);
|
||||||
|
|
||||||
return ans;
|
return ans;
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
|||||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) override;
|
Ort::Value cross_v) override;
|
||||||
|
|
||||||
|
void SetConfig(const OfflineWhisperModelConfig &config) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineWhisperModelConfig config_;
|
OfflineWhisperModelConfig config_;
|
||||||
OfflineWhisperModel *model_; // not owned
|
OfflineWhisperModel *model_; // not owned
|
||||||
|
|||||||
Reference in New Issue
Block a user