Support multilingual whisper models (#274)
This commit is contained in:
@@ -14,10 +14,14 @@ namespace sherpa_onnx {
|
||||
void PybindOfflineWhisperModelConfig(py::module *m) {
|
||||
using PyClass = OfflineWhisperModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &>(),
|
||||
py::arg("encoder"), py::arg("decoder"))
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string &>(),
|
||||
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
|
||||
py::arg("task"))
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def_readwrite("language", &PyClass::language)
|
||||
.def_readwrite("task", &PyClass::task)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -244,6 +244,8 @@ class OfflineRecognizer(object):
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
tokens: str,
|
||||
language: str = "en",
|
||||
task: str = "transcribe",
|
||||
num_threads: int = 1,
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
@@ -268,6 +270,14 @@ class OfflineRecognizer(object):
|
||||
|
||||
symbol integer_id
|
||||
|
||||
language:
|
||||
The spoken language in the audio file. Example values: en, de, zh,
|
||||
jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
for all possible values. Note that for non-multilingual models, the
|
||||
only valid value is 'en'.
|
||||
task:
|
||||
Valid values are: transcribe, translate. Note that for
|
||||
non-multilingual models, the only valid value is 'transcribe'.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
decoding_method:
|
||||
@@ -279,7 +289,12 @@ class OfflineRecognizer(object):
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
|
||||
whisper=OfflineWhisperModelConfig(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
language=language,
|
||||
task=task,
|
||||
),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
|
||||
Reference in New Issue
Block a user