Add timestamps and tokens for .Net's online models. (#690)

This commit is contained in:
Fangjun Kuang
2024-03-23 18:51:56 +08:00
committed by GitHub
parent e6da2c5556
commit 1952772654
26 changed files with 135 additions and 73 deletions

View File

@@ -162,15 +162,17 @@ const SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
memset(r, 0, sizeof(SherpaOnnxOnlineRecognizerResult));
// copy text
r->text = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
const_cast<char *>(r->text)[text.size()] = 0;
char *pText = new char[text.size() + 1];
std::copy(text.begin(), text.end(), pText);
pText[text.size()] = 0;
r->text = pText;
// copy json
const auto &json = result.AsJsonString();
r->json = new char[json.size() + 1];
std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
const_cast<char *>(r->json)[json.size()] = 0;
char *pJson = new char[json.size() + 1];
std::copy(json.begin(), json.end(), pJson);
pJson[json.size()] = 0;
r->json = pJson;
// copy tokens
auto count = result.tokens.size();
@@ -183,15 +185,12 @@ const SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
r->count = count;
// Each word ends with nullptr
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
char *tokens = new char[total_length]{};
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
tokens_temp[i] = tokens + pos;
memcpy(tokens + pos, result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
}
@@ -205,6 +204,7 @@ const SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
r->timestamps = nullptr;
}
r->tokens = tokens;
} else {
r->count = 0;
r->timestamps = nullptr;
@@ -391,9 +391,10 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
auto r = new SherpaOnnxOfflineRecognizerResult;
memset(r, 0, sizeof(SherpaOnnxOfflineRecognizerResult));
r->text = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
const_cast<char *>(r->text)[text.size()] = 0;
char *pText = new char[text.size() + 1];
std::copy(text.begin(), text.end(), pText);
pText[text.size()] = 0;
r->text = pText;
if (!result.timestamps.empty()) {
r->timestamps = new float[result.timestamps.size()];
@@ -530,15 +531,17 @@ const SherpaOnnxKeywordResult *GetKeywordResult(
r->start_time = result.start_time;
// copy keyword
r->keyword = new char[keyword.size() + 1];
std::copy(keyword.begin(), keyword.end(), const_cast<char *>(r->keyword));
const_cast<char *>(r->keyword)[keyword.size()] = 0;
char *pKeyword = new char[keyword.size() + 1];
std::copy(keyword.begin(), keyword.end(), pKeyword);
pKeyword[keyword.size()] = 0;
r->keyword = pKeyword;
// copy json
const auto &json = result.AsJsonString();
r->json = new char[json.size() + 1];
std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
const_cast<char *>(r->json)[json.size()] = 0;
char *pJson = new char[json.size() + 1];
std::copy(json.begin(), json.end(), pJson);
pJson[json.size()] = 0;
r->json = pJson;
// copy tokens
auto count = result.tokens.size();
@@ -551,18 +554,16 @@ const SherpaOnnxKeywordResult *GetKeywordResult(
r->count = count;
// Each word ends with nullptr
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
char *pTokens = new char[total_length]{};
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
tokens_temp[i] = pTokens + pos;
memcpy(pTokens + pos, result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
}
r->tokens = pTokens;
r->tokens_arr = tokens_temp;
if (!result.timestamps.empty()) {

View File

@@ -145,6 +145,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
const char *const *tokens_arr;
// Pointer to continuous memory which holds timestamps
//
// Caution: If timestamp information is not available, this pointer is NULL.
// Please check whether it is NULL before you access it; otherwise, you would
// get segmentation fault.
float *timestamps;
// The number of tokens/timestamps in above pointer

View File

@@ -105,7 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
int32_t feature_dim_ = 0;
int32_t feature_dim_ = 80;
};
} // namespace sherpa_onnx