Ebranchformer (#1951)

* adding ebranchformer encoder

* extend surfaced FeatureExtractorConfig

- so ebranchformer feature extraction can be configured from Python
- the GlobCmvn is not needed, as it is a module in the OnnxEncoder

* clean the code

* Integrating remarks from Fangjun
This commit is contained in:
Karel Vesely
2025-03-04 12:41:09 +01:00
committed by GitHub
parent 209eaaae1d
commit 7740dbfb96
8 changed files with 609 additions and 5 deletions

View File

@@ -11,15 +11,21 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<int32_t, int32_t, float, float, float>(),
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f,
py::arg("dither") = 0.0f)
.def(py::init<int32_t, int32_t, float, float, float, bool, bool>(),
py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80,
py::arg("low_freq") = 20.0f,
py::arg("high_freq") = -400.0f,
py::arg("dither") = 0.0f,
py::arg("normalize_samples") = true,
py::arg("snip_edges") = false)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("low_freq", &PyClass::low_freq)
.def_readwrite("high_freq", &PyClass::high_freq)
.def_readwrite("dither", &PyClass::dither)
.def_readwrite("normalize_samples", &PyClass::normalize_samples)
.def_readwrite("snip_edges", &PyClass::snip_edges)
.def("__str__", &PyClass::ToString);
}

View File

@@ -22,6 +22,23 @@ Args:
to the range [-1, 1].
)";
constexpr const char *kGetFramesUsage = R"(
Get n frames starting from the given frame index.
(hint: intended for debugging, for comparing FBANK features across pipelines)
Args:
frame_index:
The starting frame index
n:
Number of frames to get.
Return:
Return a 2-D tensor of shape (n, feature_dim).
which is flattened into a 1-D vector (flattened in row major).
Unflatten in python with:
`features = np.reshape(arr, (n, feature_dim))`
)";
void PybindOnlineStream(py::module *m) {
using PyClass = OnlineStream;
py::class_<PyClass>(*m, "OnlineStream")
@@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) {
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
py::call_guard<py::gil_scoped_release>())
.def("input_finished", &PyClass::InputFinished,
py::call_guard<py::gil_scoped_release>())
.def("get_frames", &PyClass::GetFrames,
py::arg("frame_index"), py::arg("n"), kGetFramesUsage,
py::call_guard<py::gil_scoped_release>());
}

View File

@@ -50,6 +50,8 @@ class OnlineRecognizer(object):
low_freq: float = 20.0,
high_freq: float = -400.0,
dither: float = 0.0,
normalize_samples: bool = True,
snip_edges: bool = False,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
@@ -118,6 +120,15 @@ class OnlineRecognizer(object):
By default the audio samples are in range [-1,+1],
so dithering constant 0.00003 is a good value,
equivalent to the default 1.0 from kaldi
normalize_samples:
True for +/- 1.0 range of audio samples (default, zipformer feats),
False for +/- 32k samples (ebranchformer features).
snip_edges:
handling of end of audio signal in kaldi feature extraction.
If true, end effects will be handled by outputting only frames that
completely fit in the file, and the number of frames depends on the
frame-length. If false, the number of frames depends only on the
frame-shift, and we reflect the data at the ends.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
@@ -248,6 +259,8 @@ class OnlineRecognizer(object):
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
normalize_samples=normalize_samples,
snip_edges=snip_edges,
feature_dim=feature_dim,
low_freq=low_freq,
high_freq=high_freq,