diff --git a/dotnet-examples/offline-decode-files/Program.cs b/dotnet-examples/offline-decode-files/Program.cs index 9cda985e..da90ccda 100644 --- a/dotnet-examples/offline-decode-files/Program.cs +++ b/dotnet-examples/offline-decode-files/Program.cs @@ -40,6 +40,12 @@ class OfflineDecodeFiles [Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")] public string WhisperDecoder { get; set; } + [Option("whisper-language", Required = false, Default = "", HelpText = "Language of the input file. Can be empty")] + public string WhisperLanguage{ get; set; } + + [Option("whisper-task", Required = false, Default = "transcribe", HelpText = "transcribe or translate")] + public string WhisperTask{ get; set; } + [Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")] public string TdnnModel { get; set; } @@ -193,6 +199,8 @@ to download pre-trained Tdnn models. { config.ModelConfig.Whisper.Encoder = options.WhisperEncoder; config.ModelConfig.Whisper.Decoder = options.WhisperDecoder; + config.ModelConfig.Whisper.Language = options.WhisperLanguage; + config.ModelConfig.Whisper.Task = options.WhisperTask; } else if (!String.IsNullOrEmpty(options.TdnnModel)) { diff --git a/go-api-examples/non-streaming-decode-files/main.go b/go-api-examples/non-streaming-decode-files/main.go index 651b06e1..d4f076c5 100644 --- a/go-api-examples/non-streaming-decode-files/main.go +++ b/go-api-examples/non-streaming-decode-files/main.go @@ -29,6 +29,8 @@ func main() { flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model") flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model") + flag.StringVar(&config.ModelConfig.Whisper.Language, "whisper-language", "", "Language of the input wave. You can leave it empty ") + flag.StringVar(&config.ModelConfig.Whisper.Task, "whisper-task", "transcribe", "transcribe or translate") flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model") diff --git a/nodejs-examples/test-offline-nemo-ctc.js b/nodejs-examples/test-offline-nemo-ctc.js index 46fb869a..c657fd23 100644 --- a/nodejs-examples/test-offline-nemo-ctc.js +++ b/nodejs-examples/test-offline-nemo-ctc.js @@ -27,6 +27,8 @@ function createOfflineRecognizer() { whisper: { encoder: '', decoder: '', + language: '', + task: '', }, tdnn: { model: '', diff --git a/nodejs-examples/test-offline-paraformer.js b/nodejs-examples/test-offline-paraformer.js index a7d6b63e..175b227e 100644 --- a/nodejs-examples/test-offline-paraformer.js +++ b/nodejs-examples/test-offline-paraformer.js @@ -27,6 +27,8 @@ function createOfflineRecognizer() { whisper: { encoder: '', decoder: '', + language: '', + task: '', }, tdnn: { model: '', diff --git a/nodejs-examples/test-offline-transducer.js b/nodejs-examples/test-offline-transducer.js index 46bdf23d..289c01dc 100644 --- a/nodejs-examples/test-offline-transducer.js +++ b/nodejs-examples/test-offline-transducer.js @@ -30,6 +30,8 @@ function createOfflineRecognizer() { whisper: { encoder: '', decoder: '', + language: '', + task: '', }, tdnn: { model: '', diff --git a/nodejs-examples/test-offline-whisper.js b/nodejs-examples/test-offline-whisper.js index 1012ce15..28b101ae 100644 --- a/nodejs-examples/test-offline-whisper.js +++ b/nodejs-examples/test-offline-whisper.js @@ -27,6 +27,8 @@ function createOfflineRecognizer() { whisper: { encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx', decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', + language: '', + task: 'transcribe', }, tdnn: { model: '', diff --git a/scripts/dotnet/offline.cs b/scripts/dotnet/offline.cs index 3c8c0e05..4ef2a4a1 100644 --- a/scripts/dotnet/offline.cs +++ b/scripts/dotnet/offline.cs @@ -279,12 +279,20 @@ namespace SherpaOnnx { Encoder = ""; Decoder = ""; + Language = ""; + Task = "transcribe"; } [MarshalAs(UnmanagedType.LPStr)] public string Encoder; [MarshalAs(UnmanagedType.LPStr)] public string Decoder; + + [MarshalAs(UnmanagedType.LPStr)] + public string Language; + + [MarshalAs(UnmanagedType.LPStr)] + public string Task; } [StructLayout(LayoutKind.Sequential)] diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 9a869e25..c5037241 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct { } type OfflineWhisperModelConfig struct { - Encoder string - Decoder string + Encoder string + Decoder string + Language string + Task string } type OfflineTdnnModelConfig struct { @@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) defer C.free(unsafe.Pointer(c.model_config.whisper.decoder)) + c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language) + defer C.free(unsafe.Pointer(c.model_config.whisper.language)) + + c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) + defer C.free(unsafe.Pointer(c.model_config.whisper.task)) + c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 965240cb..9ef5ad25 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -11,13 +11,13 @@ #include "sherpa-onnx/csrc/circular-buffer.h" #include "sherpa-onnx/csrc/display.h" +#include "sherpa-onnx/csrc/keyword-spotter.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-tts.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/voice-activity-detector.h" #include "sherpa-onnx/csrc/wave-writer.h" -#include "sherpa-onnx/csrc/keyword-spotter.h" struct SherpaOnnxOnlineRecognizer { std::unique_ptr impl; @@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( recognizer_config.model_config.whisper.language = SHERPA_ONNX_OR(config->model_config.whisper.language, ""); + recognizer_config.model_config.whisper.task = + SHERPA_ONNX_OR(config->model_config.whisper.task, "transcribe"); + recognizer_config.model_config.tdnn.model = SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); @@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter { std::unique_ptr impl; }; -SherpaOnnxKeywordSpotter* CreateKeywordSpotter( - const SherpaOnnxKeywordSpotterConfig* config) { +SherpaOnnxKeywordSpotter *CreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig *config) { sherpa_onnx::KeywordSpotterConfig spotter_config; spotter_config.feat_config.sampling_rate = @@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( spotter_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); - spotter_config.max_active_paths = - SHERPA_ONNX_OR(config->max_active_paths, 4); + spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); spotter_config.num_trailing_blanks = - SHERPA_ONNX_OR(config->num_trailing_blanks , 1); + SHERPA_ONNX_OR(config->num_trailing_blanks, 1); - spotter_config.keywords_score = - SHERPA_ONNX_OR(config->keywords_score, 1.0); + spotter_config.keywords_score = SHERPA_ONNX_OR(config->keywords_score, 1.0); spotter_config.keywords_threshold = SHERPA_ONNX_OR(config->keywords_threshold, 0.25); - spotter_config.keywords_file = - SHERPA_ONNX_OR(config->keywords_file, ""); + spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); if (config->model_config.debug) { SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); @@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( return nullptr; } - SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; + SherpaOnnxKeywordSpotter *spotter = new SherpaOnnxKeywordSpotter; - spotter->impl = - std::make_unique(spotter_config); + spotter->impl = std::make_unique(spotter_config); return spotter; } -void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) { delete spotter; } -SherpaOnnxOnlineStream* CreateKeywordStream( - const SherpaOnnxKeywordSpotter* spotter) { - SherpaOnnxOnlineStream* stream = +SherpaOnnxOnlineStream *CreateKeywordStream( + const SherpaOnnxKeywordSpotter *spotter) { + SherpaOnnxOnlineStream *stream = new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); return stream; } -int32_t IsKeywordStreamReady( - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { +int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter, + SherpaOnnxOnlineStream *stream) { return spotter->impl->IsReady(stream->impl.get()); } -void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, - SherpaOnnxOnlineStream* stream) { +void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter, + SherpaOnnxOnlineStream *stream) { return spotter->impl->DecodeStream(stream->impl.get()); } -void DecodeMultipleKeywordStreams( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, - int32_t n) { - std::vector ss(n); +void DecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, + SherpaOnnxOnlineStream **streams, int32_t n) { + std::vector ss(n); for (int32_t i = 0; i != n; ++i) { ss[i] = streams[i]->impl.get(); } @@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams( const SherpaOnnxKeywordResult *GetKeywordResult( SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { - const sherpa_onnx::KeywordResult& result = + const sherpa_onnx::KeywordResult &result = spotter->impl->GetResult(stream->impl.get()); const auto &keyword = result.keyword; diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index b02a3420..e8e59ae6 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { const char *encoder; const char *decoder; const char *language; + const char *task; } SherpaOnnxOfflineWhisperModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { @@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { /// For Chinese, it consists of Chinese words without spaces. /// Example 1: "hello world" /// Example 2: "你好世界" - const char* keyword; + const char *keyword; /// Decoded results at the token level. /// For instance, for BPE-based models it consists of a list of BPE tokens. - const char* tokens; + const char *tokens; - const char* const* tokens_arr; + const char *const *tokens_arr; int32_t count; /// timestamps.size() == tokens.size() /// timestamps[i] records the time in seconds when tokens[i] is decoded. - float* timestamps; + float *timestamps; /// Starting time of this segment. /// When an endpoint is detected, it will change @@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { * "start_time": x, * } */ - const char* json; + const char *json; } SherpaOnnxKeywordResult; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { @@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { int32_t num_trailing_blanks; float keywords_score; float keywords_threshold; - const char* keywords_file; + const char *keywords_file; } SherpaOnnxKeywordSpotterConfig; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter @@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter /// @param config Config for the keyword spotter. /// @return Return a pointer to the spotter. The user has to invoke /// DestroyKeywordSpotter() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter( - const SherpaOnnxKeywordSpotterConfig* config); +SHERPA_ONNX_API SherpaOnnxKeywordSpotter *CreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig *config); /// Free a pointer returned by CreateKeywordSpotter() /// /// @param p A pointer returned by CreateKeywordSpotter() -SHERPA_ONNX_API void DestroyKeywordSpotter( - SherpaOnnxKeywordSpotter* spotter); +SHERPA_ONNX_API void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter); /// Create an online stream for accepting wave samples. /// /// @param spotter A pointer returned by CreateKeywordSpotter() /// @return Return a pointer to an OnlineStream. The user has to invoke /// DestroyOnlineStream() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream( - const SherpaOnnxKeywordSpotter* spotter); +SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateKeywordStream( + const SherpaOnnxKeywordSpotter *spotter); /// Return 1 if there are enough number of feature frames for decoding. /// Return 0 otherwise. /// /// @param spotter A pointer returned by CreateKeywordSpotter /// @param stream A pointer returned by CreateKeywordStream -SHERPA_ONNX_API int32_t IsKeywordStreamReady( - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream); +SHERPA_ONNX_API int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter, + SherpaOnnxOnlineStream *stream); /// Call this function to run the neural network model and decoding. // /// Precondition for this function: IsKeywordStreamReady() MUST return 1. -SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, - SherpaOnnxOnlineStream* stream); +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter, + SherpaOnnxOnlineStream *stream); /// This function is similar to DecodeKeywordStream(). It decodes multiple /// OnlineStream in parallel. @@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( /// Destroy the pointer returned by GetKeywordResult(). /// /// @param r A pointer returned by GetKeywordResult() -SHERPA_ONNX_API void DestroyKeywordResult( - const SherpaOnnxKeywordResult *r); +SHERPA_ONNX_API void DestroyKeywordResult(const SherpaOnnxKeywordResult *r); // ============================================================ // For VAD diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index c41a193d..c55e72f5 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl { inputs.push_back(std::move(length_scale_tensor)); inputs.push_back(std::move(noise_scale_w_tensor)); - if (input_names_.size() == 6 && input_names_.back() == "sid") { + if (input_names_.size() == 6 && + (input_names_.back() == "sid" || input_names_.back() == "speaker")) { inputs.push_back(std::move(sid_tensor)); } diff --git a/sherpa-onnx/csrc/transducer-keyword-decoder.cc b/sherpa-onnx/csrc/transducer-keyword-decoder.cc index ef8314ed..f31348ea 100644 --- a/sherpa-onnx/csrc/transducer-keyword-decoder.cc +++ b/sherpa-onnx/csrc/transducer-keyword-decoder.cc @@ -2,14 +2,16 @@ // // Copyright (c) 2023-2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" + #include #include +#include #include #include #include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/onnx-utils.h" -#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" namespace sherpa_onnx { diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 8a5ea907..3742b9a7 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -242,17 +242,17 @@ class SherpaOnnxRecognizer { /// the given hotWords appended to the default hotwords. func reset(hotwords: String? = nil) { guard let words = hotwords, !words.isEmpty else { - Reset(recognizer, stream) - return + Reset(recognizer, stream) + return } - + words.withCString { cString in - let newStream = CreateOnlineStreamWithHotwords(recognizer, cString) - // lock while release and replace stream - objc_sync_enter(self) - DestroyOnlineStream(stream) - stream = newStream - objc_sync_exit(self) + let newStream = CreateOnlineStreamWithHotwords(recognizer, cString) + // lock while release and replace stream + objc_sync_enter(self) + DestroyOnlineStream(stream) + stream = newStream + objc_sync_exit(self) } } @@ -300,11 +300,15 @@ func sherpaOnnxOfflineNemoEncDecCtcModelConfig( func sherpaOnnxOfflineWhisperModelConfig( encoder: String = "", - decoder: String = "" + decoder: String = "", + language: String = "", + task: String = "transcribe" ) -> SherpaOnnxOfflineWhisperModelConfig { return SherpaOnnxOfflineWhisperModelConfig( encoder: toCPointer(encoder), - decoder: toCPointer(decoder) + decoder: toCPointer(decoder), + language: toCPointer(language), + task: toCPointer(task) ) } diff --git a/wasm/asr/sherpa-onnx-asr.js b/wasm/asr/sherpa-onnx-asr.js index 97b19783..55c8f2d9 100644 --- a/wasm/asr/sherpa-onnx-asr.js +++ b/wasm/asr/sherpa-onnx-asr.js @@ -393,11 +393,13 @@ function initSherpaOnnxOfflineNemoEncDecCtcModelConfig(config, Module) { function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { const encoderLen = Module.lengthBytesUTF8(config.encoder) + 1; const decoderLen = Module.lengthBytesUTF8(config.decoder) + 1; + const languageLen = Module.lengthBytesUTF8(config.language) + 1; + const taskLen = Module.lengthBytesUTF8(config.task) + 1; - const n = encoderLen + decoderLen; + const n = encoderLen + decoderLen + languageLen + taskLen; const buffer = Module._malloc(n); - const len = 2 * 4; // 2 pointers + const len = 4 * 4; // 4 pointers const ptr = Module._malloc(len); let offset = 0; @@ -405,12 +407,25 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { offset += encoderLen; Module.stringToUTF8(config.decoder, buffer + offset, decoderLen); + offset += decoderLen; + + Module.stringToUTF8(config.language, buffer + offset, languageLen); + offset += languageLen; + + Module.stringToUTF8(config.task, buffer + offset, taskLen); offset = 0; Module.setValue(ptr, buffer + offset, 'i8*'); offset += encoderLen; Module.setValue(ptr + 4, buffer + offset, 'i8*'); + offset += decoderLen; + + Module.setValue(ptr + 8, buffer + offset, 'i8*'); + offset += languageLen; + + Module.setValue(ptr + 12, buffer + offset, 'i8*'); + offset += taskLen; return { buffer: buffer, ptr: ptr, len: len, diff --git a/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc b/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc index edbf250e..22f770ae 100644 --- a/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc +++ b/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); -static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 2 * 4, ""); +static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, ""); static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); @@ -77,6 +77,8 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { fprintf(stdout, "----------offline whisper model config----------\n"); fprintf(stdout, "encoder: %s\n", whisper->encoder); fprintf(stdout, "decoder: %s\n", whisper->decoder); + fprintf(stdout, "language: %s\n", whisper->language); + fprintf(stdout, "task: %s\n", whisper->task); fprintf(stdout, "----------offline tdnn model config----------\n"); fprintf(stdout, "model: %s\n", tdnn->model);