Add Java and Kotlin API for sense voice (#1164)
This commit is contained in:
@@ -27,6 +27,7 @@ java_files += OfflineTransducerModelConfig.java
|
||||
java_files += OfflineParaformerModelConfig.java
|
||||
java_files += OfflineWhisperModelConfig.java
|
||||
java_files += OfflineNemoEncDecCtcModelConfig.java
|
||||
java_files += OfflineSenseVoiceModelConfig.java
|
||||
java_files += OfflineModelConfig.java
|
||||
java_files += OfflineRecognizerConfig.java
|
||||
java_files += OfflineRecognizerResult.java
|
||||
|
||||
@@ -7,6 +7,7 @@ public class OfflineModelConfig {
|
||||
private final OfflineParaformerModelConfig paraformer;
|
||||
private final OfflineWhisperModelConfig whisper;
|
||||
private final OfflineNemoEncDecCtcModelConfig nemo;
|
||||
private final OfflineSenseVoiceModelConfig senseVoice;
|
||||
private final String teleSpeech;
|
||||
private final String tokens;
|
||||
private final int numThreads;
|
||||
@@ -22,6 +23,7 @@ public class OfflineModelConfig {
|
||||
this.paraformer = builder.paraformer;
|
||||
this.whisper = builder.whisper;
|
||||
this.nemo = builder.nemo;
|
||||
this.senseVoice = builder.senseVoice;
|
||||
this.teleSpeech = builder.teleSpeech;
|
||||
this.tokens = builder.tokens;
|
||||
this.numThreads = builder.numThreads;
|
||||
@@ -48,6 +50,10 @@ public class OfflineModelConfig {
|
||||
return whisper;
|
||||
}
|
||||
|
||||
public OfflineSenseVoiceModelConfig getSenseVoice() {
|
||||
return senseVoice;
|
||||
}
|
||||
|
||||
public String getTokens() {
|
||||
return tokens;
|
||||
}
|
||||
@@ -85,6 +91,7 @@ public class OfflineModelConfig {
|
||||
private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build();
|
||||
private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build();
|
||||
private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build();
|
||||
private OfflineSenseVoiceModelConfig senseVoice = OfflineSenseVoiceModelConfig.builder().build();
|
||||
private String teleSpeech = "";
|
||||
private String tokens = "";
|
||||
private int numThreads = 1;
|
||||
@@ -113,7 +120,6 @@ public class OfflineModelConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
public Builder setTeleSpeech(String teleSpeech) {
|
||||
this.teleSpeech = teleSpeech;
|
||||
return this;
|
||||
@@ -124,6 +130,11 @@ public class OfflineModelConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setSenseVoice(OfflineSenseVoiceModelConfig senseVoice) {
|
||||
this.senseVoice = senseVoice;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTokens(String tokens) {
|
||||
this.tokens = tokens;
|
||||
return this;
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OfflineSenseVoiceModelConfig {
|
||||
private final String model;
|
||||
private final String language;
|
||||
private final boolean useInverseTextNormalization;
|
||||
|
||||
private OfflineSenseVoiceModelConfig(Builder builder) {
|
||||
this.model = builder.model;
|
||||
this.language = builder.language;
|
||||
this.useInverseTextNormalization = builder.useInverseTextNormalization;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public String getLanguage() {
|
||||
return language;
|
||||
}
|
||||
|
||||
public boolean getUseInverseTextNormalization() {
|
||||
return useInverseTextNormalization;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String model = "";
|
||||
private String language = "";
|
||||
private boolean useInverseTextNormalization = true;
|
||||
|
||||
public OfflineSenseVoiceModelConfig build() {
|
||||
return new OfflineSenseVoiceModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLanguage(String language) {
|
||||
this.language = language;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setInverseTextNormalization(boolean useInverseTextNormalization) {
|
||||
this.useInverseTextNormalization = useInverseTextNormalization;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -171,6 +171,31 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
ans.model_config.whisper.tail_paddings =
|
||||
env->GetIntField(whisper_config, fid);
|
||||
|
||||
// sense voice
|
||||
fid = env->GetFieldID(model_config_cls, "senseVoice",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineSenseVoiceModelConfig;");
|
||||
jobject sense_voice_config = env->GetObjectField(model_config, fid);
|
||||
jclass sense_voice_config_cls = env->GetObjectClass(sense_voice_config);
|
||||
|
||||
fid = env->GetFieldID(sense_voice_config_cls, "model", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(sense_voice_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.sense_voice.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid =
|
||||
env->GetFieldID(sense_voice_config_cls, "language", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(sense_voice_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.sense_voice.language = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(sense_voice_config_cls, "useInverseTextNormalization",
|
||||
"Z");
|
||||
ans.model_config.sense_voice.use_itn =
|
||||
env->GetBooleanField(sense_voice_config, fid);
|
||||
|
||||
// nemo
|
||||
fid = env->GetFieldID(
|
||||
model_config_cls, "nemo",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig;");
|
||||
|
||||
@@ -30,11 +30,18 @@ data class OfflineWhisperModelConfig(
|
||||
var tailPaddings: Int = 1000, // Padding added at the end of the samples
|
||||
)
|
||||
|
||||
data class OfflineSenseVoiceModelConfig(
|
||||
var model: String = "",
|
||||
var language: String = "",
|
||||
var useInverseTextNormalization: Boolean = true,
|
||||
)
|
||||
|
||||
data class OfflineModelConfig(
|
||||
var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
|
||||
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
|
||||
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
|
||||
var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(),
|
||||
var senseVoice: OfflineSenseVoiceModelConfig = OfflineSenseVoiceModelConfig(),
|
||||
var teleSpeech: String = "",
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
@@ -321,6 +328,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
|
||||
modelType = "paraformer",
|
||||
)
|
||||
}
|
||||
|
||||
15 -> {
|
||||
val modelDir = "sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17"
|
||||
return OfflineModelConfig(
|
||||
senseVoice = OfflineSenseVoiceModelConfig(
|
||||
model = "$modelDir/model.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
)
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user