Refactor online recognizer (#250)

* Refactor online recognizer.

Make it easier to support other streaming models.

Note that it is a breaking change for the Python API.
`sherpa_onnx.OnlineRecognizer()` used before should be
replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
This commit is contained in:
Fangjun Kuang
2023-08-09 20:27:31 +08:00
committed by GitHub
parent 6061318e3f
commit 79c2ce5dd4
40 changed files with 670 additions and 480 deletions

View File

@@ -38,11 +38,11 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.model_config.encoder_filename =
recognizer_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.encoder, "");
recognizer_config.model_config.decoder_filename =
recognizer_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.decoder, "");
recognizer_config.model_config.joiner_filename =
recognizer_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.joiner, "");
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
@@ -143,7 +143,7 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
auto count = result.tokens.size();
if (count > 0) {
size_t total_length = 0;
for (const auto& token : result.tokens) {
for (const auto &token : result.tokens) {
// +1 for the null character at the end of each token
total_length += token.size() + 1;
}
@@ -154,10 +154,10 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
r->timestamps = new float[r->count];
char **tokens_temp = new char*[r->count];
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
tokens_temp[i] = const_cast<char*>(r->tokens) + pos;
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character