Refactor online recognizer (#250)
* Refactor online recognizer. Make it easier to support other streaming models. Note that it is a breaking change for the Python API. `sherpa_onnx.OnlineRecognizer()` used before should be replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
This commit is contained in:
@@ -38,11 +38,11 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
|
||||
recognizer_config.feat_config.feature_dim =
|
||||
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
|
||||
|
||||
recognizer_config.model_config.encoder_filename =
|
||||
recognizer_config.model_config.transducer.encoder =
|
||||
SHERPA_ONNX_OR(config->model_config.encoder, "");
|
||||
recognizer_config.model_config.decoder_filename =
|
||||
recognizer_config.model_config.transducer.decoder =
|
||||
SHERPA_ONNX_OR(config->model_config.decoder, "");
|
||||
recognizer_config.model_config.joiner_filename =
|
||||
recognizer_config.model_config.transducer.joiner =
|
||||
SHERPA_ONNX_OR(config->model_config.joiner, "");
|
||||
recognizer_config.model_config.tokens =
|
||||
SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
@@ -143,7 +143,7 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
|
||||
auto count = result.tokens.size();
|
||||
if (count > 0) {
|
||||
size_t total_length = 0;
|
||||
for (const auto& token : result.tokens) {
|
||||
for (const auto &token : result.tokens) {
|
||||
// +1 for the null character at the end of each token
|
||||
total_length += token.size() + 1;
|
||||
}
|
||||
@@ -154,10 +154,10 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
|
||||
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
|
||||
total_length);
|
||||
r->timestamps = new float[r->count];
|
||||
char **tokens_temp = new char*[r->count];
|
||||
char **tokens_temp = new char *[r->count];
|
||||
int32_t pos = 0;
|
||||
for (int32_t i = 0; i < r->count; ++i) {
|
||||
tokens_temp[i] = const_cast<char*>(r->tokens) + pos;
|
||||
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
|
||||
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
|
||||
result.tokens[i].c_str(), result.tokens[i].size());
|
||||
// +1 to move past the null character
|
||||
|
||||
Reference in New Issue
Block a user