Add tail_paddings to Whisper C API. (#886)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.9.24")
|
set(SHERPA_ONNX_VERSION "1.9.25")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ function createOfflineRecognizer() {
|
|||||||
decoder: '',
|
decoder: '',
|
||||||
language: '',
|
language: '',
|
||||||
task: '',
|
task: '',
|
||||||
|
tailPaddings: -1,
|
||||||
},
|
},
|
||||||
tdnn: {
|
tdnn: {
|
||||||
model: '',
|
model: '',
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ function createOfflineRecognizer() {
|
|||||||
decoder: '',
|
decoder: '',
|
||||||
language: '',
|
language: '',
|
||||||
task: '',
|
task: '',
|
||||||
|
tailPaddings: -1,
|
||||||
},
|
},
|
||||||
tdnn: {
|
tdnn: {
|
||||||
model: '',
|
model: '',
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ function createOfflineRecognizer() {
|
|||||||
decoder: '',
|
decoder: '',
|
||||||
language: '',
|
language: '',
|
||||||
task: '',
|
task: '',
|
||||||
|
tailPaddings: -1,
|
||||||
},
|
},
|
||||||
tdnn: {
|
tdnn: {
|
||||||
model: '',
|
model: '',
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ function createOfflineRecognizer() {
|
|||||||
decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
|
decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
|
||||||
language: '',
|
language: '',
|
||||||
task: 'transcribe',
|
task: 'transcribe',
|
||||||
|
tailPaddings: -1,
|
||||||
},
|
},
|
||||||
tdnn: {
|
tdnn: {
|
||||||
model: '',
|
model: '',
|
||||||
|
|||||||
@@ -301,6 +301,7 @@ namespace SherpaOnnx
|
|||||||
Decoder = "";
|
Decoder = "";
|
||||||
Language = "";
|
Language = "";
|
||||||
Task = "transcribe";
|
Task = "transcribe";
|
||||||
|
TailPaddings = -1;
|
||||||
}
|
}
|
||||||
[MarshalAs(UnmanagedType.LPStr)]
|
[MarshalAs(UnmanagedType.LPStr)]
|
||||||
public string Encoder;
|
public string Encoder;
|
||||||
@@ -313,6 +314,8 @@ namespace SherpaOnnx
|
|||||||
|
|
||||||
[MarshalAs(UnmanagedType.LPStr)]
|
[MarshalAs(UnmanagedType.LPStr)]
|
||||||
public string Task;
|
public string Task;
|
||||||
|
|
||||||
|
public int TailPaddings;
|
||||||
}
|
}
|
||||||
|
|
||||||
[StructLayout(LayoutKind.Sequential)]
|
[StructLayout(LayoutKind.Sequential)]
|
||||||
|
|||||||
@@ -336,10 +336,11 @@ type OfflineNemoEncDecCtcModelConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OfflineWhisperModelConfig struct {
|
type OfflineWhisperModelConfig struct {
|
||||||
Encoder string
|
Encoder string
|
||||||
Decoder string
|
Decoder string
|
||||||
Language string
|
Language string
|
||||||
Task string
|
Task string
|
||||||
|
TailPaddings int
|
||||||
}
|
}
|
||||||
|
|
||||||
type OfflineTdnnModelConfig struct {
|
type OfflineTdnnModelConfig struct {
|
||||||
@@ -441,6 +442,8 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
|
|||||||
c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
|
c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
|
||||||
defer C.free(unsafe.Pointer(c.model_config.whisper.task))
|
defer C.free(unsafe.Pointer(c.model_config.whisper.task))
|
||||||
|
|
||||||
|
c.model_config.whisper.tail_paddings = C.int(config.ModelConfig.Whisper.TailPaddings)
|
||||||
|
|
||||||
c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
|
c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
|
||||||
defer C.free(unsafe.Pointer(c.model_config.tdnn.model))
|
defer C.free(unsafe.Pointer(c.model_config.tdnn.model))
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ static SherpaOnnxOfflineWhisperModelConfig GetOfflineWhisperModelConfig(
|
|||||||
SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder);
|
SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder);
|
||||||
SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder);
|
SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder);
|
||||||
SHERPA_ONNX_ASSIGN_ATTR_STR(language, language);
|
SHERPA_ONNX_ASSIGN_ATTR_STR(language, language);
|
||||||
SHERPA_ONNX_ASSIGN_ATTR_STR(task, languagek);
|
SHERPA_ONNX_ASSIGN_ATTR_STR(task, task);
|
||||||
|
SHERPA_ONNX_ASSIGN_ATTR_INT32(tail_paddings, tailPaddings);
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -341,6 +341,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
|||||||
recognizer_config.model_config.whisper.task = "transcribe";
|
recognizer_config.model_config.whisper.task = "transcribe";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
recognizer_config.model_config.whisper.tail_paddings =
|
||||||
|
SHERPA_ONNX_OR(config->model_config.whisper.tail_paddings, -1);
|
||||||
|
|
||||||
recognizer_config.model_config.tdnn.model =
|
recognizer_config.model_config.tdnn.model =
|
||||||
SHERPA_ONNX_OR(config->model_config.tdnn.model, "");
|
SHERPA_ONNX_OR(config->model_config.tdnn.model, "");
|
||||||
|
|
||||||
|
|||||||
@@ -359,6 +359,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
|
|||||||
const char *decoder;
|
const char *decoder;
|
||||||
const char *language;
|
const char *language;
|
||||||
const char *task;
|
const char *task;
|
||||||
|
int32_t tail_paddings;
|
||||||
} SherpaOnnxOfflineWhisperModelConfig;
|
} SherpaOnnxOfflineWhisperModelConfig;
|
||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
|
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
|
||||||
|
|||||||
@@ -314,13 +314,15 @@ func sherpaOnnxOfflineWhisperModelConfig(
|
|||||||
encoder: String = "",
|
encoder: String = "",
|
||||||
decoder: String = "",
|
decoder: String = "",
|
||||||
language: String = "",
|
language: String = "",
|
||||||
task: String = "transcribe"
|
task: String = "transcribe",
|
||||||
|
tailPaddings: Int = -1
|
||||||
) -> SherpaOnnxOfflineWhisperModelConfig {
|
) -> SherpaOnnxOfflineWhisperModelConfig {
|
||||||
return SherpaOnnxOfflineWhisperModelConfig(
|
return SherpaOnnxOfflineWhisperModelConfig(
|
||||||
encoder: toCPointer(encoder),
|
encoder: toCPointer(encoder),
|
||||||
decoder: toCPointer(decoder),
|
decoder: toCPointer(decoder),
|
||||||
language: toCPointer(language),
|
language: toCPointer(language),
|
||||||
task: toCPointer(task)
|
task: toCPointer(task),
|
||||||
|
tail_paddings: Int32(tailPaddings)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -453,6 +453,8 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) {
|
|||||||
Module.setValue(ptr + 12, buffer + offset, 'i8*');
|
Module.setValue(ptr + 12, buffer + offset, 'i8*');
|
||||||
offset += taskLen;
|
offset += taskLen;
|
||||||
|
|
||||||
|
Module.setValue(ptr + 16, config.tailPaddings || -1, 'i32');
|
||||||
|
|
||||||
return {
|
return {
|
||||||
buffer: buffer, ptr: ptr, len: len,
|
buffer: buffer, ptr: ptr, len: len,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, "");
|
|||||||
static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, "");
|
static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, "");
|
||||||
|
|
||||||
static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, "");
|
static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, "");
|
||||||
static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, "");
|
static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 5 * 4, "");
|
||||||
static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, "");
|
static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, "");
|
||||||
static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, "");
|
static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, "");
|
||||||
|
|
||||||
@@ -80,6 +80,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) {
|
|||||||
fprintf(stdout, "decoder: %s\n", whisper->decoder);
|
fprintf(stdout, "decoder: %s\n", whisper->decoder);
|
||||||
fprintf(stdout, "language: %s\n", whisper->language);
|
fprintf(stdout, "language: %s\n", whisper->language);
|
||||||
fprintf(stdout, "task: %s\n", whisper->task);
|
fprintf(stdout, "task: %s\n", whisper->task);
|
||||||
|
fprintf(stdout, "tail_paddings: %d\n", whisper->tail_paddings);
|
||||||
|
|
||||||
fprintf(stdout, "----------offline tdnn model config----------\n");
|
fprintf(stdout, "----------offline tdnn model config----------\n");
|
||||||
fprintf(stdout, "model: %s\n", tdnn->model);
|
fprintf(stdout, "model: %s\n", tdnn->model);
|
||||||
|
|||||||
Reference in New Issue
Block a user