Add C++ runtime for SenseVoice models (#1148)
This commit is contained in:
25
.github/scripts/test-offline-ctc.sh
vendored
25
.github/scripts/test-offline-ctc.sh
vendored
@@ -15,7 +15,30 @@ echo "PATH: $PATH"
|
|||||||
|
|
||||||
which $EXE
|
which $EXE
|
||||||
|
|
||||||
if false; then
|
log "------------------------------------------------------------"
|
||||||
|
log "Run SenseVoice models"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
repo=sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
|
||||||
|
|
||||||
|
for m in model.onnx model.int8.onnx; do
|
||||||
|
for w in zh en yue ja ko; do
|
||||||
|
for use_itn in 0 1; do
|
||||||
|
echo "$m $w $use_itn"
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--sense-voice-model=$repo/$m \
|
||||||
|
--sense-voice-use-itn=$use_itn \
|
||||||
|
$repo/test_wavs/$w.wav
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
if true; then
|
||||||
# It has problems with onnxruntime 1.18
|
# It has problems with onnxruntime 1.18
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
log "Run Wenet models"
|
log "Run Wenet models"
|
||||||
|
|||||||
12
.github/scripts/test-python.sh
vendored
12
.github/scripts/test-python.sh
vendored
@@ -10,6 +10,18 @@ log() {
|
|||||||
|
|
||||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
|
||||||
|
log "test offline SenseVoice CTC"
|
||||||
|
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
name=$(basename $url)
|
||||||
|
repo=$(basename -s .tar.bz2 $name)
|
||||||
|
|
||||||
|
curl -SL -O $url
|
||||||
|
tar xvf $name
|
||||||
|
rm $name
|
||||||
|
ls -lh $repo
|
||||||
|
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
log "test offline TeleSpeech CTC"
|
log "test offline TeleSpeech CTC"
|
||||||
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
|
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
|
||||||
name=$(basename $url)
|
name=$(basename $url)
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ jobs:
|
|||||||
echo "pwd: $PWD"
|
echo "pwd: $PWD"
|
||||||
ls -lh ../scripts/sense-voice
|
ls -lh ../scripts/sense-voice
|
||||||
|
|
||||||
rm -rf ./
|
rm -rf ./*
|
||||||
|
|
||||||
cp -v ../scripts/sense-voice/*.onnx .
|
cp -v ../scripts/sense-voice/*.onnx .
|
||||||
cp -v ../scripts/sense-voice/tokens.txt .
|
cp -v ../scripts/sense-voice/tokens.txt .
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -111,3 +111,4 @@ sherpa-onnx-telespeech-ctc-*
|
|||||||
*.fst
|
*.fst
|
||||||
.ccache
|
.ccache
|
||||||
lib*.a
|
lib*.a
|
||||||
|
sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
## 1.10.17
|
||||||
|
|
||||||
|
* Support SenseVoice CTC models.
|
||||||
|
|
||||||
## 1.10.16
|
## 1.10.16
|
||||||
|
|
||||||
* Support zh-en TTS model from MeloTTS.
|
* Support zh-en TTS model from MeloTTS.
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ project(sherpa-onnx)
|
|||||||
# ./nodejs-addon-examples
|
# ./nodejs-addon-examples
|
||||||
# ./dart-api-examples/
|
# ./dart-api-examples/
|
||||||
# ./CHANGELOG.md
|
# ./CHANGELOG.md
|
||||||
set(SHERPA_ONNX_VERSION "1.10.16")
|
set(SHERPA_ONNX_VERSION "1.10.17")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
67
python-api-examples/offline-sense-voice-ctc-decode-files.py
Normal file
67
python-api-examples/offline-sense-voice-ctc-decode-files.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming SenseVoice CTC model from
|
||||||
|
https://github.com/FunAudioLLM/SenseVoice
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
For instance,
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"
|
||||||
|
tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
|
||||||
|
test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ja.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ko.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/yue.wav"
|
||||||
|
|
||||||
|
if not Path(model).is_file() or not Path(test_wav).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||||
|
model=model,
|
||||||
|
tokens=tokens,
|
||||||
|
use_itn=True,
|
||||||
|
debug=True,
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
print(wave_filename)
|
||||||
|
print(stream.result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -162,7 +162,9 @@ def main():
|
|||||||
"neg_mean": neg_mean,
|
"neg_mean": neg_mean,
|
||||||
"inv_stddev": inv_stddev,
|
"inv_stddev": inv_stddev,
|
||||||
"model_type": "sense_voice_ctc",
|
"model_type": "sense_voice_ctc",
|
||||||
"version": "1",
|
# version 1: Use QInt8
|
||||||
|
# version 2: Use QUInt8
|
||||||
|
"version": "2",
|
||||||
"model_author": "iic",
|
"model_author": "iic",
|
||||||
"maintainer": "k2-fsa",
|
"maintainer": "k2-fsa",
|
||||||
"vocab_size": vocab_size,
|
"vocab_size": vocab_size,
|
||||||
@@ -185,7 +187,10 @@ def main():
|
|||||||
model_input=filename,
|
model_input=filename,
|
||||||
model_output=filename_int8,
|
model_output=filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul"],
|
||||||
weight_type=QuantType.QInt8,
|
# Note that we have to use QUInt8 here.
|
||||||
|
#
|
||||||
|
# When QInt8 is used, C++ onnxruntime produces incorrect results
|
||||||
|
weight_type=QuantType.QUInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream {
|
|||||||
|
|
||||||
static sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
static sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
||||||
const SherpaOnnxOfflineRecognizerConfig *config);
|
const SherpaOnnxOfflineRecognizerConfig *config);
|
||||||
|
|
||||||
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
||||||
const SherpaOnnxOfflineRecognizerConfig *config) {
|
const SherpaOnnxOfflineRecognizerConfig *config) {
|
||||||
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
|
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
|
||||||
@@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
|||||||
recognizer_config.model_config.telespeech_ctc =
|
recognizer_config.model_config.telespeech_ctc =
|
||||||
SHERPA_ONNX_OR(config->model_config.telespeech_ctc, "");
|
SHERPA_ONNX_OR(config->model_config.telespeech_ctc, "");
|
||||||
|
|
||||||
|
recognizer_config.model_config.sense_voice.model =
|
||||||
|
SHERPA_ONNX_OR(config->model_config.sense_voice.model, "");
|
||||||
|
|
||||||
|
recognizer_config.model_config.sense_voice.language =
|
||||||
|
SHERPA_ONNX_OR(config->model_config.sense_voice.language, "");
|
||||||
|
|
||||||
|
recognizer_config.model_config.sense_voice.use_itn =
|
||||||
|
config->model_config.sense_voice.use_itn;
|
||||||
|
|
||||||
recognizer_config.lm_config.model =
|
recognizer_config.lm_config.model =
|
||||||
SHERPA_ONNX_OR(config->lm_config.model, "");
|
SHERPA_ONNX_OR(config->lm_config.model, "");
|
||||||
recognizer_config.lm_config.scale =
|
recognizer_config.lm_config.scale =
|
||||||
|
|||||||
@@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig {
|
|||||||
float scale;
|
float scale;
|
||||||
} SherpaOnnxOfflineLMConfig;
|
} SherpaOnnxOfflineLMConfig;
|
||||||
|
|
||||||
|
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSenseVoiceModelConfig {
|
||||||
|
const char *model;
|
||||||
|
const char *language;
|
||||||
|
int32_t use_itn;
|
||||||
|
} SherpaOnnxOfflineSenseVoiceModelConfig;
|
||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
|
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
|
||||||
SherpaOnnxOfflineTransducerModelConfig transducer;
|
SherpaOnnxOfflineTransducerModelConfig transducer;
|
||||||
SherpaOnnxOfflineParaformerModelConfig paraformer;
|
SherpaOnnxOfflineParaformerModelConfig paraformer;
|
||||||
@@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
|
|||||||
const char *modeling_unit;
|
const char *modeling_unit;
|
||||||
const char *bpe_vocab;
|
const char *bpe_vocab;
|
||||||
const char *telespeech_ctc;
|
const char *telespeech_ctc;
|
||||||
|
SherpaOnnxOfflineSenseVoiceModelConfig sense_voice;
|
||||||
} SherpaOnnxOfflineModelConfig;
|
} SherpaOnnxOfflineModelConfig;
|
||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
|
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ set(sources
|
|||||||
offline-recognizer-impl.cc
|
offline-recognizer-impl.cc
|
||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
offline-rnn-lm.cc
|
offline-rnn-lm.cc
|
||||||
|
offline-sense-voice-model-config.cc
|
||||||
|
offline-sense-voice-model.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
offline-tdnn-ctc-model.cc
|
offline-tdnn-ctc-model.cc
|
||||||
offline-tdnn-model-config.cc
|
offline-tdnn-model-config.cc
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h
|
// sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2024 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
|
|
||||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||||
const OfflineModelConfig &config) {
|
const OfflineModelConfig &config) {
|
||||||
|
// TODO(fangjun): Refactor it. We don't need to use model_type here
|
||||||
ModelType model_type = ModelType::kUnknown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
std::string filename;
|
std::string filename;
|
||||||
@@ -148,6 +149,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
|
|
||||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||||
AAssetManager *mgr, const OfflineModelConfig &config) {
|
AAssetManager *mgr, const OfflineModelConfig &config) {
|
||||||
|
// TODO(fangjun): Refactor it. We don't need to use model_type here
|
||||||
ModelType model_type = ModelType::kUnknown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
std::string filename;
|
std::string filename;
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
tdnn.Register(po);
|
tdnn.Register(po);
|
||||||
zipformer_ctc.Register(po);
|
zipformer_ctc.Register(po);
|
||||||
wenet_ctc.Register(po);
|
wenet_ctc.Register(po);
|
||||||
|
sense_voice.Register(po);
|
||||||
|
|
||||||
po->Register("telespeech-ctc", &telespeech_ctc,
|
po->Register("telespeech-ctc", &telespeech_ctc,
|
||||||
"Path to model.onnx for telespeech ctc");
|
"Path to model.onnx for telespeech ctc");
|
||||||
@@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return wenet_ctc.Validate();
|
return wenet_ctc.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!sense_voice.model.empty()) {
|
||||||
|
return sense_voice.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
||||||
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
||||||
telespeech_ctc.c_str());
|
telespeech_ctc.c_str());
|
||||||
return false;
|
return false;
|
||||||
} else {
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return transducer.Validate();
|
if (!transducer.encoder_filename.empty()) {
|
||||||
|
return transducer.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string OfflineModelConfig::ToString() const {
|
std::string OfflineModelConfig::ToString() const {
|
||||||
@@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "tdnn=" << tdnn.ToString() << ", ";
|
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||||
|
os << "sense_voice=" << sense_voice.ToString() << ", ";
|
||||||
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||||
@@ -24,6 +25,7 @@ struct OfflineModelConfig {
|
|||||||
OfflineTdnnModelConfig tdnn;
|
OfflineTdnnModelConfig tdnn;
|
||||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||||
OfflineWenetCtcModelConfig wenet_ctc;
|
OfflineWenetCtcModelConfig wenet_ctc;
|
||||||
|
OfflineSenseVoiceModelConfig sense_voice;
|
||||||
std::string telespeech_ctc;
|
std::string telespeech_ctc;
|
||||||
|
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
@@ -53,6 +55,7 @@ struct OfflineModelConfig {
|
|||||||
const OfflineTdnnModelConfig &tdnn,
|
const OfflineTdnnModelConfig &tdnn,
|
||||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||||
|
const OfflineSenseVoiceModelConfig &sense_voice,
|
||||||
const std::string &telespeech_ctc,
|
const std::string &telespeech_ctc,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads, bool debug,
|
||||||
const std::string &provider, const std::string &model_type,
|
const std::string &provider, const std::string &model_type,
|
||||||
@@ -65,6 +68,7 @@ struct OfflineModelConfig {
|
|||||||
tdnn(tdnn),
|
tdnn(tdnn),
|
||||||
zipformer_ctc(zipformer_ctc),
|
zipformer_ctc(zipformer_ctc),
|
||||||
wenet_ctc(wenet_ctc),
|
wenet_ctc(wenet_ctc),
|
||||||
|
sense_voice(sense_voice),
|
||||||
telespeech_ctc(telespeech_ctc),
|
telespeech_ctc(telespeech_ctc),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
|
|||||||
@@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OfflineRecognizerConfig GetConfig() const override {
|
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||||
return config_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Decode a single stream.
|
// Decode a single stream.
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||||
@@ -31,6 +32,28 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||||
const OfflineRecognizerConfig &config) {
|
const OfflineRecognizerConfig &config) {
|
||||||
|
if (!config.model_config.sense_voice.model.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.paraformer.model.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.nemo_ctc.model.empty() ||
|
||||||
|
!config.model_config.zipformer_ctc.model.empty() ||
|
||||||
|
!config.model_config.tdnn.model.empty() ||
|
||||||
|
!config.model_config.wenet_ctc.model.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.whisper.encoder.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||||
|
// following models:
|
||||||
|
// 1. transducer and nemo_transducer
|
||||||
if (!config.model_config.model_type.empty()) {
|
if (!config.model_config.model_type.empty()) {
|
||||||
const auto &model_type = config.model_config.model_type;
|
const auto &model_type = config.model_config.model_type;
|
||||||
if (model_type == "transducer") {
|
if (model_type == "transducer") {
|
||||||
@@ -180,6 +203,28 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||||
AAssetManager *mgr, const OfflineRecognizerConfig &config) {
|
AAssetManager *mgr, const OfflineRecognizerConfig &config) {
|
||||||
|
if (!config.model_config.sense_voice.model.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.paraformer.model.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.nemo_ctc.model.empty() ||
|
||||||
|
!config.model_config.zipformer_ctc.model.empty() ||
|
||||||
|
!config.model_config.tdnn.model.empty() ||
|
||||||
|
!config.model_type.wenet_ctc.model.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.whisper.encoder.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||||
|
// following models:
|
||||||
|
// 1. transducer and nemo_transducer
|
||||||
if (!config.model_config.model_type.empty()) {
|
if (!config.model_config.model_type.empty()) {
|
||||||
const auto &model_type = config.model_config.model_type;
|
const auto &model_type = config.model_config.model_type;
|
||||||
if (model_type == "transducer") {
|
if (model_type == "transducer") {
|
||||||
|
|||||||
@@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Paraformer models assume input samples are in the range
|
InitFeatConfig();
|
||||||
// [-32768, 32767], so we set normalize_samples to false
|
|
||||||
config_.feat_config.normalize_samples = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
@@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Paraformer models assume input samples are in the range
|
InitFeatConfig();
|
||||||
// [-32768, 32767], so we set normalize_samples to false
|
|
||||||
config_.feat_config.normalize_samples = false;
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OfflineRecognizerConfig GetConfig() const override {
|
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||||
return config_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void InitFeatConfig() {
|
||||||
|
// Paraformer models assume input samples are in the range
|
||||||
|
// [-32768, 32767], so we set normalize_samples to false
|
||||||
|
config_.feat_config.normalize_samples = false;
|
||||||
|
config_.feat_config.window_type = "hamming";
|
||||||
|
config_.feat_config.high_freq = 0;
|
||||||
|
config_.feat_config.snip_edges = true;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
||||||
int32_t lfr_window_size = model_->LfrWindowSize();
|
int32_t lfr_window_size = model_->LfrWindowSize();
|
||||||
int32_t lfr_window_shift = model_->LfrWindowShift();
|
int32_t lfr_window_shift = model_->LfrWindowShift();
|
||||||
|
|||||||
363
sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h
Normal file
363
sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-sense-voice-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OfflineRecognitionResult ConvertSenseVoiceResult(
|
||||||
|
const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
|
||||||
|
int32_t frame_shift_ms, int32_t subsampling_factor) {
|
||||||
|
OfflineRecognitionResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
r.timestamps.reserve(src.timestamps.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
for (int32_t i = 4; i < src.tokens.size(); ++i) {
|
||||||
|
auto sym = sym_table[src.tokens[i]];
|
||||||
|
text.append(sym);
|
||||||
|
|
||||||
|
r.tokens.push_back(std::move(sym));
|
||||||
|
}
|
||||||
|
r.text = std::move(text);
|
||||||
|
|
||||||
|
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||||
|
|
||||||
|
for (int32_t i = 4; i < src.timestamps.size(); ++i) {
|
||||||
|
float time = frame_shift_s * (src.timestamps[i] - 4);
|
||||||
|
r.timestamps.push_back(time);
|
||||||
|
}
|
||||||
|
|
||||||
|
r.words = std::move(src.words);
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OfflineRecognizerSenseVoiceImpl(
|
||||||
|
const OfflineRecognizerConfig &config)
|
||||||
|
: OfflineRecognizerImpl(config),
|
||||||
|
config_(config),
|
||||||
|
symbol_table_(config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineSenseVoiceModel>(config.model_config)) {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
if (config.decoding_method == "greedy_search") {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||||
|
config.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
InitFeatConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineRecognizerSenseVoiceImpl(AAssetManager *mgr,
|
||||||
|
const OfflineRecognizerConfig &config)
|
||||||
|
: OfflineRecognizerImpl(mgr, config),
|
||||||
|
config_(config),
|
||||||
|
symbol_table_(mgr, config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineSenseVoiceModel>(mgr,
|
||||||
|
config.model_config)) {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
if (config.decoding_method == "greedy_search") {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||||
|
config.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
InitFeatConfig();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
if (n == 1) {
|
||||||
|
DecodeOneStream(ss[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
// 1. Apply LFR
|
||||||
|
// 2. Apply CMVN
|
||||||
|
//
|
||||||
|
// Please refer to
|
||||||
|
// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf
|
||||||
|
// for what LFR means
|
||||||
|
//
|
||||||
|
// "Lower Frame Rate Neural Network Acoustic Models"
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::vector<Ort::Value> features;
|
||||||
|
features.reserve(n);
|
||||||
|
|
||||||
|
int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size;
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> features_vec(n);
|
||||||
|
std::vector<int32_t> features_length_vec(n);
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
std::vector<float> f = ss[i]->GetFrames();
|
||||||
|
|
||||||
|
f = ApplyLFR(f);
|
||||||
|
ApplyCMVN(&f);
|
||||||
|
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
features_vec[i] = std::move(f);
|
||||||
|
|
||||||
|
features_length_vec[i] = num_frames;
|
||||||
|
|
||||||
|
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
features.push_back(std::move(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const Ort::Value *> features_pointer(n);
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
features_pointer[i] = &features[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int64_t, 1> features_length_shape = {n};
|
||||||
|
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_length_vec.data(), n,
|
||||||
|
features_length_shape.data(), features_length_shape.size());
|
||||||
|
|
||||||
|
// Caution(fangjun): We cannot pad it with log(eps),
|
||||||
|
// i.e., -23.025850929940457f
|
||||||
|
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||||
|
|
||||||
|
int32_t language = 0;
|
||||||
|
if (config_.model_config.sense_voice.language.empty()) {
|
||||||
|
language = 0;
|
||||||
|
} else if (meta_data.lang2id.count(
|
||||||
|
config_.model_config.sense_voice.language)) {
|
||||||
|
language =
|
||||||
|
meta_data.lang2id.at(config_.model_config.sense_voice.language);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
|
||||||
|
config_.model_config.sense_voice.language.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int32_t> language_array(n);
|
||||||
|
std::fill(language_array.begin(), language_array.end(), language);
|
||||||
|
|
||||||
|
std::vector<int32_t> text_norm_array(n);
|
||||||
|
std::fill(text_norm_array.begin(), text_norm_array.end(),
|
||||||
|
config_.model_config.sense_voice.use_itn
|
||||||
|
? meta_data.with_itn_id
|
||||||
|
: meta_data.without_itn_id);
|
||||||
|
|
||||||
|
Ort::Value language_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, language_array.data(), n, features_length_shape.data(),
|
||||||
|
features_length_shape.size());
|
||||||
|
|
||||||
|
Ort::Value text_norm_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, text_norm_array.data(), n, features_length_shape.data(),
|
||||||
|
features_length_shape.size());
|
||||||
|
|
||||||
|
Ort::Value logits{nullptr};
|
||||||
|
try {
|
||||||
|
logits = model_->Forward(std::move(x), std::move(x_length),
|
||||||
|
std::move(language_tensor),
|
||||||
|
std::move(text_norm_tensor));
|
||||||
|
} catch (const Ort::Exception &ex) {
|
||||||
|
SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",
|
||||||
|
ex.what());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// decoder_->Decode() requires that logits_length is of dtype int64
|
||||||
|
std::vector<int64_t> features_length_vec_64;
|
||||||
|
features_length_vec_64.reserve(n);
|
||||||
|
for (auto i : features_length_vec) {
|
||||||
|
i += 4;
|
||||||
|
features_length_vec_64.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value logits_length = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_length_vec_64.data(), n,
|
||||||
|
features_length_shape.data(), features_length_shape.size());
|
||||||
|
|
||||||
|
auto results =
|
||||||
|
decoder_->Decode(std::move(logits), std::move(logits_length));
|
||||||
|
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
int32_t subsampling_factor = meta_data.window_shift;
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
|
||||||
|
frame_shift_ms, subsampling_factor);
|
||||||
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
ss[i]->SetResult(r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void DecodeOneStream(OfflineStream *s) const {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size;
|
||||||
|
std::vector<float> f = s->GetFrames();
|
||||||
|
f = ApplyLFR(f);
|
||||||
|
ApplyCMVN(&f);
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
|
||||||
|
int64_t scale_shape = 1;
|
||||||
|
|
||||||
|
Ort::Value x_length =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &num_frames, 1, &scale_shape, 1);
|
||||||
|
|
||||||
|
int32_t language = 0;
|
||||||
|
if (config_.model_config.sense_voice.language.empty()) {
|
||||||
|
language = 0;
|
||||||
|
} else if (meta_data.lang2id.count(
|
||||||
|
config_.model_config.sense_voice.language)) {
|
||||||
|
language =
|
||||||
|
meta_data.lang2id.at(config_.model_config.sense_voice.language);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
|
||||||
|
config_.model_config.sense_voice.language.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t text_norm = config_.model_config.sense_voice.use_itn
|
||||||
|
? meta_data.with_itn_id
|
||||||
|
: meta_data.without_itn_id;
|
||||||
|
|
||||||
|
Ort::Value language_tensor =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &language, 1, &scale_shape, 1);
|
||||||
|
|
||||||
|
Ort::Value text_norm_tensor =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &text_norm, 1, &scale_shape, 1);
|
||||||
|
|
||||||
|
Ort::Value logits{nullptr};
|
||||||
|
try {
|
||||||
|
logits = model_->Forward(std::move(x), std::move(x_length),
|
||||||
|
std::move(language_tensor),
|
||||||
|
std::move(text_norm_tensor));
|
||||||
|
} catch (const Ort::Exception &ex) {
|
||||||
|
SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",
|
||||||
|
ex.what());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t new_num_frames = num_frames + 4;
|
||||||
|
Ort::Value logits_length = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &new_num_frames, 1, &scale_shape, 1);
|
||||||
|
|
||||||
|
auto results =
|
||||||
|
decoder_->Decode(std::move(logits), std::move(logits_length));
|
||||||
|
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
int32_t subsampling_factor = meta_data.window_shift;
|
||||||
|
auto r = ConvertSenseVoiceResult(results[0], symbol_table_, frame_shift_ms,
|
||||||
|
subsampling_factor);
|
||||||
|
|
||||||
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
s->SetResult(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitFeatConfig() {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
|
||||||
|
config_.feat_config.normalize_samples = meta_data.normalize_samples;
|
||||||
|
config_.feat_config.window_type = "hamming";
|
||||||
|
config_.feat_config.high_freq = 0;
|
||||||
|
config_.feat_config.snip_edges = true;
|
||||||
|
}
|
||||||
|
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
|
||||||
|
int32_t lfr_window_size = meta_data.window_size;
|
||||||
|
int32_t lfr_window_shift = meta_data.window_shift;
|
||||||
|
int32_t in_feat_dim = config_.feat_config.feature_dim;
|
||||||
|
|
||||||
|
int32_t in_num_frames = in.size() / in_feat_dim;
|
||||||
|
int32_t out_num_frames =
|
||||||
|
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
|
||||||
|
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
|
||||||
|
|
||||||
|
std::vector<float> out(out_num_frames * out_feat_dim);
|
||||||
|
|
||||||
|
const float *p_in = in.data();
|
||||||
|
float *p_out = out.data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != out_num_frames; ++i) {
|
||||||
|
std::copy(p_in, p_in + out_feat_dim, p_out);
|
||||||
|
|
||||||
|
p_out += out_feat_dim;
|
||||||
|
p_in += lfr_window_shift * in_feat_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplyCMVN(std::vector<float> *v) const {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
|
||||||
|
const std::vector<float> &neg_mean = meta_data.neg_mean;
|
||||||
|
const std::vector<float> &inv_stddev = meta_data.inv_stddev;
|
||||||
|
|
||||||
|
int32_t dim = neg_mean.size();
|
||||||
|
int32_t num_frames = v->size() / dim;
|
||||||
|
|
||||||
|
float *p = v->data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != num_frames; ++i) {
|
||||||
|
for (int32_t k = 0; k != dim; ++k) {
|
||||||
|
p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
p += dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig config_;
|
||||||
|
SymbolTable symbol_table_;
|
||||||
|
std::unique_ptr<OfflineSenseVoiceModel> model_;
|
||||||
|
std::unique_ptr<OfflineCtcDecoder> decoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
|
||||||
55
sherpa-onnx/csrc/offline-sense-voice-model-config.cc
Normal file
55
sherpa-onnx/csrc/offline-sense-voice-model-config.cc
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-sense-voice-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("sense-voice-model", &model,
|
||||||
|
"Path to model.onnx of SenseVoice.");
|
||||||
|
po->Register(
|
||||||
|
"sense-voice-language", &language,
|
||||||
|
"Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used");
|
||||||
|
po->Register(
|
||||||
|
"sense-voice-use-itn", &use_itn,
|
||||||
|
"True to enable inverse text normalization. False to disable it.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineSenseVoiceModelConfig::Validate() const {
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!language.empty()) {
|
||||||
|
if (language != "auto" && language != "zh" && language != "en" &&
|
||||||
|
language != "ja" && language != "ko" && language != "yue") {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, "
|
||||||
|
"ja, ko, yue. Or you can leave it empty to use 'auto'",
|
||||||
|
language.c_str());
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineSenseVoiceModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineSenseVoiceModelConfig(";
|
||||||
|
os << "model=\"" << model << "\", ";
|
||||||
|
os << "language=\"" << language << "\", ";
|
||||||
|
os << "use_itn=" << (use_itn ? "True" : "False") << ")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
39
sherpa-onnx/csrc/offline-sense-voice-model-config.h
Normal file
39
sherpa-onnx/csrc/offline-sense-voice-model-config.h
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-sense-voice-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineSenseVoiceModelConfig {
|
||||||
|
std::string model;
|
||||||
|
|
||||||
|
// "" or "auto" to let the model recognize the language
|
||||||
|
// valid values:
|
||||||
|
// zh, en, ja, ko, yue, auto
|
||||||
|
std::string language = "auto";
|
||||||
|
|
||||||
|
// true to use inverse text normalization
|
||||||
|
// false to not use inverse text normalization
|
||||||
|
bool use_itn = false;
|
||||||
|
|
||||||
|
OfflineSenseVoiceModelConfig() = default;
|
||||||
|
explicit OfflineSenseVoiceModelConfig(const std::string &model,
|
||||||
|
const std::string &language,
|
||||||
|
bool use_itn)
|
||||||
|
: model(model), language(language), use_itn(use_itn) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
|
||||||
50
sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h
Normal file
50
sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineSenseVoiceModelMetaData {
|
||||||
|
// ID for using inverse text normalization
|
||||||
|
int32_t with_itn_id;
|
||||||
|
|
||||||
|
// ID for not using inverse text normalization
|
||||||
|
int32_t without_itn_id;
|
||||||
|
|
||||||
|
int32_t window_size; // lfr_m
|
||||||
|
int32_t window_shift; // lfr_n
|
||||||
|
int32_t vocab_size;
|
||||||
|
|
||||||
|
int32_t subsampling_factor = 1;
|
||||||
|
|
||||||
|
// Usually 0 for SenseVoice models.
|
||||||
|
// 0 means samples are scaled to [-32768, 32767] before are sent to the
|
||||||
|
// feature extractor
|
||||||
|
int32_t normalize_samples = 0;
|
||||||
|
|
||||||
|
int32_t blank_id = 0;
|
||||||
|
|
||||||
|
// possible values:
|
||||||
|
// zh, en, ja, ko, yue, auto
|
||||||
|
// where
|
||||||
|
// zh is Chinese (Mandarin)
|
||||||
|
// en is English
|
||||||
|
// ja is Japanese
|
||||||
|
// ko is Korean
|
||||||
|
// yue is Cantonese
|
||||||
|
// auto is to let the model recognize the language
|
||||||
|
std::unordered_map<std::string, int32_t> lang2id;
|
||||||
|
|
||||||
|
std::vector<float> neg_mean;
|
||||||
|
std::vector<float> inv_stddev;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
|
||||||
156
sherpa-onnx/csrc/offline-sense-voice-model.cc
Normal file
156
sherpa-onnx/csrc/offline-sense-voice-model.cc
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-sense-voice-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-sense-voice-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineSenseVoiceModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
auto buf = ReadFile(config_.sense_voice.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
auto buf = ReadFile(mgr, config_.sense_voice.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Ort::Value Forward(Ort::Value features, Ort::Value features_length,
|
||||||
|
Ort::Value language, Ort::Value text_norm) {
|
||||||
|
std::array<Ort::Value, 4> inputs = {
|
||||||
|
std::move(features),
|
||||||
|
std::move(features_length),
|
||||||
|
std::move(language),
|
||||||
|
std::move(text_norm),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto ans =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
return std::move(ans[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const {
|
||||||
|
return meta_data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "lfr_window_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.window_shift, "lfr_window_shift");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples,
|
||||||
|
"normalize_samples");
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.with_itn_id, "with_itn");
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.without_itn_id, "without_itn");
|
||||||
|
|
||||||
|
int32_t lang_auto = 0;
|
||||||
|
int32_t lang_zh = 0;
|
||||||
|
int32_t lang_en = 0;
|
||||||
|
int32_t lang_ja = 0;
|
||||||
|
int32_t lang_ko = 0;
|
||||||
|
int32_t lang_yue = 0;
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lang_auto, "lang_auto");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lang_zh, "lang_zh");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lang_en, "lang_en");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lang_ja, "lang_ja");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lang_ko, "lang_ko");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lang_yue, "lang_yue");
|
||||||
|
|
||||||
|
meta_data_.lang2id = {
|
||||||
|
{"auto", lang_auto}, {"zh", lang_zh}, {"ja", lang_ja},
|
||||||
|
{"ko", lang_ko}, {"yue", lang_yue},
|
||||||
|
};
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.neg_mean, "neg_mean");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "inv_stddev");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
OfflineSenseVoiceModelMetaData meta_data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineSenseVoiceModel::OfflineSenseVoiceModel(const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineSenseVoiceModel::OfflineSenseVoiceModel(AAssetManager *mgr,
|
||||||
|
const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
OfflineSenseVoiceModel::~OfflineSenseVoiceModel() = default;
|
||||||
|
|
||||||
|
Ort::Value OfflineSenseVoiceModel::Forward(Ort::Value features,
|
||||||
|
Ort::Value features_length,
|
||||||
|
Ort::Value language,
|
||||||
|
Ort::Value text_norm) const {
|
||||||
|
return impl_->Forward(std::move(features), std::move(features_length),
|
||||||
|
std::move(language), std::move(text_norm));
|
||||||
|
}
|
||||||
|
|
||||||
|
const OfflineSenseVoiceModelMetaData &OfflineSenseVoiceModel::GetModelMetadata()
|
||||||
|
const {
|
||||||
|
return impl_->GetModelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineSenseVoiceModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
61
sherpa-onnx/csrc/offline-sense-voice-model.h
Normal file
61
sherpa-onnx/csrc/offline-sense-voice-model.h
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-sense-voice-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineSenseVoiceModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineSenseVoiceModel(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineSenseVoiceModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~OfflineSenseVoiceModel();
|
||||||
|
|
||||||
|
/** Run the forward method of the model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||||
|
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||||
|
* valid frames in `features` before padding.
|
||||||
|
* Its dtype is int32_t.
|
||||||
|
* @param language A 1-D tensor of shape (N,) with dtype int32_t
|
||||||
|
* @param text_norm A 1-D tensor of shape (N,) with dtype int32_t
|
||||||
|
*
|
||||||
|
* @return Return logits of shape (N, T, C) with dtype float
|
||||||
|
*
|
||||||
|
* Note: The subsampling factor is 1 for SenseVoice, so there is
|
||||||
|
* no need to output logits_length.
|
||||||
|
*/
|
||||||
|
Ort::Value Forward(Ort::Value features, Ort::Value features_length,
|
||||||
|
Ort::Value language, Ort::Value text_norm) const;
|
||||||
|
|
||||||
|
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
|
||||||
@@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <functional>
|
||||||
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
@@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float ComputeSum(const Ort::Value *v, int32_t n /*= -1*/) {
|
||||||
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
auto size = static_cast<int32_t>(std::accumulate(
|
||||||
|
shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
|
||||||
|
if (n != -1 && n < size && n > 0) {
|
||||||
|
size = n;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float *p = v->GetTensorData<float>();
|
||||||
|
|
||||||
|
return std::accumulate(p, p + size, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
float ComputeMean(const Ort::Value *v, int32_t n /*= -1*/) {
|
||||||
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
auto size = static_cast<int32_t>(std::accumulate(
|
||||||
|
shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
|
||||||
|
|
||||||
|
if (n != -1 && n < size && n > 0) {
|
||||||
|
size = n;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sum = ComputeSum(v, n);
|
||||||
|
return sum / size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrintShape(const Ort::Value *v) {
|
||||||
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
std::ostringstream os;
|
||||||
|
for (auto i : shape) {
|
||||||
|
os << i << ", ";
|
||||||
|
}
|
||||||
|
os << "\n";
|
||||||
|
fprintf(stderr, "%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T /*= float*/>
|
template <typename T /*= float*/>
|
||||||
void Print1D(Ort::Value *v) {
|
void Print1D(const Ort::Value *v) {
|
||||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
const T *d = v->GetTensorData<T>();
|
const T *d = v->GetTensorData<T>();
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||||
os << *d << " ";
|
os << d[i] << " ";
|
||||||
}
|
}
|
||||||
os << "\n";
|
os << "\n";
|
||||||
fprintf(stderr, "%s\n", os.str().c_str());
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
template void Print1D<int64_t>(Ort::Value *v);
|
template void Print1D<int64_t>(const Ort::Value *v);
|
||||||
template void Print1D<float>(Ort::Value *v);
|
template void Print1D<int32_t>(const Ort::Value *v);
|
||||||
|
template void Print1D<float>(const Ort::Value *v);
|
||||||
|
|
||||||
template <typename T /*= float*/>
|
template <typename T /*= float*/>
|
||||||
void Print2D(Ort::Value *v) {
|
void Print2D(const Ort::Value *v) {
|
||||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
const T *d = v->GetTensorData<T>();
|
const T *d = v->GetTensorData<T>();
|
||||||
|
|
||||||
@@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) {
|
|||||||
fprintf(stderr, "%s\n", os.str().c_str());
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
template void Print2D<int64_t>(Ort::Value *v);
|
template void Print2D<int64_t>(const Ort::Value *v);
|
||||||
template void Print2D<float>(Ort::Value *v);
|
template void Print2D<float>(const Ort::Value *v);
|
||||||
|
|
||||||
void Print3D(Ort::Value *v) {
|
void Print3D(const Ort::Value *v) {
|
||||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
const float *d = v->GetTensorData<float>();
|
const float *d = v->GetTensorData<float>();
|
||||||
|
|
||||||
@@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) {
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Print4D(Ort::Value *v) {
|
void Print4D(const Ort::Value *v) {
|
||||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
const float *d = v->GetTensorData<float>();
|
const float *d = v->GetTensorData<float>();
|
||||||
|
|
||||||
|
|||||||
@@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
|
|||||||
// Return a shallow copy
|
// Return a shallow copy
|
||||||
Ort::Value View(Ort::Value *v);
|
Ort::Value View(Ort::Value *v);
|
||||||
|
|
||||||
|
float ComputeSum(const Ort::Value *v, int32_t n = -1);
|
||||||
|
float ComputeMean(const Ort::Value *v, int32_t n = -1);
|
||||||
|
|
||||||
// Print a 1-D tensor to stderr
|
// Print a 1-D tensor to stderr
|
||||||
template <typename T = float>
|
template <typename T = float>
|
||||||
void Print1D(Ort::Value *v);
|
void Print1D(const Ort::Value *v);
|
||||||
|
|
||||||
// Print a 2-D tensor to stderr
|
// Print a 2-D tensor to stderr
|
||||||
template <typename T = float>
|
template <typename T = float>
|
||||||
void Print2D(Ort::Value *v);
|
void Print2D(const Ort::Value *v);
|
||||||
|
|
||||||
// Print a 3-D tensor to stderr
|
// Print a 3-D tensor to stderr
|
||||||
void Print3D(Ort::Value *v);
|
void Print3D(const Ort::Value *v);
|
||||||
|
|
||||||
// Print a 4-D tensor to stderr
|
// Print a 4-D tensor to stderr
|
||||||
void Print4D(Ort::Value *v);
|
void Print4D(const Ort::Value *v);
|
||||||
|
|
||||||
|
void PrintShape(const Ort::Value *v);
|
||||||
|
|
||||||
template <typename T = float>
|
template <typename T = float>
|
||||||
void Fill(Ort::Value *tensor, T value) {
|
void Fill(Ort::Value *tensor, T value) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ set(srcs
|
|||||||
offline-paraformer-model-config.cc
|
offline-paraformer-model-config.cc
|
||||||
offline-punctuation.cc
|
offline-punctuation.cc
|
||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
|
offline-sense-voice-model-config.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
offline-tdnn-model-config.cc
|
offline-tdnn-model-config.cc
|
||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
|
||||||
@@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
PybindOfflineTdnnModelConfig(m);
|
PybindOfflineTdnnModelConfig(m);
|
||||||
PybindOfflineZipformerCtcModelConfig(m);
|
PybindOfflineZipformerCtcModelConfig(m);
|
||||||
PybindOfflineWenetCtcModelConfig(m);
|
PybindOfflineWenetCtcModelConfig(m);
|
||||||
|
PybindOfflineSenseVoiceModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
@@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
const OfflineNemoEncDecCtcModelConfig &,
|
const OfflineNemoEncDecCtcModelConfig &,
|
||||||
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
|
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
|
||||||
const OfflineZipformerCtcModelConfig &,
|
const OfflineZipformerCtcModelConfig &,
|
||||||
const OfflineWenetCtcModelConfig &, const std::string &,
|
const OfflineWenetCtcModelConfig &,
|
||||||
|
const OfflineSenseVoiceModelConfig &, const std::string &,
|
||||||
const std::string &, int32_t, bool, const std::string &,
|
const std::string &, int32_t, bool, const std::string &,
|
||||||
const std::string &, const std::string &, const std::string &>(),
|
const std::string &, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
@@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
||||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||||
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||||
|
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
||||||
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
||||||
py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("num_threads"), py::arg("debug") = false,
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
||||||
@@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
.def_readwrite("tdnn", &PyClass::tdnn)
|
.def_readwrite("tdnn", &PyClass::tdnn)
|
||||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||||
|
.def_readwrite("sense_voice", &PyClass::sense_voice)
|
||||||
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
|
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ namespace sherpa_onnx {
|
|||||||
void PybindOfflineParaformerModelConfig(py::module *m) {
|
void PybindOfflineParaformerModelConfig(py::module *m) {
|
||||||
using PyClass = OfflineParaformerModelConfig;
|
using PyClass = OfflineParaformerModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
|
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
|
||||||
|
.def(py::init<>())
|
||||||
.def(py::init<const std::string &>(), py::arg("model"))
|
.def(py::init<const std::string &>(), py::arg("model"))
|
||||||
.def_readwrite("model", &PyClass::model)
|
.def_readwrite("model", &PyClass::model)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
|
|||||||
26
sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
Normal file
26
sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineSenseVoiceModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineSenseVoiceModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineSenseVoiceModelConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const std::string &, const std::string &, bool>(),
|
||||||
|
py::arg("model"), py::arg("language"), py::arg("use_itn"))
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def_readwrite("language", &PyClass::language)
|
||||||
|
.def_readwrite("use_itn", &PyClass::use_itn)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineSenseVoiceModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
|
||||||
@@ -10,6 +10,7 @@ from _sherpa_onnx import (
|
|||||||
OfflineModelConfig,
|
OfflineModelConfig,
|
||||||
OfflineNemoEncDecCtcModelConfig,
|
OfflineNemoEncDecCtcModelConfig,
|
||||||
OfflineParaformerModelConfig,
|
OfflineParaformerModelConfig,
|
||||||
|
OfflineSenseVoiceModelConfig,
|
||||||
)
|
)
|
||||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
@@ -173,6 +174,88 @@ class OfflineRecognizer(object):
|
|||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_sense_voice(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
tokens: str,
|
||||||
|
num_threads: int = 1,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
feature_dim: int = 80,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
|
language: str = "",
|
||||||
|
use_itn: bool = False,
|
||||||
|
rule_fsts: str = "",
|
||||||
|
rule_fars: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
|
||||||
|
to download pre-trained models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
|
||||||
|
model:
|
||||||
|
Path to ``model.onnx``.
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
sample_rate:
|
||||||
|
Sample rate of the training data used to train the model.
|
||||||
|
feature_dim:
|
||||||
|
Dimension of the feature used to train the model.
|
||||||
|
decoding_method:
|
||||||
|
Valid values are greedy_search.
|
||||||
|
debug:
|
||||||
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
|
language:
|
||||||
|
If not empty, then valid values are: auto, zh, en, ja, ko, yue
|
||||||
|
use_itn:
|
||||||
|
True to enable inverse text normalization; False to disable it.
|
||||||
|
rule_fsts:
|
||||||
|
If not empty, it specifies fsts for inverse text normalization.
|
||||||
|
If there are multiple fsts, they are separated by a comma.
|
||||||
|
rule_fars:
|
||||||
|
If not empty, it specifies fst archives for inverse text normalization.
|
||||||
|
If there are multiple archives, they are separated by a comma.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
model_config = OfflineModelConfig(
|
||||||
|
sense_voice=OfflineSenseVoiceModelConfig(
|
||||||
|
model=model,
|
||||||
|
language=language,
|
||||||
|
use_itn=use_itn,
|
||||||
|
),
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
feat_config = FeatureExtractorConfig(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OfflineRecognizerConfig(
|
||||||
|
feat_config=feat_config,
|
||||||
|
model_config=model_config,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
rule_fsts=rule_fsts,
|
||||||
|
rule_fars=rule_fars,
|
||||||
|
)
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
self.config = recognizer_config
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_paraformer(
|
def from_paraformer(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sherpaOnnxOfflineSenseVoiceModelConfig(
|
||||||
|
model: String = "",
|
||||||
|
language: String = "",
|
||||||
|
useInverseTextNormalization: Bool = false
|
||||||
|
) -> SherpaOnnxOfflineSenseVoiceModelConfig {
|
||||||
|
return SherpaOnnxOfflineSenseVoiceModelConfig(
|
||||||
|
model: toCPointer(model),
|
||||||
|
language: toCPointer(language),
|
||||||
|
use_itn: useInverseTextNormalization ? 1 : 0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func sherpaOnnxOfflineLMConfig(
|
func sherpaOnnxOfflineLMConfig(
|
||||||
model: String = "",
|
model: String = "",
|
||||||
scale: Float = 1.0
|
scale: Float = 1.0
|
||||||
@@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig(
|
|||||||
modelType: String = "",
|
modelType: String = "",
|
||||||
modelingUnit: String = "cjkchar",
|
modelingUnit: String = "cjkchar",
|
||||||
bpeVocab: String = "",
|
bpeVocab: String = "",
|
||||||
teleSpeechCtc: String = ""
|
teleSpeechCtc: String = "",
|
||||||
|
senseVoice: SherpaOnnxOfflineSenseVoiceModelConfig = sherpaOnnxOfflineSenseVoiceModelConfig()
|
||||||
) -> SherpaOnnxOfflineModelConfig {
|
) -> SherpaOnnxOfflineModelConfig {
|
||||||
return SherpaOnnxOfflineModelConfig(
|
return SherpaOnnxOfflineModelConfig(
|
||||||
transducer: transducer,
|
transducer: transducer,
|
||||||
@@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig(
|
|||||||
model_type: toCPointer(modelType),
|
model_type: toCPointer(modelType),
|
||||||
modeling_unit: toCPointer(modelingUnit),
|
modeling_unit: toCPointer(modelingUnit),
|
||||||
bpe_vocab: toCPointer(bpeVocab),
|
bpe_vocab: toCPointer(bpeVocab),
|
||||||
telespeech_ctc: toCPointer(teleSpeechCtc)
|
telespeech_ctc: toCPointer(teleSpeechCtc),
|
||||||
|
sense_voice: senseVoice
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ func run() {
|
|||||||
var modelConfig: SherpaOnnxOfflineModelConfig
|
var modelConfig: SherpaOnnxOfflineModelConfig
|
||||||
var modelType = "whisper"
|
var modelType = "whisper"
|
||||||
// modelType = "paraformer"
|
// modelType = "paraformer"
|
||||||
|
// modelType = "sense_voice"
|
||||||
|
|
||||||
if modelType == "whisper" {
|
if modelType == "whisper" {
|
||||||
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
||||||
@@ -47,6 +48,19 @@ func run() {
|
|||||||
debug: 0,
|
debug: 0,
|
||||||
modelType: "paraformer"
|
modelType: "paraformer"
|
||||||
)
|
)
|
||||||
|
} else if modelType == "sense_voice" {
|
||||||
|
let model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"
|
||||||
|
let tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
|
||||||
|
let senseVoiceConfig = sherpaOnnxOfflineSenseVoiceModelConfig(
|
||||||
|
model: model,
|
||||||
|
useInverseTextNormalization: true
|
||||||
|
)
|
||||||
|
|
||||||
|
modelConfig = sherpaOnnxOfflineModelConfig(
|
||||||
|
tokens: tokens,
|
||||||
|
debug: 0,
|
||||||
|
senseVoice: senseVoiceConfig
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
print("Please specify a supported modelType \(modelType)")
|
print("Please specify a supported modelType \(modelType)")
|
||||||
return
|
return
|
||||||
@@ -63,7 +77,10 @@ func run() {
|
|||||||
|
|
||||||
recognizer = SherpaOnnxOfflineRecognizer(config: &config)
|
recognizer = SherpaOnnxOfflineRecognizer(config: &config)
|
||||||
|
|
||||||
let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
var filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
||||||
|
if modelType == "sense_voice" {
|
||||||
|
filePath = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav"
|
||||||
|
}
|
||||||
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
|
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
|
||||||
let audioFile = try! AVAudioFile(forReading: fileURL as URL)
|
let audioFile = try! AVAudioFile(forReading: fileURL as URL)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user