Add CTC HLG decoding for JNI (#810)
This commit is contained in:
@@ -5,6 +5,7 @@ public class OfflineModelConfig {
|
||||
private final OfflineTransducerModelConfig transducer;
|
||||
private final OfflineParaformerModelConfig paraformer;
|
||||
private final OfflineWhisperModelConfig whisper;
|
||||
private final OfflineNemoEncDecCtcModelConfig nemo;
|
||||
private final String tokens;
|
||||
private final int numThreads;
|
||||
private final boolean debug;
|
||||
@@ -16,6 +17,7 @@ public class OfflineModelConfig {
|
||||
this.transducer = builder.transducer;
|
||||
this.paraformer = builder.paraformer;
|
||||
this.whisper = builder.whisper;
|
||||
this.nemo = builder.nemo;
|
||||
this.tokens = builder.tokens;
|
||||
this.numThreads = builder.numThreads;
|
||||
this.debug = builder.debug;
|
||||
@@ -64,6 +66,7 @@ public class OfflineModelConfig {
|
||||
private OfflineParaformerModelConfig paraformer = OfflineParaformerModelConfig.builder().build();
|
||||
private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build();
|
||||
private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build();
|
||||
private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build();
|
||||
private String tokens = "";
|
||||
private int numThreads = 1;
|
||||
private boolean debug = true;
|
||||
@@ -84,6 +87,11 @@ public class OfflineModelConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNemo(OfflineNemoEncDecCtcModelConfig nemo) {
|
||||
this.nemo = nemo;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setWhisper(OfflineWhisperModelConfig whisper) {
|
||||
this.whisper = whisper;
|
||||
return this;
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OfflineNemoEncDecCtcModelConfig {
|
||||
private final String model;
|
||||
|
||||
private OfflineNemoEncDecCtcModelConfig(Builder builder) {
|
||||
this.model = builder.model;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String model = "";
|
||||
|
||||
public OfflineNemoEncDecCtcModelConfig build() {
|
||||
return new OfflineNemoEncDecCtcModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineCtcFstDecoderConfig {
|
||||
private final String graph;
|
||||
private final int maxActive;
|
||||
|
||||
private OnlineCtcFstDecoderConfig(Builder builder) {
|
||||
this.graph = builder.graph;
|
||||
this.maxActive = builder.maxActive;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getGraph() {
|
||||
return graph;
|
||||
}
|
||||
|
||||
public float getMaxActive() {
|
||||
return maxActive;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String graph = "";
|
||||
private int maxActive = 3000;
|
||||
|
||||
public OnlineCtcFstDecoderConfig build() {
|
||||
return new OnlineCtcFstDecoderConfig(this);
|
||||
}
|
||||
|
||||
public Builder setGraph(String model) {
|
||||
this.graph = graph;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMaxActive(int maxActive) {
|
||||
this.maxActive = maxActive;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ public class OnlineRecognizerConfig {
|
||||
private final FeatureConfig featConfig;
|
||||
private final OnlineModelConfig modelConfig;
|
||||
private final OnlineLMConfig lmConfig;
|
||||
|
||||
private final OnlineCtcFstDecoderConfig ctcFstDecoderConfig;
|
||||
private final EndpointConfig endpointConfig;
|
||||
private final boolean enableEndpoint;
|
||||
private final String decodingMethod;
|
||||
@@ -17,6 +19,7 @@ public class OnlineRecognizerConfig {
|
||||
this.featConfig = builder.featConfig;
|
||||
this.modelConfig = builder.modelConfig;
|
||||
this.lmConfig = builder.lmConfig;
|
||||
this.ctcFstDecoderConfig = builder.ctcFstDecoderConfig;
|
||||
this.endpointConfig = builder.endpointConfig;
|
||||
this.enableEndpoint = builder.enableEndpoint;
|
||||
this.decodingMethod = builder.decodingMethod;
|
||||
@@ -37,6 +40,7 @@ public class OnlineRecognizerConfig {
|
||||
private FeatureConfig featConfig = FeatureConfig.builder().build();
|
||||
private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build();
|
||||
private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build();
|
||||
private OnlineCtcFstDecoderConfig ctcFstDecoderConfig = OnlineCtcFstDecoderConfig.builder().build();
|
||||
private EndpointConfig endpointConfig = EndpointConfig.builder().build();
|
||||
private boolean enableEndpoint = true;
|
||||
private String decodingMethod = "greedy_search";
|
||||
@@ -63,6 +67,11 @@ public class OnlineRecognizerConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setCtcFstDecoderConfig(OnlineCtcFstDecoderConfig ctcFstDecoderConfig) {
|
||||
this.ctcFstDecoderConfig = ctcFstDecoderConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEndpointConfig(EndpointConfig endpointConfig) {
|
||||
this.endpointConfig = endpointConfig;
|
||||
return this;
|
||||
|
||||
Reference in New Issue
Block a user