Add Java API for spoken language identification with whisper multilingual models (#817)

This commit is contained in:
Fangjun Kuang
2024-04-26 19:05:39 +08:00
committed by GitHub
parent f2d074aea9
commit db25986240
12 changed files with 406 additions and 11 deletions

View File

@@ -0,0 +1,59 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
public class SpokenLanguageIdentification {
static {
System.loadLibrary("sherpa-onnx-jni");
}
private final Map<String, String> localeMap;
private long ptr = 0; // this is the asr engine ptrss
public SpokenLanguageIdentification(SpokenLanguageIdentificationConfig config) {
ptr = newFromFile(config);
String[] languages = Locale.getISOLanguages();
localeMap = new HashMap<String, String>(languages.length);
for (String language : languages) {
Locale locale = new Locale(language);
localeMap.put(language, locale.getDisplayName());
}
}
public String compute(OfflineStream stream) {
String lang = compute(ptr, stream.getPtr());
return localeMap.getOrDefault(lang, lang);
}
public OfflineStream createStream() {
long p = createStream(ptr);
return new OfflineStream(p);
}
@Override
protected void finalize() throws Throwable {
release();
}
// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}
private native void delete(long ptr);
private native long newFromFile(SpokenLanguageIdentificationConfig config);
private native long createStream(long ptr);
private native String compute(long ptr, long streamPtr);
}

View File

@@ -0,0 +1,56 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class SpokenLanguageIdentificationConfig {
private final SpokenLanguageIdentificationWhisperConfig whisper;
private final int numThreads;
private final boolean debug;
private final String provider;
private SpokenLanguageIdentificationConfig(Builder builder) {
this.whisper = builder.whisper;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
}
public static Builder builder() {
return new Builder();
}
public SpokenLanguageIdentificationWhisperConfig getWhisper() {
return whisper;
}
public static class Builder {
private SpokenLanguageIdentificationWhisperConfig whisper = SpokenLanguageIdentificationWhisperConfig.builder().build();
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
public SpokenLanguageIdentificationConfig build() {
return new SpokenLanguageIdentificationConfig(this);
}
public Builder setWhisper(SpokenLanguageIdentificationWhisperConfig whisper) {
this.whisper = whisper;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public Builder setDebug(boolean debug) {
this.debug = debug;
return this;
}
public Builder setProvider(String provider) {
this.provider = provider;
return this;
}
}
}

View File

@@ -0,0 +1,56 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class SpokenLanguageIdentificationWhisperConfig {
private final String encoder;
private final String decoder;
private final int tailPaddings;
private SpokenLanguageIdentificationWhisperConfig(Builder builder) {
this.encoder = builder.encoder;
this.decoder = builder.decoder;
this.tailPaddings = builder.tailPaddings;
}
public static Builder builder() {
return new Builder();
}
public String getEncoder() {
return encoder;
}
public String getDecoder() {
return decoder;
}
public int getTailPaddings() {
return tailPaddings;
}
public static class Builder {
private String encoder = "";
private String decoder = "";
private int tailPaddings = 1000; // number of frames to pad
public SpokenLanguageIdentificationWhisperConfig build() {
return new SpokenLanguageIdentificationWhisperConfig(this);
}
public Builder setEncoder(String encoder) {
this.encoder = encoder;
return this;
}
public Builder setDecoder(String decoder) {
this.decoder = decoder;
return this;
}
public Builder setTailPaddings(int tailPaddings) {
this.tailPaddings = tailPaddings;
return this;
}
}
}