Support paraformer on iOS (#265)

* Fix C API to support streaming paraformer

* Fix Swift API

* Support paraformer in iOS
This commit is contained in:
Fangjun Kuang
2023-08-14 14:38:41 +08:00
committed by GitHub
parent 35526e26e1
commit a8bdb4b38a
12 changed files with 204 additions and 86 deletions

View File

@@ -39,11 +39,17 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.encoder, "");
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
recognizer_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.decoder, "");
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
recognizer_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.joiner, "");
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
recognizer_config.model_config.paraformer.encoder =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder, "");
recognizer_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
@@ -128,6 +134,8 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
const auto &text = result.text;
auto r = new SherpaOnnxOnlineRecognizerResult;
memset(r, 0, sizeof(SherpaOnnxOnlineRecognizerResult));
// copy text
r->text = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
@@ -153,7 +161,6 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
r->timestamps = new float[r->count];
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
@@ -162,10 +169,17 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
r->timestamps[i] = result.timestamps[i];
}
r->tokens_arr = tokens_temp;
if (!result.timestamps.empty()) {
r->timestamps = new float[r->count];
std::copy(result.timestamps.begin(), result.timestamps.end(),
r->timestamps);
} else {
r->timestamps = nullptr;
}
r->tokens_arr = tokens_temp;
} else {
r->count = 0;
r->timestamps = nullptr;