Add Java and Kotlin API for sense voice (#1164)

This commit is contained in:
Fangjun Kuang
2024-07-22 14:08:40 +08:00
committed by GitHub
parent ac8223bd8a
commit dd300b1de5
16 changed files with 601 additions and 2 deletions

View File

@@ -27,6 +27,7 @@ java_files += OfflineTransducerModelConfig.java
java_files += OfflineParaformerModelConfig.java
java_files += OfflineWhisperModelConfig.java
java_files += OfflineNemoEncDecCtcModelConfig.java
java_files += OfflineSenseVoiceModelConfig.java
java_files += OfflineModelConfig.java
java_files += OfflineRecognizerConfig.java
java_files += OfflineRecognizerResult.java

View File

@@ -7,6 +7,7 @@ public class OfflineModelConfig {
private final OfflineParaformerModelConfig paraformer;
private final OfflineWhisperModelConfig whisper;
private final OfflineNemoEncDecCtcModelConfig nemo;
private final OfflineSenseVoiceModelConfig senseVoice;
private final String teleSpeech;
private final String tokens;
private final int numThreads;
@@ -22,6 +23,7 @@ public class OfflineModelConfig {
this.paraformer = builder.paraformer;
this.whisper = builder.whisper;
this.nemo = builder.nemo;
this.senseVoice = builder.senseVoice;
this.teleSpeech = builder.teleSpeech;
this.tokens = builder.tokens;
this.numThreads = builder.numThreads;
@@ -48,6 +50,10 @@ public class OfflineModelConfig {
return whisper;
}
public OfflineSenseVoiceModelConfig getSenseVoice() {
return senseVoice;
}
public String getTokens() {
return tokens;
}
@@ -85,6 +91,7 @@ public class OfflineModelConfig {
private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build();
private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build();
private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build();
private OfflineSenseVoiceModelConfig senseVoice = OfflineSenseVoiceModelConfig.builder().build();
private String teleSpeech = "";
private String tokens = "";
private int numThreads = 1;
@@ -113,7 +120,6 @@ public class OfflineModelConfig {
return this;
}
public Builder setTeleSpeech(String teleSpeech) {
this.teleSpeech = teleSpeech;
return this;
@@ -124,6 +130,11 @@ public class OfflineModelConfig {
return this;
}
public Builder setSenseVoice(OfflineSenseVoiceModelConfig senseVoice) {
this.senseVoice = senseVoice;
return this;
}
public Builder setTokens(String tokens) {
this.tokens = tokens;
return this;

View File

@@ -0,0 +1,56 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineSenseVoiceModelConfig {
private final String model;
private final String language;
private final boolean useInverseTextNormalization;
private OfflineSenseVoiceModelConfig(Builder builder) {
this.model = builder.model;
this.language = builder.language;
this.useInverseTextNormalization = builder.useInverseTextNormalization;
}
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public String getLanguage() {
return language;
}
public boolean getUseInverseTextNormalization() {
return useInverseTextNormalization;
}
public static class Builder {
private String model = "";
private String language = "";
private boolean useInverseTextNormalization = true;
public OfflineSenseVoiceModelConfig build() {
return new OfflineSenseVoiceModelConfig(this);
}
public Builder setModel(String model) {
this.model = model;
return this;
}
public Builder setLanguage(String language) {
this.language = language;
return this;
}
public Builder setInverseTextNormalization(boolean useInverseTextNormalization) {
this.useInverseTextNormalization = useInverseTextNormalization;
return this;
}
}
}

View File

@@ -171,6 +171,31 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
ans.model_config.whisper.tail_paddings =
env->GetIntField(whisper_config, fid);
// sense voice
fid = env->GetFieldID(model_config_cls, "senseVoice",
"Lcom/k2fsa/sherpa/onnx/OfflineSenseVoiceModelConfig;");
jobject sense_voice_config = env->GetObjectField(model_config, fid);
jclass sense_voice_config_cls = env->GetObjectClass(sense_voice_config);
fid = env->GetFieldID(sense_voice_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(sense_voice_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.sense_voice.model = p;
env->ReleaseStringUTFChars(s, p);
fid =
env->GetFieldID(sense_voice_config_cls, "language", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(sense_voice_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.sense_voice.language = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(sense_voice_config_cls, "useInverseTextNormalization",
"Z");
ans.model_config.sense_voice.use_itn =
env->GetBooleanField(sense_voice_config, fid);
// nemo
fid = env->GetFieldID(
model_config_cls, "nemo",
"Lcom/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig;");

View File

@@ -30,11 +30,18 @@ data class OfflineWhisperModelConfig(
var tailPaddings: Int = 1000, // Padding added at the end of the samples
)
data class OfflineSenseVoiceModelConfig(
var model: String = "",
var language: String = "",
var useInverseTextNormalization: Boolean = true,
)
data class OfflineModelConfig(
var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(),
var senseVoice: OfflineSenseVoiceModelConfig = OfflineSenseVoiceModelConfig(),
var teleSpeech: String = "",
var numThreads: Int = 1,
var debug: Boolean = false,
@@ -321,6 +328,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
modelType = "paraformer",
)
}
15 -> {
val modelDir = "sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17"
return OfflineModelConfig(
senseVoice = OfflineSenseVoiceModelConfig(
model = "$modelDir/model.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
)
}
}
return null
}